diff --git a/tests/snippets/membership.py b/tests/snippets/membership.py index 3b76b8052..e2c14884d 100644 --- a/tests/snippets/membership.py +++ b/tests/snippets/membership.py @@ -19,9 +19,8 @@ assert 1 in (1, 2) assert 3 not in (1, 2) # test set -# TODO: uncomment this when sets are implemented -# assert 1 in set(1, 2) -# assert 3 not in set(1, 2) +assert 1 in set([1, 2]) +assert 3 not in set([1, 2]) # test dicts # TODO: test dicts when keys other than strings will be allowed diff --git a/vm/src/obj/objiter.rs b/vm/src/obj/objiter.rs index 90fff2d1d..a32293e35 100644 --- a/vm/src/obj/objiter.rs +++ b/vm/src/obj/objiter.rs @@ -7,6 +7,7 @@ use super::super::pyobject::{ TypeProtocol, }; use super::super::vm::VirtualMachine; +use super::objbool; use super::objstr; use super::objtype; // Required for arg_check! to use isinstance @@ -61,13 +62,16 @@ fn iter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { ); loop { match vm.call_method(&iter, "__next__", vec![]) { - Ok(element) => { - if &element == needle { - return Ok(vm.new_bool(true)); - } else { - continue; + Ok(element) => match vm.call_method(needle, "__eq__", vec![element.clone()]) { + Ok(value) => { + if objbool::get_value(&value) { + return Ok(vm.new_bool(true)); + } else { + continue; + } } - } + Err(_) => return Err(vm.new_type_error("".to_string())), + }, Err(_) => return Ok(vm.new_bool(false)), } } diff --git a/vm/src/obj/objlist.rs b/vm/src/obj/objlist.rs index 24831e452..9389baf1f 100644 --- a/vm/src/obj/objlist.rs +++ b/vm/src/obj/objlist.rs @@ -2,6 +2,7 @@ use super::super::pyobject::{ AttributeProtocol, PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, PyResult, TypeProtocol, }; use super::super::vm::VirtualMachine; +use super::objbool; use super::objiter; use super::objsequence::{seq_equal, PySliceableSequence}; use super::objstr; @@ -162,16 +163,21 @@ fn reverse(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } } -pub fn contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - trace!("list.len called with: {:?}", args); +fn list_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + trace!("list.contains called with: {:?}", args); arg_check!( vm, args, - required = [(list, Some(vm.ctx.list_type())), (x, None)] + required = [(list, Some(vm.ctx.list_type())), (needle, None)] ); for element in get_elements(list).iter() { - if x == element { - return Ok(vm.new_bool(true)); + match vm.call_method(needle, "__eq__", vec![element.clone()]) { + Ok(value) => { + if objbool::get_value(&value) { + return Ok(vm.new_bool(true)); + } + } + Err(_) => return Err(vm.new_type_error("".to_string())), } } @@ -181,7 +187,7 @@ pub fn contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn init(context: &PyContext) { let ref list_type = context.list_type; list_type.set_attr("__add__", context.new_rustfunc(list_add)); - list_type.set_attr("__contains__", context.new_rustfunc(contains)); + list_type.set_attr("__contains__", context.new_rustfunc(list_contains)); list_type.set_attr("__eq__", context.new_rustfunc(list_eq)); list_type.set_attr("__len__", context.new_rustfunc(list_len)); list_type.set_attr("__new__", context.new_rustfunc(list_new)); diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index 0496915a7..73c319128 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -7,6 +7,7 @@ use super::super::pyobject::{ PyResult, TypeProtocol, }; use super::super::vm::VirtualMachine; +use super::objbool; use super::objiter; use super::objstr; use super::objtype; @@ -88,8 +89,29 @@ fn set_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.new_str(s)) } +pub fn set_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(set, Some(vm.ctx.set_type())), (needle, None)] + ); + for element in get_elements(set).iter() { + match vm.call_method(needle, "__eq__", vec![element.1.clone()]) { + Ok(value) => { + if objbool::get_value(&value) { + return Ok(vm.new_bool(true)); + } + } + Err(_) => return Err(vm.new_type_error("".to_string())), + } + } + + Ok(vm.new_bool(false)) +} + pub fn init(context: &PyContext) { let ref set_type = context.set_type; + set_type.set_attr("__contains__", context.new_rustfunc(set_contains)); set_type.set_attr("__len__", context.new_rustfunc(set_len)); set_type.set_attr("__new__", context.new_rustfunc(set_new)); set_type.set_attr("__repr__", context.new_rustfunc(set_repr)); diff --git a/vm/src/obj/objtuple.rs b/vm/src/obj/objtuple.rs index adb4e8d47..65143289f 100644 --- a/vm/src/obj/objtuple.rs +++ b/vm/src/obj/objtuple.rs @@ -2,6 +2,7 @@ use super::super::pyobject::{ AttributeProtocol, PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, PyResult, TypeProtocol, }; use super::super::vm::VirtualMachine; +use super::objbool; use super::objsequence::seq_equal; use super::objstr; use super::objtype; @@ -54,11 +55,16 @@ pub fn tuple_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, args, - required = [(tuple, Some(vm.ctx.tuple_type())), (x, None)] + required = [(tuple, Some(vm.ctx.tuple_type())), (needle, None)] ); for element in get_elements(tuple).iter() { - if x == element { - return Ok(vm.new_bool(true)); + match vm.call_method(needle, "__eq__", vec![element.clone()]) { + Ok(value) => { + if objbool::get_value(&value) { + return Ok(vm.new_bool(true)); + } + } + Err(_) => return Err(vm.new_type_error("".to_string())), } }