diff --git a/tests/snippets/set.py b/tests/snippets/set.py new file mode 100644 index 000000000..8b31c7c23 --- /dev/null +++ b/tests/snippets/set.py @@ -0,0 +1,26 @@ +assert set([1,2]) == set([1,2]) +assert not set([1,2,3]) == set([1,2]) + +assert set([1,2,3]) >= set([1,2]) +assert set([1,2]) >= set([1,2]) +assert not set([1,3]) >= set([1,2]) + +assert set([1,2,3]).issuperset(set([1,2])) +assert set([1,2]).issuperset(set([1,2])) +assert not set([1,3]).issuperset(set([1,2])) + +assert set([1,2,3]) > set([1,2]) +assert not set([1,2]) > set([1,2]) +assert not set([1,3]) > set([1,2]) + +assert set([1,2]) <= set([1,2,3]) +assert set([1,2]) <= set([1,2]) +assert not set([1,3]) <= set([1,2]) + +assert set([1,2]).issubset(set([1,2,3])) +assert set([1,2]).issubset(set([1,2])) +assert not set([1,3]).issubset(set([1,2])) + +assert set([1,2]) < set([1,2,3]) +assert not set([1,2]) < set([1,2]) +assert not set([1,3]) < set([1,2]) diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index 5ac7507d6..2648fb4f3 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -126,6 +126,99 @@ pub fn set_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.new_bool(false)) } +fn set_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + return set_compare_inner( + vm, + args, + &|zelf: usize, other: usize| -> bool { zelf != other }, + false, + ); +} + +fn set_ge(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + return set_compare_inner( + vm, + args, + &|zelf: usize, other: usize| -> bool { zelf < other }, + false, + ); +} + +fn set_gt(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + return set_compare_inner( + vm, + args, + &|zelf: usize, other: usize| -> bool { zelf <= other }, + false, + ); +} + +fn set_le(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + return set_compare_inner( + vm, + args, + &|zelf: usize, other: usize| -> bool { zelf < other }, + true, + ); +} + +fn set_lt(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + return set_compare_inner( + vm, + args, + &|zelf: usize, other: usize| -> bool { zelf <= other }, + true, + ); +} + +fn set_compare_inner( + vm: &mut VirtualMachine, + args: PyFuncArgs, + size_func: &Fn(usize, usize) -> bool, + swap: bool, +) -> PyResult { + arg_check!( + vm, + args, + required = [ + (zelf, Some(vm.ctx.set_type())), + (other, Some(vm.ctx.set_type())) + ] + ); + + let get_zelf = |swap: bool| -> &PyObjectRef { + if swap { + other + } else { + zelf + } + }; + let get_other = |swap: bool| -> &PyObjectRef { + if swap { + zelf + } else { + other + } + }; + + let zelf_elements = get_elements(get_zelf(swap)); + let other_elements = get_elements(get_other(swap)); + if size_func(zelf_elements.len(), other_elements.len()) { + return Ok(vm.new_bool(false)); + } + for element in other_elements.iter() { + match vm.call_method(get_zelf(swap), "__contains__", vec![element.1.clone()]) { + Ok(value) => { + if !objbool::get_value(&value) { + return Ok(vm.new_bool(false)); + } + } + Err(_) => return Err(vm.new_type_error("".to_string())), + } + } + Ok(vm.new_bool(true)) +} + fn frozenset_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(o, Some(vm.ctx.frozenset_type()))]); @@ -159,6 +252,13 @@ pub fn init(context: &PyContext) { context.set_attr(&set_type, "__len__", context.new_rustfunc(set_len)); context.set_attr(&set_type, "__new__", context.new_rustfunc(set_new)); context.set_attr(&set_type, "__repr__", context.new_rustfunc(set_repr)); + context.set_attr(&set_type, "__eq__", context.new_rustfunc(set_eq)); + context.set_attr(&set_type, "__ge__", context.new_rustfunc(set_ge)); + context.set_attr(&set_type, "__gt__", context.new_rustfunc(set_gt)); + context.set_attr(&set_type, "__le__", context.new_rustfunc(set_le)); + context.set_attr(&set_type, "__lt__", context.new_rustfunc(set_lt)); + context.set_attr(&set_type, "issubset", context.new_rustfunc(set_le)); + context.set_attr(&set_type, "issuperset", context.new_rustfunc(set_ge)); context.set_attr(&set_type, "__doc__", context.new_str(set_doc.to_string())); context.set_attr(&set_type, "add", context.new_rustfunc(set_add));