diff --git a/tests/snippets/bools.py b/tests/snippets/bools.py index 0c277143b..23f22dce7 100644 --- a/tests/snippets/bools.py +++ b/tests/snippets/bools.py @@ -51,3 +51,29 @@ assert True > 0 assert int(True) == 1 assert True.conjugate() == 1 assert isinstance(True.conjugate(), int) + +# Boolean operations on pairs of Bools should return Bools, not ints +assert (False | True) is True +assert (False & True) is False +assert (False ^ True) is True +# But only if both are Bools +assert (False | 1) is not True +assert (0 | True) is not True +assert (False & 1) is not False +assert (0 & True) is not False +assert (False ^ 1) is not True +assert (0 ^ True) is not True + +# Check that the same works with __XXX__ methods +assert False.__or__(0) is not False +assert False.__or__(False) is False +assert False.__ror__(0) is not False +assert False.__ror__(False) is False +assert False.__and__(0) is not False +assert False.__and__(False) is False +assert False.__rand__(0) is not False +assert False.__rand__(False) is False +assert False.__xor__(0) is not False +assert False.__xor__(False) is False +assert False.__rxor__(0) is not False +assert False.__rxor__(False) is False diff --git a/vm/src/obj/objbool.rs b/vm/src/obj/objbool.rs index 724c8bc5a..c1a999e45 100644 --- a/vm/src/obj/objbool.rs +++ b/vm/src/obj/objbool.rs @@ -42,6 +42,12 @@ The class bool is a subclass of the class int, and cannot be subclassed."; extend_class!(context, bool_type, { "__new__" => context.new_rustfunc(bool_new), "__repr__" => context.new_rustfunc(bool_repr), + "__or__" => context.new_rustfunc(bool_or), + "__ror__" => context.new_rustfunc(bool_ror), + "__and__" => context.new_rustfunc(bool_and), + "__rand__" => context.new_rustfunc(bool_rand), + "__xor__" => context.new_rustfunc(bool_xor), + "__rxor__" => context.new_rustfunc(bool_rxor), "__doc__" => context.new_str(bool_doc.to_string()) }); } @@ -71,6 +77,72 @@ fn bool_repr(vm: &VirtualMachine, args: PyFuncArgs) -> Result PyResult { + if objtype::isinstance(lhs, &vm.ctx.bool_type()) + && objtype::isinstance(rhs, &vm.ctx.bool_type()) + { + let lhs = get_value(lhs); + let rhs = get_value(rhs); + (lhs || rhs).into_pyobject(vm) + } else { + Ok(lhs.payload::().unwrap().or(rhs.clone(), vm)) + } +} + +fn bool_or(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(lhs, None), (rhs, None)]); + do_bool_or(vm, lhs, rhs) +} + +fn bool_ror(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(rhs, None), (lhs, None)]); + do_bool_or(vm, lhs, rhs) +} + +fn do_bool_and(vm: &VirtualMachine, lhs: &PyObjectRef, rhs: &PyObjectRef) -> PyResult { + if objtype::isinstance(lhs, &vm.ctx.bool_type()) + && objtype::isinstance(rhs, &vm.ctx.bool_type()) + { + let lhs = get_value(lhs); + let rhs = get_value(rhs); + (lhs && rhs).into_pyobject(vm) + } else { + Ok(lhs.payload::().unwrap().and(rhs.clone(), vm)) + } +} + +fn bool_and(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(lhs, None), (rhs, None)]); + do_bool_and(vm, lhs, rhs) +} + +fn bool_rand(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(rhs, None), (lhs, None)]); + do_bool_and(vm, lhs, rhs) +} + +fn do_bool_xor(vm: &VirtualMachine, lhs: &PyObjectRef, rhs: &PyObjectRef) -> PyResult { + if objtype::isinstance(lhs, &vm.ctx.bool_type()) + && objtype::isinstance(rhs, &vm.ctx.bool_type()) + { + let lhs = get_value(lhs); + let rhs = get_value(rhs); + (lhs ^ rhs).into_pyobject(vm) + } else { + Ok(lhs.payload::().unwrap().xor(rhs.clone(), vm)) + } +} + +fn bool_xor(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(lhs, None), (rhs, None)]); + do_bool_xor(vm, lhs, rhs) +} + +fn bool_rxor(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(rhs, None), (lhs, None)]); + do_bool_xor(vm, lhs, rhs) +} + fn bool_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index 88c91f89f..7effb0897 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -309,7 +309,7 @@ impl PyInt { } #[pymethod(name = "__xor__")] - fn xor(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + pub fn xor(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { if objtype::isinstance(&other, &vm.ctx.int_type()) { vm.ctx.new_int((&self.value) ^ get_value(&other)) } else { @@ -319,15 +319,11 @@ impl PyInt { #[pymethod(name = "__rxor__")] fn rxor(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { - if objtype::isinstance(&other, &vm.ctx.int_type()) { - vm.ctx.new_int(get_value(&other) ^ (&self.value)) - } else { - vm.ctx.not_implemented() - } + self.xor(other, vm) } #[pymethod(name = "__or__")] - fn or(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + pub fn or(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { if objtype::isinstance(&other, &vm.ctx.int_type()) { vm.ctx.new_int((&self.value) | get_value(&other)) } else { @@ -336,7 +332,7 @@ impl PyInt { } #[pymethod(name = "__and__")] - fn and(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { + pub fn and(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { if objtype::isinstance(&other, &vm.ctx.int_type()) { let v2 = get_value(&other); vm.ctx.new_int((&self.value) & v2)