mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-09 22:49:57 +09:00
Move common arithmatic into functions which check on object type
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
use super::objint;
|
||||
use super::objtype;
|
||||
use super::pyobject::{
|
||||
AttributeProtocol, PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, TypeProtocol,
|
||||
AttributeProtocol, PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, PyResult, TypeProtocol,
|
||||
};
|
||||
use super::vm::VirtualMachine;
|
||||
|
||||
@@ -19,8 +20,50 @@ pub fn get_value(obj: PyObjectRef) -> f64 {
|
||||
}
|
||||
}
|
||||
|
||||
fn float_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [(i, Some(vm.ctx.float_type())), (i2, None)]
|
||||
);
|
||||
|
||||
if objtype::isinstance(i2.clone(), vm.ctx.float_type()) {
|
||||
Ok(vm
|
||||
.ctx
|
||||
.new_float(get_value(i.clone()) + get_value(i2.clone())))
|
||||
} else if objtype::isinstance(i2.clone(), vm.ctx.int_type()) {
|
||||
Ok(vm
|
||||
.ctx
|
||||
.new_float(get_value(i.clone()) + objint::get_value(i2.clone()) as f64))
|
||||
} else {
|
||||
Err(vm.new_type_error(format!("Cannot add {:?} and {:?}", i, i2)))
|
||||
}
|
||||
}
|
||||
|
||||
fn float_sub(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [(i, Some(vm.ctx.float_type())), (i2, None)]
|
||||
);
|
||||
|
||||
if objtype::isinstance(i2.clone(), vm.ctx.float_type()) {
|
||||
Ok(vm
|
||||
.ctx
|
||||
.new_float(get_value(i.clone()) - get_value(i2.clone())))
|
||||
} else if objtype::isinstance(i2.clone(), vm.ctx.int_type()) {
|
||||
Ok(vm
|
||||
.ctx
|
||||
.new_float(get_value(i.clone()) - objint::get_value(i2.clone()) as f64))
|
||||
} else {
|
||||
Err(vm.new_type_error(format!("Cannot add {:?} and {:?}", i, i2)))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(context: &PyContext) {
|
||||
let ref float_type = context.float_type;
|
||||
float_type.set_attr("__add__", context.new_rustfunc(float_add));
|
||||
float_type.set_attr("__str__", context.new_rustfunc(str));
|
||||
float_type.set_attr("__sub__", context.new_rustfunc(float_sub));
|
||||
float_type.set_attr("__repr__", context.new_rustfunc(str));
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ fn str(vm: &mut VirtualMachine, args: PyFuncArgs) -> Result<PyObjectRef, PyObjec
|
||||
}
|
||||
|
||||
// Retrieve inner int value:
|
||||
fn get_value(obj: PyObjectRef) -> i32 {
|
||||
pub fn get_value(obj: PyObjectRef) -> i32 {
|
||||
if let PyObjectKind::Integer { value } = &obj.borrow().kind {
|
||||
*value
|
||||
} else {
|
||||
@@ -28,6 +28,10 @@ fn int_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
);
|
||||
if objtype::isinstance(i2.clone(), vm.ctx.int_type()) {
|
||||
Ok(vm.ctx.new_int(get_value(i.clone()) + get_value(i2.clone())))
|
||||
} else if objtype::isinstance(i2.clone(), vm.ctx.float_type()) {
|
||||
Ok(vm
|
||||
.ctx
|
||||
.new_float(get_value(i.clone()) as f64 + objfloat::get_value(i2.clone())))
|
||||
} else {
|
||||
Err(vm.new_type_error(format!("Cannot add {:?} and {:?}", i, i2)))
|
||||
}
|
||||
@@ -78,9 +82,23 @@ fn int_truediv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
}
|
||||
}
|
||||
|
||||
fn int_mod(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [(i, Some(vm.ctx.int_type())), (i2, None)]
|
||||
);
|
||||
if objtype::isinstance(i2.clone(), vm.ctx.int_type()) {
|
||||
Ok(vm.ctx.new_int(get_value(i.clone()) % get_value(i2.clone())))
|
||||
} else {
|
||||
Err(vm.new_type_error(format!("Cannot modulo {:?} and {:?}", i, i2)))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(context: &PyContext) {
|
||||
let ref int_type = context.int_type;
|
||||
int_type.set_attr("__add__", context.new_rustfunc(int_add));
|
||||
int_type.set_attr("__mod__", context.new_rustfunc(int_mod));
|
||||
int_type.set_attr("__mul__", context.new_rustfunc(int_mul));
|
||||
int_type.set_attr("__repr__", context.new_rustfunc(str));
|
||||
int_type.set_attr("__str__", context.new_rustfunc(str));
|
||||
|
||||
@@ -25,6 +25,31 @@ pub fn set_item(
|
||||
}
|
||||
}
|
||||
|
||||
fn get_elements(obj: PyObjectRef) -> Vec<PyObjectRef> {
|
||||
if let PyObjectKind::List { elements } = &obj.borrow().kind {
|
||||
elements.to_vec()
|
||||
} else {
|
||||
panic!("Cannot extract list elements from non-list");
|
||||
}
|
||||
}
|
||||
|
||||
fn list_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [(o, Some(vm.ctx.list_type())), (o2, None)]
|
||||
);
|
||||
|
||||
if objtype::isinstance(o2.clone(), vm.ctx.list_type()) {
|
||||
let e1 = get_elements(o.clone());
|
||||
let e2 = get_elements(o2.clone());
|
||||
let elements = e1.iter().chain(e2.iter()).map(|e| e.clone()).collect();
|
||||
Ok(vm.ctx.new_list(elements))
|
||||
} else {
|
||||
Err(vm.new_type_error(format!("Cannot add {:?} and {:?}", o, o2)))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn append(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
trace!("list.append called with: {:?}", args);
|
||||
arg_check!(
|
||||
@@ -78,6 +103,7 @@ fn reverse(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
|
||||
pub fn init(context: &PyContext) {
|
||||
let ref list_type = context.list_type;
|
||||
list_type.set_attr("__add__", context.new_rustfunc(list_add));
|
||||
list_type.set_attr("__len__", context.new_rustfunc(len));
|
||||
list_type.set_attr("append", context.new_rustfunc(append));
|
||||
list_type.set_attr("clear", context.new_rustfunc(clear));
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use super::objint;
|
||||
use super::objsequence::PySliceableSequence;
|
||||
use super::objtype;
|
||||
use super::pyobject::{
|
||||
@@ -7,12 +8,13 @@ use super::vm::VirtualMachine;
|
||||
|
||||
pub fn init(context: &PyContext) {
|
||||
let ref str_type = context.str_type;
|
||||
str_type.set_attr("__add__", context.new_rustfunc(str_add));
|
||||
str_type.set_attr("__mul__", context.new_rustfunc(str_mul));
|
||||
str_type.set_attr("__new__", context.new_rustfunc(str_new));
|
||||
str_type.set_attr("__str__", context.new_rustfunc(str_str));
|
||||
str_type.set_attr("__add__", context.new_rustfunc(str_add));
|
||||
}
|
||||
|
||||
fn get_value(obj: PyObjectRef) -> String {
|
||||
pub fn get_value(obj: PyObjectRef) -> String {
|
||||
if let PyObjectKind::String { value } = &obj.borrow().kind {
|
||||
value.to_string()
|
||||
} else {
|
||||
@@ -40,6 +42,25 @@ fn str_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
}
|
||||
}
|
||||
|
||||
fn str_mul(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [(s, Some(vm.ctx.str_type())), (s2, None)]
|
||||
);
|
||||
if objtype::isinstance(s2.clone(), vm.ctx.int_type()) {
|
||||
let value1 = get_value(s.clone());
|
||||
let value2 = objint::get_value(s2.clone());
|
||||
let mut result = String::new();
|
||||
for _x in 0..value2 {
|
||||
result.push_str(value1.as_str());
|
||||
}
|
||||
Ok(vm.ctx.new_str(result))
|
||||
} else {
|
||||
Err(vm.new_type_error(format!("Cannot multiply {:?} and {:?}", s, s2)))
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: should with following format
|
||||
// class str(object='')
|
||||
// class str(object=b'', encoding='utf-8', errors='strict')
|
||||
|
||||
@@ -14,7 +14,6 @@ use std::cell::RefCell;
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::ops::{Add, Mul, Rem, Sub};
|
||||
use std::rc::Rc;
|
||||
|
||||
/* Python objects and references.
|
||||
@@ -740,126 +739,6 @@ impl PyObject {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Add<&'a PyObject> for &'a PyObject {
|
||||
type Output = PyObjectKind;
|
||||
|
||||
fn add(self, rhs: &'a PyObject) -> Self::Output {
|
||||
match self.kind {
|
||||
PyObjectKind::Integer { value: ref value1 } => match &rhs.kind {
|
||||
PyObjectKind::Integer { value: ref value2 } => PyObjectKind::Integer {
|
||||
value: value1 + value2,
|
||||
},
|
||||
PyObjectKind::Float { value: ref value2 } => PyObjectKind::Float {
|
||||
value: (*value1 as f64) + value2,
|
||||
},
|
||||
_ => {
|
||||
panic!("NOT IMPL");
|
||||
}
|
||||
},
|
||||
PyObjectKind::Float { value: ref value1 } => match &rhs.kind {
|
||||
PyObjectKind::Float { value: ref value2 } => PyObjectKind::Float {
|
||||
value: value1 + value2,
|
||||
},
|
||||
PyObjectKind::Integer { value: ref value2 } => PyObjectKind::Float {
|
||||
value: value1 + (*value2 as f64),
|
||||
},
|
||||
_ => {
|
||||
panic!("NOT IMPL");
|
||||
}
|
||||
},
|
||||
PyObjectKind::String { value: ref value1 } => match rhs.kind {
|
||||
PyObjectKind::String { value: ref value2 } => PyObjectKind::String {
|
||||
value: format!("{}{}", value1, value2),
|
||||
},
|
||||
_ => {
|
||||
panic!("NOT IMPL");
|
||||
}
|
||||
},
|
||||
PyObjectKind::List { elements: ref e1 } => match rhs.kind {
|
||||
PyObjectKind::List { elements: ref e2 } => PyObjectKind::List {
|
||||
elements: e1.iter().chain(e2.iter()).map(|e| e.clone()).collect(),
|
||||
},
|
||||
_ => {
|
||||
panic!("NOT IMPL");
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
// TODO: Lookup __add__ method in dictionary?
|
||||
panic!("NOT IMPL");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Sub<&'a PyObject> for &'a PyObject {
|
||||
type Output = PyObjectKind;
|
||||
|
||||
fn sub(self, rhs: &'a PyObject) -> Self::Output {
|
||||
match self.kind {
|
||||
PyObjectKind::Integer { value: value1 } => match rhs.kind {
|
||||
PyObjectKind::Integer { value: value2 } => PyObjectKind::Integer {
|
||||
value: value1 - value2,
|
||||
},
|
||||
_ => {
|
||||
panic!("NOT IMPL");
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
panic!("NOT IMPL");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Mul<&'a PyObject> for &'a PyObject {
|
||||
type Output = PyObjectKind;
|
||||
|
||||
fn mul(self, rhs: &'a PyObject) -> Self::Output {
|
||||
match self.kind {
|
||||
PyObjectKind::Integer { value: value1 } => match rhs.kind {
|
||||
PyObjectKind::Integer { value: value2 } => PyObjectKind::Integer {
|
||||
value: value1 * value2,
|
||||
},
|
||||
_ => {
|
||||
panic!("NOT IMPL");
|
||||
}
|
||||
},
|
||||
PyObjectKind::String { value: ref value1 } => match rhs.kind {
|
||||
PyObjectKind::Integer { value: value2 } => {
|
||||
let mut result = String::new();
|
||||
for _x in 0..value2 {
|
||||
result.push_str(value1.as_str());
|
||||
}
|
||||
PyObjectKind::String { value: result }
|
||||
}
|
||||
_ => {
|
||||
panic!("NOT IMPL");
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
panic!("NOT IMPL");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Rem<&'a PyObject> for &'a PyObject {
|
||||
type Output = PyObjectKind;
|
||||
|
||||
fn rem(self, rhs: &'a PyObject) -> Self::Output {
|
||||
match (&self.kind, &rhs.kind) {
|
||||
(PyObjectKind::Integer { value: value1 }, PyObjectKind::Integer { value: value2 }) => {
|
||||
PyObjectKind::Integer {
|
||||
value: value1 % value2,
|
||||
}
|
||||
}
|
||||
(kind1, kind2) => {
|
||||
unimplemented!("% not implemented for kinds: {:?} {:?}", kind1, kind2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// impl<'a> PartialEq<&'a PyObject> for &'a PyObject {
|
||||
impl PartialEq for PyObject {
|
||||
fn eq(&self, other: &PyObject) -> bool {
|
||||
@@ -924,33 +803,7 @@ impl Ord for PyObject {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{PyContext, PyObjectKind};
|
||||
|
||||
#[test]
|
||||
fn test_add_py_integers() {
|
||||
let ctx = PyContext::new();
|
||||
let a = ctx.new_int(33);
|
||||
let b = ctx.new_int(12);
|
||||
let c = &*a.borrow() + &*b.borrow();
|
||||
match c {
|
||||
PyObjectKind::Integer { value } => assert_eq!(value, 45),
|
||||
_ => assert!(false),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiply_str() {
|
||||
let ctx = PyContext::new();
|
||||
let a = ctx.new_str(String::from("Hello "));
|
||||
let b = ctx.new_int(4);
|
||||
let c = &*a.borrow() * &*b.borrow();
|
||||
match c {
|
||||
PyObjectKind::String { value } => {
|
||||
assert_eq!(value, String::from("Hello Hello Hello Hello "))
|
||||
}
|
||||
_ => assert!(false),
|
||||
}
|
||||
}
|
||||
use super::PyContext;
|
||||
|
||||
#[test]
|
||||
fn test_type_type() {
|
||||
|
||||
60
vm/src/vm.rs
60
vm/src/vm.rs
@@ -319,33 +319,32 @@ impl VirtualMachine {
|
||||
}
|
||||
|
||||
fn _sub(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
|
||||
let b2 = &*b.borrow();
|
||||
let a2 = &*a.borrow();
|
||||
// TODO: Fix this correctly, and for all arithmetic operations
|
||||
Ok(PyObject::new(
|
||||
a2 - b2,
|
||||
a2.typ.clone().unwrap_or(self.get_type()),
|
||||
))
|
||||
self.call_method(a, "__sub__".to_string(), vec![b])
|
||||
}
|
||||
|
||||
fn _add(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
|
||||
let b2 = &*b.borrow();
|
||||
let a2 = &*a.borrow();
|
||||
Ok(PyObject::new(a2 + b2, self.get_type()))
|
||||
self.call_method(a, "__add__".to_string(), vec![b])
|
||||
}
|
||||
|
||||
fn _mul(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
|
||||
let b2 = &*b.borrow();
|
||||
let a2 = &*a.borrow();
|
||||
Ok(PyObject::new(a2 * b2, self.get_type()))
|
||||
self.call_method(a, "__mul__".to_string(), vec![b])
|
||||
}
|
||||
|
||||
fn _div(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
|
||||
let func = match self.get_attribute(a.clone(), &"__truediv__".to_string()) {
|
||||
self.call_method(a, "__truediv__".to_string(), vec![b])
|
||||
}
|
||||
|
||||
fn call_method(
|
||||
&mut self,
|
||||
obj: PyObjectRef,
|
||||
method_name: String,
|
||||
args: Vec<PyObjectRef>,
|
||||
) -> PyResult {
|
||||
let func = match self.get_attribute(obj, &method_name) {
|
||||
Ok(v) => v,
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
let args = PyFuncArgs { args: vec![b] };
|
||||
let args = PyFuncArgs { args: args };
|
||||
self.invoke(func, args)
|
||||
}
|
||||
|
||||
@@ -371,9 +370,7 @@ impl VirtualMachine {
|
||||
}
|
||||
|
||||
fn _modulo(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
|
||||
let b2 = &*b.borrow();
|
||||
let a2 = &*a.borrow();
|
||||
Ok(PyObject::new(a2 % b2, self.get_type()))
|
||||
self.call_method(a, "__mod__".to_string(), vec![b])
|
||||
}
|
||||
|
||||
fn execute_binop(&mut self, op: &bytecode::BinaryOperator) -> Option<PyResult> {
|
||||
@@ -931,3 +928,30 @@ impl VirtualMachine {
|
||||
self.current_frame().get_lineno()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::objint;
|
||||
use super::objstr;
|
||||
use super::VirtualMachine;
|
||||
|
||||
#[test]
|
||||
fn test_add_py_integers() {
|
||||
let mut vm = VirtualMachine::new();
|
||||
let a = vm.ctx.new_int(33);
|
||||
let b = vm.ctx.new_int(12);
|
||||
let res = vm._add(a, b).unwrap();
|
||||
let value = objint::get_value(res);
|
||||
assert_eq!(value, 45);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiply_str() {
|
||||
let mut vm = VirtualMachine::new();
|
||||
let a = vm.ctx.new_str(String::from("Hello "));
|
||||
let b = vm.ctx.new_int(4);
|
||||
let res = vm._mul(a, b).unwrap();
|
||||
let value = objstr::get_value(res);
|
||||
assert_eq!(value, String::from("Hello Hello Hello Hello "))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user