diff --git a/tests/snippets/set.py b/tests/snippets/set.py index 0a4552444..efc070f54 100644 --- a/tests/snippets/set.py +++ b/tests/snippets/set.py @@ -25,7 +25,6 @@ assert set([1,2]) < set([1,2,3]) assert not set([1,2]) < set([1,2]) assert not set([1,3]) < set([1,2]) - class Hashable(object): def __init__(self, obj): self.obj = obj @@ -46,6 +45,30 @@ assert len(a) == 3 a.clear() assert len(a) == 0 +assert set([1,2,3]).union(set([4,5])) == set([1,2,3,4,5]) +assert set([1,2,3]).union(set([1,2,3,4,5])) == set([1,2,3,4,5]) + +assert set([1,2,3]) | set([4,5]) == set([1,2,3,4,5]) +assert set([1,2,3]) | set([1,2,3,4,5]) == set([1,2,3,4,5]) + +assert set([1,2,3]).intersection(set([1,2])) == set([1,2]) +assert set([1,2,3]).intersection(set([5,6])) == set([]) + +assert set([1,2,3]) & set([4,5]) == set([]) +assert set([1,2,3]) & set([1,2,3,4,5]) == set([1,2,3]) + +assert set([1,2,3]).difference(set([1,2])) == set([3]) +assert set([1,2,3]).difference(set([5,6])) == set([1,2,3]) + +assert set([1,2,3]) - set([4,5]) == set([1,2,3]) +assert set([1,2,3]) - set([1,2,3,4,5]) == set([]) + +assert set([1,2,3]).symmetric_difference(set([1,2])) == set([3]) +assert set([1,2,3]).symmetric_difference(set([5,6])) == set([1,2,3,5,6]) + +assert set([1,2,3]) ^ set([4,5]) == set([1,2,3,4,5]) +assert set([1,2,3]) ^ set([1,2,3,4,5]) == set([4,5]) + try: set([[]]) except TypeError: @@ -59,3 +82,4 @@ except TypeError: pass else: assert False, "TypeError was not raised" + diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index f7a623fe5..537378379 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -285,6 +285,103 @@ fn set_compare_inner( Ok(vm.new_bool(true)) } +fn set_union(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [ + (zelf, Some(vm.ctx.set_type())), + (other, Some(vm.ctx.set_type())) + ] + ); + + let mut elements = get_elements(zelf).clone(); + elements.extend(get_elements(other).clone()); + + Ok(PyObject::new( + PyObjectPayload::Set { elements }, + vm.ctx.set_type(), + )) +} + +fn set_intersection(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + set_combine_inner(vm, args, SetCombineOperation::Intersection) +} + +fn set_difference(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + set_combine_inner(vm, args, SetCombineOperation::Difference) +} + +fn set_symmetric_difference(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [ + (zelf, Some(vm.ctx.set_type())), + (other, Some(vm.ctx.set_type())) + ] + ); + + let mut elements = HashMap::new(); + + for element in get_elements(zelf).iter() { + let value = vm.call_method(other, "__contains__", vec![element.1.clone()])?; + if !objbool::get_value(&value) { + elements.insert(element.0.clone(), element.1.clone()); + } + } + + for element in get_elements(other).iter() { + let value = vm.call_method(zelf, "__contains__", vec![element.1.clone()])?; + if !objbool::get_value(&value) { + elements.insert(element.0.clone(), element.1.clone()); + } + } + + Ok(PyObject::new( + PyObjectPayload::Set { elements }, + vm.ctx.set_type(), + )) +} + +enum SetCombineOperation { + Intersection, + Difference, +} + +fn set_combine_inner( + vm: &mut VirtualMachine, + args: PyFuncArgs, + op: SetCombineOperation, +) -> PyResult { + arg_check!( + vm, + args, + required = [ + (zelf, Some(vm.ctx.set_type())), + (other, Some(vm.ctx.set_type())) + ] + ); + + let mut elements = HashMap::new(); + + for element in get_elements(zelf).iter() { + let value = vm.call_method(other, "__contains__", vec![element.1.clone()])?; + let should_add = match op { + SetCombineOperation::Intersection => objbool::get_value(&value), + SetCombineOperation::Difference => !objbool::get_value(&value), + }; + if should_add { + elements.insert(element.0.clone(), element.1.clone()); + } + } + + Ok(PyObject::new( + PyObjectPayload::Set { elements }, + vm.ctx.set_type(), + )) +} + fn frozenset_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(o, Some(vm.ctx.frozenset_type()))]); @@ -325,6 +422,30 @@ pub fn init(context: &PyContext) { context.set_attr(&set_type, "__lt__", context.new_rustfunc(set_lt)); context.set_attr(&set_type, "issubset", context.new_rustfunc(set_le)); context.set_attr(&set_type, "issuperset", context.new_rustfunc(set_ge)); + context.set_attr(&set_type, "union", context.new_rustfunc(set_union)); + context.set_attr(&set_type, "__or__", context.new_rustfunc(set_union)); + context.set_attr( + &set_type, + "intersection", + context.new_rustfunc(set_intersection), + ); + context.set_attr(&set_type, "__and__", context.new_rustfunc(set_intersection)); + context.set_attr( + &set_type, + "difference", + context.new_rustfunc(set_difference), + ); + context.set_attr(&set_type, "__sub__", context.new_rustfunc(set_difference)); + context.set_attr( + &set_type, + "symmetric_difference", + context.new_rustfunc(set_symmetric_difference), + ); + context.set_attr( + &set_type, + "__xor__", + context.new_rustfunc(set_symmetric_difference), + ); context.set_attr(&set_type, "__doc__", context.new_str(set_doc.to_string())); context.set_attr(&set_type, "add", context.new_rustfunc(set_add)); context.set_attr(&set_type, "remove", context.new_rustfunc(set_remove));