diff --git a/tests/snippets/builtin_complex.py b/tests/snippets/builtin_complex.py index 2df8e4df7..4d0fa1202 100644 --- a/tests/snippets/builtin_complex.py +++ b/tests/snippets/builtin_complex.py @@ -24,6 +24,11 @@ assert complex(1, 2) != complex(1, 1) assert complex(1, 2) != 'foo' assert complex(1, 2).__eq__('foo') == NotImplemented +# __mul__ + +assert complex(2, -3) * complex(-5, 7) == complex(11, 29) +assert complex(2, -3) * 5 == complex(10, -15) + # __neg__ assert -complex(1, -1) == complex(-1, 1) diff --git a/tests/snippets/builtin_range.py b/tests/snippets/builtin_range.py index 2ea71f7a7..55d4f5194 100644 --- a/tests/snippets/builtin_range.py +++ b/tests/snippets/builtin_range.py @@ -27,6 +27,11 @@ assert range(4, 10, 2).count(4) == 1 assert range(4, 10, 2).count(7) == 0 assert range(10).count("foo") == 0 +# __eq__ +assert range(1, 2, 3) == range(1, 2, 3) +assert range(1, 2, 1) == range(1, 2) +assert range(2) == range(0, 2) + # __bool__ assert bool(range(1)) assert bool(range(1, 2)) @@ -54,3 +59,9 @@ assert list(reversed(range(1,10,5))) == [6, 1] # range retains the original int refs i = 2**64 assert range(i).stop is i + +# negative index +assert range(10)[-1] == 9 +assert_raises(IndexError, lambda: range(10)[-11], 'out of bound') +assert range(10)[-2:4] == range(8, 4) +assert range(10)[-6:-2] == range(4, 8) diff --git a/tests/snippets/bytearray.py b/tests/snippets/bytearray.py index 563da2c53..4286f8171 100644 --- a/tests/snippets/bytearray.py +++ b/tests/snippets/bytearray.py @@ -65,3 +65,9 @@ except IndexError: pass else: assert False + +a = bytearray(b'appen') +assert len(a) == 5 +a.append(100) +assert len(a) == 6 +assert a.pop() == 100 diff --git a/tests/snippets/floats.py b/tests/snippets/floats.py index 176ce927e..788c52195 100644 --- a/tests/snippets/floats.py +++ b/tests/snippets/floats.py @@ -8,6 +8,7 @@ a = 1.2 b = 1.3 c = 1.2 z = 2 +ov = 10 ** 1000 assert -a == -1.2 @@ -37,6 +38,20 @@ assert a / z == 0.6 assert 6 / a == 5.0 assert 2.0 % z == 0.0 assert z % 2.0 == 0.0 +assert_raises(OverflowError, lambda: a + ov) +assert_raises(OverflowError, lambda: a - ov) +assert_raises(OverflowError, lambda: a * ov) +assert_raises(OverflowError, lambda: a / ov) +assert_raises(OverflowError, lambda: a // ov) +assert_raises(OverflowError, lambda: a % ov) +assert_raises(OverflowError, lambda: a ** ov) +assert_raises(OverflowError, lambda: ov + a) +assert_raises(OverflowError, lambda: ov - a) +assert_raises(OverflowError, lambda: ov * a) +assert_raises(OverflowError, lambda: ov / a) +assert_raises(OverflowError, lambda: ov // a) +assert_raises(OverflowError, lambda: ov % a) +# assert_raises(OverflowError, lambda: ov ** a) assert a < 5 assert a <= 5 @@ -91,6 +106,8 @@ assert 2.0.__sub__(1.0) == 1.0 assert 2.0.__rmul__(1.0) == 2.0 assert 1.0.__truediv__(2.0) == 0.5 assert 1.0.__rtruediv__(2.0) == 2.0 +assert 2.5.__divmod__(2.0) == (1.0, 0.5) +assert 2.0.__rdivmod__(2.5) == (1.0, 0.5) assert 1.0.__add__(1) == 2.0 assert 1.0.__radd__(1) == 2.0 @@ -105,6 +122,11 @@ assert 2.0.__rmod__(2) == 0.0 assert_raises(ZeroDivisionError, lambda: 2.0 / 0) assert_raises(ZeroDivisionError, lambda: 2.0 // 0) assert_raises(ZeroDivisionError, lambda: 2.0 % 0) +assert_raises(ZeroDivisionError, lambda: divmod(2.0, 0)) +assert_raises(ZeroDivisionError, lambda: 2 / 0.0) +assert_raises(ZeroDivisionError, lambda: 2 // 0.0) +assert_raises(ZeroDivisionError, lambda: 2 % 0.0) +# assert_raises(ZeroDivisionError, lambda: divmod(2, 0.0)) assert 1.2.__int__() == 1 assert 1.2.__float__() == 1.2 @@ -123,6 +145,10 @@ assert 1.5.__round__(None) == 2.0 assert_raises(OverflowError, float('inf').__round__) assert_raises(ValueError, float('nan').__round__) +assert 1.2 ** 2 == 1.44 +assert_raises(OverflowError, lambda: 1.2 ** (10 ** 1000)) +assert 3 ** 2.0 == 9.0 + assert (1.7).real == 1.7 assert (1.3).is_integer() == False assert (1.0).is_integer() == True diff --git a/tests/snippets/ints.py b/tests/snippets/ints.py index 1fbbc2bef..a072cc542 100644 --- a/tests/snippets/ints.py +++ b/tests/snippets/ints.py @@ -74,3 +74,37 @@ with assertRaises(TypeError): with assertRaises(TypeError): # check that first parameter is truly positional only int(val_options=1) + +class A(object): + def __int__(self): + return 10 + +assert int(A()) == 10 + +class B(object): + pass + +b = B() +b.__int__ = lambda: 20 + +with assertRaises(TypeError): + assert int(b) == 20 + +class C(object): + def __int__(self): + return 'str' + +with assertRaises(TypeError): + int(C()) + +class I(int): + def __int__(self): + return 3 + +assert int(I(1)) == 3 + +class F(float): + def __int__(self): + return 3 + +assert int(F(1.2)) == 3 diff --git a/tests/snippets/testutils.py b/tests/snippets/testutils.py index 7dda57b17..49b7fedc1 100644 --- a/tests/snippets/testutils.py +++ b/tests/snippets/testutils.py @@ -14,9 +14,9 @@ def assert_raises(exc_type, expr, msg=None): except exc_type: pass else: - failmsg = '{!s} was not raised'.format(exc_type.__name__) + failmsg = '{} was not raised'.format(exc_type.__name__) if msg is not None: - failmsg += ': {!s}'.format(msg) + failmsg += ': {}'.format(msg) assert False, failmsg @@ -29,8 +29,8 @@ class assertRaises: def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None: - failmsg = '{!s} was not raised'.format(self.expected.__name__) - assert False, failmsg + failmsg = '{} was not raised'.format(self.expected.__name__) + assert False, failmsg if not issubclass(exc_type, self.expected): return False return True diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 41082707e..300d11b31 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -123,11 +123,14 @@ fn builtin_dir(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { } fn builtin_divmod(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(x, None), (y, None)]); - match vm.get_method(x.clone(), "__divmod__") { - Ok(attrib) => vm.invoke(attrib, vec![y.clone()]), - Err(..) => Err(vm.new_type_error("unsupported operand type(s) for divmod".to_string())), - } + arg_check!(vm, args, required = [(a, None), (b, None)]); + vm.call_or_reflection( + a.clone(), + b.clone(), + "__divmod__", + "__rdivmod__", + |vm, a, b| Err(vm.new_unsupported_operand_error(a, b, "divmod")), + ) } /// Implements `eval`. diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index f6244399b..6f78ecf48 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -7,7 +7,7 @@ use std::ops::{Deref, DerefMut}; use num_traits::ToPrimitive; use crate::function::OptionalArg; -use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; use super::objint; @@ -79,15 +79,12 @@ pub fn init(context: &PyContext) { "istitle" =>context.new_rustfunc(PyByteArrayRef::istitle), "isupper" => context.new_rustfunc(PyByteArrayRef::isupper), "lower" => context.new_rustfunc(PyByteArrayRef::lower), + "append" => context.new_rustfunc(PyByteArrayRef::append), "pop" => context.new_rustfunc(PyByteArrayRef::pop), "upper" => context.new_rustfunc(PyByteArrayRef::upper) }); - let bytearrayiterator_type = &context.bytearrayiterator_type; - extend_class!(context, bytearrayiterator_type, { - "__next__" => context.new_rustfunc(PyByteArrayIteratorRef::next), - "__iter__" => context.new_rustfunc(PyByteArrayIteratorRef::iter), - }); + PyByteArrayIterator::extend_class(context, &context.bytearrayiterator_type); } fn bytearray_new( @@ -213,6 +210,10 @@ impl PyByteArrayRef { self.value.borrow_mut().clear(); } + fn append(self, x: u8, _vm: &VirtualMachine) { + self.value.borrow_mut().push(x); + } + fn pop(self, vm: &VirtualMachine) -> PyResult { let mut bytes = self.value.borrow_mut(); bytes @@ -282,6 +283,7 @@ mod tests { } } +#[pyclass] #[derive(Debug)] pub struct PyByteArrayIterator { position: Cell, @@ -294,10 +296,10 @@ impl PyValue for PyByteArrayIterator { } } -type PyByteArrayIteratorRef = PyRef; - -impl PyByteArrayIteratorRef { - fn next(self, vm: &VirtualMachine) -> PyResult { +#[pyimpl] +impl PyByteArrayIterator { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { if self.position.get() < self.bytearray.value.borrow().len() { let ret = self.bytearray.value.borrow()[self.position.get()]; self.position.set(self.position.get() + 1); @@ -307,7 +309,8 @@ impl PyByteArrayIteratorRef { } } - fn iter(self, _vm: &VirtualMachine) -> Self { - self + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf } } diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 0556de053..df9795902 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -64,11 +64,7 @@ pub fn init(context: &PyContext) { extend_class!(context, bytes_type, { "fromhex" => context.new_rustfunc(PyBytesRef::fromhex), }); - let bytesiterator_type = &context.bytesiterator_type; - extend_class!(context, bytesiterator_type, { - "__next__" => context.new_rustfunc(PyBytesIteratorRef::next), - "__iter__" => context.new_rustfunc(PyBytesIteratorRef::iter), - }); + PyBytesIterator::extend_class(context, &context.bytesiterator_type); } #[pyimpl] @@ -271,6 +267,7 @@ impl PyBytesRef { } } +#[pyclass] #[derive(Debug)] pub struct PyBytesIterator { position: Cell, @@ -283,10 +280,10 @@ impl PyValue for PyBytesIterator { } } -type PyBytesIteratorRef = PyRef; - -impl PyBytesIteratorRef { - fn next(self, vm: &VirtualMachine) -> PyResult { +#[pyimpl] +impl PyBytesIterator { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { if self.position.get() < self.bytes.inner.len() { let ret = self.bytes[self.position.get()]; self.position.set(self.position.get() + 1); @@ -296,7 +293,8 @@ impl PyBytesIteratorRef { } } - fn iter(self, _vm: &VirtualMachine) -> Self { - self + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf } } diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index 93a0559aa..f9ff9d9b0 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -67,7 +67,9 @@ impl PyComplexRef { } fn to_complex(value: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - if objtype::isinstance(&value, &vm.ctx.int_type()) { + if objtype::isinstance(&value, &vm.ctx.complex_type()) { + Ok(Some(get_value(&value))) + } else if objtype::isinstance(&value, &vm.ctx.int_type()) { match objint::get_value(&value).to_f64() { Some(v) => Ok(Some(Complex64::new(v, 0.0))), None => Err(vm.new_overflow_error("int too large to convert to float".to_string())), @@ -161,6 +163,28 @@ impl PyComplex { vm.ctx.new_bool(result) } + #[pymethod(name = "__float__")] + fn float(&self, vm: &VirtualMachine) -> PyResult { + return Err(vm.new_type_error(String::from("Can't convert complex to float"))); + } + + #[pymethod(name = "__int__")] + fn int(&self, vm: &VirtualMachine) -> PyResult { + return Err(vm.new_type_error(String::from("Can't convert complex to int"))); + } + + #[pymethod(name = "__mul__")] + fn mul(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + match to_complex(other, vm) { + Ok(Some(other)) => Ok(vm.ctx.new_complex(Complex64::new( + self.value.re * other.re - self.value.im * other.im, + self.value.re * other.im + self.value.im * other.re, + ))), + Ok(None) => Ok(vm.ctx.not_implemented()), + Err(err) => Err(err), + } + } + #[pymethod(name = "__neg__")] fn neg(&self, _vm: &VirtualMachine) -> PyComplex { PyComplex::from(-self.value) diff --git a/vm/src/obj/objenumerate.rs b/vm/src/obj/objenumerate.rs index c94aeabbd..56d232272 100644 --- a/vm/src/obj/objenumerate.rs +++ b/vm/src/obj/objenumerate.rs @@ -5,13 +5,14 @@ use num_bigint::BigInt; use num_traits::Zero; use crate::function::OptionalArg; -use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; use super::objint::PyIntRef; use super::objiter; use super::objtype::PyClassRef; +#[pyclass] #[derive(Debug)] pub struct PyEnumerate { counter: RefCell, @@ -44,8 +45,10 @@ fn enumerate_new( .into_ref_with_type(vm, cls) } -impl PyEnumerateRef { - fn next(self, vm: &VirtualMachine) -> PyResult { +#[pyimpl] +impl PyEnumerate { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { let iterator = &self.iterator; let counter = &self.counter; let next_obj = objiter::call_next(vm, iterator)?; @@ -58,16 +61,15 @@ impl PyEnumerateRef { Ok(result) } - fn iter(self, _vm: &VirtualMachine) -> Self { - self + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf } } pub fn init(context: &PyContext) { - let enumerate_type = &context.enumerate_type; - extend_class!(context, enumerate_type, { + PyEnumerate::extend_class(context, &context.enumerate_type); + extend_class!(context, &context.enumerate_type, { "__new__" => context.new_rustfunc(enumerate_new), - "__next__" => context.new_rustfunc(PyEnumerateRef::next), - "__iter__" => context.new_rustfunc(PyEnumerateRef::iter), }); } diff --git a/vm/src/obj/objfilter.rs b/vm/src/obj/objfilter.rs index 858a0531a..24800d6aa 100644 --- a/vm/src/obj/objfilter.rs +++ b/vm/src/obj/objfilter.rs @@ -1,4 +1,4 @@ -use crate::pyobject::{IdProtocol, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; // Required for arg_check! to use isinstance use super::objbool; @@ -7,6 +7,11 @@ use crate::obj::objtype::PyClassRef; pub type PyFilterRef = PyRef; +/// filter(function or None, iterable) --> filter object +/// +/// Return an iterator yielding those items of iterable for which function(item) +/// is true. If function is None, return the items that are true. +#[pyclass] #[derive(Debug)] pub struct PyFilter { predicate: PyObjectRef, @@ -34,8 +39,10 @@ fn filter_new( .into_ref_with_type(vm, cls) } -impl PyFilterRef { - fn next(self, vm: &VirtualMachine) -> PyResult { +#[pyimpl] +impl PyFilter { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { let predicate = &self.predicate; let iterator = &self.iterator; loop { @@ -53,23 +60,15 @@ impl PyFilterRef { } } - fn iter(self, _vm: &VirtualMachine) -> Self { - self + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf } } pub fn init(context: &PyContext) { - let filter_type = &context.filter_type; - - let filter_doc = - "filter(function or None, iterable) --> filter object\n\n\ - Return an iterator yielding those items of iterable for which function(item)\n\ - is true. If function is None, return the items that are true."; - - extend_class!(context, filter_type, { + PyFilter::extend_class(context, &context.filter_type); + extend_class!(context, &context.filter_type, { "__new__" => context.new_rustfunc(filter_new), - "__doc__" => context.new_str(filter_doc.to_string()), - "__next__" => context.new_rustfunc(PyFilterRef::next), - "__iter__" => context.new_rustfunc(PyFilterRef::iter), }); } diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index 42810da8e..6c097b706 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -43,9 +43,27 @@ impl From for PyFloat { } } -fn mod_(v1: f64, v2: f64, vm: &VirtualMachine) -> PyResult { +fn try_float(value: &PyObjectRef, vm: &VirtualMachine) -> PyResult> { + Ok(if objtype::isinstance(&value, &vm.ctx.float_type()) { + Some(get_value(&value)) + } else if objtype::isinstance(&value, &vm.ctx.int_type()) { + Some(objint::get_float_value(&value, vm)?) + } else { + None + }) +} + +fn inner_div(v1: f64, v2: f64, vm: &VirtualMachine) -> PyResult { if v2 != 0.0 { - Ok(vm.ctx.new_float(v1 % v2)) + Ok(v1 / v2) + } else { + Err(vm.new_zero_division_error("float division by zero".to_string())) + } +} + +fn inner_mod(v1: f64, v2: f64, vm: &VirtualMachine) -> PyResult { + if v2 != 0.0 { + Ok(v1 % v2) } else { Err(vm.new_zero_division_error("float mod by zero".to_string())) } @@ -73,6 +91,22 @@ fn try_to_bigint(value: f64, vm: &VirtualMachine) -> PyResult { } } +fn inner_floordiv(v1: f64, v2: f64, vm: &VirtualMachine) -> PyResult { + if v2 != 0.0 { + Ok((v1 / v2).floor()) + } else { + Err(vm.new_zero_division_error("float floordiv by zero".to_string())) + } +} + +fn inner_divmod(v1: f64, v2: f64, vm: &VirtualMachine) -> PyResult<(f64, f64)> { + if v2 != 0.0 { + Ok(((v1 / v2).floor(), v1 % v2)) + } else { + Err(vm.new_zero_division_error("float divmod()".to_string())) + } +} + #[pyimpl] impl PyFloat { #[pymethod(name = "__eq__")] @@ -153,20 +187,15 @@ impl PyFloat { } #[pymethod(name = "__add__")] - fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { - let v1 = self.value; - if objtype::isinstance(&other, &vm.ctx.float_type()) { - vm.ctx.new_float(v1 + get_value(&other)) - } else if objtype::isinstance(&other, &vm.ctx.int_type()) { - vm.ctx - .new_float(v1 + objint::get_value(&other).to_f64().unwrap()) - } else { - vm.ctx.not_implemented() - } + fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + try_float(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (self.value + other).into_pyobject(vm), + ) } #[pymethod(name = "__radd__")] - fn radd(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + fn radd(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { self.add(other, vm) } @@ -177,45 +206,51 @@ impl PyFloat { #[pymethod(name = "__divmod__")] fn divmod(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - if objtype::isinstance(&other, &vm.ctx.float_type()) - || objtype::isinstance(&other, &vm.ctx.int_type()) - { - let r1 = self.floordiv(other.clone(), vm)?; - let r2 = self.mod_(other, vm)?; - Ok(vm.ctx.new_tuple(vec![r1, r2])) - } else { - Ok(vm.ctx.not_implemented()) - } + try_float(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| { + let (r1, r2) = inner_divmod(self.value, other, vm)?; + Ok(vm + .ctx + .new_tuple(vec![vm.ctx.new_float(r1), vm.ctx.new_float(r2)])) + }, + ) + } + + #[pymethod(name = "__rdivmod__")] + fn rdivmod(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + try_float(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| { + let (r1, r2) = inner_divmod(other, self.value, vm)?; + Ok(vm + .ctx + .new_tuple(vec![vm.ctx.new_float(r1), vm.ctx.new_float(r2)])) + }, + ) } #[pymethod(name = "__floordiv__")] fn floordiv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let v1 = self.value; - let v2 = if objtype::isinstance(&other, &vm.ctx.float_type) { - get_value(&other) - } else if objtype::isinstance(&other, &vm.ctx.int_type) { - objint::get_float_value(&other, vm)? - } else { - return Ok(vm.ctx.not_implemented()); - }; + try_float(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| inner_floordiv(self.value, other, vm)?.into_pyobject(vm), + ) + } - if v2 != 0.0 { - Ok(vm.ctx.new_float((v1 / v2).floor())) - } else { - Err(vm.new_zero_division_error("float floordiv by zero".to_string())) - } + #[pymethod(name = "__rfloordiv__")] + fn rfloordiv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + try_float(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| inner_floordiv(other, self.value, vm)?.into_pyobject(vm), + ) } fn new_float(cls: PyClassRef, arg: PyObjectRef, vm: &VirtualMachine) -> PyResult { let value = if objtype::isinstance(&arg, &vm.ctx.float_type()) { get_value(&arg) } else if objtype::isinstance(&arg, &vm.ctx.int_type()) { - match objint::get_float_value(&arg, vm) { - Ok(f) => f, - Err(e) => { - return Err(e); - } - } + objint::get_float_value(&arg, vm)? } else if objtype::isinstance(&arg, &vm.ctx.str_type()) { match lexical::try_parse(objstr::get_value(&arg)) { Ok(f) => f, @@ -246,28 +281,18 @@ impl PyFloat { #[pymethod(name = "__mod__")] fn mod_(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let v1 = self.value; - let v2 = if objtype::isinstance(&other, &vm.ctx.float_type) { - get_value(&other) - } else if objtype::isinstance(&other, &vm.ctx.int_type) { - objint::get_float_value(&other, vm)? - } else { - return Ok(vm.ctx.not_implemented()); - }; - - mod_(v1, v2, vm) + try_float(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| inner_mod(self.value, other, vm)?.into_pyobject(vm), + ) } #[pymethod(name = "__rmod__")] fn rmod(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let v2 = self.value; - let v1 = if objtype::isinstance(&other, &vm.ctx.int_type) { - objint::get_float_value(&other, vm)? - } else { - return Ok(vm.ctx.not_implemented()); - }; - - mod_(v1, v2, vm) + try_float(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| inner_mod(other, self.value, vm)?.into_pyobject(vm), + ) } #[pymethod(name = "__neg__")] @@ -276,44 +301,35 @@ impl PyFloat { } #[pymethod(name = "__pow__")] - fn pow(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { - let v1 = self.value; - if objtype::isinstance(&other, &vm.ctx.float_type()) { - vm.ctx.new_float(v1.powf(get_value(&other))) - } else if objtype::isinstance(&other, &vm.ctx.int_type()) { - let result = v1.powf(objint::get_value(&other).to_f64().unwrap()); - vm.ctx.new_float(result) - } else { - vm.ctx.not_implemented() - } + fn pow(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + try_float(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| self.value.powf(other).into_pyobject(vm), + ) + } + + #[pymethod(name = "__rpow__")] + fn rpow(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + try_float(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| other.powf(self.value).into_pyobject(vm), + ) } #[pymethod(name = "__sub__")] fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let v1 = self.value; - if objtype::isinstance(&other, &vm.ctx.float_type()) { - Ok(vm.ctx.new_float(v1 - get_value(&other))) - } else if objtype::isinstance(&other, &vm.ctx.int_type()) { - Ok(vm - .ctx - .new_float(v1 - objint::get_value(&other).to_f64().unwrap())) - } else { - Ok(vm.ctx.not_implemented()) - } + try_float(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (self.value - other).into_pyobject(vm), + ) } #[pymethod(name = "__rsub__")] fn rsub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let v1 = self.value; - if objtype::isinstance(&other, &vm.ctx.float_type()) { - Ok(vm.ctx.new_float(get_value(&other) - v1)) - } else if objtype::isinstance(&other, &vm.ctx.int_type()) { - Ok(vm - .ctx - .new_float(objint::get_value(&other).to_f64().unwrap() - v1)) - } else { - Ok(vm.ctx.not_implemented()) - } + try_float(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (other - self.value).into_pyobject(vm), + ) } #[pymethod(name = "__repr__")] @@ -323,52 +339,26 @@ impl PyFloat { #[pymethod(name = "__truediv__")] fn truediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let v1 = self.value; - let v2 = if objtype::isinstance(&other, &vm.ctx.float_type) { - get_value(&other) - } else if objtype::isinstance(&other, &vm.ctx.int_type) { - objint::get_float_value(&other, vm)? - } else { - return Ok(vm.ctx.not_implemented()); - }; - - if v2 != 0.0 { - Ok(vm.ctx.new_float(v1 / v2)) - } else { - Err(vm.new_zero_division_error("float division by zero".to_string())) - } + try_float(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| inner_div(self.value, other, vm)?.into_pyobject(vm), + ) } #[pymethod(name = "__rtruediv__")] fn rtruediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let v1 = self.value; - let v2 = if objtype::isinstance(&other, &vm.ctx.float_type) { - get_value(&other) - } else if objtype::isinstance(&other, &vm.ctx.int_type) { - objint::get_float_value(&other, vm)? - } else { - return Ok(vm.ctx.not_implemented()); - }; - - if v1 != 0.0 { - Ok(vm.ctx.new_float(v2 / v1)) - } else { - Err(vm.new_zero_division_error("float division by zero".to_string())) - } + try_float(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| inner_div(other, self.value, vm)?.into_pyobject(vm), + ) } #[pymethod(name = "__mul__")] fn mul(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let v1 = self.value; - if objtype::isinstance(&other, &vm.ctx.float_type) { - Ok(vm.ctx.new_float(v1 * get_value(&other))) - } else if objtype::isinstance(&other, &vm.ctx.int_type) { - Ok(vm - .ctx - .new_float(v1 * objint::get_value(&other).to_f64().unwrap())) - } else { - Ok(vm.ctx.not_implemented()) - } + try_float(&other, vm)?.map_or_else( + || Ok(vm.ctx.not_implemented()), + |other| (self.value * other).into_pyobject(vm), + ) } #[pymethod(name = "__rmul__")] diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index c9a5c9a40..3687ed2b0 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -1,12 +1,12 @@ use std::fmt; use std::hash::{Hash, Hasher}; -use num_bigint::{BigInt, ToBigInt}; +use num_bigint::BigInt; use num_integer::Integer; use num_traits::{Pow, Signed, ToPrimitive, Zero}; use crate::format::FormatSpec; -use crate::function::OptionalArg; +use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyobject::{ IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, @@ -510,11 +510,8 @@ fn int_new(cls: PyClassRef, options: IntOptions, vm: &VirtualMachine) -> PyResul } // Casting function: -// TODO: this should just call `__int__` on the object pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, base: u32) -> PyResult { match_class!(obj.clone(), - i @ PyInt => Ok(i.as_bigint().clone()), - f @ PyFloat => Ok(f.to_f64().to_bigint().unwrap()), s @ PyString => { i32::from_str_radix(s.as_str(), base) .map(BigInt::from) @@ -523,10 +520,21 @@ pub fn to_int(vm: &VirtualMachine, obj: &PyObjectRef, base: u32) -> PyResult Err(vm.new_type_error(format!( - "int() argument must be a string or a number, not '{}'", - obj.class().name - ))) + obj => { + if let Ok(f) = vm.get_method(obj.clone(), "__int__") { + let int_res = vm.invoke(f, PyFuncArgs::default())?; + match int_res.payload::() { + Some(i) => Ok(i.as_bigint().clone()), + None => Err(vm.new_type_error(format!( + "TypeError: __int__ returned non-int (type '{}')", int_res.class().name))), + } + } else { + Err(vm.new_type_error(format!( + "int() argument must be a string or a number, not '{}'", + obj.class().name + ))) + } + } ) } diff --git a/vm/src/obj/objiter.rs b/vm/src/obj/objiter.rs index 6e1dc1d06..efe672d62 100644 --- a/vm/src/obj/objiter.rs +++ b/vm/src/obj/objiter.rs @@ -4,7 +4,9 @@ use std::cell::Cell; -use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol}; +use crate::pyobject::{ + PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, +}; use crate::vm::VirtualMachine; use super::objtype; @@ -75,6 +77,7 @@ pub fn new_stop_iteration(vm: &VirtualMachine) -> PyObjectRef { vm.new_exception(stop_iteration_type, "End of iterator".to_string()) } +#[pyclass] #[derive(Debug)] pub struct PySequenceIterator { pub position: Cell, @@ -87,10 +90,10 @@ impl PyValue for PySequenceIterator { } } -type PySequenceIteratorRef = PyRef; - -impl PySequenceIteratorRef { - fn next(self, vm: &VirtualMachine) -> PyResult { +#[pyimpl] +impl PySequenceIterator { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { let number = vm.ctx.new_int(self.position.get()); match vm.call_method(&self.obj, "__getitem__", vec![number]) { Ok(val) => { @@ -105,16 +108,12 @@ impl PySequenceIteratorRef { } } - fn iter(self, _vm: &VirtualMachine) -> Self { - self + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf } } pub fn init(context: &PyContext) { - let iter_type = &context.iter_type; - - extend_class!(context, iter_type, { - "__next__" => context.new_rustfunc(PySequenceIteratorRef::next), - "__iter__" => context.new_rustfunc(PySequenceIteratorRef::iter), - }); + PySequenceIterator::extend_class(context, &context.iter_type); } diff --git a/vm/src/obj/objlist.rs b/vm/src/obj/objlist.rs index 5c1bad5af..6b00b6828 100644 --- a/vm/src/obj/objlist.rs +++ b/vm/src/obj/objlist.rs @@ -8,7 +8,8 @@ use num_traits::{One, Signed, ToPrimitive, Zero}; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyobject::{ - IdProtocol, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + IdProtocol, PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, + TryFromObject, }; use crate::vm::{ReprGuard, VirtualMachine}; @@ -776,6 +777,7 @@ fn list_sort(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.get_none()) } +#[pyclass] #[derive(Debug)] pub struct PyListIterator { pub position: Cell, @@ -788,10 +790,10 @@ impl PyValue for PyListIterator { } } -type PyListIteratorRef = PyRef; - -impl PyListIteratorRef { - fn next(self, vm: &VirtualMachine) -> PyResult { +#[pyimpl] +impl PyListIterator { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { if self.position.get() < self.list.elements.borrow().len() { let ret = self.list.elements.borrow()[self.position.get()].clone(); self.position.set(self.position.get() + 1); @@ -801,8 +803,9 @@ impl PyListIteratorRef { } } - fn iter(self, _vm: &VirtualMachine) -> Self { - self + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf } } @@ -848,9 +851,5 @@ pub fn init(context: &PyContext) { "remove" => context.new_rustfunc(PyListRef::remove) }); - let listiterator_type = &context.listiterator_type; - extend_class!(context, listiterator_type, { - "__next__" => context.new_rustfunc(PyListIteratorRef::next), - "__iter__" => context.new_rustfunc(PyListIteratorRef::iter), - }); + PyListIterator::extend_class(context, &context.listiterator_type); } diff --git a/vm/src/obj/objmap.rs b/vm/src/obj/objmap.rs index 061f4179f..61fe67c0d 100644 --- a/vm/src/obj/objmap.rs +++ b/vm/src/obj/objmap.rs @@ -1,10 +1,15 @@ use crate::function::Args; -use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; use super::objiter; use super::objtype::PyClassRef; +/// map(func, *iterables) --> map object +/// +/// Make an iterator that computes the function using arguments from +/// each of the iterables. Stops when the shortest iterable is exhausted. +#[pyclass] #[derive(Debug)] pub struct PyMap { mapper: PyObjectRef, @@ -35,8 +40,10 @@ fn map_new( .into_ref_with_type(vm, cls.clone()) } -impl PyMapRef { - fn next(self, vm: &VirtualMachine) -> PyResult { +#[pyimpl] +impl PyMap { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { let next_objs = self .iterators .iter() @@ -47,22 +54,15 @@ impl PyMapRef { vm.invoke(self.mapper.clone(), next_objs) } - fn iter(self, _vm: &VirtualMachine) -> Self { - self + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf } } pub fn init(context: &PyContext) { - let map_type = &context.map_type; - - let map_doc = "map(func, *iterables) --> map object\n\n\ - Make an iterator that computes the function using arguments from\n\ - each of the iterables. Stops when the shortest iterable is exhausted."; - - extend_class!(context, map_type, { + PyMap::extend_class(context, &context.map_type); + extend_class!(context, &context.map_type, { "__new__" => context.new_rustfunc(map_new), - "__next__" => context.new_rustfunc(PyMapRef::next), - "__iter__" => context.new_rustfunc(PyMapRef::iter), - "__doc__" => context.new_str(map_doc.to_string()) }); } diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index 1a21a49ad..2a43579c5 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -1,5 +1,4 @@ use std::cell::Cell; -use std::ops::Mul; use num_bigint::{BigInt, Sign}; use num_integer::Integer; @@ -7,14 +6,14 @@ use num_traits::{One, Signed, Zero}; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyobject::{ - PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, + PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; use super::objint::{PyInt, PyIntRef}; use super::objiter; use super::objslice::{PySlice, PySliceRef}; -use super::objtype::PyClassRef; +use super::objtype::{self, PyClassRef}; #[derive(Debug, Clone)] pub struct PyRange { @@ -65,15 +64,22 @@ impl PyRange { } #[inline] - pub fn get<'a, T>(&'a self, index: T) -> Option - where - &'a BigInt: Mul, - { + pub fn get(&self, index: &BigInt) -> Option { let start = self.start.as_bigint(); let stop = self.stop.as_bigint(); let step = self.step.as_bigint(); - let result = start + step * index; + let index = if index < &BigInt::zero() { + let index = stop + index; + if index < BigInt::zero() { + return None; + } + index + } else { + index.clone() + }; + + let result = start + step * &index; if (self.forward() && !self.is_empty() && &result < stop) || (!self.forward() && !self.is_empty() && &result > stop) @@ -104,6 +110,7 @@ pub fn init(context: &PyContext) { "__bool__" => context.new_rustfunc(PyRange::bool), "__contains__" => context.new_rustfunc(PyRange::contains), "__doc__" => context.new_str(range_doc.to_string()), + "__eq__" => context.new_rustfunc(PyRange::eq), "__getitem__" => context.new_rustfunc(PyRange::getitem), "__iter__" => context.new_rustfunc(PyRange::iter), "__len__" => context.new_rustfunc(PyRange::len), @@ -117,11 +124,7 @@ pub fn init(context: &PyContext) { "step" => context.new_property(PyRange::step), }); - let rangeiterator_type = &context.rangeiterator_type; - extend_class!(context, rangeiterator_type, { - "__next__" => context.new_rustfunc(PyRangeIteratorRef::next), - "__iter__" => context.new_rustfunc(PyRangeIteratorRef::iter), - }); + PyRangeIterator::extend_class(context, &context.rangeiterator_type); } type PyRangeRef = PyRef; @@ -240,6 +243,17 @@ impl PyRange { } } + fn eq(&self, rhs: PyObjectRef, vm: &VirtualMachine) -> bool { + if objtype::isinstance(&rhs, &vm.ctx.range_type()) { + let rhs = get_value(&rhs); + self.start.as_bigint() == rhs.start.as_bigint() + && self.stop.as_bigint() == rhs.stop.as_bigint() + && self.step.as_bigint() == rhs.step.as_bigint() + } else { + false + } + } + fn index(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { if let Ok(int) = needle.downcast::() { match self.index_of(int.as_bigint()) { @@ -274,7 +288,7 @@ impl PyRange { } RangeIndex::Slice(slice) => { let new_start = if let Some(int) = slice.start_index(vm)? { - if let Some(i) = self.get(int) { + if let Some(i) = self.get(&int) { PyInt::new(i).into_ref(vm) } else { self.start.clone() @@ -284,7 +298,7 @@ impl PyRange { }; let new_end = if let Some(int) = slice.stop_index(vm)? { - if let Some(i) = self.get(int) { + if let Some(i) = self.get(&int) { PyInt::new(i).into_ref(vm) } else { self.stop.clone() @@ -323,6 +337,7 @@ fn range_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(range.into_object()) } +#[pyclass] #[derive(Debug)] pub struct PyRangeIterator { position: Cell, @@ -335,11 +350,12 @@ impl PyValue for PyRangeIterator { } } -type PyRangeIteratorRef = PyRef; - -impl PyRangeIteratorRef { - fn next(self, vm: &VirtualMachine) -> PyResult { - if let Some(int) = self.range.get(self.position.get()) { +#[pyimpl] +impl PyRangeIterator { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let position = BigInt::from(self.position.get()); + if let Some(int) = self.range.get(&position) { self.position.set(self.position.get() + 1); Ok(int) } else { @@ -347,8 +363,9 @@ impl PyRangeIteratorRef { } } - fn iter(self, _vm: &VirtualMachine) -> Self { - self + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf } } diff --git a/vm/src/obj/objtuple.rs b/vm/src/obj/objtuple.rs index 216945284..42999189f 100644 --- a/vm/src/obj/objtuple.rs +++ b/vm/src/obj/objtuple.rs @@ -3,7 +3,7 @@ use std::fmt; use std::hash::{Hash, Hasher}; use crate::function::OptionalArg; -use crate::pyobject::{IdProtocol, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::{ReprGuard, VirtualMachine}; use super::objbool; @@ -224,6 +224,7 @@ fn tuple_new( PyTuple::from(elements).into_ref_with_type(vm, cls) } +#[pyclass] #[derive(Debug)] pub struct PyTupleIterator { position: Cell, @@ -236,10 +237,10 @@ impl PyValue for PyTupleIterator { } } -type PyTupleIteratorRef = PyRef; - -impl PyTupleIteratorRef { - fn next(self, vm: &VirtualMachine) -> PyResult { +#[pyimpl] +impl PyTupleIterator { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { if self.position.get() < self.tuple.elements.borrow().len() { let ret = self.tuple.elements.borrow()[self.position.get()].clone(); self.position.set(self.position.get() + 1); @@ -249,8 +250,9 @@ impl PyTupleIteratorRef { } } - fn iter(self, _vm: &VirtualMachine) -> Self { - self + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf } } @@ -282,9 +284,5 @@ If the argument is a tuple, the return value is the same object."; "index" => context.new_rustfunc(PyTupleRef::index) }); - let tupleiterator_type = &context.tupleiterator_type; - extend_class!(context, tupleiterator_type, { - "__next__" => context.new_rustfunc(PyTupleIteratorRef::next), - "__iter__" => context.new_rustfunc(PyTupleIteratorRef::iter), - }); + PyTupleIterator::extend_class(context, &context.tupleiterator_type); } diff --git a/vm/src/obj/objzip.rs b/vm/src/obj/objzip.rs index 81640a724..d9c577aa0 100644 --- a/vm/src/obj/objzip.rs +++ b/vm/src/obj/objzip.rs @@ -1,5 +1,5 @@ use crate::function::Args; -use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; use super::objiter; @@ -7,6 +7,7 @@ use crate::obj::objtype::PyClassRef; pub type PyZipRef = PyRef; +#[pyclass] #[derive(Debug)] pub struct PyZip { iterators: Vec, @@ -26,8 +27,10 @@ fn zip_new(cls: PyClassRef, iterables: Args, vm: &VirtualMachine) -> PyResult PyResult { +#[pyimpl] +impl PyZip { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { if self.iterators.is_empty() { Err(objiter::new_stop_iteration(vm)) } else { @@ -41,16 +44,15 @@ impl PyZipRef { } } - fn iter(self, _vm: &VirtualMachine) -> Self { - self + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf } } pub fn init(context: &PyContext) { - let zip_type = &context.zip_type; - extend_class!(context, zip_type, { + PyZip::extend_class(context, &context.zip_type); + extend_class!(context, &context.zip_type, { "__new__" => context.new_rustfunc(zip_new), - "__next__" => context.new_rustfunc(PyZipRef::next), - "__iter__" => context.new_rustfunc(PyZipRef::iter), }); } diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index 6d81ca9c1..2b6c859eb 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -15,7 +15,7 @@ use crate::obj::objiter; use crate::obj::objstr; use crate::obj::objstr::PyStringRef; use crate::obj::objtype::PyClassRef; -use crate::pyobject::{ItemProtocol, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{ItemProtocol, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; #[cfg(unix)] @@ -232,6 +232,7 @@ impl DirEntryRef { } } +#[pyclass] #[derive(Debug)] struct ScandirIterator { entries: RefCell, @@ -243,10 +244,10 @@ impl PyValue for ScandirIterator { } } -type ScandirIteratorRef = PyRef; - -impl ScandirIteratorRef { - fn next(self, vm: &VirtualMachine) -> PyResult { +#[pyimpl] +impl ScandirIterator { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { match self.entries.borrow_mut().next() { Some(entry) => match entry { Ok(entry) => Ok(DirEntry { entry }.into_ref(vm).into_object()), @@ -256,8 +257,9 @@ impl ScandirIteratorRef { } } - fn iter(self, _vm: &VirtualMachine) -> Self { - self + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf } } @@ -418,10 +420,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let environ = _os_environ(vm); - let scandir_iter = py_class!(ctx, "ScandirIter", ctx.object(), { - "__iter__" => ctx.new_rustfunc(ScandirIteratorRef::iter), - "__next__" => ctx.new_rustfunc(ScandirIteratorRef::next), - }); + let scandir_iter = ctx.new_class("ScandirIter", ctx.object()); + ScandirIterator::extend_class(ctx, &scandir_iter); let dir_entry = py_class!(ctx, "DirEntry", ctx.object(), { "name" => ctx.new_property(DirEntryRef::name),