diff --git a/tests/snippets/membership.py b/tests/snippets/membership.py new file mode 100644 index 000000000..3b76b8052 --- /dev/null +++ b/tests/snippets/membership.py @@ -0,0 +1,62 @@ +# test lists +assert 3 in [1, 2, 3] +assert 3 not in [1, 2] + +assert not (3 in [1, 2]) +assert not (3 not in [1, 2, 3]) + +# test strings +assert "foo" in "foobar" +assert "whatever" not in "foobar" + +# test bytes +# TODO: uncomment this when bytes are implemented +# assert b"foo" in b"foobar" +# assert b"whatever" not in b"foobar" + +# test tuple +assert 1 in (1, 2) +assert 3 not in (1, 2) + +# test set +# TODO: uncomment this when sets are implemented +# assert 1 in set(1, 2) +# assert 3 not in set(1, 2) + +# test dicts +# TODO: test dicts when keys other than strings will be allowed +assert "a" in {"a": 0, "b": 0} +assert "c" not in {"a": 0, "b": 0} + +# test iter +assert 3 in iter([1, 2, 3]) +assert 3 not in iter([1, 2]) + +# test sequence +# TODO: uncomment this when ranges are usable +# assert 1 in range(0, 2) +# assert 3 not in range(0, 2) + +# test __contains__ in user objects +class MyNotContainingClass(): + pass + + +try: + 1 in MyNotContainingClass() +except TypeError: + pass +else: + assert False, "TypeError not raised" + + +class MyContainingClass(): + def __init__(self, value): + self.value = value + + def __contains__(self, something): + return something == self.value + + +assert 2 in MyContainingClass(2) +assert 1 not in MyContainingClass(2) diff --git a/vm/src/obj/objdict.rs b/vm/src/obj/objdict.rs index 6729278a5..910105b57 100644 --- a/vm/src/obj/objdict.rs +++ b/vm/src/obj/objdict.rs @@ -63,6 +63,26 @@ fn dict_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.new_str(s)) } +pub fn dict_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [ + (dict, Some(vm.ctx.dict_type())), + (needle, Some(vm.ctx.str_type())) + ] + ); + + let needle = objstr::get_value(&needle); + for element in get_elements(dict).iter() { + if &needle == element.0 { + return Ok(vm.new_bool(true)); + } + } + + Ok(vm.new_bool(false)) +} + pub fn create_type(type_type: PyObjectRef, object_type: PyObjectRef, dict_type: PyObjectRef) { (*dict_type.borrow_mut()).kind = PyObjectKind::Class { name: String::from("dict"), @@ -75,6 +95,7 @@ pub fn create_type(type_type: PyObjectRef, object_type: PyObjectRef, dict_type: pub fn init(context: &PyContext) { let ref dict_type = context.dict_type; dict_type.set_attr("__len__", context.new_rustfunc(dict_len)); + dict_type.set_attr("__contains__", context.new_rustfunc(dict_contains)); dict_type.set_attr("__new__", context.new_rustfunc(dict_new)); dict_type.set_attr("__repr__", context.new_rustfunc(dict_repr)); } diff --git a/vm/src/obj/objiter.rs b/vm/src/obj/objiter.rs index 09afd8ce3..90fff2d1d 100644 --- a/vm/src/obj/objiter.rs +++ b/vm/src/obj/objiter.rs @@ -53,6 +53,26 @@ fn iter_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(iter.clone()) } +fn iter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(iter, Some(vm.ctx.iter_type())), (needle, None)] + ); + loop { + match vm.call_method(&iter, "__next__", vec![]) { + Ok(element) => { + if &element == needle { + return Ok(vm.new_bool(true)); + } else { + continue; + } + } + Err(_) => return Ok(vm.new_bool(false)), + } + } +} + fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(iter, Some(vm.ctx.iter_type()))]); @@ -86,7 +106,8 @@ fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn init(context: &PyContext) { let ref iter_type = context.iter_type; - iter_type.set_attr("__new__", context.new_rustfunc(iter_new)); + iter_type.set_attr("__contains__", context.new_rustfunc(iter_contains)); iter_type.set_attr("__iter__", context.new_rustfunc(iter_iter)); + iter_type.set_attr("__new__", context.new_rustfunc(iter_new)); iter_type.set_attr("__next__", context.new_rustfunc(iter_next)); } diff --git a/vm/src/obj/objlist.rs b/vm/src/obj/objlist.rs index 605bfdc4a..24831e452 100644 --- a/vm/src/obj/objlist.rs +++ b/vm/src/obj/objlist.rs @@ -162,10 +162,27 @@ fn reverse(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } } +pub fn contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + trace!("list.len called with: {:?}", args); + arg_check!( + vm, + args, + required = [(list, Some(vm.ctx.list_type())), (x, None)] + ); + for element in get_elements(list).iter() { + if x == element { + return Ok(vm.new_bool(true)); + } + } + + Ok(vm.new_bool(false)) +} + pub fn init(context: &PyContext) { let ref list_type = context.list_type; - list_type.set_attr("__eq__", context.new_rustfunc(list_eq)); list_type.set_attr("__add__", context.new_rustfunc(list_add)); + list_type.set_attr("__contains__", context.new_rustfunc(contains)); + list_type.set_attr("__eq__", context.new_rustfunc(list_eq)); list_type.set_attr("__len__", context.new_rustfunc(list_len)); list_type.set_attr("__new__", context.new_rustfunc(list_new)); list_type.set_attr("__repr__", context.new_rustfunc(list_repr)); diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 5725e5b4d..009a0d9c5 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -8,8 +8,9 @@ use super::objtype; pub fn init(context: &PyContext) { let ref str_type = context.str_type; - str_type.set_attr("__eq__", context.new_rustfunc(str_eq)); str_type.set_attr("__add__", context.new_rustfunc(str_add)); + str_type.set_attr("__eq__", context.new_rustfunc(str_eq)); + str_type.set_attr("__contains__", context.new_rustfunc(str_contains)); str_type.set_attr("__len__", context.new_rustfunc(str_len)); str_type.set_attr("__mul__", context.new_rustfunc(str_mul)); str_type.set_attr("__new__", context.new_rustfunc(str_new)); @@ -197,6 +198,20 @@ fn str_startswith(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_bool(value.starts_with(pat.as_str()))) } +fn str_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [ + (s, Some(vm.ctx.str_type())), + (needle, Some(vm.ctx.str_type())) + ] + ); + let value = get_value(&s); + let needle = get_value(&needle); + Ok(vm.ctx.new_bool(value.contains(needle.as_str()))) +} + // TODO: should with following format // class str(object='') // class str(object=b'', encoding='utf-8', errors='strict') diff --git a/vm/src/obj/objtuple.rs b/vm/src/obj/objtuple.rs index 3643f4f52..adb4e8d47 100644 --- a/vm/src/obj/objtuple.rs +++ b/vm/src/obj/objtuple.rs @@ -50,6 +50,21 @@ fn tuple_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.new_str(s)) } +pub fn tuple_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(tuple, Some(vm.ctx.tuple_type())), (x, None)] + ); + for element in get_elements(tuple).iter() { + if x == element { + return Ok(vm.new_bool(true)); + } + } + + Ok(vm.new_bool(false)) +} + pub fn get_elements(obj: &PyObjectRef) -> Vec { if let PyObjectKind::Tuple { elements } = &obj.borrow().kind { elements.to_vec() @@ -61,6 +76,7 @@ pub fn get_elements(obj: &PyObjectRef) -> Vec { pub fn init(context: &PyContext) { let ref tuple_type = context.tuple_type; tuple_type.set_attr("__eq__", context.new_rustfunc(tuple_eq)); + tuple_type.set_attr("__contains__", context.new_rustfunc(tuple_contains)); tuple_type.set_attr("__len__", context.new_rustfunc(tuple_len)); tuple_type.set_attr("__repr__", context.new_rustfunc(tuple_repr)); } diff --git a/vm/src/vm.rs b/vm/src/vm.rs index af2b3677a..4ef7bddcc 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -580,6 +580,33 @@ impl VirtualMachine { a.get_id() } + // https://docs.python.org/3/reference/expressions.html#membership-test-operations + fn _membership(&mut self, needle: PyObjectRef, haystack: &PyObjectRef) -> PyResult { + self.call_method(&haystack, "__contains__", vec![needle]) + // TODO: implement __iter__ and __getitem__ cases when __contains__ is + // not implemented. + } + + fn _in(&mut self, needle: PyObjectRef, haystack: PyObjectRef) -> PyResult { + match self._membership(needle, &haystack) { + Ok(found) => Ok(found), + Err(_) => Err(self.new_type_error(format!( + "{} has no __contains__ method", + objtype::get_type_name(&haystack.typ()) + ))), + } + } + + fn _not_in(&mut self, needle: PyObjectRef, haystack: PyObjectRef) -> PyResult { + match self._membership(needle, &haystack) { + Ok(found) => Ok(self.ctx.new_bool(!objbool::get_value(&found))), + Err(_) => Err(self.new_type_error(format!( + "{} has no __contains__ method", + objtype::get_type_name(&haystack.typ()) + ))), + } + } + fn _is(&self, a: PyObjectRef, b: PyObjectRef) -> bool { // Pointer equal: a.is(&b) @@ -603,7 +630,8 @@ impl VirtualMachine { &bytecode::ComparisonOperator::GreaterOrEqual => self._ge(a, b), &bytecode::ComparisonOperator::Is => Ok(self.ctx.new_bool(self._is(a, b))), &bytecode::ComparisonOperator::IsNot => self._is_not(a, b), - _ => panic!("NOT IMPL {:?}", op), + &bytecode::ComparisonOperator::In => self._in(a, b), + &bytecode::ComparisonOperator::NotIn => self._not_in(a, b), }; match result { Ok(value) => {