Move common arithmatic into functions which check on object type

This commit is contained in:
Windel Bouwman
2018-08-29 12:44:45 +02:00
parent fb47ce36e9
commit 201281554f
6 changed files with 155 additions and 170 deletions

View File

@@ -1,6 +1,7 @@
use super::objint;
use super::objtype;
use super::pyobject::{
AttributeProtocol, PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, TypeProtocol,
AttributeProtocol, PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, PyResult, TypeProtocol,
};
use super::vm::VirtualMachine;
@@ -19,8 +20,50 @@ pub fn get_value(obj: PyObjectRef) -> f64 {
}
}
fn float_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(i, Some(vm.ctx.float_type())), (i2, None)]
);
if objtype::isinstance(i2.clone(), vm.ctx.float_type()) {
Ok(vm
.ctx
.new_float(get_value(i.clone()) + get_value(i2.clone())))
} else if objtype::isinstance(i2.clone(), vm.ctx.int_type()) {
Ok(vm
.ctx
.new_float(get_value(i.clone()) + objint::get_value(i2.clone()) as f64))
} else {
Err(vm.new_type_error(format!("Cannot add {:?} and {:?}", i, i2)))
}
}
fn float_sub(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(i, Some(vm.ctx.float_type())), (i2, None)]
);
if objtype::isinstance(i2.clone(), vm.ctx.float_type()) {
Ok(vm
.ctx
.new_float(get_value(i.clone()) - get_value(i2.clone())))
} else if objtype::isinstance(i2.clone(), vm.ctx.int_type()) {
Ok(vm
.ctx
.new_float(get_value(i.clone()) - objint::get_value(i2.clone()) as f64))
} else {
Err(vm.new_type_error(format!("Cannot add {:?} and {:?}", i, i2)))
}
}
pub fn init(context: &PyContext) {
let ref float_type = context.float_type;
float_type.set_attr("__add__", context.new_rustfunc(float_add));
float_type.set_attr("__str__", context.new_rustfunc(str));
float_type.set_attr("__sub__", context.new_rustfunc(float_sub));
float_type.set_attr("__repr__", context.new_rustfunc(str));
}

View File

@@ -12,7 +12,7 @@ fn str(vm: &mut VirtualMachine, args: PyFuncArgs) -> Result<PyObjectRef, PyObjec
}
// Retrieve inner int value:
fn get_value(obj: PyObjectRef) -> i32 {
pub fn get_value(obj: PyObjectRef) -> i32 {
if let PyObjectKind::Integer { value } = &obj.borrow().kind {
*value
} else {
@@ -28,6 +28,10 @@ fn int_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
);
if objtype::isinstance(i2.clone(), vm.ctx.int_type()) {
Ok(vm.ctx.new_int(get_value(i.clone()) + get_value(i2.clone())))
} else if objtype::isinstance(i2.clone(), vm.ctx.float_type()) {
Ok(vm
.ctx
.new_float(get_value(i.clone()) as f64 + objfloat::get_value(i2.clone())))
} else {
Err(vm.new_type_error(format!("Cannot add {:?} and {:?}", i, i2)))
}
@@ -78,9 +82,23 @@ fn int_truediv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
}
}
fn int_mod(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(i, Some(vm.ctx.int_type())), (i2, None)]
);
if objtype::isinstance(i2.clone(), vm.ctx.int_type()) {
Ok(vm.ctx.new_int(get_value(i.clone()) % get_value(i2.clone())))
} else {
Err(vm.new_type_error(format!("Cannot modulo {:?} and {:?}", i, i2)))
}
}
pub fn init(context: &PyContext) {
let ref int_type = context.int_type;
int_type.set_attr("__add__", context.new_rustfunc(int_add));
int_type.set_attr("__mod__", context.new_rustfunc(int_mod));
int_type.set_attr("__mul__", context.new_rustfunc(int_mul));
int_type.set_attr("__repr__", context.new_rustfunc(str));
int_type.set_attr("__str__", context.new_rustfunc(str));

View File

@@ -25,6 +25,31 @@ pub fn set_item(
}
}
fn get_elements(obj: PyObjectRef) -> Vec<PyObjectRef> {
if let PyObjectKind::List { elements } = &obj.borrow().kind {
elements.to_vec()
} else {
panic!("Cannot extract list elements from non-list");
}
}
fn list_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(o, Some(vm.ctx.list_type())), (o2, None)]
);
if objtype::isinstance(o2.clone(), vm.ctx.list_type()) {
let e1 = get_elements(o.clone());
let e2 = get_elements(o2.clone());
let elements = e1.iter().chain(e2.iter()).map(|e| e.clone()).collect();
Ok(vm.ctx.new_list(elements))
} else {
Err(vm.new_type_error(format!("Cannot add {:?} and {:?}", o, o2)))
}
}
pub fn append(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
trace!("list.append called with: {:?}", args);
arg_check!(
@@ -78,6 +103,7 @@ fn reverse(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
pub fn init(context: &PyContext) {
let ref list_type = context.list_type;
list_type.set_attr("__add__", context.new_rustfunc(list_add));
list_type.set_attr("__len__", context.new_rustfunc(len));
list_type.set_attr("append", context.new_rustfunc(append));
list_type.set_attr("clear", context.new_rustfunc(clear));

View File

@@ -1,3 +1,4 @@
use super::objint;
use super::objsequence::PySliceableSequence;
use super::objtype;
use super::pyobject::{
@@ -7,12 +8,13 @@ use super::vm::VirtualMachine;
pub fn init(context: &PyContext) {
let ref str_type = context.str_type;
str_type.set_attr("__add__", context.new_rustfunc(str_add));
str_type.set_attr("__mul__", context.new_rustfunc(str_mul));
str_type.set_attr("__new__", context.new_rustfunc(str_new));
str_type.set_attr("__str__", context.new_rustfunc(str_str));
str_type.set_attr("__add__", context.new_rustfunc(str_add));
}
fn get_value(obj: PyObjectRef) -> String {
pub fn get_value(obj: PyObjectRef) -> String {
if let PyObjectKind::String { value } = &obj.borrow().kind {
value.to_string()
} else {
@@ -40,6 +42,25 @@ fn str_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
}
}
fn str_mul(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(s, Some(vm.ctx.str_type())), (s2, None)]
);
if objtype::isinstance(s2.clone(), vm.ctx.int_type()) {
let value1 = get_value(s.clone());
let value2 = objint::get_value(s2.clone());
let mut result = String::new();
for _x in 0..value2 {
result.push_str(value1.as_str());
}
Ok(vm.ctx.new_str(result))
} else {
Err(vm.new_type_error(format!("Cannot multiply {:?} and {:?}", s, s2)))
}
}
// TODO: should with following format
// class str(object='')
// class str(object=b'', encoding='utf-8', errors='strict')

View File

@@ -14,7 +14,6 @@ use std::cell::RefCell;
use std::cmp::Ordering;
use std::collections::HashMap;
use std::fmt;
use std::ops::{Add, Mul, Rem, Sub};
use std::rc::Rc;
/* Python objects and references.
@@ -740,126 +739,6 @@ impl PyObject {
}
}
impl<'a> Add<&'a PyObject> for &'a PyObject {
type Output = PyObjectKind;
fn add(self, rhs: &'a PyObject) -> Self::Output {
match self.kind {
PyObjectKind::Integer { value: ref value1 } => match &rhs.kind {
PyObjectKind::Integer { value: ref value2 } => PyObjectKind::Integer {
value: value1 + value2,
},
PyObjectKind::Float { value: ref value2 } => PyObjectKind::Float {
value: (*value1 as f64) + value2,
},
_ => {
panic!("NOT IMPL");
}
},
PyObjectKind::Float { value: ref value1 } => match &rhs.kind {
PyObjectKind::Float { value: ref value2 } => PyObjectKind::Float {
value: value1 + value2,
},
PyObjectKind::Integer { value: ref value2 } => PyObjectKind::Float {
value: value1 + (*value2 as f64),
},
_ => {
panic!("NOT IMPL");
}
},
PyObjectKind::String { value: ref value1 } => match rhs.kind {
PyObjectKind::String { value: ref value2 } => PyObjectKind::String {
value: format!("{}{}", value1, value2),
},
_ => {
panic!("NOT IMPL");
}
},
PyObjectKind::List { elements: ref e1 } => match rhs.kind {
PyObjectKind::List { elements: ref e2 } => PyObjectKind::List {
elements: e1.iter().chain(e2.iter()).map(|e| e.clone()).collect(),
},
_ => {
panic!("NOT IMPL");
}
},
_ => {
// TODO: Lookup __add__ method in dictionary?
panic!("NOT IMPL");
}
}
}
}
impl<'a> Sub<&'a PyObject> for &'a PyObject {
type Output = PyObjectKind;
fn sub(self, rhs: &'a PyObject) -> Self::Output {
match self.kind {
PyObjectKind::Integer { value: value1 } => match rhs.kind {
PyObjectKind::Integer { value: value2 } => PyObjectKind::Integer {
value: value1 - value2,
},
_ => {
panic!("NOT IMPL");
}
},
_ => {
panic!("NOT IMPL");
}
}
}
}
impl<'a> Mul<&'a PyObject> for &'a PyObject {
type Output = PyObjectKind;
fn mul(self, rhs: &'a PyObject) -> Self::Output {
match self.kind {
PyObjectKind::Integer { value: value1 } => match rhs.kind {
PyObjectKind::Integer { value: value2 } => PyObjectKind::Integer {
value: value1 * value2,
},
_ => {
panic!("NOT IMPL");
}
},
PyObjectKind::String { value: ref value1 } => match rhs.kind {
PyObjectKind::Integer { value: value2 } => {
let mut result = String::new();
for _x in 0..value2 {
result.push_str(value1.as_str());
}
PyObjectKind::String { value: result }
}
_ => {
panic!("NOT IMPL");
}
},
_ => {
panic!("NOT IMPL");
}
}
}
}
impl<'a> Rem<&'a PyObject> for &'a PyObject {
type Output = PyObjectKind;
fn rem(self, rhs: &'a PyObject) -> Self::Output {
match (&self.kind, &rhs.kind) {
(PyObjectKind::Integer { value: value1 }, PyObjectKind::Integer { value: value2 }) => {
PyObjectKind::Integer {
value: value1 % value2,
}
}
(kind1, kind2) => {
unimplemented!("% not implemented for kinds: {:?} {:?}", kind1, kind2);
}
}
}
}
// impl<'a> PartialEq<&'a PyObject> for &'a PyObject {
impl PartialEq for PyObject {
fn eq(&self, other: &PyObject) -> bool {
@@ -924,33 +803,7 @@ impl Ord for PyObject {
#[cfg(test)]
mod tests {
use super::{PyContext, PyObjectKind};
#[test]
fn test_add_py_integers() {
let ctx = PyContext::new();
let a = ctx.new_int(33);
let b = ctx.new_int(12);
let c = &*a.borrow() + &*b.borrow();
match c {
PyObjectKind::Integer { value } => assert_eq!(value, 45),
_ => assert!(false),
}
}
#[test]
fn test_multiply_str() {
let ctx = PyContext::new();
let a = ctx.new_str(String::from("Hello "));
let b = ctx.new_int(4);
let c = &*a.borrow() * &*b.borrow();
match c {
PyObjectKind::String { value } => {
assert_eq!(value, String::from("Hello Hello Hello Hello "))
}
_ => assert!(false),
}
}
use super::PyContext;
#[test]
fn test_type_type() {

View File

@@ -319,33 +319,32 @@ impl VirtualMachine {
}
fn _sub(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
let b2 = &*b.borrow();
let a2 = &*a.borrow();
// TODO: Fix this correctly, and for all arithmetic operations
Ok(PyObject::new(
a2 - b2,
a2.typ.clone().unwrap_or(self.get_type()),
))
self.call_method(a, "__sub__".to_string(), vec![b])
}
fn _add(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
let b2 = &*b.borrow();
let a2 = &*a.borrow();
Ok(PyObject::new(a2 + b2, self.get_type()))
self.call_method(a, "__add__".to_string(), vec![b])
}
fn _mul(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
let b2 = &*b.borrow();
let a2 = &*a.borrow();
Ok(PyObject::new(a2 * b2, self.get_type()))
self.call_method(a, "__mul__".to_string(), vec![b])
}
fn _div(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
let func = match self.get_attribute(a.clone(), &"__truediv__".to_string()) {
self.call_method(a, "__truediv__".to_string(), vec![b])
}
fn call_method(
&mut self,
obj: PyObjectRef,
method_name: String,
args: Vec<PyObjectRef>,
) -> PyResult {
let func = match self.get_attribute(obj, &method_name) {
Ok(v) => v,
Err(err) => return Err(err),
};
let args = PyFuncArgs { args: vec![b] };
let args = PyFuncArgs { args: args };
self.invoke(func, args)
}
@@ -371,9 +370,7 @@ impl VirtualMachine {
}
fn _modulo(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
let b2 = &*b.borrow();
let a2 = &*a.borrow();
Ok(PyObject::new(a2 % b2, self.get_type()))
self.call_method(a, "__mod__".to_string(), vec![b])
}
fn execute_binop(&mut self, op: &bytecode::BinaryOperator) -> Option<PyResult> {
@@ -931,3 +928,30 @@ impl VirtualMachine {
self.current_frame().get_lineno()
}
}
#[cfg(test)]
mod tests {
use super::super::objint;
use super::objstr;
use super::VirtualMachine;
#[test]
fn test_add_py_integers() {
let mut vm = VirtualMachine::new();
let a = vm.ctx.new_int(33);
let b = vm.ctx.new_int(12);
let res = vm._add(a, b).unwrap();
let value = objint::get_value(res);
assert_eq!(value, 45);
}
#[test]
fn test_multiply_str() {
let mut vm = VirtualMachine::new();
let a = vm.ctx.new_str(String::from("Hello "));
let b = vm.ctx.new_int(4);
let res = vm._mul(a, b).unwrap();
let value = objstr::get_value(res);
assert_eq!(value, String::from("Hello Hello Hello Hello "))
}
}