diff --git a/tests/snippets/bools.py b/tests/snippets/bools.py index 50bd39f7e..2aa817ca4 100644 --- a/tests/snippets/bools.py +++ b/tests/snippets/bools.py @@ -40,8 +40,6 @@ assert (True and True) assert not (False and fake) assert (True and 5) == 5 -assert bool.__doc__ == "bool(x) -> bool\n\nReturns True when the argument x is true, False otherwise.\nThe builtins True and False are the only two instances of the class bool.\nThe class bool is a subclass of the class int, and cannot be subclassed." - # Bools are also ints. assert isinstance(True, int) assert True + True == 2 diff --git a/tests/snippets/builtin_enumerate.py b/tests/snippets/builtin_enumerate.py new file mode 100644 index 000000000..75f6f7412 --- /dev/null +++ b/tests/snippets/builtin_enumerate.py @@ -0,0 +1,23 @@ +assert list(enumerate(['a', 'b', 'c'])) == [(0, 'a'), (1, 'b'), (2, 'c')] + +assert type(enumerate([])) == enumerate + +assert list(enumerate(['a', 'b', 'c'], -100)) == [(-100, 'a'), (-99, 'b'), (-98, 'c')] +assert list(enumerate(['a', 'b', 'c'], 2**200)) == [(2**200, 'a'), (2**200 + 1, 'b'), (2**200 + 2, 'c')] + + +# test infinite iterator +class Counter(object): + counter = 0 + + def __next__(self): + self.counter += 1 + return self.counter + + def __iter__(self): + return self + + +it = enumerate(Counter()) +assert next(it) == (0, 1) +assert next(it) == (1, 2) diff --git a/tests/snippets/builtin_range.py b/tests/snippets/builtin_range.py index 5e58ae069..c8efb189d 100644 --- a/tests/snippets/builtin_range.py +++ b/tests/snippets/builtin_range.py @@ -18,10 +18,16 @@ def assert_raises(expr, exc_type): assert range(2**63+1)[2**63] == 9223372036854775808 +# len tests +assert len(range(10, 5)) == 0, 'Range with no elements should have length = 0' +assert len(range(10, 5, -2)) == 3, 'Expected length 3, for elements: 10, 8, 6' +assert len(range(5, 10, 2)) == 3, 'Expected length 3, for elements: 5, 7, 9' + # index tests assert range(10).index(6) == 6 assert range(4, 10).index(6) == 2 assert range(4, 10, 2).index(6) == 1 +assert range(10, 4, -2).index(8) == 1 # index raises value error on out of bounds assert_raises(lambda _: range(10).index(-1), ValueError) @@ -29,3 +35,25 @@ assert_raises(lambda _: range(10).index(10), ValueError) # index raises value error if out of step assert_raises(lambda _: range(4, 10, 2).index(5), ValueError) + +# index raises value error if needle is not an int +assert_raises(lambda _: range(10).index('foo'), ValueError) + +# __bool__ +assert range(1).__bool__() +assert range(1, 2).__bool__() + +assert not range(0).__bool__() +assert not range(1, 1).__bool__() + +# __contains__ +assert range(10).__contains__(6) +assert range(4, 10).__contains__(6) +assert range(4, 10, 2).__contains__(6) +assert range(10, 4, -2).__contains__(10) +assert range(10, 4, -2).__contains__(8) + +assert not range(10).__contains__(-1) +assert not range(10, 4, -2).__contains__(9) +assert not range(10, 4, -2).__contains__(4) +assert not range(10).__contains__('foo') diff --git a/tests/snippets/builtin_slice.py b/tests/snippets/builtin_slice.py new file mode 100644 index 000000000..b7c3922c0 --- /dev/null +++ b/tests/snippets/builtin_slice.py @@ -0,0 +1,77 @@ + +a = [] +assert a[:] == [] +assert a[:2**100] == [] +assert a[-2**100:] == [] +assert a[::2**100] == [] +assert a[10:20] == [] +assert a[-20:-10] == [] + +b = [1, 2] + +assert b[:] == [1, 2] +assert b[:2**100] == [1, 2] +assert b[-2**100:] == [1, 2] +assert b[2**100:] == [] +assert b[::2**100] == [1] +assert b[-10:1] == [1] +assert b[0:0] == [] +assert b[1:0] == [] + +try: + _ = b[::0] +except ValueError: + pass +else: + assert False, "Zero step slice should raise ValueError" + +assert b[::-1] == [2, 1] +assert b[1::-1] == [2, 1] +assert b[0::-1] == [1] +assert b[0:-5:-1] == [1] +assert b[:0:-1] == [2] +assert b[5:0:-1] == [2] + +c = list(range(10)) + +assert c[9:6:-3] == [9] +assert c[9::-3] == [9, 6, 3, 0] +assert c[9::-4] == [9, 5, 1] +assert c[8::-2**100] == [8] + +assert c[7:7:-2] == [] +assert c[7:8:-2] == [] + +d = "123456" + +assert d[3::-1] == "4321" +assert d[4::-3] == "52" + + +slice_a = slice(5) +assert slice_a.start is None +assert slice_a.stop == 5 +assert slice_a.step is None + +slice_b = slice(1, 5) +assert slice_b.start == 1 +assert slice_b.stop == 5 +assert slice_b.step is None + +slice_c = slice(1, 5, 2) +assert slice_c.start == 1 +assert slice_c.stop == 5 +assert slice_c.step == 2 + + +class SubScript(object): + def __getitem__(self, item): + assert type(item) == slice + + def __setitem__(self, key, value): + assert type(key) == slice + + +ss = SubScript() +_ = ss[:] +ss[:1] = 1 diff --git a/tests/snippets/builtin_zip.py b/tests/snippets/builtin_zip.py new file mode 100644 index 000000000..3665c7702 --- /dev/null +++ b/tests/snippets/builtin_zip.py @@ -0,0 +1,24 @@ +assert list(zip(['a', 'b', 'c'], range(3), [9, 8, 7, 99])) == [('a', 0, 9), ('b', 1, 8), ('c', 2, 7)] + +assert list(zip(['a', 'b', 'c'])) == [('a',), ('b',), ('c',)] +assert list(zip()) == [] + +assert list(zip(*zip(['a', 'b', 'c'], range(1, 4)))) == [('a', 'b', 'c'), (1, 2, 3)] + + +# test infinite iterator +class Counter(object): + def __init__(self, counter=0): + self.counter = counter + + def __next__(self): + self.counter += 1 + return self.counter + + def __iter__(self): + return self + + +it = zip(Counter(), Counter(3)) +assert next(it) == (1, 4) +assert next(it) == (2, 5) diff --git a/tests/snippets/builtins.py b/tests/snippets/builtins.py index 539b49ef7..76b28a7b9 100644 --- a/tests/snippets/builtins.py +++ b/tests/snippets/builtins.py @@ -5,12 +5,8 @@ assert callable(type) # TODO: # assert callable(callable) -assert list(enumerate(['a', 'b', 'c'])) == [(0, 'a'), (1, 'b'), (2, 'c')] - assert type(frozenset) is type -assert list(zip(['a', 'b', 'c'], range(3), [9, 8, 7, 99])) == [('a', 0, 9), ('b', 1, 8), ('c', 2, 7)] - assert 3 == eval('1+2') code = compile('5+3', 'x.py', 'eval') diff --git a/tests/snippets/division_by_zero.py b/tests/snippets/division_by_zero.py new file mode 100644 index 000000000..7cb68cd76 --- /dev/null +++ b/tests/snippets/division_by_zero.py @@ -0,0 +1,34 @@ +try: + 5 / 0 +except ZeroDivisionError: + pass +else: + assert False, 'Expected ZeroDivisionError' + +try: + 5 / -0.0 +except ZeroDivisionError: + pass +else: + assert False, 'Expected ZeroDivisionError' + +try: + 5 / (2-2) +except ZeroDivisionError: + pass +else: + assert False, 'Expected ZeroDivisionError' + +try: + 5 % 0 +except ZeroDivisionError: + pass +else: + assert False, 'Expected ZeroDivisionError' + +try: + raise ZeroDivisionError('Is an ArithmeticError subclass?') +except ArithmeticError: + pass +else: + assert False, 'Expected ZeroDivisionError' diff --git a/tests/snippets/numbers.py b/tests/snippets/numbers.py index 12e892385..ecd2263af 100644 --- a/tests/snippets/numbers.py +++ b/tests/snippets/numbers.py @@ -2,8 +2,6 @@ x = 5 x.__init__(6) assert x == 5 -assert int.__doc__ == "int(x=0) -> integer\nint(x, base=10) -> integer\n\nConvert a number or string to an integer, or return 0 if no arguments\nare given. If x is a number, return x.__int__(). For floating point\nnumbers, this truncates towards zero.\n\nIf x is not a number or if base is given, then x must be a string,\nbytes, or bytearray instance representing an integer literal in the\ngiven base. The literal can be preceded by '+' or '-' and be surrounded\nby whitespace. The base defaults to 10. Valid bases are 0 and 2-36.\nBase 0 means to interpret the base from the string as an integer literal.\n>>> int('0b100', base=0)\n4" - class A(int): pass diff --git a/tests/snippets/tuple.py b/tests/snippets/tuple.py index 56a61f827..eb5102fa3 100644 --- a/tests/snippets/tuple.py +++ b/tests/snippets/tuple.py @@ -15,3 +15,7 @@ assert x * -1 == () # integers less than zero treated as 0 assert y < x, "tuple __lt__ failed" assert x > y, "tuple __gt__ failed" + + +b = (1,2,3) +assert b.index(2) == 1 diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 2e789383e..cad71c42e 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -21,8 +21,7 @@ use super::pyobject::{ use super::stdlib::io::io_open; use super::vm::VirtualMachine; -use num_bigint::ToBigInt; -use num_traits::{Signed, ToPrimitive, Zero}; +use num_traits::{Signed, ToPrimitive}; fn get_locals(vm: &mut VirtualMachine) -> PyObjectRef { let d = vm.new_dict(); @@ -136,11 +135,11 @@ fn builtin_compile(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let mode = { let mode = objstr::get_value(mode); - if mode == String::from("exec") { + if mode == "exec" { compile::Mode::Exec - } else if mode == "eval".to_string() { + } else if mode == "eval" { compile::Mode::Eval - } else if mode == "single".to_string() { + } else if mode == "single" { compile::Mode::Single } else { return Err( @@ -180,29 +179,6 @@ fn builtin_divmod(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } } -fn builtin_enumerate(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(iterable, None)], - optional = [(start, None)] - ); - let items = vm.extract_elements(iterable)?; - let start = if let Some(start) = start { - objint::get_value(start) - } else { - Zero::zero() - }; - let mut new_items = vec![]; - for (i, item) in items.into_iter().enumerate() { - let element = vm - .ctx - .new_tuple(vec![vm.ctx.new_int(i.to_bigint().unwrap() + &start), item]); - new_items.push(element); - } - Ok(vm.ctx.new_list(new_items)) -} - /// Implements `eval`. /// See also: https://docs.python.org/3/library/functions.html#eval fn builtin_eval(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -359,7 +335,7 @@ fn builtin_hex(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { fn builtin_id(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(obj, None)]); - Ok(vm.context().new_int(obj.get_id().to_bigint().unwrap())) + Ok(vm.context().new_int(obj.get_id())) } // builtin_input @@ -553,9 +529,7 @@ fn builtin_ord(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { ))); } match string.chars().next() { - Some(character) => Ok(vm - .context() - .new_int((character as i32).to_bigint().unwrap())), + Some(character) => Ok(vm.context().new_int(character as i32)), None => Err(vm.new_type_error( "ord() could not guess the integer representing this character".to_string(), )), @@ -635,7 +609,7 @@ fn builtin_sum(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let items = vm.extract_elements(iterable)?; // Start with zero and add at will: - let mut sum = vm.ctx.new_int(Zero::zero()); + let mut sum = vm.ctx.new_int(0); for item in items { sum = vm._add(sum, item)?; } @@ -643,32 +617,6 @@ fn builtin_sum(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } // builtin_vars - -fn builtin_zip(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - no_kwargs!(vm, args); - - // TODO: process one element at a time from iterators. - let mut iterables = vec![]; - for iterable in args.args.iter() { - let iterable = vm.extract_elements(iterable)?; - iterables.push(iterable); - } - - let minsize: usize = iterables.iter().map(|i| i.len()).min().unwrap_or(0); - - let mut new_items = vec![]; - for i in 0..minsize { - let items = iterables - .iter() - .map(|iterable| iterable[i].clone()) - .collect(); - let element = vm.ctx.new_tuple(items); - new_items.push(element); - } - - Ok(vm.ctx.new_list(new_items)) -} - // builtin___import__ pub fn make_module(ctx: &PyContext) -> PyObjectRef { @@ -694,7 +642,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { ctx.set_attr(&py_mod, "dict", ctx.dict_type()); ctx.set_attr(&py_mod, "divmod", ctx.new_rustfunc(builtin_divmod)); ctx.set_attr(&py_mod, "dir", ctx.new_rustfunc(builtin_dir)); - ctx.set_attr(&py_mod, "enumerate", ctx.new_rustfunc(builtin_enumerate)); + ctx.set_attr(&py_mod, "enumerate", ctx.enumerate_type()); ctx.set_attr(&py_mod, "eval", ctx.new_rustfunc(builtin_eval)); ctx.set_attr(&py_mod, "exec", ctx.new_rustfunc(builtin_exec)); ctx.set_attr(&py_mod, "float", ctx.float_type()); @@ -729,13 +677,14 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { ctx.set_attr(&py_mod, "repr", ctx.new_rustfunc(builtin_repr)); ctx.set_attr(&py_mod, "set", ctx.set_type()); ctx.set_attr(&py_mod, "setattr", ctx.new_rustfunc(builtin_setattr)); + ctx.set_attr(&py_mod, "slice", ctx.slice_type()); ctx.set_attr(&py_mod, "staticmethod", ctx.staticmethod_type()); ctx.set_attr(&py_mod, "str", ctx.str_type()); ctx.set_attr(&py_mod, "sum", ctx.new_rustfunc(builtin_sum)); ctx.set_attr(&py_mod, "super", ctx.super_type()); ctx.set_attr(&py_mod, "tuple", ctx.tuple_type()); ctx.set_attr(&py_mod, "type", ctx.type_type()); - ctx.set_attr(&py_mod, "zip", ctx.new_rustfunc(builtin_zip)); + ctx.set_attr(&py_mod, "zip", ctx.zip_type()); // Exceptions: ctx.set_attr( @@ -744,6 +693,11 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { ctx.exceptions.base_exception_type.clone(), ); ctx.set_attr(&py_mod, "Exception", ctx.exceptions.exception_type.clone()); + ctx.set_attr( + &py_mod, + "ArithmeticError", + ctx.exceptions.arithmetic_error.clone(), + ); ctx.set_attr( &py_mod, "AssertionError", @@ -755,6 +709,11 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { ctx.exceptions.attribute_error.clone(), ); ctx.set_attr(&py_mod, "NameError", ctx.exceptions.name_error.clone()); + ctx.set_attr( + &py_mod, + "OverflowError", + ctx.exceptions.overflow_error.clone(), + ); ctx.set_attr( &py_mod, "RuntimeError", @@ -774,6 +733,11 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { "StopIteration", ctx.exceptions.stop_iteration.clone(), ); + ctx.set_attr( + &py_mod, + "ZeroDivisionError", + ctx.exceptions.zero_division_error.clone(), + ); py_mod } diff --git a/vm/src/compile.rs b/vm/src/compile.rs index fb9215341..9050e59ce 100644 --- a/vm/src/compile.rs +++ b/vm/src/compile.rs @@ -45,9 +45,8 @@ pub fn compile( }, }; - match result { - Err(msg) => return Err(vm.new_exception(syntax_error.clone(), msg)), - _ => {} + if let Err(msg) = result { + return Err(vm.new_exception(syntax_error.clone(), msg)); } let code = compiler.pop_code_object(); @@ -589,7 +588,7 @@ impl Compiler { ast::Statement::Assign { targets, value } => { self.compile_expression(value)?; - for (i, target) in targets.into_iter().enumerate() { + for (i, target) in targets.iter().enumerate() { if i + 1 != targets.len() { self.emit(Instruction::Duplicate); } @@ -665,7 +664,7 @@ impl Compiler { let mut flags = bytecode::FunctionOpArg::empty(); if have_kwargs { - flags = flags | bytecode::FunctionOpArg::HAS_DEFAULTS; + flags |= bytecode::FunctionOpArg::HAS_DEFAULTS; } Ok(flags) diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs new file mode 100644 index 000000000..3c4dee22a --- /dev/null +++ b/vm/src/dictdatatype.rs @@ -0,0 +1,230 @@ +use super::obj::objbool; +use super::obj::objint; +use super::pyobject::{IdProtocol, PyObjectRef, PyResult}; +use super::vm::VirtualMachine; +use num_traits::ToPrimitive; +/// Ordered dictionary implementation. +/// Inspired by: https://morepypy.blogspot.com/2015/01/faster-more-memory-efficient-and-more.html +/// And: https://www.youtube.com/watch?v=p33CVV29OG8 +/// And: http://code.activestate.com/recipes/578375/ +use std::collections::HashMap; + +pub struct Dict { + size: usize, + indices: HashMap, + entries: Vec>, +} + +struct DictEntry { + hash: usize, + key: PyObjectRef, + value: PyObjectRef, +} + +impl Dict { + pub fn new() -> Self { + Dict { + size: 0, + indices: HashMap::new(), + entries: Vec::new(), + } + } + + /// Store a key + pub fn insert( + &mut self, + vm: &mut VirtualMachine, + key: &PyObjectRef, + value: PyObjectRef, + ) -> Result<(), PyObjectRef> { + match self.lookup(vm, key)? { + LookupResult::Existing(index) => { + // Update existing key + if let Some(ref mut entry) = self.entries[index] { + entry.value = value; + Ok(()) + } else { + panic!("Lookup returned invalid index into entries!"); + } + } + LookupResult::NewIndex { + hash_index, + hash_value, + } => { + // New key: + let entry = DictEntry { + hash: hash_value, + key: key.clone(), + value, + }; + let index = self.entries.len(); + self.entries.push(Some(entry)); + self.indices.insert(hash_index, index); + self.size += 1; + Ok(()) + } + } + } + + pub fn contains( + &self, + vm: &mut VirtualMachine, + key: &PyObjectRef, + ) -> Result { + if let LookupResult::Existing(_index) = self.lookup(vm, key)? { + Ok(true) + } else { + Ok(false) + } + } + + /// Retrieve a key + pub fn get(&self, vm: &mut VirtualMachine, key: &PyObjectRef) -> PyResult { + if let LookupResult::Existing(index) = self.lookup(vm, key)? { + if let Some(entry) = &self.entries[index] { + Ok(entry.value.clone()) + } else { + panic!("Lookup returned invalid index into entries!"); + } + } else { + let key_repr = vm.to_pystr(key)?; + Err(vm.new_value_error(format!("Key not found: {}", key_repr))) + } + } + + /// Delete a key + pub fn delete( + &mut self, + vm: &mut VirtualMachine, + key: &PyObjectRef, + ) -> Result<(), PyObjectRef> { + if let LookupResult::Existing(index) = self.lookup(vm, key)? { + self.entries[index] = None; + self.size -= 1; + Ok(()) + } else { + let key_repr = vm.to_pystr(key)?; + Err(vm.new_value_error(format!("Key not found: {}", key_repr))) + } + } + + pub fn len(&self) -> usize { + self.size + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn get_items(&self) -> Vec<(PyObjectRef, PyObjectRef)> { + self.entries + .iter() + .filter(|e| e.is_some()) + .map(|e| e.as_ref().unwrap()) + .map(|e| (e.key.clone(), e.value.clone())) + .collect() + } + + /// Lookup the index for the given key. + fn lookup( + &self, + vm: &mut VirtualMachine, + key: &PyObjectRef, + ) -> Result { + let hash_value = calc_hash(vm, key)?; + let perturb = hash_value; + let mut hash_index: usize = hash_value; + loop { + if self.indices.contains_key(&hash_index) { + // Now we have an index, lets check the key. + let index = self.indices[&hash_index]; + if let Some(entry) = &self.entries[index] { + // Okay, we have an entry at this place + if entry.key.is(key) { + // Literally the same object + break Ok(LookupResult::Existing(index)); + } else if entry.hash == hash_value { + if do_eq(vm, &entry.key, key)? { + break Ok(LookupResult::Existing(index)); + } else { + // entry mismatch. + } + } else { + // entry mismatch. + } + } else { + // Removed entry, continue search... + } + } else { + // Hash not in table, we are at free slot now. + break Ok(LookupResult::NewIndex { + hash_value, + hash_index, + }); + } + + // Update i to next probe location: + hash_index = hash_index + .wrapping_mul(5) + .wrapping_add(perturb) + .wrapping_add(1); + // warn!("Perturb value: {}", i); + } + } +} + +enum LookupResult { + NewIndex { + hash_value: usize, + hash_index: usize, + }, // return not found, index into indices + Existing(usize), // Existing record, index into entries +} + +fn calc_hash(vm: &mut VirtualMachine, key: &PyObjectRef) -> Result { + let hash = vm.call_method(key, "__hash__", vec![])?; + Ok(objint::get_value(&hash).to_usize().unwrap()) +} + +/// Invoke __eq__ on two keys +fn do_eq( + vm: &mut VirtualMachine, + key1: &PyObjectRef, + key2: &PyObjectRef, +) -> Result { + let result = vm._eq(key1, key2.clone())?; + Ok(objbool::get_value(&result)) +} + +#[cfg(test)] +mod tests { + use super::{Dict, VirtualMachine}; + + #[test] + fn test_insert() { + let mut vm = VirtualMachine::new(); + let mut dict = Dict::new(); + assert_eq!(0, dict.len()); + + let key1 = vm.new_bool(true); + let value1 = vm.new_str("abc".to_string()); + dict.insert(&mut vm, &key1, value1.clone()).unwrap(); + assert_eq!(1, dict.len()); + + let key2 = vm.new_str("x".to_string()); + let value2 = vm.new_str("def".to_string()); + dict.insert(&mut vm, &key2, value2.clone()).unwrap(); + assert_eq!(2, dict.len()); + + dict.insert(&mut vm, &key1, value2.clone()).unwrap(); + assert_eq!(2, dict.len()); + + dict.delete(&mut vm, &key1).unwrap(); + assert_eq!(1, dict.len()); + + dict.insert(&mut vm, &key1, value2).unwrap(); + assert_eq!(2, dict.len()); + + assert_eq!(true, dict.contains(&mut vm, &key1).unwrap()); + } +} diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index 305297919..be3652a73 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -81,6 +81,7 @@ fn exception_str(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { #[derive(Debug)] pub struct ExceptionZoo { + pub arithmetic_error: PyObjectRef, pub assertion_error: PyObjectRef, pub attribute_error: PyObjectRef, pub base_exception_type: PyObjectRef, @@ -93,12 +94,14 @@ pub struct ExceptionZoo { pub name_error: PyObjectRef, pub not_implemented_error: PyObjectRef, pub os_error: PyObjectRef, + pub overflow_error: PyObjectRef, pub permission_error: PyObjectRef, pub runtime_error: PyObjectRef, pub stop_iteration: PyObjectRef, pub syntax_error: PyObjectRef, pub type_error: PyObjectRef, pub value_error: PyObjectRef, + pub zero_division_error: PyObjectRef, } impl ExceptionZoo { @@ -113,6 +116,8 @@ impl ExceptionZoo { let exception_type = create_type("Exception", &type_type, &base_exception_type, &dict_type); + let arithmetic_error = + create_type("ArithmeticError", &type_type, &exception_type, &dict_type); let assertion_error = create_type("AssertionError", &type_type, &exception_type, &dict_type); let attribute_error = @@ -128,8 +133,18 @@ impl ExceptionZoo { let type_error = create_type("TypeError", &type_type, &exception_type, &dict_type); let value_error = create_type("ValueError", &type_type, &exception_type, &dict_type); + let overflow_error = + create_type("OverflowError", &type_type, &arithmetic_error, &dict_type); + let zero_division_error = create_type( + "ZeroDivisionError", + &type_type, + &arithmetic_error, + &dict_type, + ); + let module_not_found_error = create_type("ModuleNotFoundError", &type_type, &import_error, &dict_type); + let not_implemented_error = create_type( "NotImplementedError", &type_type, @@ -142,6 +157,7 @@ impl ExceptionZoo { let permission_error = create_type("PermissionError", &type_type, &os_error, &dict_type); ExceptionZoo { + arithmetic_error, assertion_error, attribute_error, base_exception_type, @@ -154,12 +170,14 @@ impl ExceptionZoo { name_error, not_implemented_error, os_error, + overflow_error, permission_error, runtime_error, stop_iteration, syntax_error, type_error, value_error, + zero_division_error, } } } diff --git a/vm/src/format.rs b/vm/src/format.rs index 5b4035ef7..6dfa77c7e 100644 --- a/vm/src/format.rs +++ b/vm/src/format.rs @@ -314,7 +314,7 @@ impl FormatSpec { } None => Ok(magnitude.to_str_radix(10)), }; - if !raw_magnitude_string_result.is_ok() { + if raw_magnitude_string_result.is_err() { return raw_magnitude_string_result; } let magnitude_string = format!( diff --git a/vm/src/frame.rs b/vm/src/frame.rs index ac118dcb7..86c0bc29d 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -20,8 +20,7 @@ use super::pyobject::{ PyResult, TypeProtocol, }; use super::vm::VirtualMachine; -use num_bigint::ToBigInt; -use num_traits::ToPrimitive; +use num_bigint::BigInt; #[derive(Clone, Debug)] enum Block { @@ -121,7 +120,7 @@ impl Frame { trace!("Adding to traceback: {:?} {:?}", traceback, lineno); let pos = vm.ctx.new_tuple(vec![ vm.ctx.new_str(filename.clone()), - vm.ctx.new_int(lineno.get_row().to_bigint().unwrap()), + vm.ctx.new_int(lineno.get_row()), vm.ctx.new_str(run_obj_name.clone()), ]); objlist::list_append( @@ -263,22 +262,22 @@ impl Frame { assert!(*size == 2 || *size == 3); let elements = self.pop_multiple(*size); - let mut out: Vec> = elements + let mut out: Vec> = elements .into_iter() .map(|x| match x.borrow().payload { - PyObjectPayload::Integer { ref value } => Some(value.to_i32().unwrap()), + PyObjectPayload::Integer { ref value } => Some(value.clone()), PyObjectPayload::None => None, _ => panic!("Expect Int or None as BUILD_SLICE arguments, got {:?}", x), }) .collect(); - let start = out[0]; - let stop = out[1]; - let step = if out.len() == 3 { out[2] } else { None }; + let start = out[0].take(); + let stop = out[1].take(); + let step = if out.len() == 3 { out[2].take() } else { None }; let obj = PyObject::new( PyObjectPayload::Slice { start, stop, step }, - vm.ctx.type_type(), + vm.ctx.slice_type(), ); self.push_value(obj); Ok(None) @@ -611,7 +610,7 @@ impl Frame { .iter() .skip(*before) .take(middle) - .map(|x| x.clone()) + .cloned() .collect(); let t = vm.ctx.new_list(middle_elements); self.push_value(t); @@ -887,25 +886,25 @@ impl Frame { ) -> FrameResult { let b_ref = self.pop_value(); let a_ref = self.pop_value(); - let value = match op { - &bytecode::BinaryOperator::Subtract => vm._sub(a_ref, b_ref), - &bytecode::BinaryOperator::Add => vm._add(a_ref, b_ref), - &bytecode::BinaryOperator::Multiply => vm._mul(a_ref, b_ref), - &bytecode::BinaryOperator::MatrixMultiply => { + let value = match *op { + bytecode::BinaryOperator::Subtract => vm._sub(a_ref, b_ref), + bytecode::BinaryOperator::Add => vm._add(a_ref, b_ref), + bytecode::BinaryOperator::Multiply => vm._mul(a_ref, b_ref), + bytecode::BinaryOperator::MatrixMultiply => { vm.call_method(&a_ref, "__matmul__", vec![b_ref]) } - &bytecode::BinaryOperator::Power => vm._pow(a_ref, b_ref), - &bytecode::BinaryOperator::Divide => vm._div(a_ref, b_ref), - &bytecode::BinaryOperator::FloorDivide => { + bytecode::BinaryOperator::Power => vm._pow(a_ref, b_ref), + bytecode::BinaryOperator::Divide => vm._div(a_ref, b_ref), + bytecode::BinaryOperator::FloorDivide => { vm.call_method(&a_ref, "__floordiv__", vec![b_ref]) } - &bytecode::BinaryOperator::Subscript => self.subscript(vm, a_ref, b_ref), - &bytecode::BinaryOperator::Modulo => vm._modulo(a_ref, b_ref), - &bytecode::BinaryOperator::Lshift => vm.call_method(&a_ref, "__lshift__", vec![b_ref]), - &bytecode::BinaryOperator::Rshift => vm.call_method(&a_ref, "__rshift__", vec![b_ref]), - &bytecode::BinaryOperator::Xor => vm._xor(a_ref, b_ref), - &bytecode::BinaryOperator::Or => vm._or(a_ref, b_ref), - &bytecode::BinaryOperator::And => vm._and(a_ref, b_ref), + bytecode::BinaryOperator::Subscript => self.subscript(vm, a_ref, b_ref), + bytecode::BinaryOperator::Modulo => vm._modulo(a_ref, b_ref), + bytecode::BinaryOperator::Lshift => vm.call_method(&a_ref, "__lshift__", vec![b_ref]), + bytecode::BinaryOperator::Rshift => vm.call_method(&a_ref, "__rshift__", vec![b_ref]), + bytecode::BinaryOperator::Xor => vm._xor(a_ref, b_ref), + bytecode::BinaryOperator::Or => vm._or(a_ref, b_ref), + bytecode::BinaryOperator::And => vm._and(a_ref, b_ref), }?; self.push_value(value); @@ -918,11 +917,11 @@ impl Frame { op: &bytecode::UnaryOperator, ) -> FrameResult { let a = self.pop_value(); - let value = match op { - &bytecode::UnaryOperator::Minus => vm.call_method(&a, "__neg__", vec![])?, - &bytecode::UnaryOperator::Plus => vm.call_method(&a, "__pos__", vec![])?, - &bytecode::UnaryOperator::Invert => vm.call_method(&a, "__invert__", vec![])?, - &bytecode::UnaryOperator::Not => { + let value = match *op { + bytecode::UnaryOperator::Minus => vm.call_method(&a, "__neg__", vec![])?, + bytecode::UnaryOperator::Plus => vm.call_method(&a, "__pos__", vec![])?, + bytecode::UnaryOperator::Invert => vm.call_method(&a, "__invert__", vec![])?, + bytecode::UnaryOperator::Not => { let value = objbool::boolval(vm, a)?; vm.ctx.new_bool(!value) } @@ -995,17 +994,17 @@ impl Frame { ) -> FrameResult { let b = self.pop_value(); let a = self.pop_value(); - let value = match op { - &bytecode::ComparisonOperator::Equal => vm._eq(&a, b)?, - &bytecode::ComparisonOperator::NotEqual => vm._ne(&a, b)?, - &bytecode::ComparisonOperator::Less => vm._lt(&a, b)?, - &bytecode::ComparisonOperator::LessOrEqual => vm._le(&a, b)?, - &bytecode::ComparisonOperator::Greater => vm._gt(&a, b)?, - &bytecode::ComparisonOperator::GreaterOrEqual => vm._ge(&a, b)?, - &bytecode::ComparisonOperator::Is => vm.ctx.new_bool(self._is(a, b)), - &bytecode::ComparisonOperator::IsNot => self._is_not(vm, a, b)?, - &bytecode::ComparisonOperator::In => self._in(vm, a, b)?, - &bytecode::ComparisonOperator::NotIn => self._not_in(vm, a, b)?, + let value = match *op { + bytecode::ComparisonOperator::Equal => vm._eq(&a, b)?, + bytecode::ComparisonOperator::NotEqual => vm._ne(&a, b)?, + bytecode::ComparisonOperator::Less => vm._lt(&a, b)?, + bytecode::ComparisonOperator::LessOrEqual => vm._le(&a, b)?, + bytecode::ComparisonOperator::Greater => vm._gt(&a, b)?, + bytecode::ComparisonOperator::GreaterOrEqual => vm._ge(&a, b)?, + bytecode::ComparisonOperator::Is => vm.ctx.new_bool(self._is(a, b)), + bytecode::ComparisonOperator::IsNot => self._is_not(vm, a, b)?, + bytecode::ComparisonOperator::In => self._in(vm, a, b)?, + bytecode::ComparisonOperator::NotIn => self._not_in(vm, a, b)?, }; self.push_value(value); @@ -1036,7 +1035,7 @@ impl Frame { fn unwrap_constant(&self, vm: &VirtualMachine, value: &bytecode::Constant) -> PyObjectRef { match *value { - bytecode::Constant::Integer { ref value } => vm.ctx.new_int(value.to_bigint().unwrap()), + bytecode::Constant::Integer { ref value } => vm.ctx.new_int(value.clone()), bytecode::Constant::Float { ref value } => vm.ctx.new_float(*value), bytecode::Constant::Complex { ref value } => vm.ctx.new_complex(*value), bytecode::Constant::String { ref value } => vm.new_str(value.clone()), diff --git a/vm/src/import.rs b/vm/src/import.rs index e560d3b18..be8807881 100644 --- a/vm/src/import.rs +++ b/vm/src/import.rs @@ -103,7 +103,7 @@ fn find_source(vm: &VirtualMachine, current_path: PathBuf, name: &str) -> Result } } - match filepaths.iter().filter(|p| p.exists()).next() { + match filepaths.iter().find(|p| p.exists()) { Some(path) => Ok(path.to_path_buf()), None => Err(format!("No module named '{}'", name)), } diff --git a/vm/src/obj/mod.rs b/vm/src/obj/mod.rs index a60ce79ff..f08998fcb 100644 --- a/vm/src/obj/mod.rs +++ b/vm/src/obj/mod.rs @@ -6,6 +6,7 @@ pub mod objbytes; pub mod objcode; pub mod objcomplex; pub mod objdict; +pub mod objenumerate; pub mod objfilter; pub mod objfloat; pub mod objframe; @@ -21,7 +22,9 @@ pub mod objproperty; pub mod objrange; pub mod objsequence; pub mod objset; +pub mod objslice; pub mod objstr; pub mod objsuper; pub mod objtuple; pub mod objtype; +pub mod objzip; diff --git a/vm/src/obj/objbool.rs b/vm/src/obj/objbool.rs index f3a2fafb8..dbec65a7c 100644 --- a/vm/src/obj/objbool.rs +++ b/vm/src/obj/objbool.rs @@ -20,6 +20,7 @@ pub fn boolval(vm: &mut VirtualMachine, obj: PyObjectRef) -> Result !value.is_zero(), _ => return Err(vm.new_type_error(String::from("TypeError"))), }; + v } else { true @@ -81,7 +82,7 @@ fn bool_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(match val { Some(val) => { let bv = boolval(vm, val.clone())?; - vm.new_bool(bv.clone()) + vm.new_bool(bv) } None => vm.context().new_bool(false), }) diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index f62b689af..b4afaa639 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -9,7 +9,6 @@ use super::objint; use super::super::vm::VirtualMachine; use super::objbytes::get_value; use super::objtype; -use num_bigint::ToBigInt; use num_traits::ToPrimitive; // Binary data support @@ -17,6 +16,20 @@ use num_traits::ToPrimitive; /// Fill bytearray class methods dictionary. pub fn init(context: &PyContext) { let bytearray_type = &context.bytearray_type; + + let bytearray_doc = + "bytearray(iterable_of_ints) -> bytearray\n\ + bytearray(string, encoding[, errors]) -> bytearray\n\ + bytearray(bytes_or_buffer) -> mutable copy of bytes_or_buffer\n\ + bytearray(int) -> bytes array of size given by the parameter initialized with null bytes\n\ + bytearray() -> empty bytes array\n\n\ + Construct a mutable bytearray object from:\n \ + - an iterable yielding integers in range(256)\n \ + - a text string encoded using the specified encoding\n \ + - a bytes or a buffer object\n \ + - any object implementing the buffer API.\n \ + - an integer"; + context.set_attr( &bytearray_type, "__eq__", @@ -37,6 +50,11 @@ pub fn init(context: &PyContext) { "__len__", context.new_rustfunc(bytesarray_len), ); + context.set_attr( + &bytearray_type, + "__doc__", + context.new_str(bytearray_doc.to_string()), + ); context.set_attr( &bytearray_type, "isalnum", @@ -110,8 +128,7 @@ fn bytesarray_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(a, Some(vm.ctx.bytearray_type()))]); let byte_vec = get_value(a).to_vec(); - let value = byte_vec.len().to_bigint(); - Ok(vm.ctx.new_int(value.unwrap())) + Ok(vm.ctx.new_int(byte_vec.len())) } fn bytearray_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -205,11 +222,12 @@ fn bytearray_istitle(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } }; - if is_cased(current) && next.is_uppercase() && !prev_cased { - return Ok(vm.new_bool(false)); - } else if !is_cased(current) && next.is_lowercase() { + if (is_cased(current) && next.is_uppercase() && !prev_cased) + || (!is_cased(current) && next.is_lowercase()) + { return Ok(vm.new_bool(false)); } + prev_cased = is_cased(current); } diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 6288264a2..1909247cb 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -4,7 +4,6 @@ use super::super::pyobject::{ use super::super::vm::VirtualMachine; use super::objint; use super::objtype; -use num_bigint::ToBigInt; use num_traits::ToPrimitive; use std::cell::Ref; use std::hash::{Hash, Hasher}; @@ -15,12 +14,29 @@ use std::ops::Deref; // Fill bytes class methods: pub fn init(context: &PyContext) { let bytes_type = &context.bytes_type; + + let bytes_doc = + "bytes(iterable_of_ints) -> bytes\n\ + bytes(string, encoding[, errors]) -> bytes\n\ + bytes(bytes_or_buffer) -> immutable copy of bytes_or_buffer\n\ + bytes(int) -> bytes object of size given by the parameter initialized with null bytes\n\ + bytes() -> empty bytes object\n\nConstruct an immutable array of bytes from:\n \ + - an iterable yielding integers in range(256)\n \ + - a text string encoded using the specified encoding\n \ + - any object implementing the buffer API.\n \ + - an integer"; + context.set_attr(bytes_type, "__eq__", context.new_rustfunc(bytes_eq)); context.set_attr(bytes_type, "__hash__", context.new_rustfunc(bytes_hash)); context.set_attr(bytes_type, "__new__", context.new_rustfunc(bytes_new)); context.set_attr(bytes_type, "__repr__", context.new_rustfunc(bytes_repr)); context.set_attr(bytes_type, "__len__", context.new_rustfunc(bytes_len)); - context.set_attr(bytes_type, "__iter__", context.new_rustfunc(bytes_iter)) + context.set_attr(bytes_type, "__iter__", context.new_rustfunc(bytes_iter)); + context.set_attr( + bytes_type, + "__doc__", + context.new_str(bytes_doc.to_string()), + ); } fn bytes_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -70,8 +86,7 @@ fn bytes_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(a, Some(vm.ctx.bytes_type()))]); let byte_vec = get_value(a).to_vec(); - let value = byte_vec.len().to_bigint(); - Ok(vm.ctx.new_int(value.unwrap())) + Ok(vm.ctx.new_int(byte_vec.len())) } fn bytes_hash(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -80,7 +95,7 @@ fn bytes_hash(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let mut hasher = std::collections::hash_map::DefaultHasher::new(); data.hash(&mut hasher); let hash = hasher.finish(); - Ok(vm.ctx.new_int(hash.to_bigint().unwrap())) + Ok(vm.ctx.new_int(hash)) } pub fn get_value<'a>(obj: &'a PyObjectRef) -> impl Deref> + 'a { diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index bc624d50f..353b89b7f 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -8,8 +8,18 @@ use num_complex::Complex64; pub fn init(context: &PyContext) { let complex_type = &context.complex_type; + + let complex_doc = + "Create a complex number from a real part and an optional imaginary part.\n\n\ + This is equivalent to (real + imag*1j) where imag defaults to 0."; + context.set_attr(&complex_type, "__add__", context.new_rustfunc(complex_add)); context.set_attr(&complex_type, "__new__", context.new_rustfunc(complex_new)); + context.set_attr( + &complex_type, + "__doc__", + context.new_str(complex_doc.to_string()), + ); context.set_attr( &complex_type, "__repr__", diff --git a/vm/src/obj/objdict.rs b/vm/src/obj/objdict.rs index b547eb906..8a3e77ff8 100644 --- a/vm/src/obj/objdict.rs +++ b/vm/src/obj/objdict.rs @@ -5,8 +5,7 @@ use super::super::vm::VirtualMachine; use super::objiter; use super::objstr; use super::objtype; -use num_bigint::ToBigInt; -use std::cell::{Ref, RefMut}; +use std::cell::{Ref, RefCell, RefMut}; use std::collections::HashMap; use std::ops::{Deref, DerefMut}; @@ -106,10 +105,7 @@ pub fn contains_key_str(dict: &PyObjectRef, key: &str) -> bool { pub fn content_contains_key_str(elements: &DictContentType, key: &str) -> bool { // TODO: let hash: usize = key; - match elements.get(key) { - Some(_) => true, - None => false, - } + elements.get(key).is_some() } // Python dict methods: @@ -140,7 +136,7 @@ fn dict_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let elem_iter = objiter::get_iter(vm, &element)?; let needle = objiter::get_next_object(vm, &elem_iter)?.ok_or_else(|| err(vm))?; let value = objiter::get_next_object(vm, &elem_iter)?.ok_or_else(|| err(vm))?; - if let Some(_) = objiter::get_next_object(vm, &elem_iter)? { + if objiter::get_next_object(vm, &elem_iter)?.is_some() { return Err(err(vm)); } set_item(&dict, &needle, &value); @@ -156,7 +152,7 @@ fn dict_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { fn dict_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(dict_obj, Some(vm.ctx.dict_type()))]); let elements = get_elements(dict_obj); - Ok(vm.ctx.new_int(elements.len().to_bigint().unwrap())) + Ok(vm.ctx.new_int(elements.len())) } fn dict_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -277,7 +273,7 @@ fn dict_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn create_type(type_type: PyObjectRef, object_type: PyObjectRef, dict_type: PyObjectRef) { (*dict_type.borrow_mut()).payload = PyObjectPayload::Class { name: String::from("dict"), - dict: new(dict_type.clone()), + dict: RefCell::new(HashMap::new()), mro: vec![object_type], }; (*dict_type.borrow_mut()).typ = Some(type_type.clone()); diff --git a/vm/src/obj/objenumerate.rs b/vm/src/obj/objenumerate.rs new file mode 100644 index 000000000..582f89852 --- /dev/null +++ b/vm/src/obj/objenumerate.rs @@ -0,0 +1,69 @@ +use super::super::pyobject::{ + PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, +}; +use super::super::vm::VirtualMachine; +use super::objint; +use super::objiter; +use super::objtype; // Required for arg_check! to use isinstance +use num_bigint::BigInt; +use num_traits::Zero; +use std::ops::AddAssign; + +fn enumerate_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(cls, Some(vm.ctx.type_type())), (iterable, None)], + optional = [(start, Some(vm.ctx.int_type()))] + ); + let counter = if let Some(x) = start { + objint::get_value(x) + } else { + BigInt::zero() + }; + let iterator = objiter::get_iter(vm, iterable)?; + Ok(PyObject::new( + PyObjectPayload::EnumerateIterator { counter, iterator }, + cls.clone(), + )) +} + +fn enumerate_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(enumerate, Some(vm.ctx.enumerate_type()))] + ); + + if let PyObjectPayload::EnumerateIterator { + ref mut counter, + ref mut iterator, + } = enumerate.borrow_mut().payload + { + let next_obj = objiter::call_next(vm, iterator)?; + let result = vm + .ctx + .new_tuple(vec![vm.ctx.new_int(counter.clone()), next_obj]); + + AddAssign::add_assign(counter, 1); + + Ok(result) + } else { + panic!("enumerate doesn't have correct payload"); + } +} + +pub fn init(context: &PyContext) { + let enumerate_type = &context.enumerate_type; + objiter::iter_type_init(context, enumerate_type); + context.set_attr( + enumerate_type, + "__new__", + context.new_rustfunc(enumerate_new), + ); + context.set_attr( + enumerate_type, + "__next__", + context.new_rustfunc(enumerate_next), + ); +} diff --git a/vm/src/obj/objfilter.rs b/vm/src/obj/objfilter.rs index 009d2ad2e..b4bc4ff5e 100644 --- a/vm/src/obj/objfilter.rs +++ b/vm/src/obj/objfilter.rs @@ -7,7 +7,7 @@ use super::objbool; use super::objiter; use super::objtype; // Required for arg_check! to use isinstance -pub fn filter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { +fn filter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, args, @@ -23,21 +23,6 @@ pub fn filter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { )) } -fn filter_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(filter, Some(vm.ctx.filter_type()))]); - // Return self: - Ok(filter.clone()) -} - -fn filter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(filter, Some(vm.ctx.filter_type())), (needle, None)] - ); - objiter::contains(vm, filter, needle) -} - fn filter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(filter, Some(vm.ctx.filter_type()))]); @@ -72,12 +57,7 @@ fn filter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn init(context: &PyContext) { let filter_type = &context.filter_type; - context.set_attr( - &filter_type, - "__contains__", - context.new_rustfunc(filter_contains), - ); - context.set_attr(&filter_type, "__iter__", context.new_rustfunc(filter_iter)); + objiter::iter_type_init(context, filter_type); context.set_attr(&filter_type, "__new__", context.new_rustfunc(filter_new)); context.set_attr(&filter_type, "__next__", context.new_rustfunc(filter_next)); } diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index 4cb996a51..1bb57e572 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -268,6 +268,9 @@ fn float_pow(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn init(context: &PyContext) { let float_type = &context.float_type; + + let float_doc = "Convert a string or number to a floating point number, if possible."; + context.set_attr(&float_type, "__eq__", context.new_rustfunc(float_eq)); context.set_attr(&float_type, "__lt__", context.new_rustfunc(float_lt)); context.set_attr(&float_type, "__le__", context.new_rustfunc(float_le)); @@ -291,4 +294,9 @@ pub fn init(context: &PyContext) { context.set_attr(&float_type, "__pow__", context.new_rustfunc(float_pow)); context.set_attr(&float_type, "__sub__", context.new_rustfunc(float_sub)); context.set_attr(&float_type, "__repr__", context.new_rustfunc(float_repr)); + context.set_attr( + &float_type, + "__doc__", + context.new_str(float_doc.to_string()), + ); } diff --git a/vm/src/obj/objfunction.rs b/vm/src/obj/objfunction.rs index f04897481..01c05aab1 100644 --- a/vm/src/obj/objfunction.rs +++ b/vm/src/obj/objfunction.rs @@ -1,6 +1,6 @@ use super::super::pyobject::{ - AttributeProtocol, IdProtocol, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, - PyResult, TypeProtocol, + AttributeProtocol, IdProtocol, PyContext, PyFuncArgs, PyObjectPayload, PyObjectRef, PyResult, + TypeProtocol, }; use super::super::vm::VirtualMachine; use super::objtype; @@ -110,12 +110,7 @@ fn classmethod_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { trace!("classmethod.__new__ {:?}", args.args); arg_check!(vm, args, required = [(cls, None), (callable, None)]); - let py_obj = PyObject::new( - PyObjectPayload::Instance { - dict: vm.ctx.new_dict(), - }, - cls.clone(), - ); + let py_obj = vm.ctx.new_instance(cls.clone(), None); vm.ctx.set_attr(&py_obj, "function", callable.clone()); Ok(py_obj) } @@ -148,12 +143,7 @@ fn staticmethod_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { trace!("staticmethod.__new__ {:?}", args.args); arg_check!(vm, args, required = [(cls, None), (callable, None)]); - let py_obj = PyObject::new( - PyObjectPayload::Instance { - dict: vm.ctx.new_dict(), - }, - cls.clone(), - ); + let py_obj = vm.ctx.new_instance(cls.clone(), None); vm.ctx.set_attr(&py_obj, "function", callable.clone()); Ok(py_obj) } diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index 06424e149..d37af68c9 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -250,7 +250,7 @@ fn int_hash(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let mut hasher = std::collections::hash_map::DefaultHasher::new(); value.hash(&mut hasher); let hash = hasher.finish(); - Ok(vm.ctx.new_int(hash.to_bigint().unwrap())) + Ok(vm.ctx.new_int(hash)) } fn int_abs(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -366,17 +366,25 @@ fn int_truediv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { args, required = [(i, Some(vm.ctx.int_type())), (i2, None)] ); - let v1 = get_value(i); - if objtype::isinstance(i2, &vm.ctx.int_type()) { - Ok(vm - .ctx - .new_float(v1.to_f64().unwrap() / get_value(i2).to_f64().unwrap())) + + let v1 = get_value(i) + .to_f64() + .ok_or_else(|| vm.new_overflow_error("int too large to convert to float".to_string()))?; + + let v2 = if objtype::isinstance(i2, &vm.ctx.int_type()) { + get_value(i2) + .to_f64() + .ok_or_else(|| vm.new_overflow_error("int too large to convert to float".to_string()))? } else if objtype::isinstance(i2, &vm.ctx.float_type()) { - Ok(vm - .ctx - .new_float(v1.to_f64().unwrap() / objfloat::get_value(i2))) + objfloat::get_value(i2) } else { - Err(vm.new_type_error(format!("Cannot divide {} and {}", i.borrow(), i2.borrow()))) + return Err(vm.new_type_error(format!("Cannot divide {} and {}", i.borrow(), i2.borrow()))); + }; + + if v2 == 0.0 { + Err(vm.new_zero_division_error("integer division by zero".to_string())) + } else { + Ok(vm.ctx.new_float(v1 / v2)) } } @@ -388,7 +396,13 @@ fn int_mod(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { ); let v1 = get_value(i); if objtype::isinstance(i2, &vm.ctx.int_type()) { - Ok(vm.ctx.new_int(v1 % get_value(i2))) + let v2 = get_value(i2); + + if v2 != BigInt::zero() { + Ok(vm.ctx.new_int(v1 % get_value(i2))) + } else { + Err(vm.new_zero_division_error("integer modulo by zero".to_string())) + } } else { Err(vm.new_type_error(format!("Cannot modulo {} and {}", i.borrow(), i2.borrow()))) } @@ -414,7 +428,7 @@ fn int_pow(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let v1 = get_value(i); if objtype::isinstance(i2, &vm.ctx.int_type()) { let v2 = get_value(i2).to_u32().unwrap(); - Ok(vm.ctx.new_int(v1.pow(v2).to_bigint().unwrap())) + Ok(vm.ctx.new_int(v1.pow(v2))) } else if objtype::isinstance(i2, &vm.ctx.float_type()) { let v2 = objfloat::get_value(i2); Ok(vm.ctx.new_float((v1.to_f64().unwrap()).powf(v2))) @@ -513,7 +527,7 @@ fn int_bit_length(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(i, Some(vm.ctx.int_type()))]); let v = get_value(i); let bits = v.bits(); - Ok(vm.ctx.new_int(bits.to_bigint().unwrap())) + Ok(vm.ctx.new_int(bits)) } fn int_conjugate(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { diff --git a/vm/src/obj/objiter.rs b/vm/src/obj/objiter.rs index b7368f7b2..40633e984 100644 --- a/vm/src/obj/objiter.rs +++ b/vm/src/obj/objiter.rs @@ -9,7 +9,6 @@ use super::super::vm::VirtualMachine; use super::objbool; // use super::objstr; use super::objtype; // Required for arg_check! to use isinstance -use num_bigint::{BigInt, ToBigInt}; /* * This helper function is called at multiple places. First, it is called @@ -65,7 +64,17 @@ pub fn get_all( Ok(elements) } -pub fn contains(vm: &mut VirtualMachine, iter: &PyObjectRef, needle: &PyObjectRef) -> PyResult { +pub fn new_stop_iteration(vm: &mut VirtualMachine) -> PyObjectRef { + let stop_iteration_type = vm.ctx.exceptions.stop_iteration.clone(); + vm.new_exception(stop_iteration_type, "End of iterator".to_string()) +} + +fn contains(vm: &mut VirtualMachine, args: PyFuncArgs, iter_type: PyObjectRef) -> PyResult { + arg_check!( + vm, + args, + required = [(iter, Some(iter_type)), (needle, None)] + ); loop { if let Some(element) = get_next_object(vm, iter)? { let equal = vm.call_method(needle, "__eq__", vec![element.clone()])?; @@ -80,6 +89,34 @@ pub fn contains(vm: &mut VirtualMachine, iter: &PyObjectRef, needle: &PyObjectRe } } +/// Common setup for iter types, adds __iter__ and __contains__ methods +pub fn iter_type_init(context: &PyContext, iter_type: &PyObjectRef) { + let contains_func = { + let cloned_iter_type = iter_type.clone(); + move |vm: &mut VirtualMachine, args: PyFuncArgs| { + contains(vm, args, cloned_iter_type.clone()) + } + }; + context.set_attr( + &iter_type, + "__contains__", + context.new_rustfunc(contains_func), + ); + let iter_func = { + let cloned_iter_type = iter_type.clone(); + move |vm: &mut VirtualMachine, args: PyFuncArgs| { + arg_check!( + vm, + args, + required = [(iter, Some(cloned_iter_type.clone()))] + ); + // Return self: + Ok(iter.clone()) + } + }; + context.set_attr(&iter_type, "__iter__", context.new_rustfunc(iter_func)); +} + // Sequence iterator: fn iter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(iter_target, None)]); @@ -87,21 +124,6 @@ fn iter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { get_iter(vm, iter_target) } -fn iter_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(iter, Some(vm.ctx.iter_type()))]); - // Return self: - Ok(iter.clone()) -} - -fn iter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(iter, Some(vm.ctx.iter_type())), (needle, None)] - ); - contains(vm, iter, needle) -} - fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(iter, Some(vm.ctx.iter_type()))]); @@ -118,35 +140,26 @@ fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { *position += 1; Ok(obj_ref) } else { - let stop_iteration_type = vm.ctx.exceptions.stop_iteration.clone(); - let stop_iteration = - vm.new_exception(stop_iteration_type, "End of iterator".to_string()); - Err(stop_iteration) + Err(new_stop_iteration(vm)) } } PyObjectPayload::Range { ref range } => { - if let Some(int) = range.get(BigInt::from(*position)) { + if let Some(int) = range.get(*position) { *position += 1; - Ok(vm.ctx.new_int(int.to_bigint().unwrap())) + Ok(vm.ctx.new_int(int)) } else { - let stop_iteration_type = vm.ctx.exceptions.stop_iteration.clone(); - let stop_iteration = - vm.new_exception(stop_iteration_type, "End of iterator".to_string()); - Err(stop_iteration) + Err(new_stop_iteration(vm)) } } PyObjectPayload::Bytes { ref value } => { if *position < value.len() { - let obj_ref = vm.ctx.new_int(value[*position].to_bigint().unwrap()); + let obj_ref = vm.ctx.new_int(value[*position]); *position += 1; Ok(obj_ref) } else { - let stop_iteration_type = vm.ctx.exceptions.stop_iteration.clone(); - let stop_iteration = - vm.new_exception(stop_iteration_type, "End of iterator".to_string()); - Err(stop_iteration) + Err(new_stop_iteration(vm)) } } @@ -161,12 +174,7 @@ fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn init(context: &PyContext) { let iter_type = &context.iter_type; - context.set_attr( - &iter_type, - "__contains__", - context.new_rustfunc(iter_contains), - ); - context.set_attr(&iter_type, "__iter__", context.new_rustfunc(iter_iter)); + iter_type_init(context, iter_type); context.set_attr(&iter_type, "__new__", context.new_rustfunc(iter_new)); context.set_attr(&iter_type, "__next__", context.new_rustfunc(iter_next)); } diff --git a/vm/src/obj/objlist.rs b/vm/src/obj/objlist.rs index 4ec4a7e80..369ed9fc5 100644 --- a/vm/src/obj/objlist.rs +++ b/vm/src/obj/objlist.rs @@ -10,7 +10,6 @@ use super::objsequence::{ }; use super::objstr; use super::objtype; -use num_bigint::ToBigInt; use num_traits::ToPrimitive; // set_item: @@ -22,9 +21,12 @@ fn set_item( ) -> PyResult { if objtype::isinstance(&idx, &vm.ctx.int_type()) { let value = objint::get_value(&idx).to_i32().unwrap(); - let pos_index = l.get_pos(value); - l[pos_index] = obj; - Ok(vm.get_none()) + if let Some(pos_index) = l.get_pos(value) { + l[pos_index] = obj; + Ok(vm.get_none()) + } else { + Err(vm.new_index_error("list index out of range".to_string())) + } } else { panic!( "TypeError: indexing type {:?} with index {:?} is not supported (yet?)", @@ -172,7 +174,7 @@ fn list_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { if objtype::isinstance(o2, &vm.ctx.list_type()) { let e1 = get_elements(o); let e2 = get_elements(o2); - let elements = e1.iter().chain(e2.iter()).map(|e| e.clone()).collect(); + let elements = e1.iter().chain(e2.iter()).cloned().collect(); Ok(vm.ctx.new_list(elements)) } else { Err(vm.new_type_error(format!("Cannot add {} and {}", o.borrow(), o2.borrow()))) @@ -224,10 +226,10 @@ fn list_count(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { for element in elements.iter() { let is_eq = vm._eq(element, value.clone())?; if objbool::boolval(vm, is_eq)? { - count = count + 1; + count += 1; } } - Ok(vm.context().new_int(count.to_bigint().unwrap())) + Ok(vm.context().new_int(count)) } pub fn list_extend(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -252,7 +254,7 @@ fn list_index(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { for (index, element) in get_elements(list).iter().enumerate() { let py_equal = vm.call_method(needle, "__eq__", vec![element.clone()])?; if objbool::get_value(&py_equal) { - return Ok(vm.context().new_int(index.to_bigint().unwrap())); + return Ok(vm.context().new_int(index)); } } let needle_str = objstr::get_value(&vm.to_str(needle).unwrap()); @@ -263,7 +265,7 @@ fn list_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { trace!("list.len called with: {:?}", args); arg_check!(vm, args, required = [(list, Some(vm.ctx.list_type()))]); let elements = get_elements(list); - Ok(vm.context().new_int(elements.len().to_bigint().unwrap())) + Ok(vm.context().new_int(elements.len())) } fn list_reverse(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { diff --git a/vm/src/obj/objmap.rs b/vm/src/obj/objmap.rs index ed6130643..722eba01a 100644 --- a/vm/src/obj/objmap.rs +++ b/vm/src/obj/objmap.rs @@ -5,7 +5,7 @@ use super::super::vm::VirtualMachine; use super::objiter; use super::objtype; // Required for arg_check! to use isinstance -pub fn map_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { +fn map_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { no_kwargs!(vm, args); let cls = &args.args[0]; if args.args.len() < 3 { @@ -27,21 +27,6 @@ pub fn map_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } } -fn map_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(map, Some(vm.ctx.map_type()))]); - // Return self: - Ok(map.clone()) -} - -fn map_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(map, Some(vm.ctx.map_type())), (needle, None)] - ); - objiter::contains(vm, map, needle) -} - fn map_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(map, Some(vm.ctx.map_type()))]); @@ -70,12 +55,7 @@ fn map_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn init(context: &PyContext) { let map_type = &context.map_type; - context.set_attr( - &map_type, - "__contains__", - context.new_rustfunc(map_contains), - ); - context.set_attr(&map_type, "__iter__", context.new_rustfunc(map_iter)); + objiter::iter_type_init(context, map_type); context.set_attr(&map_type, "__new__", context.new_rustfunc(map_new)); context.set_attr(&map_type, "__next__", context.new_rustfunc(map_next)); } diff --git a/vm/src/obj/objobject.rs b/vm/src/obj/objobject.rs index 190f3078b..cb5065b20 100644 --- a/vm/src/obj/objobject.rs +++ b/vm/src/obj/objobject.rs @@ -1,25 +1,25 @@ use super::super::pyobject::{ - AttributeProtocol, IdProtocol, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, - PyResult, TypeProtocol, + AttributeProtocol, IdProtocol, PyContext, PyFuncArgs, PyObjectPayload, PyObjectRef, PyResult, + TypeProtocol, }; use super::super::vm::VirtualMachine; use super::objbool; -use super::objdict; use super::objstr; use super::objtype; +use std::cell::RefCell; +use std::collections::HashMap; pub fn new_instance(vm: &mut VirtualMachine, mut args: PyFuncArgs) -> PyResult { // more or less __new__ operator let type_ref = args.shift(); - let dict = vm.new_dict(); - let obj = PyObject::new(PyObjectPayload::Instance { dict }, type_ref.clone()); + let obj = vm.ctx.new_instance(type_ref.clone(), None); Ok(obj) } -pub fn create_object(type_type: PyObjectRef, object_type: PyObjectRef, dict_type: PyObjectRef) { +pub fn create_object(type_type: PyObjectRef, object_type: PyObjectRef, _dict_type: PyObjectRef) { (*object_type.borrow_mut()).payload = PyObjectPayload::Class { name: String::from("object"), - dict: objdict::new(dict_type), + dict: RefCell::new(HashMap::new()), mro: vec![], }; (*object_type.borrow_mut()).typ = Some(type_type.clone()); @@ -62,15 +62,14 @@ fn object_delattr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { ] ); - // Get dict: - let dict = match zelf.borrow().payload { - PyObjectPayload::Class { ref dict, .. } => dict.clone(), - PyObjectPayload::Instance { ref dict, .. } => dict.clone(), - _ => return Err(vm.new_type_error("TypeError: no dictionary.".to_string())), - }; - - // Delete attr from dict: - vm.call_method(&dict, "__delitem__", vec![attr.clone()]) + match zelf.borrow().payload { + PyObjectPayload::Class { ref dict, .. } | PyObjectPayload::Instance { ref dict, .. } => { + let attr_name = objstr::get_value(attr); + dict.borrow_mut().remove(&attr_name); + Ok(vm.get_none()) + } + _ => Err(vm.new_type_error("TypeError: no dictionary.".to_string())), + } } fn object_str(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -116,8 +115,13 @@ fn object_init(vm: &mut VirtualMachine, _args: PyFuncArgs) -> PyResult { fn object_dict(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { match args.args[0].borrow().payload { - PyObjectPayload::Class { ref dict, .. } => Ok(dict.clone()), - PyObjectPayload::Instance { ref dict, .. } => Ok(dict.clone()), + PyObjectPayload::Class { ref dict, .. } | PyObjectPayload::Instance { ref dict, .. } => { + let new_dict = vm.new_dict(); + for (attr, value) in dict.borrow().iter() { + vm.ctx.set_item(&new_dict, &attr, value.clone()); + } + Ok(new_dict) + } _ => Err(vm.new_type_error("TypeError: no dictionary.".to_string())), } } diff --git a/vm/src/obj/objproperty.rs b/vm/src/obj/objproperty.rs index 1d146889b..ada639270 100644 --- a/vm/src/obj/objproperty.rs +++ b/vm/src/obj/objproperty.rs @@ -2,14 +2,41 @@ */ -use super::super::pyobject::{ - PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, -}; +use super::super::pyobject::{PyContext, PyFuncArgs, PyObjectRef, PyResult, TypeProtocol}; use super::super::vm::VirtualMachine; use super::objtype; pub fn init(context: &PyContext) { let property_type = &context.property_type; + + let property_doc = + "Property attribute.\n\n \ + fget\n \ + function to be used for getting an attribute value\n \ + fset\n \ + function to be used for setting an attribute value\n \ + fdel\n \ + function to be used for del\'ing an attribute\n \ + doc\n \ + docstring\n\n\ + Typical use is to define a managed attribute x:\n\n\ + class C(object):\n \ + def getx(self): return self._x\n \ + def setx(self, value): self._x = value\n \ + def delx(self): del self._x\n \ + x = property(getx, setx, delx, \"I\'m the \'x\' property.\")\n\n\ + Decorators make defining new properties or modifying existing ones easy:\n\n\ + class C(object):\n \ + @property\n \ + def x(self):\n \"I am the \'x\' property.\"\n \ + return self._x\n \ + @x.setter\n \ + def x(self, value):\n \ + self._x = value\n \ + @x.deleter\n \ + def x(self):\n \ + del self._x"; + context.set_attr( &property_type, "__get__", @@ -20,6 +47,11 @@ pub fn init(context: &PyContext) { "__new__", context.new_rustfunc(property_new), ); + context.set_attr( + &property_type, + "__doc__", + context.new_str(property_doc.to_string()), + ); // TODO: how to handle __set__ ? } @@ -55,12 +87,7 @@ fn property_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { trace!("property.__new__ {:?}", args.args); arg_check!(vm, args, required = [(cls, None), (fget, None)]); - let py_obj = PyObject::new( - PyObjectPayload::Instance { - dict: vm.ctx.new_dict(), - }, - cls.clone(), - ); + let py_obj = vm.ctx.new_instance(cls.clone(), None); vm.ctx.set_attr(&py_obj, "fget", fget.clone()); Ok(py_obj) } diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index 647298ea2..91b875ae3 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -4,9 +4,10 @@ use super::super::pyobject::{ use super::super::vm::VirtualMachine; use super::objint; use super::objtype; -use num_bigint::{BigInt, ToBigInt}; +use num_bigint::{BigInt, Sign}; use num_integer::Integer; use num_traits::{One, Signed, ToPrimitive, Zero}; +use std::ops::Mul; #[derive(Debug, Clone)] pub struct RangeType { @@ -18,25 +19,50 @@ pub struct RangeType { } impl RangeType { + #[inline] + pub fn try_len(&self) -> Option { + match self.step.sign() { + Sign::Plus if self.start < self.end => ((&self.end - &self.start - 1usize) + / &self.step) + .to_usize() + .map(|sz| sz + 1), + Sign::Minus if self.start > self.end => ((&self.start - &self.end - 1usize) + / (-&self.step)) + .to_usize() + .map(|sz| sz + 1), + _ => Some(0), + } + } + #[inline] pub fn len(&self) -> usize { - ((self.end.clone() - self.start.clone()) / self.step.clone()) - .abs() - .to_usize() - .unwrap() + self.try_len().unwrap() + } + + #[inline] + fn offset(&self, value: &BigInt) -> Option { + match self.step.sign() { + Sign::Plus if value >= &self.start && value < &self.end => Some(value - &self.start), + Sign::Minus if value <= &self.start && value > &self.end => Some(&self.start - value), + _ => None, + } + } + + #[inline] + pub fn contains(&self, value: &BigInt) -> bool { + match self.offset(value) { + Some(ref offset) => offset.is_multiple_of(&self.step), + None => false, + } } #[inline] pub fn index_of(&self, value: &BigInt) -> Option { - if value < &self.start || value >= &self.end { - return None; - } - - let offset = value - &self.start; - if offset.is_multiple_of(&self.step) { - Some(offset / &self.step) - } else { - None + match self.offset(value) { + Some(ref offset) if offset.is_multiple_of(&self.step) => { + Some((offset / &self.step).abs()) + } + Some(_) | None => None, } } @@ -52,17 +78,29 @@ impl RangeType { } #[inline] - pub fn get(&self, index: BigInt) -> Option { - let result = self.start.clone() + self.step.clone() * index; + pub fn get<'a, T>(&'a self, index: T) -> Option + where + &'a BigInt: Mul, + { + let result = &self.start + &self.step * index; - if self.forward() && !self.is_empty() && result < self.end { - Some(result) - } else if !self.forward() && !self.is_empty() && result > self.end { + if (self.forward() && !self.is_empty() && result < self.end) + || (!self.forward() && !self.is_empty() && result > self.end) + { Some(result) } else { None } } + + #[inline] + pub fn repr(&self) -> String { + if self.step == BigInt::one() { + format!("range({}, {})", self.start, self.end) + } else { + format!("range({}, {}, {})", self.start, self.end, self.step) + } + } } pub fn init(context: &PyContext) { @@ -75,6 +113,13 @@ pub fn init(context: &PyContext) { "__getitem__", context.new_rustfunc(range_getitem), ); + context.set_attr(&range_type, "__repr__", context.new_rustfunc(range_repr)); + context.set_attr(&range_type, "__bool__", context.new_rustfunc(range_bool)); + context.set_attr( + &range_type, + "__contains__", + context.new_rustfunc(range_contains), + ); context.set_attr(&range_type, "index", context.new_rustfunc(range_index)); } @@ -134,12 +179,14 @@ fn range_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { fn range_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(zelf, Some(vm.ctx.range_type()))]); - let len = match zelf.borrow().payload { - PyObjectPayload::Range { ref range } => range.len(), + if let Some(len) = match zelf.borrow().payload { + PyObjectPayload::Range { ref range } => range.try_len(), _ => unreachable!(), - }; - - Ok(vm.ctx.new_int(len.to_bigint().unwrap())) + } { + Ok(vm.ctx.new_int(len)) + } else { + Err(vm.new_overflow_error("Python int too large to convert to Rust usize".to_string())) + } } fn range_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -156,20 +203,19 @@ fn range_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { match subscript.borrow().payload { PyObjectPayload::Integer { ref value } => { - if let Some(int) = zrange.get(value.clone()) { - Ok(PyObject::new( - PyObjectPayload::Integer { - value: int.to_bigint().unwrap(), - }, - vm.ctx.int_type(), - )) + if let Some(int) = zrange.get(value) { + Ok(vm.ctx.new_int(int)) } else { Err(vm.new_index_error("range object index out of range".to_string())) } } - PyObjectPayload::Slice { start, stop, step } => { + PyObjectPayload::Slice { + ref start, + ref stop, + ref step, + } => { let new_start = if let Some(int) = start { - if let Some(i) = zrange.get(int.into()) { + if let Some(i) = zrange.get(int) { i } else { zrange.start.clone() @@ -179,7 +225,7 @@ fn range_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { }; let new_end = if let Some(int) = stop { - if let Some(i) = zrange.get(int.into()) { + if let Some(i) = zrange.get(int) { i } else { zrange.end @@ -189,7 +235,7 @@ fn range_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { }; let new_step = if let Some(int) = step { - (int as i64) * zrange.step + int * zrange.step } else { zrange.step }; @@ -210,6 +256,45 @@ fn range_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } } +fn range_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(zelf, Some(vm.ctx.range_type()))]); + + let s = match zelf.borrow().payload { + PyObjectPayload::Range { ref range } => range.repr(), + _ => unreachable!(), + }; + + Ok(vm.ctx.new_str(s)) +} + +fn range_bool(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(zelf, Some(vm.ctx.range_type()))]); + + let len = match zelf.borrow().payload { + PyObjectPayload::Range { ref range } => range.len(), + _ => unreachable!(), + }; + + Ok(vm.ctx.new_bool(len > 0)) +} + +fn range_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(zelf, Some(vm.ctx.range_type())), (needle, None)] + ); + + if let PyObjectPayload::Range { ref range } = zelf.borrow().payload { + Ok(vm.ctx.new_bool(match needle.borrow().payload { + PyObjectPayload::Integer { ref value } => range.contains(value), + _ => false, + })) + } else { + unreachable!() + } +} + fn range_index(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, diff --git a/vm/src/obj/objsequence.rs b/vm/src/obj/objsequence.rs index c5504b22d..5d85b12a0 100644 --- a/vm/src/obj/objsequence.rs +++ b/vm/src/obj/objsequence.rs @@ -2,53 +2,97 @@ use super::super::pyobject::{PyObject, PyObjectPayload, PyObjectRef, PyResult, T use super::super::vm::VirtualMachine; use super::objbool; use super::objint; -use num_traits::ToPrimitive; +use num_bigint::BigInt; +use num_traits::{One, Signed, ToPrimitive, Zero}; use std::cell::{Ref, RefMut}; use std::marker::Sized; -use std::ops::{Deref, DerefMut}; +use std::ops::{Deref, DerefMut, Range}; pub trait PySliceableSequence { - fn do_slice(&self, start: usize, stop: usize) -> Self; - fn do_stepped_slice(&self, start: usize, stop: usize, step: usize) -> Self; + fn do_slice(&self, range: Range) -> Self; + fn do_slice_reverse(&self, range: Range) -> Self; + fn do_stepped_slice(&self, range: Range, step: usize) -> Self; + fn do_stepped_slice_reverse(&self, range: Range, step: usize) -> Self; + fn empty() -> Self; fn len(&self) -> usize; - fn get_pos(&self, p: i32) -> usize { + fn get_pos(&self, p: i32) -> Option { if p < 0 { if -p as usize > self.len() { - // return something that is out of bounds so `get_item` raises an IndexError - self.len() + 1 + None } else { - self.len() - ((-p) as usize) + Some(self.len() - ((-p) as usize)) } - } else if p as usize > self.len() { - // This is for the slicing case where the end element is greater than the length of the - // sequence - self.len() + } else if p as usize >= self.len() { + None } else { - p as usize + Some(p as usize) } } - fn get_slice_items(&self, slice: &PyObjectRef) -> Self + + fn get_slice_pos(&self, slice_pos: &BigInt) -> usize { + if let Some(pos) = slice_pos.to_i32() { + if let Some(index) = self.get_pos(pos) { + // within bounds + return index; + } + } + + if slice_pos.is_negative() { + 0 + } else { + self.len() + } + } + + fn get_slice_range(&self, start: &Option, stop: &Option) -> Range { + let start = start.as_ref().map(|x| self.get_slice_pos(x)).unwrap_or(0); + let stop = stop + .as_ref() + .map(|x| self.get_slice_pos(x)) + .unwrap_or(self.len()); + + start..stop + } + + fn get_slice_items( + &self, + vm: &mut VirtualMachine, + slice: &PyObjectRef, + ) -> Result where Self: Sized, { // TODO: we could potentially avoid this copy and use slice match &(slice.borrow()).payload { PyObjectPayload::Slice { start, stop, step } => { - let start = match start { - &Some(start) => self.get_pos(start), - &None => 0, - }; - let stop = match stop { - &Some(stop) => self.get_pos(stop), - &None => self.len() as usize, - }; - match step { - &None | &Some(1) => self.do_slice(start, stop), - &Some(num) => { - if num < 0 { - unimplemented!("negative step indexing not yet supported") - }; - self.do_stepped_slice(start, stop, num as usize) + let step = step.clone().unwrap_or(BigInt::one()); + if step.is_zero() { + Err(vm.new_value_error("slice step cannot be zero".to_string())) + } else if step.is_positive() { + let range = self.get_slice_range(start, stop); + if range.start < range.end { + match step.to_i32() { + Some(1) => Ok(self.do_slice(range)), + Some(num) => Ok(self.do_stepped_slice(range, num as usize)), + None => Ok(self.do_slice(range.start..range.start + 1)), + } + } else { + Ok(Self::empty()) + } + } else { + // calculate the range for the reverse slice, first the bounds needs to be made + // exclusive around stop, the lower number + let start = start.as_ref().map(|x| x + 1); + let stop = stop.as_ref().map(|x| x + 1); + let range = self.get_slice_range(&stop, &start); + if range.start < range.end { + match (-step).to_i32() { + Some(1) => Ok(self.do_slice_reverse(range)), + Some(num) => Ok(self.do_stepped_slice_reverse(range, num as usize)), + None => Ok(self.do_slice(range.end - 1..range.end)), + } + } else { + Ok(Self::empty()) } } } @@ -58,12 +102,28 @@ pub trait PySliceableSequence { } impl PySliceableSequence for Vec { - fn do_slice(&self, start: usize, stop: usize) -> Self { - self[start..stop].to_vec() + fn do_slice(&self, range: Range) -> Self { + self[range].to_vec() } - fn do_stepped_slice(&self, start: usize, stop: usize, step: usize) -> Self { - self[start..stop].iter().step_by(step).cloned().collect() + + fn do_slice_reverse(&self, range: Range) -> Self { + let mut slice = self[range].to_vec(); + slice.reverse(); + slice } + + fn do_stepped_slice(&self, range: Range, step: usize) -> Self { + self[range].iter().step_by(step).cloned().collect() + } + + fn do_stepped_slice_reverse(&self, range: Range, step: usize) -> Self { + self[range].iter().rev().step_by(step).cloned().collect() + } + + fn empty() -> Self { + Vec::new() + } + fn len(&self) -> usize { self.len() } @@ -78,8 +138,7 @@ pub fn get_item( match &(subscript.borrow()).payload { PyObjectPayload::Integer { value } => match value.to_i32() { Some(value) => { - let pos_index = elements.to_vec().get_pos(value); - if pos_index < elements.len() { + if let Some(pos_index) = elements.to_vec().get_pos(value) { let obj = elements[pos_index].clone(); Ok(obj) } else { @@ -94,7 +153,7 @@ pub fn get_item( PyObjectPayload::Slice { .. } => Ok(PyObject::new( match &(sequence.borrow()).payload { PyObjectPayload::Sequence { .. } => PyObjectPayload::Sequence { - elements: elements.to_vec().get_slice_items(&subscript), + elements: elements.to_vec().get_slice_items(vm, &subscript)?, }, ref payload => panic!("sequence get_item called for non-sequence: {:?}", payload), }, diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index b22213dd0..5ac7507d6 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -11,7 +11,6 @@ use super::objbool; use super::objiter; use super::objstr; use super::objtype; -use num_bigint::ToBigInt; use std::collections::HashMap; pub fn get_elements(obj: &PyObjectRef) -> HashMap { @@ -67,15 +66,10 @@ fn set_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Some(iterable) => { let mut elements = HashMap::new(); let iterator = objiter::get_iter(vm, iterable)?; - loop { - match vm.call_method(&iterator, "__next__", vec![]) { - Ok(v) => { - // TODO: should we use the hash function here? - let key = v.get_id(); - elements.insert(key, v); - } - _ => break, - } + while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) { + // TODO: should we use the hash function here? + let key = v.get_id(); + elements.insert(key, v); } elements } @@ -91,7 +85,7 @@ fn set_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { trace!("set.len called with: {:?}", args); arg_check!(vm, args, required = [(s, Some(vm.ctx.set_type()))]); let elements = get_elements(s); - Ok(vm.context().new_int(elements.len().to_bigint().unwrap())) + Ok(vm.context().new_int(elements.len())) } fn set_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -152,6 +146,11 @@ fn frozenset_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn init(context: &PyContext) { let set_type = &context.set_type; + + let set_doc = "set() -> new empty set object\n\ + set(iterable) -> new set object\n\n\ + Build an unordered collection of unique elements."; + context.set_attr( &set_type, "__contains__", @@ -160,15 +159,26 @@ pub fn init(context: &PyContext) { context.set_attr(&set_type, "__len__", context.new_rustfunc(set_len)); context.set_attr(&set_type, "__new__", context.new_rustfunc(set_new)); context.set_attr(&set_type, "__repr__", context.new_rustfunc(set_repr)); + context.set_attr(&set_type, "__doc__", context.new_str(set_doc.to_string())); context.set_attr(&set_type, "add", context.new_rustfunc(set_add)); let frozenset_type = &context.frozenset_type; + + let frozenset_doc = "frozenset() -> empty frozenset object\n\ + frozenset(iterable) -> frozenset object\n\n\ + Build an immutable unordered collection of unique elements."; + context.set_attr( &frozenset_type, "__contains__", context.new_rustfunc(set_contains), ); context.set_attr(&frozenset_type, "__len__", context.new_rustfunc(set_len)); + context.set_attr( + &frozenset_type, + "__doc__", + context.new_str(frozenset_doc.to_string()), + ); context.set_attr( &frozenset_type, "__repr__", diff --git a/vm/src/obj/objslice.rs b/vm/src/obj/objslice.rs new file mode 100644 index 000000000..547e252ba --- /dev/null +++ b/vm/src/obj/objslice.rs @@ -0,0 +1,95 @@ +use super::super::pyobject::{ + PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, +}; +use super::super::vm::VirtualMachine; +use super::objint; +use super::objtype; // Required for arg_check! to use isinstance +use num_bigint::BigInt; + +fn slice_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + no_kwargs!(vm, args); + let (cls, start, stop, step): ( + &PyObjectRef, + Option<&PyObjectRef>, + Option<&PyObjectRef>, + Option<&PyObjectRef>, + ) = match args.args.len() { + 0 | 1 => Err(vm.new_type_error("slice() must have at least one arguments.".to_owned())), + 2 => { + arg_check!( + vm, + args, + required = [ + (cls, Some(vm.ctx.type_type())), + (stop, Some(vm.ctx.int_type())) + ] + ); + Ok((cls, None, Some(stop), None)) + } + _ => { + arg_check!( + vm, + args, + required = [ + (cls, Some(vm.ctx.type_type())), + (start, Some(vm.ctx.int_type())), + (stop, Some(vm.ctx.int_type())) + ], + optional = [(step, Some(vm.ctx.int_type()))] + ); + Ok((cls, Some(start), Some(stop), step)) + } + }?; + Ok(PyObject::new( + PyObjectPayload::Slice { + start: start.map(|x| objint::get_value(x)), + stop: stop.map(|x| objint::get_value(x)), + step: step.map(|x| objint::get_value(x)), + }, + cls.clone(), + )) +} + +fn get_property_value(vm: &mut VirtualMachine, value: &Option) -> PyResult { + if let Some(value) = value { + Ok(vm.ctx.new_int(value.clone())) + } else { + Ok(vm.get_none()) + } +} + +fn slice_start(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(slice, Some(vm.ctx.slice_type()))]); + if let PyObjectPayload::Slice { start, .. } = &slice.borrow().payload { + get_property_value(vm, start) + } else { + panic!("Slice has incorrect payload."); + } +} + +fn slice_stop(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(slice, Some(vm.ctx.slice_type()))]); + if let PyObjectPayload::Slice { stop, .. } = &slice.borrow().payload { + get_property_value(vm, stop) + } else { + panic!("Slice has incorrect payload."); + } +} + +fn slice_step(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(slice, Some(vm.ctx.slice_type()))]); + if let PyObjectPayload::Slice { step, .. } = &slice.borrow().payload { + get_property_value(vm, step) + } else { + panic!("Slice has incorrect payload."); + } +} + +pub fn init(context: &PyContext) { + let zip_type = &context.slice_type; + + context.set_attr(zip_type, "__new__", context.new_rustfunc(slice_new)); + context.set_attr(zip_type, "start", context.new_property(slice_start)); + context.set_attr(zip_type, "stop", context.new_property(slice_stop)); + context.set_attr(zip_type, "step", context.new_property(slice_step)); +} diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index d71cf3a69..a49e7a2c6 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -6,9 +6,9 @@ use super::super::vm::VirtualMachine; use super::objint; use super::objsequence::PySliceableSequence; use super::objtype; -use num_bigint::ToBigInt; use num_traits::ToPrimitive; use std::hash::{Hash, Hasher}; +use std::ops::Range; // rust's builtin to_lowercase isn't sufficient for casefold extern crate caseless; extern crate unicode_segmentation; @@ -308,13 +308,13 @@ fn str_hash(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let mut hasher = std::collections::hash_map::DefaultHasher::new(); value.hash(&mut hasher); let hash = hasher.finish(); - Ok(vm.ctx.new_int(hash.to_bigint().unwrap())) + Ok(vm.ctx.new_int(hash)) } fn str_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(s, Some(vm.ctx.str_type()))]); let sv = get_value(s); - Ok(vm.ctx.new_int(sv.chars().count().to_bigint().unwrap())) + Ok(vm.ctx.new_int(sv.chars().count())) } fn str_mul(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -512,12 +512,11 @@ fn str_zfill(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { ); let value = get_value(&s); let len = objint::get_value(&len).to_usize().unwrap(); - let new_str: String; - if len <= value.len() { - new_str = value; + let new_str = if len <= value.len() { + value } else { - new_str = format!("{}{}", "0".repeat(len - value.len()), value); - } + format!("{}{}", "0".repeat(len - value.len()), value) + }; Ok(vm.ctx.new_str(new_str)) } @@ -554,7 +553,7 @@ fn str_count(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Err(e) => return Err(vm.new_index_error(e)), }; let num_occur: usize = value[start..end].matches(&sub).count(); - Ok(vm.ctx.new_int(num_occur.to_bigint().unwrap())) + Ok(vm.ctx.new_int(num_occur)) } fn str_index(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -573,13 +572,13 @@ fn str_index(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok((start, end)) => (start, end), Err(e) => return Err(vm.new_index_error(e)), }; - let ind: usize = match value[start..end + 1].find(&sub) { + let ind: usize = match value[start..=end].find(&sub) { Some(num) => num, None => { return Err(vm.new_value_error("substring not found".to_string())); } }; - Ok(vm.ctx.new_int(ind.to_bigint().unwrap())) + Ok(vm.ctx.new_int(ind)) } fn str_find(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -598,11 +597,11 @@ fn str_find(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok((start, end)) => (start, end), Err(e) => return Err(vm.new_index_error(e)), }; - let ind: i128 = match value[start..end + 1].find(&sub) { + let ind: i128 = match value[start..=end].find(&sub) { Some(num) => num as i128, None => -1 as i128, }; - Ok(vm.ctx.new_int(ind.to_bigint().unwrap())) + Ok(vm.ctx.new_int(ind)) } // casefold is much more aggresive than lower @@ -797,7 +796,7 @@ fn str_istitle(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { if value.is_empty() { is_titled = false; } else { - for word in value.split(" ") { + for word in value.split(' ') { if word != make_title(&word) { is_titled = false; break; @@ -888,13 +887,13 @@ fn str_rindex(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok((start, end)) => (start, end), Err(e) => return Err(vm.new_index_error(e)), }; - let ind: i64 = match value[start..end + 1].rfind(&sub) { + let ind: i64 = match value[start..=end].rfind(&sub) { Some(num) => num as i64, None => { return Err(vm.new_value_error("substring not found".to_string())); } }; - Ok(vm.ctx.new_int(ind.to_bigint().unwrap())) + Ok(vm.ctx.new_int(ind)) } fn str_rfind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -913,11 +912,11 @@ fn str_rfind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok((start, end)) => (start, end), Err(e) => return Err(vm.new_index_error(e)), }; - let ind = match value[start..end + 1].rfind(&sub) { + let ind = match value[start..=end].rfind(&sub) { Some(num) => num as i128, None => -1 as i128, }; - Ok(vm.ctx.new_int(ind.to_bigint().unwrap())) + Ok(vm.ctx.new_int(ind)) } fn str_isnumeric(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -1005,14 +1004,23 @@ fn str_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } impl PySliceableSequence for String { - fn do_slice(&self, start: usize, stop: usize) -> Self { + fn do_slice(&self, range: Range) -> Self { to_graphemes(self) - .get(start..stop) + .get(range) .map_or(String::default(), |c| c.join("")) } - fn do_stepped_slice(&self, start: usize, stop: usize, step: usize) -> Self { - if let Some(s) = to_graphemes(self).get(start..stop) { + fn do_slice_reverse(&self, range: Range) -> Self { + to_graphemes(self) + .get_mut(range) + .map_or(String::default(), |slice| { + slice.reverse(); + slice.join("") + }) + } + + fn do_stepped_slice(&self, range: Range, step: usize) -> Self { + if let Some(s) = to_graphemes(self).get(range) { return s .iter() .cloned() @@ -1023,6 +1031,23 @@ impl PySliceableSequence for String { String::default() } + fn do_stepped_slice_reverse(&self, range: Range, step: usize) -> Self { + if let Some(s) = to_graphemes(self).get(range) { + return s + .iter() + .rev() + .cloned() + .step_by(step) + .collect::>() + .join(""); + } + String::default() + } + + fn empty() -> Self { + String::default() + } + fn len(&self) -> usize { to_graphemes(self).len() } @@ -1041,20 +1066,21 @@ pub fn subscript(vm: &mut VirtualMachine, value: &str, b: PyObjectRef) -> PyResu match objint::get_value(&b).to_i32() { Some(pos) => { let graphemes = to_graphemes(value); - let idx = graphemes.get_pos(pos); - graphemes - .get(idx) - .map(|c| vm.new_str(c.to_string())) - .ok_or(vm.new_index_error("string index out of range".to_string())) + if let Some(idx) = graphemes.get_pos(pos) { + Ok(vm.new_str(graphemes[idx].to_string())) + } else { + Err(vm.new_index_error("string index out of range".to_string())) + } } None => { Err(vm.new_index_error("cannot fit 'int' into an index-sized integer".to_string())) } } } else { - match &(*b.borrow()).payload { - &PyObjectPayload::Slice { .. } => { - Ok(vm.new_str(value.to_string().get_slice_items(&b).to_string())) + match (*b.borrow()).payload { + PyObjectPayload::Slice { .. } => { + let string = value.to_string().get_slice_items(vm, &b)?; + Ok(vm.new_str(string)) } _ => panic!( "TypeError: indexing type {:?} with index {:?} is not supported (yet?)", diff --git a/vm/src/obj/objsuper.rs b/vm/src/obj/objsuper.rs index 8a0830503..4fa239903 100644 --- a/vm/src/obj/objsuper.rs +++ b/vm/src/obj/objsuper.rs @@ -12,7 +12,27 @@ use super::objtype; pub fn init(context: &PyContext) { let super_type = &context.super_type; + + let super_doc = "super() -> same as super(__class__, )\n\ + super(type) -> unbound super object\n\ + super(type, obj) -> bound super object; requires isinstance(obj, type)\n\ + super(type, type2) -> bound super object; requires issubclass(type2, type)\n\ + Typical use to call a cooperative superclass method:\n\ + class C(B):\n \ + def meth(self, arg):\n \ + super().meth(arg)\n\ + This works for class methods too:\n\ + class C(B):\n \ + @classmethod\n \ + def cmeth(cls, arg):\n \ + super().cmeth(arg)\n"; + context.set_attr(&super_type, "__init__", context.new_rustfunc(super_init)); + context.set_attr( + &super_type, + "__doc__", + context.new_str(super_doc.to_string()), + ); } fn super_init(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { diff --git a/vm/src/obj/objtuple.rs b/vm/src/obj/objtuple.rs index b969142cc..2d99db2d4 100644 --- a/vm/src/obj/objtuple.rs +++ b/vm/src/obj/objtuple.rs @@ -9,7 +9,6 @@ use super::objsequence::{ }; use super::objstr; use super::objtype; -use num_bigint::ToBigInt; use std::hash::{Hash, Hasher}; fn tuple_lt(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -110,7 +109,7 @@ fn tuple_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { if objtype::isinstance(other, &vm.ctx.tuple_type()) { let e1 = get_elements(zelf); let e2 = get_elements(other); - let elements = e1.iter().chain(e2.iter()).map(|e| e.clone()).collect(); + let elements = e1.iter().chain(e2.iter()).cloned().collect(); Ok(vm.ctx.new_tuple(elements)) } else { Err(vm.new_type_error(format!( @@ -132,10 +131,10 @@ fn tuple_count(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { for element in elements.iter() { let is_eq = vm._eq(element, value.clone())?; if objbool::boolval(vm, is_eq)? { - count = count + 1; + count += 1; } } - Ok(vm.context().new_int(count.to_bigint().unwrap())) + Ok(vm.context().new_int(count)) } fn tuple_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -164,7 +163,7 @@ fn tuple_hash(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { element_hash.hash(&mut hasher); } let hash = hasher.finish(); - Ok(vm.ctx.new_int(hash.to_bigint().unwrap())) + Ok(vm.ctx.new_int(hash)) } fn tuple_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -184,7 +183,7 @@ fn tuple_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { fn tuple_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(zelf, Some(vm.ctx.tuple_type()))]); let elements = get_elements(zelf); - Ok(vm.context().new_int(elements.len().to_bigint().unwrap())) + Ok(vm.context().new_int(elements.len())) } fn tuple_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -254,6 +253,21 @@ fn tuple_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { get_item(vm, tuple, &get_elements(&tuple), needle.clone()) } +pub fn tuple_index(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(tuple, Some(vm.ctx.tuple_type())), (needle, None)] + ); + for (index, element) in get_elements(tuple).iter().enumerate() { + let py_equal = vm.call_method(needle, "__eq__", vec![element.clone()])?; + if objbool::get_value(&py_equal) { + return Ok(vm.context().new_int(index)); + } + } + Err(vm.new_value_error("tuple.index(x): x not in tuple".to_string())) +} + pub fn tuple_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, @@ -276,6 +290,10 @@ pub fn tuple_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn init(context: &PyContext) { let tuple_type = &context.tuple_type; + let tuple_doc = "tuple() -> empty tuple +tuple(iterable) -> tuple initialized from iterable's items + +If the argument is a tuple, the return value is the same object."; context.set_attr(&tuple_type, "__add__", context.new_rustfunc(tuple_add)); context.set_attr(&tuple_type, "__eq__", context.new_rustfunc(tuple_eq)); context.set_attr( @@ -299,4 +317,10 @@ pub fn init(context: &PyContext) { context.set_attr(&tuple_type, "__le__", context.new_rustfunc(tuple_le)); context.set_attr(&tuple_type, "__gt__", context.new_rustfunc(tuple_gt)); context.set_attr(&tuple_type, "__ge__", context.new_rustfunc(tuple_ge)); + context.set_attr( + &tuple_type, + "__doc__", + context.new_str(tuple_doc.to_string()), + ); + context.set_attr(&tuple_type, "index", context.new_rustfunc(tuple_index)); } diff --git a/vm/src/obj/objtype.rs b/vm/src/obj/objtype.rs index 7b2ed48c5..dd5fb3b2c 100644 --- a/vm/src/obj/objtype.rs +++ b/vm/src/obj/objtype.rs @@ -1,21 +1,22 @@ use super::super::pyobject::{ - AttributeProtocol, IdProtocol, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, - PyResult, TypeProtocol, + AttributeProtocol, IdProtocol, PyAttributes, PyContext, PyFuncArgs, PyObject, PyObjectPayload, + PyObjectRef, PyResult, TypeProtocol, }; use super::super::vm::VirtualMachine; use super::objdict; use super::objstr; use super::objtype; // Required for arg_check! to use isinstance +use std::cell::RefCell; use std::collections::HashMap; /* * The magical type type */ -pub fn create_type(type_type: PyObjectRef, object_type: PyObjectRef, dict_type: PyObjectRef) { +pub fn create_type(type_type: PyObjectRef, object_type: PyObjectRef, _dict_type: PyObjectRef) { (*type_type.borrow_mut()).payload = PyObjectPayload::Class { name: String::from("type"), - dict: objdict::new(dict_type), + dict: RefCell::new(PyAttributes::new()), mro: vec![object_type], }; (*type_type.borrow_mut()).typ = Some(type_type.clone()); @@ -23,6 +24,11 @@ pub fn create_type(type_type: PyObjectRef, object_type: PyObjectRef, dict_type: pub fn init(context: &PyContext) { let type_type = &context.type_type; + + let type_doc = "type(object_or_name, bases, dict)\n\ + type(object) -> the object's type\n\ + type(name, bases, dict) -> a new type"; + context.set_attr(&type_type, "__call__", context.new_rustfunc(type_call)); context.set_attr(&type_type, "__new__", context.new_rustfunc(type_new)); context.set_attr( @@ -46,6 +52,7 @@ pub fn init(context: &PyContext) { "__getattribute__", context.new_rustfunc(type_getattribute), ); + context.set_attr(&type_type, "__doc__", context.new_str(type_doc.to_string())); } fn type_mro(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -119,12 +126,22 @@ pub fn type_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let mut bases = vm.extract_elements(bases)?; bases.push(vm.context().object()); let name = objstr::get_value(name); - new(typ.clone(), &name, bases, dict.clone()) + new(typ.clone(), &name, bases, py_dict_to_attributes(dict)) } else { Err(vm.new_type_error(format!(": type_new: {:?}", args))) } } +/// Take a python dictionary and convert it to attributes. +fn py_dict_to_attributes(dict: &PyObjectRef) -> PyAttributes { + let mut attrs = PyAttributes::new(); + for (key, value) in objdict::get_key_value_pairs(dict) { + let key = objstr::get_value(&key); + attrs.insert(key, value); + } + attrs +} + pub fn type_call(vm: &mut VirtualMachine, mut args: PyFuncArgs) -> PyResult { debug!("type_call: {:?}", args); let cls = args.shift(); @@ -204,18 +221,16 @@ pub fn type_getattribute(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult } } -pub fn get_attributes(obj: &PyObjectRef) -> HashMap { +pub fn get_attributes(obj: &PyObjectRef) -> PyAttributes { // Gather all members here: - let mut attributes: HashMap = HashMap::new(); + let mut attributes = PyAttributes::new(); // Get class attributes: let mut base_classes = objtype::base_classes(obj); base_classes.reverse(); for bc in base_classes { if let PyObjectPayload::Class { dict, .. } = &bc.borrow().payload { - let elements = objdict::get_key_value_pairs(dict); - for (name, value) in elements.iter() { - let name = objstr::get_value(name); + for (name, value) in dict.borrow().iter() { attributes.insert(name.to_string(), value.clone()); } } @@ -223,9 +238,7 @@ pub fn get_attributes(obj: &PyObjectRef) -> HashMap { // Get instance attributes: if let PyObjectPayload::Instance { dict } = &obj.borrow().payload { - let elements = objdict::get_key_value_pairs(dict); - for (name, value) in elements.iter() { - let name = objstr::get_value(name); + for (name, value) in dict.borrow().iter() { attributes.insert(name.to_string(), value.clone()); } } @@ -242,8 +255,8 @@ fn take_next_base( for base in &bases { let head = base[0].clone(); if !(&bases) - .into_iter() - .any(|x| x[1..].into_iter().any(|x| x.get_id() == head.get_id())) + .iter() + .any(|x| x[1..].iter().any(|x| x.get_id() == head.get_id())) { next = Some(head); break; @@ -265,7 +278,7 @@ fn linearise_mro(mut bases: Vec>) -> Option> { debug!("Linearising MRO: {:?}", bases); let mut result = vec![]; loop { - if (&bases).into_iter().all(|x| x.is_empty()) { + if (&bases).iter().all(|x| x.is_empty()) { break; } match take_next_base(bases) { @@ -279,13 +292,18 @@ fn linearise_mro(mut bases: Vec>) -> Option> { Some(result) } -pub fn new(typ: PyObjectRef, name: &str, bases: Vec, dict: PyObjectRef) -> PyResult { +pub fn new( + typ: PyObjectRef, + name: &str, + bases: Vec, + dict: HashMap, +) -> PyResult { let mros = bases.into_iter().map(|x| _mro(x).unwrap()).collect(); let mro = linearise_mro(mros).unwrap(); Ok(PyObject::new( PyObjectPayload::Class { name: String::from(name), - dict, + dict: RefCell::new(dict), mro, }, typ, @@ -305,7 +323,7 @@ fn type_prepare(vm: &mut VirtualMachine, _args: PyFuncArgs) -> PyResult { #[cfg(test)] mod tests { use super::{linearise_mro, new}; - use super::{IdProtocol, PyContext, PyObjectRef}; + use super::{HashMap, IdProtocol, PyContext, PyObjectRef}; fn map_ids(obj: Option>) -> Option> { match obj { @@ -320,20 +338,8 @@ mod tests { let object = context.object; let type_type = context.type_type; - let a = new( - type_type.clone(), - "A", - vec![object.clone()], - type_type.clone(), - ) - .unwrap(); - let b = new( - type_type.clone(), - "B", - vec![object.clone()], - type_type.clone(), - ) - .unwrap(); + let a = new(type_type.clone(), "A", vec![object.clone()], HashMap::new()).unwrap(); + let b = new(type_type.clone(), "B", vec![object.clone()], HashMap::new()).unwrap(); assert_eq!( map_ids(linearise_mro(vec![ diff --git a/vm/src/obj/objzip.rs b/vm/src/obj/objzip.rs new file mode 100644 index 000000000..471df9c2b --- /dev/null +++ b/vm/src/obj/objzip.rs @@ -0,0 +1,46 @@ +use super::super::pyobject::{ + PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, +}; +use super::super::vm::VirtualMachine; +use super::objiter; +use super::objtype; // Required for arg_check! to use isinstance + +fn zip_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + no_kwargs!(vm, args); + let cls = &args.args[0]; + let iterables = &args.args[1..]; + let iterators = iterables + .into_iter() + .map(|iterable| objiter::get_iter(vm, iterable)) + .collect::, _>>()?; + Ok(PyObject::new( + PyObjectPayload::ZipIterator { iterators }, + cls.clone(), + )) +} + +fn zip_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(zip, Some(vm.ctx.zip_type()))]); + + if let PyObjectPayload::ZipIterator { ref mut iterators } = zip.borrow_mut().payload { + if iterators.is_empty() { + Err(objiter::new_stop_iteration(vm)) + } else { + let next_objs = iterators + .iter() + .map(|iterator| objiter::call_next(vm, iterator)) + .collect::, _>>()?; + + Ok(vm.ctx.new_tuple(next_objs)) + } + } else { + panic!("zip doesn't have correct payload"); + } +} + +pub fn init(context: &PyContext) { + let zip_type = &context.zip_type; + objiter::iter_type_init(context, zip_type); + context.set_attr(zip_type, "__new__", context.new_rustfunc(zip_new)); + context.set_attr(zip_type, "__next__", context.new_rustfunc(zip_next)); +} diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 072ebb6bc..212757671 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -7,6 +7,7 @@ use super::obj::objbytes; use super::obj::objcode; use super::obj::objcomplex; use super::obj::objdict; +use super::obj::objenumerate; use super::obj::objfilter; use super::obj::objfloat; use super::obj::objframe; @@ -21,12 +22,15 @@ use super::obj::objobject; use super::obj::objproperty; use super::obj::objrange; use super::obj::objset; +use super::obj::objslice; use super::obj::objstr; use super::obj::objsuper; use super::obj::objtuple; use super::obj::objtype; +use super::obj::objzip; use super::vm::VirtualMachine; use num_bigint::BigInt; +use num_bigint::ToBigInt; use num_complex::Complex64; use num_traits::{One, Zero}; use std::cell::RefCell; @@ -69,6 +73,10 @@ pub type PyObjectWeakRef = Weak>; /// since exceptions are also python objects. pub type PyResult = Result; // A valid value, or an exception +/// For attributes we do not use a dict, but a hashmap. This is probably +/// faster, unordered, and only supports strings as keys. +pub type PyAttributes = HashMap; + impl fmt::Display for PyObject { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { use self::TypeProtocol; @@ -108,6 +116,7 @@ pub struct PyContext { pub classmethod_type: PyObjectRef, pub code_type: PyObjectRef, pub dict_type: PyObjectRef, + pub enumerate_type: PyObjectRef, pub filter_type: PyObjectRef, pub float_type: PyObjectRef, pub frame_type: PyObjectRef, @@ -128,7 +137,9 @@ pub struct PyContext { pub super_type: PyObjectRef, pub str_type: PyObjectRef, pub range_type: PyObjectRef, + pub slice_type: PyObjectRef, pub type_type: PyObjectRef, + pub zip_type: PyObjectRef, pub function_type: PyObjectRef, pub property_type: PyObjectRef, pub module_type: PyObjectRef, @@ -160,14 +171,9 @@ pub fn create_type( name: &str, type_type: &PyObjectRef, base: &PyObjectRef, - dict_type: &PyObjectRef, + _dict_type: &PyObjectRef, ) -> PyObjectRef { - let dict = PyObject::new( - PyObjectPayload::Dict { - elements: HashMap::new(), - }, - dict_type.clone(), - ); + let dict = PyAttributes::new(); objtype::new(type_type.clone(), name, vec![base.clone()], dict).unwrap() } @@ -204,12 +210,15 @@ impl PyContext { let bytearray_type = create_type("bytearray", &type_type, &object_type, &dict_type); let tuple_type = create_type("tuple", &type_type, &object_type, &dict_type); let iter_type = create_type("iter", &type_type, &object_type, &dict_type); + let enumerate_type = create_type("enumerate", &type_type, &object_type, &dict_type); let filter_type = create_type("filter", &type_type, &object_type, &dict_type); let map_type = create_type("map", &type_type, &object_type, &dict_type); + let zip_type = create_type("zip", &type_type, &object_type, &dict_type); let bool_type = create_type("bool", &type_type, &int_type, &dict_type); let memoryview_type = create_type("memoryview", &type_type, &object_type, &dict_type); let code_type = create_type("code", &type_type, &int_type, &dict_type); let range_type = create_type("range", &type_type, &object_type, &dict_type); + let slice_type = create_type("slice", &type_type, &object_type, &dict_type); let exceptions = exceptions::ExceptionZoo::new(&type_type, &object_type, &dict_type); let none = PyObject::new( @@ -246,12 +255,15 @@ impl PyContext { false_value, tuple_type, iter_type, + enumerate_type, filter_type, map_type, + zip_type, dict_type, none, str_type, range_type, + slice_type, object: object_type, function_type, super_type, @@ -280,11 +292,14 @@ impl PyContext { objmemory::init(&context); objstr::init(&context); objrange::init(&context); + objslice::init(&context); objsuper::init(&context); objtuple::init(&context); objiter::init(&context); + objenumerate::init(&context); objfilter::init(&context); objmap::init(&context); + objzip::init(&context); objbool::init(&context); objcode::init(&context); objframe::init(&context); @@ -336,6 +351,10 @@ impl PyContext { self.range_type.clone() } + pub fn slice_type(&self) -> PyObjectRef { + self.slice_type.clone() + } + pub fn frozenset_type(&self) -> PyObjectRef { self.frozenset_type.clone() } @@ -356,6 +375,10 @@ impl PyContext { self.iter_type.clone() } + pub fn enumerate_type(&self) -> PyObjectRef { + self.enumerate_type.clone() + } + pub fn filter_type(&self) -> PyObjectRef { self.filter_type.clone() } @@ -364,6 +387,10 @@ impl PyContext { self.map_type.clone() } + pub fn zip_type(&self) -> PyObjectRef { + self.zip_type.clone() + } + pub fn str_type(&self) -> PyObjectRef { self.str_type.clone() } @@ -410,16 +437,16 @@ impl PyContext { } pub fn new_object(&self) -> PyObjectRef { - PyObject::new( - PyObjectPayload::Instance { - dict: self.new_dict(), - }, - self.object(), - ) + self.new_instance(self.object(), None) } - pub fn new_int(&self, i: BigInt) -> PyObjectRef { - PyObject::new(PyObjectPayload::Integer { value: i }, self.int_type()) + pub fn new_int(&self, i: T) -> PyObjectRef { + PyObject::new( + PyObjectPayload::Integer { + value: i.to_bigint().unwrap(), + }, + self.int_type(), + ) } pub fn new_float(&self, i: f64) -> PyObjectRef { @@ -476,7 +503,7 @@ impl PyContext { } pub fn new_class(&self, name: &str, base: PyObjectRef) -> PyObjectRef { - objtype::new(self.type_type(), name, vec![base], self.new_dict()).unwrap() + objtype::new(self.type_type(), name, vec![base], PyAttributes::new()).unwrap() } pub fn new_scope(&self, parent: Option) -> PyObjectRef { @@ -530,12 +557,7 @@ impl PyContext { function: F, ) -> PyObjectRef { let fget = self.new_rustfunc(function); - let py_obj = PyObject::new( - PyObjectPayload::Instance { - dict: self.new_dict(), - }, - self.property_type(), - ); + let py_obj = self.new_instance(self.property_type(), None); self.set_attr(&py_obj, "fget", fget.clone()); py_obj } @@ -571,13 +593,23 @@ impl PyContext { &self, function: F, ) -> PyObjectRef { - let dict = self.new_dict(); - self.set_item(&dict, "function", self.new_rustfunc(function)); - self.new_instance(dict, self.member_descriptor_type()) + let mut dict = PyAttributes::new(); + dict.insert("function".to_string(), self.new_rustfunc(function)); + self.new_instance(self.member_descriptor_type(), Some(dict)) } - pub fn new_instance(&self, dict: PyObjectRef, class: PyObjectRef) -> PyObjectRef { - PyObject::new(PyObjectPayload::Instance { dict }, class) + pub fn new_instance(&self, class: PyObjectRef, dict: Option) -> PyObjectRef { + let dict = if let Some(dict) = dict { + dict + } else { + PyAttributes::new() + }; + PyObject::new( + PyObjectPayload::Instance { + dict: RefCell::new(dict), + }, + class, + ) } // Item set/get: @@ -605,8 +637,9 @@ impl PyContext { pub fn set_attr(&self, obj: &PyObjectRef, attr_name: &str, value: PyObjectRef) { match obj.borrow().payload { PyObjectPayload::Module { ref dict, .. } => self.set_item(dict, attr_name, value), - PyObjectPayload::Instance { ref dict } => self.set_item(dict, attr_name, value), - PyObjectPayload::Class { ref dict, .. } => self.set_item(dict, attr_name, value), + PyObjectPayload::Instance { ref dict } | PyObjectPayload::Class { ref dict, .. } => { + dict.borrow_mut().insert(attr_name.to_string(), value); + } ref payload => unimplemented!("set_attr unimplemented for: {:?}", payload), }; } @@ -667,10 +700,7 @@ pub trait ParentProtocol { impl ParentProtocol for PyObjectRef { fn has_parent(&self) -> bool { match self.borrow().payload { - PyObjectPayload::Scope { ref scope } => match scope.parent { - Some(_) => true, - None => false, - }, + PyObjectPayload::Scope { ref scope } => scope.parent.is_some(), _ => panic!("Only scopes have parent (not {:?}", self), } } @@ -694,7 +724,7 @@ pub trait AttributeProtocol { fn class_get_item(class: &PyObjectRef, attr_name: &str) -> Option { let class = class.borrow(); match class.payload { - PyObjectPayload::Class { ref dict, .. } => dict.get_item(attr_name), + PyObjectPayload::Class { ref dict, .. } => dict.borrow().get(attr_name).map(|v| v.clone()), _ => panic!("Only classes should be in MRO!"), } } @@ -702,7 +732,7 @@ fn class_get_item(class: &PyObjectRef, attr_name: &str) -> Option { fn class_has_item(class: &PyObjectRef, attr_name: &str) -> bool { let class = class.borrow(); match class.payload { - PyObjectPayload::Class { ref dict, .. } => dict.contains_key(attr_name), + PyObjectPayload::Class { ref dict, .. } => dict.borrow().contains_key(attr_name), _ => panic!("Only classes should be in MRO!"), } } @@ -723,7 +753,9 @@ impl AttributeProtocol for PyObjectRef { } None } - PyObjectPayload::Instance { ref dict } => dict.get_item(attr_name), + PyObjectPayload::Instance { ref dict } => { + dict.borrow().get(attr_name).map(|v| v.clone()) + } _ => None, } } @@ -733,10 +765,9 @@ impl AttributeProtocol for PyObjectRef { match obj.payload { PyObjectPayload::Module { ref dict, .. } => dict.contains_key(attr_name), PyObjectPayload::Class { ref mro, .. } => { - class_has_item(self, attr_name) - || mro.into_iter().any(|d| class_has_item(d, attr_name)) + class_has_item(self, attr_name) || mro.iter().any(|d| class_has_item(d, attr_name)) } - PyObjectPayload::Instance { ref dict } => dict.contains_key(attr_name), + PyObjectPayload::Instance { ref dict } => dict.borrow().contains_key(attr_name), _ => false, } } @@ -882,6 +913,10 @@ pub enum PyObjectPayload { position: usize, iterated_obj: PyObjectRef, }, + EnumerateIterator { + counter: BigInt, + iterator: PyObjectRef, + }, FilterIterator { predicate: PyObjectRef, iterator: PyObjectRef, @@ -890,10 +925,13 @@ pub enum PyObjectPayload { mapper: PyObjectRef, iterators: Vec, }, + ZipIterator { + iterators: Vec, + }, Slice { - start: Option, - stop: Option, - step: Option, + start: Option, + stop: Option, + step: Option, }, Range { range: objrange::RangeType, @@ -929,14 +967,14 @@ pub enum PyObjectPayload { None, Class { name: String, - dict: PyObjectRef, + dict: RefCell, mro: Vec, }, WeakRef { referent: PyObjectWeakRef, }, Instance { - dict: PyObjectRef, + dict: RefCell, }, RustFunction { function: Box PyResult>, @@ -958,8 +996,10 @@ impl fmt::Debug for PyObjectPayload { PyObjectPayload::WeakRef { .. } => write!(f, "weakref"), PyObjectPayload::Range { .. } => write!(f, "range"), PyObjectPayload::Iterator { .. } => write!(f, "iterator"), + PyObjectPayload::EnumerateIterator { .. } => write!(f, "enumerate"), PyObjectPayload::FilterIterator { .. } => write!(f, "filter"), PyObjectPayload::MapIterator { .. } => write!(f, "map"), + PyObjectPayload::ZipIterator { .. } => write!(f, "zip"), PyObjectPayload::Slice { .. } => write!(f, "slice"), PyObjectPayload::Code { ref code } => write!(f, "code: {:?}", code), PyObjectPayload::Function { .. } => write!(f, "function"), @@ -1055,8 +1095,10 @@ impl PyObject { position, iterated_obj.borrow_mut().str() ), + PyObjectPayload::EnumerateIterator { .. } => format!(""), PyObjectPayload::FilterIterator { .. } => format!(""), PyObjectPayload::MapIterator { .. } => format!(""), + PyObjectPayload::ZipIterator { .. } => format!(""), } } diff --git a/vm/src/stdlib/ast.rs b/vm/src/stdlib/ast.rs index 6b9eb66c9..6703534f4 100644 --- a/vm/src/stdlib/ast.rs +++ b/vm/src/stdlib/ast.rs @@ -9,9 +9,7 @@ use self::rustpython_parser::{ast, parser}; use super::super::obj::{objstr, objtype}; use super::super::pyobject::{PyContext, PyFuncArgs, PyObjectRef, PyResult, TypeProtocol}; use super::super::VirtualMachine; -use num_bigint::ToBigInt; use num_complex::Complex64; -use num_traits::One; use std::ops::Deref; /* @@ -47,8 +45,7 @@ fn program_to_ast(ctx: &PyContext, program: &ast::Program) -> PyObjectRef { fn create_node(ctx: &PyContext, _name: &str) -> PyObjectRef { // TODO: instantiate a class of type given by name // TODO: lookup in the current module? - let node = ctx.new_object(); - node + ctx.new_object() } fn statements_to_ast(ctx: &PyContext, statements: &[ast::LocatedStatement]) -> PyObjectRef { @@ -101,18 +98,9 @@ fn statement_to_ast(ctx: &PyContext, statement: &ast::LocatedStatement) -> PyObj ctx.set_attr(&node, "decorator_list", py_decorator_list); node } - ast::Statement::Continue => { - let node = create_node(ctx, "Continue"); - node - } - ast::Statement::Break => { - let node = create_node(ctx, "Break"); - node - } - ast::Statement::Pass => { - let node = create_node(ctx, "Pass"); - node - } + ast::Statement::Continue => create_node(ctx, "Continue"), + ast::Statement::Break => create_node(ctx, "Break"), + ast::Statement::Pass => create_node(ctx, "Pass"), ast::Statement::Assert { test, msg } => { let node = create_node(ctx, "Pass"); @@ -129,12 +117,8 @@ fn statement_to_ast(ctx: &PyContext, statement: &ast::LocatedStatement) -> PyObj ast::Statement::Delete { targets } => { let node = create_node(ctx, "Delete"); - let py_targets = ctx.new_tuple( - targets - .into_iter() - .map(|v| expression_to_ast(ctx, v)) - .collect(), - ); + let py_targets = + ctx.new_tuple(targets.iter().map(|v| expression_to_ast(ctx, v)).collect()); ctx.set_attr(&node, "targets", py_targets); node @@ -143,12 +127,7 @@ fn statement_to_ast(ctx: &PyContext, statement: &ast::LocatedStatement) -> PyObj let node = create_node(ctx, "Return"); let py_value = if let Some(value) = value { - ctx.new_tuple( - value - .into_iter() - .map(|v| expression_to_ast(ctx, v)) - .collect(), - ) + ctx.new_tuple(value.iter().map(|v| expression_to_ast(ctx, v)).collect()) } else { ctx.none() }; @@ -232,7 +211,7 @@ fn statement_to_ast(ctx: &PyContext, statement: &ast::LocatedStatement) -> PyObj }; // set lineno on node: - let lineno = ctx.new_int(statement.location.get_row().to_bigint().unwrap()); + let lineno = ctx.new_int(statement.location.get_row()); ctx.set_attr(&node, "lineno", lineno); node @@ -385,7 +364,7 @@ fn expression_to_ast(ctx: &PyContext, expression: &ast::Expression) -> PyObjectR let node = create_node(ctx, "Num"); let py_n = match value { - ast::Number::Integer { value } => ctx.new_int(value.to_bigint().unwrap()), + ast::Number::Integer { value } => ctx.new_int(value.clone()), ast::Number::Float { value } => ctx.new_float(*value), ast::Number::Complex { real, imag } => { ctx.new_complex(Complex64::new(*real, *imag)) @@ -550,7 +529,7 @@ fn expression_to_ast(ctx: &PyContext, expression: &ast::Expression) -> PyObjectR }; // TODO: retrieve correct lineno: - let lineno = ctx.new_int(One::one()); + let lineno = ctx.new_int(1); ctx.set_attr(&node, "lineno", lineno); node diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index c4391d116..e254b2417 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -27,7 +27,7 @@ use super::super::pyobject::{ use super::super::vm::VirtualMachine; fn compute_c_flag(mode: &str) -> u16 { - match mode.as_ref() { + match mode { "w" => 512, "x" => 512, "a" => 8, @@ -85,11 +85,8 @@ fn buffered_reader_read(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { .map_err(|_| vm.new_value_error("IO Error".to_string()))?; //Copy bytes from the buffer vector into the results vector - match buffer.borrow_mut().payload { - PyObjectPayload::Bytes { ref mut value } => { - result.extend(value.iter().cloned()); - } - _ => {} + if let PyObjectPayload::Bytes { ref mut value } = buffer.borrow_mut().payload { + result.extend(value.iter().cloned()); }; let len = vm.get_method(buffer.clone(), &"__len__".to_string()); @@ -171,21 +168,18 @@ fn file_io_readinto(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let handle = os::rust_file(raw_fd); let mut f = handle.take(length); - match obj.borrow_mut().payload { + if let PyObjectPayload::Bytes { ref mut value } = obj.borrow_mut().payload { //TODO: Implement for MemoryView - PyObjectPayload::Bytes { ref mut value } => { - value.clear(); - match f.read_to_end(&mut *value) { - Ok(_) => {} - Err(_) => return Err(vm.new_value_error("Error reading from Take".to_string())), - } + + value.clear(); + match f.read_to_end(&mut *value) { + Ok(_) => {} + Err(_) => return Err(vm.new_value_error("Error reading from Take".to_string())), } - _ => {} }; - let updated = os::raw_file_number(f.into_inner()).to_bigint(); - vm.ctx - .set_attr(&file_io, "fileno", vm.ctx.new_int(updated.unwrap())); + let updated = os::raw_file_number(f.into_inner()); + vm.ctx.set_attr(&file_io, "fileno", vm.ctx.new_int(updated)); Ok(vm.get_none()) } @@ -209,12 +203,11 @@ fn file_io_write(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { match handle.write(&value[..]) { Ok(len) => { //reset raw fd on the FileIO object - let updated = os::raw_file_number(handle).to_bigint(); - vm.ctx - .set_attr(&file_io, "fileno", vm.ctx.new_int(updated.unwrap())); + let updated = os::raw_file_number(handle); + vm.ctx.set_attr(&file_io, "fileno", vm.ctx.new_int(updated)); //return number of bytes written - Ok(vm.ctx.new_int(len.to_bigint().unwrap())) + Ok(vm.ctx.new_int(len)) } Err(_) => Err(vm.new_value_error("Error Writing Bytes to Handle".to_string())), } @@ -320,7 +313,7 @@ pub fn io_open(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { //the operation in the mode. //There are 3 possible classes here, each inheriting from the RawBaseIO // creating || writing || appending => BufferedWriter - let buffered = if rust_mode.contains("w") { + let buffered = if rust_mode.contains('w') { vm.invoke( buffered_writer_class, PyFuncArgs::new(vec![file_io.clone()], vec![]), @@ -334,7 +327,7 @@ pub fn io_open(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { //TODO: updating => PyBufferedRandom }; - if rust_mode.contains("t") { + if rust_mode.contains('t') { //If the mode is text this buffer type is consumed on construction of //a TextIOWrapper which is subsequently returned. vm.invoke( diff --git a/vm/src/stdlib/json.rs b/vm/src/stdlib/json.rs index 9f9e705ec..1e7ec411e 100644 --- a/vm/src/stdlib/json.rs +++ b/vm/src/stdlib/json.rs @@ -11,7 +11,6 @@ use super::super::pyobject::{ TypeProtocol, }; use super::super::VirtualMachine; -use num_bigint::ToBigInt; use num_traits::cast::ToPrimitive; // We need to have a VM available to serialise a PyObject based on its subclass, so we implement @@ -54,10 +53,9 @@ impl<'s> serde::Serialize for PyObjectSerializer<'s> { serializer.serialize_i64(v.to_i64().unwrap()) // Although this may seem nice, it does not give the right result: // v.serialize(serializer) - } else if objtype::isinstance(self.pyobject, &self.ctx.list_type()) { - let elements = objsequence::get_elements(self.pyobject); - serialize_seq_elements(serializer, &elements) - } else if objtype::isinstance(self.pyobject, &self.ctx.tuple_type()) { + } else if objtype::isinstance(self.pyobject, &self.ctx.list_type()) + || objtype::isinstance(self.pyobject, &self.ctx.tuple_type()) + { let elements = objsequence::get_elements(self.pyobject); serialize_seq_elements(serializer, &elements) } else if objtype::isinstance(self.pyobject, &self.ctx.dict_type()) { @@ -119,7 +117,7 @@ impl<'de> serde::de::DeserializeSeed<'de> for PyObjectDeserializer<'de> { { // The JSON deserialiser always uses the i64/u64 deserialisers, so we only need to // implement those for now - Ok(self.ctx.new_int(value.to_bigint().unwrap())) + Ok(self.ctx.new_int(value)) } fn visit_u64(self, value: u64) -> Result @@ -128,7 +126,7 @@ impl<'de> serde::de::DeserializeSeed<'de> for PyObjectDeserializer<'de> { { // The JSON deserialiser always uses the i64/u64 deserialisers, so we only need to // implement those for now - Ok(self.ctx.new_int(value.to_bigint().unwrap())) + Ok(self.ctx.new_int(value)) } fn visit_f64(self, value: f64) -> Result @@ -223,10 +221,8 @@ fn loads(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { .get_item("JSONDecodeError") .unwrap(); let exc = vm.new_exception(json_decode_error, format!("{}", err)); - vm.ctx - .set_item(&exc, "lineno", vm.ctx.new_int(err.line().into())); - vm.ctx - .set_item(&exc, "colno", vm.ctx.new_int(err.column().into())); + vm.ctx.set_item(&exc, "lineno", vm.ctx.new_int(err.line())); + vm.ctx.set_item(&exc, "colno", vm.ctx.new_int(err.column())); exc }) } diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index ac284bddc..4671f5fe9 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -4,7 +4,6 @@ use std::fs::OpenOptions; use std::io::ErrorKind; //3rd party imports -use num_bigint::ToBigInt; use num_traits::cast::ToPrimitive; //custom imports @@ -91,11 +90,7 @@ pub fn os_open(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { _ => vm.new_value_error("Unhandled file IO error".to_string()), })?; - Ok(vm.ctx.new_int( - raw_file_number(handle) - .to_bigint() - .expect("Invalid file descriptor"), - )) + Ok(vm.ctx.new_int(raw_file_number(handle))) } pub fn mk_module(ctx: &PyContext) -> PyObjectRef { @@ -103,11 +98,11 @@ pub fn mk_module(ctx: &PyContext) -> PyObjectRef { ctx.set_attr(&py_mod, "open", ctx.new_rustfunc(os_open)); ctx.set_attr(&py_mod, "close", ctx.new_rustfunc(os_close)); - ctx.set_attr(&py_mod, "O_RDONLY", ctx.new_int(0.to_bigint().unwrap())); - ctx.set_attr(&py_mod, "O_WRONLY", ctx.new_int(1.to_bigint().unwrap())); - ctx.set_attr(&py_mod, "O_RDWR", ctx.new_int(2.to_bigint().unwrap())); - ctx.set_attr(&py_mod, "O_NONBLOCK", ctx.new_int(4.to_bigint().unwrap())); - ctx.set_attr(&py_mod, "O_APPEND", ctx.new_int(8.to_bigint().unwrap())); - ctx.set_attr(&py_mod, "O_CREAT", ctx.new_int(512.to_bigint().unwrap())); + ctx.set_attr(&py_mod, "O_RDONLY", ctx.new_int(0)); + ctx.set_attr(&py_mod, "O_WRONLY", ctx.new_int(1)); + ctx.set_attr(&py_mod, "O_RDWR", ctx.new_int(2)); + ctx.set_attr(&py_mod, "O_NONBLOCK", ctx.new_int(4)); + ctx.set_attr(&py_mod, "O_APPEND", ctx.new_int(8)); + ctx.set_attr(&py_mod, "O_CREAT", ctx.new_int(512)); py_mod } diff --git a/vm/src/stdlib/pystruct.rs b/vm/src/stdlib/pystruct.rs index 7babc6000..166007197 100644 --- a/vm/src/stdlib/pystruct.rs +++ b/vm/src/stdlib/pystruct.rs @@ -13,7 +13,7 @@ use self::byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use super::super::obj::{objbool, objbytes, objfloat, objint, objstr, objtype}; use super::super::pyobject::{PyContext, PyFuncArgs, PyObjectRef, PyResult, TypeProtocol}; use super::super::VirtualMachine; -use num_bigint::{BigInt, ToBigInt}; +use num_bigint::BigInt; use num_traits::ToPrimitive; use std::io::{Cursor, Read, Write}; @@ -224,14 +224,14 @@ fn struct_pack(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { fn unpack_i8(vm: &mut VirtualMachine, rdr: &mut Read) -> PyResult { match rdr.read_i8() { Err(err) => panic!("Error in reading {:?}", err), - Ok(v) => Ok(vm.ctx.new_int(v.to_bigint().unwrap())), + Ok(v) => Ok(vm.ctx.new_int(v)), } } fn unpack_u8(vm: &mut VirtualMachine, rdr: &mut Read) -> PyResult { match rdr.read_u8() { Err(err) => panic!("Error in reading {:?}", err), - Ok(v) => Ok(vm.ctx.new_int(v.to_bigint().unwrap())), + Ok(v) => Ok(vm.ctx.new_int(v)), } } @@ -245,42 +245,42 @@ fn unpack_bool(vm: &mut VirtualMachine, rdr: &mut Read) -> PyResult { fn unpack_i16(vm: &mut VirtualMachine, rdr: &mut Read) -> PyResult { match rdr.read_i16::() { Err(err) => panic!("Error in reading {:?}", err), - Ok(v) => Ok(vm.ctx.new_int(v.to_bigint().unwrap())), + Ok(v) => Ok(vm.ctx.new_int(v)), } } fn unpack_u16(vm: &mut VirtualMachine, rdr: &mut Read) -> PyResult { match rdr.read_u16::() { Err(err) => panic!("Error in reading {:?}", err), - Ok(v) => Ok(vm.ctx.new_int(v.to_bigint().unwrap())), + Ok(v) => Ok(vm.ctx.new_int(v)), } } fn unpack_i32(vm: &mut VirtualMachine, rdr: &mut Read) -> PyResult { match rdr.read_i32::() { Err(err) => panic!("Error in reading {:?}", err), - Ok(v) => Ok(vm.ctx.new_int(v.to_bigint().unwrap())), + Ok(v) => Ok(vm.ctx.new_int(v)), } } fn unpack_u32(vm: &mut VirtualMachine, rdr: &mut Read) -> PyResult { match rdr.read_u32::() { Err(err) => panic!("Error in reading {:?}", err), - Ok(v) => Ok(vm.ctx.new_int(v.to_bigint().unwrap())), + Ok(v) => Ok(vm.ctx.new_int(v)), } } fn unpack_i64(vm: &mut VirtualMachine, rdr: &mut Read) -> PyResult { match rdr.read_i64::() { Err(err) => panic!("Error in reading {:?}", err), - Ok(v) => Ok(vm.ctx.new_int(v.to_bigint().unwrap())), + Ok(v) => Ok(vm.ctx.new_int(v)), } } fn unpack_u64(vm: &mut VirtualMachine, rdr: &mut Read) -> PyResult { match rdr.read_u64::() { Err(err) => panic!("Error in reading {:?}", err), - Ok(v) => Ok(vm.ctx.new_int(v.to_bigint().unwrap())), + Ok(v) => Ok(vm.ctx.new_int(v)), } } diff --git a/vm/src/stdlib/types.rs b/vm/src/stdlib/types.rs index 2060ad1f1..a744d89b7 100644 --- a/vm/src/stdlib/types.rs +++ b/vm/src/stdlib/types.rs @@ -3,7 +3,9 @@ */ use super::super::obj::{objsequence, objstr, objtype}; -use super::super::pyobject::{PyContext, PyFuncArgs, PyObjectRef, PyResult, TypeProtocol}; +use super::super::pyobject::{ + PyAttributes, PyContext, PyFuncArgs, PyObjectRef, PyResult, TypeProtocol, +}; use super::super::VirtualMachine; fn types_new_class(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -15,7 +17,6 @@ fn types_new_class(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { ); let name = objstr::get_value(name); - let dict = vm.ctx.new_dict(); let bases = match bases { Some(b) => { @@ -28,7 +29,7 @@ fn types_new_class(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { None => vec![vm.ctx.object()], }; - objtype::new(vm.ctx.type_type(), &name, bases, dict) + objtype::new(vm.ctx.type_type(), &name, bases, PyAttributes::new()) } pub fn mk_module(ctx: &PyContext) -> PyObjectRef { diff --git a/vm/src/sysmodule.rs b/vm/src/sysmodule.rs index e747fda5a..43b01d0d9 100644 --- a/vm/src/sysmodule.rs +++ b/vm/src/sysmodule.rs @@ -1,4 +1,3 @@ -use num_bigint::ToBigInt; use obj::objtype; use pyobject::{PyContext, PyFuncArgs, PyObjectRef, PyResult, TypeProtocol}; use std::rc::Rc; @@ -26,14 +25,14 @@ fn getframe(vm: &mut VirtualMachine, _args: PyFuncArgs) -> PyResult { fn sys_getrefcount(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(object, None)]); let size = Rc::strong_count(&object); - Ok(vm.ctx.new_int(size.to_bigint().unwrap())) + Ok(vm.ctx.new_int(size)) } fn sys_getsizeof(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(object, None)]); // TODO: implement default optional argument. let size = mem::size_of_val(&object.borrow()); - Ok(vm.ctx.new_int(size.to_bigint().unwrap())) + Ok(vm.ctx.new_int(size)) } pub fn mk_module(ctx: &PyContext) -> PyObjectRef { @@ -62,11 +61,7 @@ pub fn mk_module(ctx: &PyContext) -> PyObjectRef { ctx.set_item(&sys_mod, "argv", argv(ctx)); ctx.set_item(&sys_mod, "getrefcount", ctx.new_rustfunc(sys_getrefcount)); ctx.set_item(&sys_mod, "getsizeof", ctx.new_rustfunc(sys_getsizeof)); - ctx.set_item( - &sys_mod, - "maxsize", - ctx.new_int(std::usize::MAX.to_bigint().unwrap()), - ); + ctx.set_item(&sys_mod, "maxsize", ctx.new_int(std::usize::MAX)); ctx.set_item(&sys_mod, "path", path); ctx.set_item(&sys_mod, "ps1", ctx.new_str(">>>>> ".to_string())); ctx.set_item(&sys_mod, "ps2", ctx.new_str("..... ".to_string())); diff --git a/vm/src/vm.rs b/vm/src/vm.rs index b1570e386..435e7bcbc 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -91,8 +91,7 @@ impl VirtualMachine { }; // Call function: - let exception = self.invoke(exc_type, args).unwrap(); - exception + self.invoke(exc_type, args).unwrap() } pub fn new_type_error(&mut self, msg: String) -> PyObjectRef { @@ -105,6 +104,11 @@ impl VirtualMachine { self.new_exception(os_error, msg) } + pub fn new_overflow_error(&mut self, msg: String) -> PyObjectRef { + let overflow_error = self.ctx.exceptions.overflow_error.clone(); + self.new_exception(overflow_error, msg) + } + /// Create a new python ValueError object. Useful for raising errors from /// python functions implemented in rust. pub fn new_value_error(&mut self, msg: String) -> PyObjectRef { @@ -123,8 +127,13 @@ impl VirtualMachine { } pub fn new_not_implemented_error(&mut self, msg: String) -> PyObjectRef { - let value_error = self.ctx.exceptions.not_implemented_error.clone(); - self.new_exception(value_error, msg) + let not_implemented_error = self.ctx.exceptions.not_implemented_error.clone(); + self.new_exception(not_implemented_error, msg) + } + + pub fn new_zero_division_error(&mut self, msg: String) -> PyObjectRef { + let zero_division_error = self.ctx.exceptions.zero_division_error.clone(); + self.new_exception(zero_division_error, msg) } pub fn new_scope(&mut self, parent_scope: Option) -> PyObjectRef { @@ -440,9 +449,9 @@ impl VirtualMachine { value: &PyObjectRef, ) -> Result, PyObjectRef> { // Extract elements from item, if possible: - let elements = if objtype::isinstance(value, &self.ctx.tuple_type()) { - objsequence::get_elements(value).to_vec() - } else if objtype::isinstance(value, &self.ctx.list_type()) { + let elements = if objtype::isinstance(value, &self.ctx.tuple_type()) + || objtype::isinstance(value, &self.ctx.list_type()) + { objsequence::get_elements(value).to_vec() } else { let iter = objiter::get_iter(self, value)?; @@ -615,8 +624,8 @@ mod tests { #[test] fn test_add_py_integers() { let mut vm = VirtualMachine::new(); - let a = vm.ctx.new_int(33_i32.to_bigint().unwrap()); - let b = vm.ctx.new_int(12_i32.to_bigint().unwrap()); + let a = vm.ctx.new_int(33_i32); + let b = vm.ctx.new_int(12_i32); let res = vm._add(a, b).unwrap(); let value = objint::get_value(&res); assert_eq!(value, 45_i32.to_bigint().unwrap()); @@ -626,7 +635,7 @@ mod tests { fn test_multiply_str() { let mut vm = VirtualMachine::new(); let a = vm.ctx.new_str(String::from("Hello ")); - let b = vm.ctx.new_int(4_i32.to_bigint().unwrap()); + let b = vm.ctx.new_int(4_i32); let res = vm._mul(a, b).unwrap(); let value = objstr::get_value(&res); assert_eq!(value, String::from("Hello Hello Hello Hello "))