Merge pull request #475 from palaviv/Add-set-funcs3

Add set.{union,intersection,difference,symmetric_difference}
This commit is contained in:
Windel Bouwman
2019-02-15 17:08:10 +01:00
committed by GitHub
2 changed files with 146 additions and 1 deletions

View File

@@ -25,7 +25,6 @@ assert set([1,2]) < set([1,2,3])
assert not set([1,2]) < set([1,2])
assert not set([1,3]) < set([1,2])
class Hashable(object):
def __init__(self, obj):
self.obj = obj
@@ -46,6 +45,30 @@ assert len(a) == 3
a.clear()
assert len(a) == 0
assert set([1,2,3]).union(set([4,5])) == set([1,2,3,4,5])
assert set([1,2,3]).union(set([1,2,3,4,5])) == set([1,2,3,4,5])
assert set([1,2,3]) | set([4,5]) == set([1,2,3,4,5])
assert set([1,2,3]) | set([1,2,3,4,5]) == set([1,2,3,4,5])
assert set([1,2,3]).intersection(set([1,2])) == set([1,2])
assert set([1,2,3]).intersection(set([5,6])) == set([])
assert set([1,2,3]) & set([4,5]) == set([])
assert set([1,2,3]) & set([1,2,3,4,5]) == set([1,2,3])
assert set([1,2,3]).difference(set([1,2])) == set([3])
assert set([1,2,3]).difference(set([5,6])) == set([1,2,3])
assert set([1,2,3]) - set([4,5]) == set([1,2,3])
assert set([1,2,3]) - set([1,2,3,4,5]) == set([])
assert set([1,2,3]).symmetric_difference(set([1,2])) == set([3])
assert set([1,2,3]).symmetric_difference(set([5,6])) == set([1,2,3,5,6])
assert set([1,2,3]) ^ set([4,5]) == set([1,2,3,4,5])
assert set([1,2,3]) ^ set([1,2,3,4,5]) == set([4,5])
try:
set([[]])
except TypeError:
@@ -59,3 +82,4 @@ except TypeError:
pass
else:
assert False, "TypeError was not raised"

View File

@@ -285,6 +285,103 @@ fn set_compare_inner(
Ok(vm.new_bool(true))
}
fn set_union(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [
(zelf, Some(vm.ctx.set_type())),
(other, Some(vm.ctx.set_type()))
]
);
let mut elements = get_elements(zelf).clone();
elements.extend(get_elements(other).clone());
Ok(PyObject::new(
PyObjectPayload::Set { elements },
vm.ctx.set_type(),
))
}
fn set_intersection(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
set_combine_inner(vm, args, SetCombineOperation::Intersection)
}
fn set_difference(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
set_combine_inner(vm, args, SetCombineOperation::Difference)
}
fn set_symmetric_difference(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [
(zelf, Some(vm.ctx.set_type())),
(other, Some(vm.ctx.set_type()))
]
);
let mut elements = HashMap::new();
for element in get_elements(zelf).iter() {
let value = vm.call_method(other, "__contains__", vec![element.1.clone()])?;
if !objbool::get_value(&value) {
elements.insert(element.0.clone(), element.1.clone());
}
}
for element in get_elements(other).iter() {
let value = vm.call_method(zelf, "__contains__", vec![element.1.clone()])?;
if !objbool::get_value(&value) {
elements.insert(element.0.clone(), element.1.clone());
}
}
Ok(PyObject::new(
PyObjectPayload::Set { elements },
vm.ctx.set_type(),
))
}
enum SetCombineOperation {
Intersection,
Difference,
}
fn set_combine_inner(
vm: &mut VirtualMachine,
args: PyFuncArgs,
op: SetCombineOperation,
) -> PyResult {
arg_check!(
vm,
args,
required = [
(zelf, Some(vm.ctx.set_type())),
(other, Some(vm.ctx.set_type()))
]
);
let mut elements = HashMap::new();
for element in get_elements(zelf).iter() {
let value = vm.call_method(other, "__contains__", vec![element.1.clone()])?;
let should_add = match op {
SetCombineOperation::Intersection => objbool::get_value(&value),
SetCombineOperation::Difference => !objbool::get_value(&value),
};
if should_add {
elements.insert(element.0.clone(), element.1.clone());
}
}
Ok(PyObject::new(
PyObjectPayload::Set { elements },
vm.ctx.set_type(),
))
}
fn frozenset_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(o, Some(vm.ctx.frozenset_type()))]);
@@ -325,6 +422,30 @@ pub fn init(context: &PyContext) {
context.set_attr(&set_type, "__lt__", context.new_rustfunc(set_lt));
context.set_attr(&set_type, "issubset", context.new_rustfunc(set_le));
context.set_attr(&set_type, "issuperset", context.new_rustfunc(set_ge));
context.set_attr(&set_type, "union", context.new_rustfunc(set_union));
context.set_attr(&set_type, "__or__", context.new_rustfunc(set_union));
context.set_attr(
&set_type,
"intersection",
context.new_rustfunc(set_intersection),
);
context.set_attr(&set_type, "__and__", context.new_rustfunc(set_intersection));
context.set_attr(
&set_type,
"difference",
context.new_rustfunc(set_difference),
);
context.set_attr(&set_type, "__sub__", context.new_rustfunc(set_difference));
context.set_attr(
&set_type,
"symmetric_difference",
context.new_rustfunc(set_symmetric_difference),
);
context.set_attr(
&set_type,
"__xor__",
context.new_rustfunc(set_symmetric_difference),
);
context.set_attr(&set_type, "__doc__", context.new_str(set_doc.to_string()));
context.set_attr(&set_type, "add", context.new_rustfunc(set_add));
context.set_attr(&set_type, "remove", context.new_rustfunc(set_remove));