From 8f9b733a7799323412280f7824ef388832ececbb Mon Sep 17 00:00:00 2001 From: ben Date: Sun, 3 Feb 2019 13:59:43 +1300 Subject: [PATCH] Add tuple.__mul__ --- tests/snippets/list.py | 1 + tests/snippets/tuple.py | 5 +++++ vm/src/obj/objlist.rs | 12 ++---------- vm/src/obj/objsequence.rs | 15 +++++++++++++++ vm/src/obj/objtuple.rs | 18 +++++++++++++++++- 5 files changed, 40 insertions(+), 11 deletions(-) diff --git a/tests/snippets/list.py b/tests/snippets/list.py index 85010a755..22b7dec82 100644 --- a/tests/snippets/list.py +++ b/tests/snippets/list.py @@ -10,6 +10,7 @@ y.extend(x) assert y == [2, 1, 2, 3, 1, 2, 3] assert x * 0 == [], "list __mul__ by 0 failed" +assert x * -1 == [], "list __mul__ by -1 failed" assert x * 2 == [1, 2, 3, 1, 2, 3], "list __mul__ by 2 failed" assert ['a', 'b', 'c'].index('b') == 1 diff --git a/tests/snippets/tuple.py b/tests/snippets/tuple.py index 9227f9591..781a56ee5 100644 --- a/tests/snippets/tuple.py +++ b/tests/snippets/tuple.py @@ -5,3 +5,8 @@ assert x[0] == 1 y = (1,) assert y[0] == 1 + +assert x * 3 == (1, 2, 1, 2, 1, 2) +# assert 3 * x == (1, 2, 1, 2, 1, 2) +assert x * 0 == () +assert x * -1 == () # integers less than zero treated as 0 diff --git a/vm/src/obj/objlist.rs b/vm/src/obj/objlist.rs index 641bb9aa6..30df25ab5 100644 --- a/vm/src/obj/objlist.rs +++ b/vm/src/obj/objlist.rs @@ -5,7 +5,7 @@ use super::super::vm::VirtualMachine; use super::objbool; use super::objint; use super::objsequence::{ - get_elements, get_item, get_mut_elements, seq_equal, PySliceableSequence, + get_elements, get_item, get_mut_elements, seq_equal, seq_mul, PySliceableSequence, }; use super::objstr; use super::objtype; @@ -259,15 +259,7 @@ fn list_mul(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { ] ); - let counter = objint::get_value(&product).to_usize().unwrap(); - - let elements = get_elements(list); - let current_len = elements.len(); - let mut new_elements = Vec::with_capacity(counter * current_len); - - for _ in 0..counter { - new_elements.extend(elements.clone()); - } + let new_elements = seq_mul(&get_elements(list), product); Ok(vm.ctx.new_list(new_elements)) } diff --git a/vm/src/obj/objsequence.rs b/vm/src/obj/objsequence.rs index cd26391a9..8dc63b51c 100644 --- a/vm/src/obj/objsequence.rs +++ b/vm/src/obj/objsequence.rs @@ -1,6 +1,7 @@ use super::super::pyobject::{PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol}; use super::super::vm::VirtualMachine; use super::objbool; +use super::objint; use num_traits::ToPrimitive; use std::cell::{Ref, RefMut}; use std::marker::Sized; @@ -120,6 +121,20 @@ pub fn seq_equal( } } +pub fn seq_mul(elements: &Vec, product: &PyObjectRef) -> Vec { + let counter = objint::get_value(&product).to_isize().unwrap(); + + let current_len = elements.len(); + let new_len = counter.max(0) as usize * current_len; + let mut new_elements = Vec::with_capacity(new_len); + + for _ in 0..counter { + new_elements.extend(elements.clone()); + } + + new_elements +} + pub fn get_elements<'a>(obj: &'a PyObjectRef) -> impl Deref> + 'a { Ref::map(obj.borrow(), |x| { if let PyObjectPayload::Sequence { ref elements } = x.payload { diff --git a/vm/src/obj/objtuple.rs b/vm/src/obj/objtuple.rs index 9657328a1..8a0e92d0c 100644 --- a/vm/src/obj/objtuple.rs +++ b/vm/src/obj/objtuple.rs @@ -4,7 +4,7 @@ use super::super::pyobject::{ use super::super::vm::VirtualMachine; use super::objbool; use super::objint; -use super::objsequence::{get_elements, get_item, seq_equal}; +use super::objsequence::{get_elements, get_item, seq_equal, seq_mul}; use super::objstr; use super::objtype; use num_bigint::ToBigInt; @@ -119,6 +119,21 @@ fn tuple_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.new_str(s)) } +fn tuple_mul(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [ + (zelf, Some(vm.ctx.tuple_type())), + (product, Some(vm.ctx.int_type())) + ] + ); + + let new_elements = seq_mul(&get_elements(zelf), product); + + Ok(vm.ctx.new_tuple(new_elements)) +} + fn tuple_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, @@ -165,6 +180,7 @@ pub fn init(context: &PyContext) { context.set_attr(&tuple_type, "__iter__", context.new_rustfunc(tuple_iter)); context.set_attr(&tuple_type, "__len__", context.new_rustfunc(tuple_len)); context.set_attr(&tuple_type, "__new__", context.new_rustfunc(tuple_new)); + context.set_attr(&tuple_type, "__mul__", context.new_rustfunc(tuple_mul)); context.set_attr(&tuple_type, "__repr__", context.new_rustfunc(tuple_repr)); context.set_attr(&tuple_type, "count", context.new_rustfunc(tuple_count)); }