From 201281554f5ff65917d3386c934c38476d1ddffb Mon Sep 17 00:00:00 2001 From: Windel Bouwman Date: Wed, 29 Aug 2018 12:44:45 +0200 Subject: [PATCH] Move common arithmatic into functions which check on object type --- vm/src/objfloat.rs | 45 +++++++++++++- vm/src/objint.rs | 20 +++++- vm/src/objlist.rs | 26 ++++++++ vm/src/objstr.rs | 25 +++++++- vm/src/pyobject.rs | 149 +-------------------------------------------- vm/src/vm.rs | 60 ++++++++++++------ 6 files changed, 155 insertions(+), 170 deletions(-) diff --git a/vm/src/objfloat.rs b/vm/src/objfloat.rs index 27b49f080..7b0e6bd7a 100644 --- a/vm/src/objfloat.rs +++ b/vm/src/objfloat.rs @@ -1,6 +1,7 @@ +use super::objint; use super::objtype; use super::pyobject::{ - AttributeProtocol, PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, TypeProtocol, + AttributeProtocol, PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, PyResult, TypeProtocol, }; use super::vm::VirtualMachine; @@ -19,8 +20,50 @@ pub fn get_value(obj: PyObjectRef) -> f64 { } } +fn float_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(i, Some(vm.ctx.float_type())), (i2, None)] + ); + + if objtype::isinstance(i2.clone(), vm.ctx.float_type()) { + Ok(vm + .ctx + .new_float(get_value(i.clone()) + get_value(i2.clone()))) + } else if objtype::isinstance(i2.clone(), vm.ctx.int_type()) { + Ok(vm + .ctx + .new_float(get_value(i.clone()) + objint::get_value(i2.clone()) as f64)) + } else { + Err(vm.new_type_error(format!("Cannot add {:?} and {:?}", i, i2))) + } +} + +fn float_sub(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(i, Some(vm.ctx.float_type())), (i2, None)] + ); + + if objtype::isinstance(i2.clone(), vm.ctx.float_type()) { + Ok(vm + .ctx + .new_float(get_value(i.clone()) - get_value(i2.clone()))) + } else if objtype::isinstance(i2.clone(), vm.ctx.int_type()) { + Ok(vm + .ctx + .new_float(get_value(i.clone()) - objint::get_value(i2.clone()) as f64)) + } else { + Err(vm.new_type_error(format!("Cannot add {:?} and {:?}", i, i2))) + } +} + pub fn init(context: &PyContext) { let ref float_type = context.float_type; + float_type.set_attr("__add__", context.new_rustfunc(float_add)); float_type.set_attr("__str__", context.new_rustfunc(str)); + float_type.set_attr("__sub__", context.new_rustfunc(float_sub)); float_type.set_attr("__repr__", context.new_rustfunc(str)); } diff --git a/vm/src/objint.rs b/vm/src/objint.rs index c332dcd1e..de3aeb2b1 100644 --- a/vm/src/objint.rs +++ b/vm/src/objint.rs @@ -12,7 +12,7 @@ fn str(vm: &mut VirtualMachine, args: PyFuncArgs) -> Result i32 { +pub fn get_value(obj: PyObjectRef) -> i32 { if let PyObjectKind::Integer { value } = &obj.borrow().kind { *value } else { @@ -28,6 +28,10 @@ fn int_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { ); if objtype::isinstance(i2.clone(), vm.ctx.int_type()) { Ok(vm.ctx.new_int(get_value(i.clone()) + get_value(i2.clone()))) + } else if objtype::isinstance(i2.clone(), vm.ctx.float_type()) { + Ok(vm + .ctx + .new_float(get_value(i.clone()) as f64 + objfloat::get_value(i2.clone()))) } else { Err(vm.new_type_error(format!("Cannot add {:?} and {:?}", i, i2))) } @@ -78,9 +82,23 @@ fn int_truediv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } } +fn int_mod(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(i, Some(vm.ctx.int_type())), (i2, None)] + ); + if objtype::isinstance(i2.clone(), vm.ctx.int_type()) { + Ok(vm.ctx.new_int(get_value(i.clone()) % get_value(i2.clone()))) + } else { + Err(vm.new_type_error(format!("Cannot modulo {:?} and {:?}", i, i2))) + } +} + pub fn init(context: &PyContext) { let ref int_type = context.int_type; int_type.set_attr("__add__", context.new_rustfunc(int_add)); + int_type.set_attr("__mod__", context.new_rustfunc(int_mod)); int_type.set_attr("__mul__", context.new_rustfunc(int_mul)); int_type.set_attr("__repr__", context.new_rustfunc(str)); int_type.set_attr("__str__", context.new_rustfunc(str)); diff --git a/vm/src/objlist.rs b/vm/src/objlist.rs index bedc6e993..9c6735407 100644 --- a/vm/src/objlist.rs +++ b/vm/src/objlist.rs @@ -25,6 +25,31 @@ pub fn set_item( } } +fn get_elements(obj: PyObjectRef) -> Vec { + if let PyObjectKind::List { elements } = &obj.borrow().kind { + elements.to_vec() + } else { + panic!("Cannot extract list elements from non-list"); + } +} + +fn list_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(o, Some(vm.ctx.list_type())), (o2, None)] + ); + + if objtype::isinstance(o2.clone(), vm.ctx.list_type()) { + let e1 = get_elements(o.clone()); + let e2 = get_elements(o2.clone()); + let elements = e1.iter().chain(e2.iter()).map(|e| e.clone()).collect(); + Ok(vm.ctx.new_list(elements)) + } else { + Err(vm.new_type_error(format!("Cannot add {:?} and {:?}", o, o2))) + } +} + pub fn append(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { trace!("list.append called with: {:?}", args); arg_check!( @@ -78,6 +103,7 @@ fn reverse(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn init(context: &PyContext) { let ref list_type = context.list_type; + list_type.set_attr("__add__", context.new_rustfunc(list_add)); list_type.set_attr("__len__", context.new_rustfunc(len)); list_type.set_attr("append", context.new_rustfunc(append)); list_type.set_attr("clear", context.new_rustfunc(clear)); diff --git a/vm/src/objstr.rs b/vm/src/objstr.rs index 8786609e0..9fc257fea 100644 --- a/vm/src/objstr.rs +++ b/vm/src/objstr.rs @@ -1,3 +1,4 @@ +use super::objint; use super::objsequence::PySliceableSequence; use super::objtype; use super::pyobject::{ @@ -7,12 +8,13 @@ use super::vm::VirtualMachine; pub fn init(context: &PyContext) { let ref str_type = context.str_type; + str_type.set_attr("__add__", context.new_rustfunc(str_add)); + str_type.set_attr("__mul__", context.new_rustfunc(str_mul)); str_type.set_attr("__new__", context.new_rustfunc(str_new)); str_type.set_attr("__str__", context.new_rustfunc(str_str)); - str_type.set_attr("__add__", context.new_rustfunc(str_add)); } -fn get_value(obj: PyObjectRef) -> String { +pub fn get_value(obj: PyObjectRef) -> String { if let PyObjectKind::String { value } = &obj.borrow().kind { value.to_string() } else { @@ -40,6 +42,25 @@ fn str_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } } +fn str_mul(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(s, Some(vm.ctx.str_type())), (s2, None)] + ); + if objtype::isinstance(s2.clone(), vm.ctx.int_type()) { + let value1 = get_value(s.clone()); + let value2 = objint::get_value(s2.clone()); + let mut result = String::new(); + for _x in 0..value2 { + result.push_str(value1.as_str()); + } + Ok(vm.ctx.new_str(result)) + } else { + Err(vm.new_type_error(format!("Cannot multiply {:?} and {:?}", s, s2))) + } +} + // TODO: should with following format // class str(object='') // class str(object=b'', encoding='utf-8', errors='strict') diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 313317ee7..c00dc5b81 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -14,7 +14,6 @@ use std::cell::RefCell; use std::cmp::Ordering; use std::collections::HashMap; use std::fmt; -use std::ops::{Add, Mul, Rem, Sub}; use std::rc::Rc; /* Python objects and references. @@ -740,126 +739,6 @@ impl PyObject { } } -impl<'a> Add<&'a PyObject> for &'a PyObject { - type Output = PyObjectKind; - - fn add(self, rhs: &'a PyObject) -> Self::Output { - match self.kind { - PyObjectKind::Integer { value: ref value1 } => match &rhs.kind { - PyObjectKind::Integer { value: ref value2 } => PyObjectKind::Integer { - value: value1 + value2, - }, - PyObjectKind::Float { value: ref value2 } => PyObjectKind::Float { - value: (*value1 as f64) + value2, - }, - _ => { - panic!("NOT IMPL"); - } - }, - PyObjectKind::Float { value: ref value1 } => match &rhs.kind { - PyObjectKind::Float { value: ref value2 } => PyObjectKind::Float { - value: value1 + value2, - }, - PyObjectKind::Integer { value: ref value2 } => PyObjectKind::Float { - value: value1 + (*value2 as f64), - }, - _ => { - panic!("NOT IMPL"); - } - }, - PyObjectKind::String { value: ref value1 } => match rhs.kind { - PyObjectKind::String { value: ref value2 } => PyObjectKind::String { - value: format!("{}{}", value1, value2), - }, - _ => { - panic!("NOT IMPL"); - } - }, - PyObjectKind::List { elements: ref e1 } => match rhs.kind { - PyObjectKind::List { elements: ref e2 } => PyObjectKind::List { - elements: e1.iter().chain(e2.iter()).map(|e| e.clone()).collect(), - }, - _ => { - panic!("NOT IMPL"); - } - }, - _ => { - // TODO: Lookup __add__ method in dictionary? - panic!("NOT IMPL"); - } - } - } -} - -impl<'a> Sub<&'a PyObject> for &'a PyObject { - type Output = PyObjectKind; - - fn sub(self, rhs: &'a PyObject) -> Self::Output { - match self.kind { - PyObjectKind::Integer { value: value1 } => match rhs.kind { - PyObjectKind::Integer { value: value2 } => PyObjectKind::Integer { - value: value1 - value2, - }, - _ => { - panic!("NOT IMPL"); - } - }, - _ => { - panic!("NOT IMPL"); - } - } - } -} - -impl<'a> Mul<&'a PyObject> for &'a PyObject { - type Output = PyObjectKind; - - fn mul(self, rhs: &'a PyObject) -> Self::Output { - match self.kind { - PyObjectKind::Integer { value: value1 } => match rhs.kind { - PyObjectKind::Integer { value: value2 } => PyObjectKind::Integer { - value: value1 * value2, - }, - _ => { - panic!("NOT IMPL"); - } - }, - PyObjectKind::String { value: ref value1 } => match rhs.kind { - PyObjectKind::Integer { value: value2 } => { - let mut result = String::new(); - for _x in 0..value2 { - result.push_str(value1.as_str()); - } - PyObjectKind::String { value: result } - } - _ => { - panic!("NOT IMPL"); - } - }, - _ => { - panic!("NOT IMPL"); - } - } - } -} - -impl<'a> Rem<&'a PyObject> for &'a PyObject { - type Output = PyObjectKind; - - fn rem(self, rhs: &'a PyObject) -> Self::Output { - match (&self.kind, &rhs.kind) { - (PyObjectKind::Integer { value: value1 }, PyObjectKind::Integer { value: value2 }) => { - PyObjectKind::Integer { - value: value1 % value2, - } - } - (kind1, kind2) => { - unimplemented!("% not implemented for kinds: {:?} {:?}", kind1, kind2); - } - } - } -} - // impl<'a> PartialEq<&'a PyObject> for &'a PyObject { impl PartialEq for PyObject { fn eq(&self, other: &PyObject) -> bool { @@ -924,33 +803,7 @@ impl Ord for PyObject { #[cfg(test)] mod tests { - use super::{PyContext, PyObjectKind}; - - #[test] - fn test_add_py_integers() { - let ctx = PyContext::new(); - let a = ctx.new_int(33); - let b = ctx.new_int(12); - let c = &*a.borrow() + &*b.borrow(); - match c { - PyObjectKind::Integer { value } => assert_eq!(value, 45), - _ => assert!(false), - } - } - - #[test] - fn test_multiply_str() { - let ctx = PyContext::new(); - let a = ctx.new_str(String::from("Hello ")); - let b = ctx.new_int(4); - let c = &*a.borrow() * &*b.borrow(); - match c { - PyObjectKind::String { value } => { - assert_eq!(value, String::from("Hello Hello Hello Hello ")) - } - _ => assert!(false), - } - } + use super::PyContext; #[test] fn test_type_type() { diff --git a/vm/src/vm.rs b/vm/src/vm.rs index 59f9d9840..7c27976ac 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -319,33 +319,32 @@ impl VirtualMachine { } fn _sub(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult { - let b2 = &*b.borrow(); - let a2 = &*a.borrow(); - // TODO: Fix this correctly, and for all arithmetic operations - Ok(PyObject::new( - a2 - b2, - a2.typ.clone().unwrap_or(self.get_type()), - )) + self.call_method(a, "__sub__".to_string(), vec![b]) } fn _add(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult { - let b2 = &*b.borrow(); - let a2 = &*a.borrow(); - Ok(PyObject::new(a2 + b2, self.get_type())) + self.call_method(a, "__add__".to_string(), vec![b]) } fn _mul(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult { - let b2 = &*b.borrow(); - let a2 = &*a.borrow(); - Ok(PyObject::new(a2 * b2, self.get_type())) + self.call_method(a, "__mul__".to_string(), vec![b]) } fn _div(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult { - let func = match self.get_attribute(a.clone(), &"__truediv__".to_string()) { + self.call_method(a, "__truediv__".to_string(), vec![b]) + } + + fn call_method( + &mut self, + obj: PyObjectRef, + method_name: String, + args: Vec, + ) -> PyResult { + let func = match self.get_attribute(obj, &method_name) { Ok(v) => v, Err(err) => return Err(err), }; - let args = PyFuncArgs { args: vec![b] }; + let args = PyFuncArgs { args: args }; self.invoke(func, args) } @@ -371,9 +370,7 @@ impl VirtualMachine { } fn _modulo(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult { - let b2 = &*b.borrow(); - let a2 = &*a.borrow(); - Ok(PyObject::new(a2 % b2, self.get_type())) + self.call_method(a, "__mod__".to_string(), vec![b]) } fn execute_binop(&mut self, op: &bytecode::BinaryOperator) -> Option { @@ -931,3 +928,30 @@ impl VirtualMachine { self.current_frame().get_lineno() } } + +#[cfg(test)] +mod tests { + use super::super::objint; + use super::objstr; + use super::VirtualMachine; + + #[test] + fn test_add_py_integers() { + let mut vm = VirtualMachine::new(); + let a = vm.ctx.new_int(33); + let b = vm.ctx.new_int(12); + let res = vm._add(a, b).unwrap(); + let value = objint::get_value(res); + assert_eq!(value, 45); + } + + #[test] + fn test_multiply_str() { + let mut vm = VirtualMachine::new(); + let a = vm.ctx.new_str(String::from("Hello ")); + let b = vm.ctx.new_int(4); + let res = vm._mul(a, b).unwrap(); + let value = objstr::get_value(res); + assert_eq!(value, String::from("Hello Hello Hello Hello ")) + } +}