From 1ebacafb00a67ddba44cefe71bcbf8c8ebf94f19 Mon Sep 17 00:00:00 2001 From: Joey Hain Date: Fri, 8 Feb 2019 00:19:14 -0800 Subject: [PATCH] Add reversed builtin and range.__reversed__ --- tests/snippets/builtin_range.py | 4 +++ tests/snippets/builtin_reversed.py | 1 + vm/src/builtins.rs | 14 +++++++++++ vm/src/obj/objrange.rs | 39 ++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+) create mode 100644 tests/snippets/builtin_reversed.py diff --git a/tests/snippets/builtin_range.py b/tests/snippets/builtin_range.py index c822ce7e75..cdba9b6d59 100644 --- a/tests/snippets/builtin_range.py +++ b/tests/snippets/builtin_range.py @@ -52,3 +52,7 @@ assert not range(10).__contains__(-1) assert not range(10, 4, -2).__contains__(9) assert not range(10, 4, -2).__contains__(4) assert not range(10).__contains__('foo') + +# __reversed__ +assert list(range(5).__reversed__()) == [4, 3, 2, 1, 0] +assert list(range(5, 0, -1).__reversed__()) == [1, 2, 3, 4, 5] diff --git a/tests/snippets/builtin_reversed.py b/tests/snippets/builtin_reversed.py new file mode 100644 index 0000000000..2bbfcb98a2 --- /dev/null +++ b/tests/snippets/builtin_reversed.py @@ -0,0 +1 @@ +assert list(reversed(range(5))) == [4, 3, 2, 1, 0] diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 1db5c7e77c..c785a250b6 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -611,6 +611,19 @@ fn builtin_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(obj, None)]); vm.to_repr(obj) } + +fn builtin_reversed(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(obj, None)]); + + match vm.get_method(obj.clone(), "__reversed__") { + Ok(value) => vm.invoke(value, PyFuncArgs::default()), + // TODO: fallback to using __len__ and __getitem__, if object supports sequence protocol + Err(..) => Err(vm.new_type_error(format!( + "'{}' object is not reversible", + objtype::get_type_name(&obj.typ()), + ))), + } +} // builtin_reversed // builtin_round @@ -725,6 +738,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { ctx.set_attr(&py_mod, "property", ctx.property_type()); ctx.set_attr(&py_mod, "range", ctx.range_type()); ctx.set_attr(&py_mod, "repr", ctx.new_rustfunc(builtin_repr)); + ctx.set_attr(&py_mod, "reversed", ctx.new_rustfunc(builtin_reversed)); ctx.set_attr(&py_mod, "set", ctx.set_type()); ctx.set_attr(&py_mod, "setattr", ctx.new_rustfunc(builtin_setattr)); ctx.set_attr(&py_mod, "staticmethod", ctx.staticmethod_type()); diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index cd897ff347..746822876b 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -76,12 +76,34 @@ impl RangeType { None } } + + #[inline] + pub fn reversed(&self) -> Self { + match self.step.sign() { + Sign::Plus => RangeType { + start: &self.end - 1, + end: &self.start - 1, + step: -&self.step, + }, + Sign::Minus => RangeType { + start: &self.end + 1, + end: &self.start + 1, + step: -&self.step, + }, + Sign::NoSign => unreachable!(), + } + } } pub fn init(context: &PyContext) { let ref range_type = context.range_type; context.set_attr(&range_type, "__new__", context.new_rustfunc(range_new)); context.set_attr(&range_type, "__iter__", context.new_rustfunc(range_iter)); + context.set_attr( + &range_type, + "__reversed__", + context.new_rustfunc(range_reversed), + ); context.set_attr(&range_type, "__len__", context.new_rustfunc(range_len)); context.set_attr( &range_type, @@ -150,6 +172,23 @@ fn range_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { )) } +fn range_reversed(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(zelf, Some(vm.ctx.range_type()))]); + + let range = match zelf.borrow().payload { + PyObjectPayload::Range { ref range } => range.reversed(), + _ => unreachable!(), + }; + + Ok(PyObject::new( + PyObjectPayload::Iterator { + position: 0, + iterated_obj: PyObject::new(PyObjectPayload::Range { range }, vm.ctx.range_type()), + }, + vm.ctx.iter_type(), + )) +} + fn range_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(zelf, Some(vm.ctx.range_type()))]);