Fix builtin_round with non-int __round__

This commit is contained in:
Jeong YunWon
2019-10-11 00:21:19 +09:00
parent 8bce893118
commit 791e598e3d
2 changed files with 16 additions and 10 deletions

View File

@@ -19,3 +19,13 @@ with assert_raises(TypeError):
round(0, 0.0)
with assert_raises(TypeError):
round(0.0, 0.0)
class X:
def __round__(self, ndigits=None):
return 1.1
assert round(X(), 1) == 1.1
assert round(X(), None) == 1.1
assert round(X()) == 1.1

View File

@@ -674,24 +674,20 @@ fn builtin_round(
ndigits: OptionalArg<Option<PyIntRef>>,
vm: &VirtualMachine,
) -> PyResult {
match ndigits {
let rounded = match ndigits {
OptionalArg::Present(ndigits) => match ndigits {
Some(int) => {
let ndigits = vm.call_method(int.as_object(), "__int__", vec![])?;
let rounded = vm.call_method(&number, "__round__", vec![ndigits])?;
Ok(rounded)
}
None => {
let rounded = &vm.call_method(&number, "__round__", vec![])?;
Ok(vm.ctx.new_int(objint::get_value(rounded).clone()))
vm.call_method(&number, "__round__", vec![ndigits])?
}
None => vm.call_method(&number, "__round__", vec![])?,
},
OptionalArg::Missing => {
// without a parameter, the result type is coerced to int
let rounded = &vm.call_method(&number, "__round__", vec![])?;
Ok(vm.ctx.new_int(objint::get_value(rounded).clone()))
vm.call_method(&number, "__round__", vec![])?
}
}
};
Ok(rounded)
}
fn builtin_setattr(