diff --git a/tests/snippets/floats.py b/tests/snippets/floats.py index d4cc279a8..9bc8b7277 100644 --- a/tests/snippets/floats.py +++ b/tests/snippets/floats.py @@ -1 +1,8 @@ 1 + 1.1 + +a = 1.2 +b = 1.3 +c = 1.2 +assert a < b +assert a <= b +assert a <= c diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index 5fa2fdd8e..861536e5e 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -61,6 +61,21 @@ fn float_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_bool(result)) } +fn float_lt(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [ + (zelf, Some(vm.ctx.float_type())), + (other, Some(vm.ctx.float_type())) + ] + ); + let zelf = get_value(zelf); + let other = get_value(other); + let result = zelf < other; + Ok(vm.ctx.new_bool(result)) +} + fn float_le(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, @@ -187,6 +202,7 @@ fn float_pow(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn init(context: &PyContext) { let ref float_type = context.float_type; float_type.set_attr("__eq__", context.new_rustfunc(float_eq)); + float_type.set_attr("__lt__", context.new_rustfunc(float_lt)); float_type.set_attr("__le__", context.new_rustfunc(float_le)); float_type.set_attr("__abs__", context.new_rustfunc(float_abs)); float_type.set_attr("__add__", context.new_rustfunc(float_add));