diff --git a/tests/snippets/builtin_format.py b/tests/snippets/builtin_format.py index 6c06cbd98..55a6a3da1 100644 --- a/tests/snippets/builtin_format.py +++ b/tests/snippets/builtin_format.py @@ -7,3 +7,8 @@ assert_raises(TypeError, lambda: format(2, 3), 'format called with number') assert format({}) == "{}" assert_raises(TypeError, lambda: format({}, 'b'), 'format_spec not empty for dict') + +class BadFormat: + def __format__(self, spec): + return 42 +assert_raises(TypeError, lambda: format(BadFormat())) diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 1de24b86a..9d5980759 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -6,21 +6,22 @@ use std::char; use std::io::{self, Write}; use std::path::PathBuf; -use num_traits::{Signed, ToPrimitive}; +use num_traits::Signed; use crate::compile; use crate::import::import_module; use crate::obj::objbool; use crate::obj::objdict::PyDictRef; -use crate::obj::objint; +use crate::obj::objint::{self, PyIntRef}; use crate::obj::objiter; -use crate::obj::objstr::{self, PyStringRef}; +use crate::obj::objstr::{self, PyString, PyStringRef}; use crate::obj::objtype::{self, PyClassRef}; use crate::frame::Scope; use crate::function::{Args, OptionalArg, PyFuncArgs}; use crate::pyobject::{ - DictProtocol, IdProtocol, PyContext, PyObjectRef, PyResult, TryFromObject, TypeProtocol, + DictProtocol, IdProtocol, PyContext, PyIterable, PyObjectRef, PyResult, PyValue, TryFromObject, + TypeProtocol, }; use crate::vm::VirtualMachine; @@ -28,74 +29,53 @@ use crate::obj::objcode::PyCodeRef; #[cfg(not(target_arch = "wasm32"))] use crate::stdlib::io::io_open; -fn builtin_abs(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(x, None)]); +fn builtin_abs(x: PyObjectRef, vm: &VirtualMachine) -> PyResult { match vm.get_method(x.clone(), "__abs__") { Ok(attrib) => vm.invoke(attrib, PyFuncArgs::new(vec![], vec![])), Err(..) => Err(vm.new_type_error("bad operand for abs".to_string())), } } -fn builtin_all(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(iterable, None)]); - let items = vm.extract_elements(iterable)?; - for item in items { - let result = objbool::boolval(vm, item)?; - if !result { - return Ok(vm.new_bool(false)); +fn builtin_all(iterable: PyIterable, vm: &VirtualMachine) -> PyResult { + for item in iterable.iter(vm)? { + if !item? { + return Ok(false); } } - Ok(vm.new_bool(true)) + Ok(true) } -fn builtin_any(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(iterable, None)]); - let iterator = objiter::get_iter(vm, iterable)?; - - while let Some(item) = objiter::get_next_object(vm, &iterator)? { - let result = objbool::boolval(vm, item)?; - if result { - return Ok(vm.new_bool(true)); +fn builtin_any(iterable: PyIterable, vm: &VirtualMachine) -> PyResult { + for item in iterable.iter(vm)? { + if item? { + return Ok(true); } } - - Ok(vm.new_bool(false)) + Ok(false) } // builtin_ascii -fn builtin_bin(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(number, Some(vm.ctx.int_type()))]); - - let n = objint::get_value(number); - let s = if n.is_negative() { - format!("-0b{:b}", n.abs()) +fn builtin_bin(x: PyIntRef, _vm: &VirtualMachine) -> String { + let x = x.as_bigint(); + if x.is_negative() { + format!("-0b{:b}", x.abs()) } else { - format!("0b{:b}", n) - }; - - Ok(vm.new_str(s)) + format!("0b{:b}", x) + } } // builtin_breakpoint -fn builtin_callable(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(obj, None)]); - let is_callable = objtype::class_has_attr(&obj.class(), "__call__"); - Ok(vm.new_bool(is_callable)) +fn builtin_callable(obj: PyObjectRef, _vm: &VirtualMachine) -> bool { + objtype::class_has_attr(&obj.class(), "__call__") } -fn builtin_chr(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(i, Some(vm.ctx.int_type()))]); - - let code_point = objint::get_value(i).to_u32().unwrap(); - - let txt = match char::from_u32(code_point) { +fn builtin_chr(i: u32, _vm: &VirtualMachine) -> String { + match char::from_u32(i) { Some(value) => value.to_string(), None => '_'.to_string(), - }; - - Ok(vm.new_str(txt)) + } } fn builtin_compile( @@ -128,13 +108,8 @@ fn builtin_compile( }) } -fn builtin_delattr(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(obj, None), (attr, Some(vm.ctx.str_type()))] - ); - vm.del_attr(obj, attr.clone()) +fn builtin_delattr(obj: PyObjectRef, attr: PyStringRef, vm: &VirtualMachine) -> PyResult<()> { + vm.del_attr(&obj, attr.into_object()) } fn builtin_dir(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -258,17 +233,26 @@ fn make_scope( Ok(Scope::new(locals, globals)) } -fn builtin_format(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(obj, None)], - optional = [(format_spec, Some(vm.ctx.str_type()))] - ); - let format_spec = format_spec - .cloned() - .unwrap_or_else(|| vm.new_str("".to_string())); - vm.call_method(obj, "__format__", vec![format_spec]) +fn builtin_format( + value: PyObjectRef, + format_spec: OptionalArg, + vm: &VirtualMachine, +) -> PyResult { + let format_spec = format_spec.into_option().unwrap_or_else(|| { + PyString { + value: "".to_string(), + } + .into_ref(vm) + }); + + vm.call_method(&value, "__format__", vec![format_spec.into_object()])? + .downcast() + .map_err(|obj| { + vm.new_type_error(format!( + "__format__ must return a str, not {}", + obj.class().name + )) + }) } fn catch_attr_exception(ex: PyObjectRef, default: T, vm: &VirtualMachine) -> PyResult { @@ -644,14 +628,11 @@ fn builtin_sorted(vm: &VirtualMachine, mut args: PyFuncArgs) -> PyResult { Ok(lst) } -fn builtin_sum(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(iterable, None)]); - let items = vm.extract_elements(iterable)?; - +fn builtin_sum(iterable: PyIterable, start: OptionalArg, vm: &VirtualMachine) -> PyResult { // Start with zero and add at will: - let mut sum = vm.ctx.new_int(0); - for item in items { - sum = vm._add(sum, item)?; + let mut sum = start.into_option().unwrap_or_else(|| vm.ctx.new_int(0)); + for item in iterable.iter(vm)? { + sum = vm._add(sum, item?)?; } Ok(sum) } diff --git a/vm/src/function.rs b/vm/src/function.rs index 1c61ffc87..a567e7903 100644 --- a/vm/src/function.rs +++ b/vm/src/function.rs @@ -308,7 +308,7 @@ where /// An argument that may or may not be provided by the caller. /// /// This style of argument is not possible in pure Python. -pub enum OptionalArg { +pub enum OptionalArg { Present(T), Missing, } diff --git a/vm/src/vm.rs b/vm/src/vm.rs index 4148184f4..36f4450e6 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -577,8 +577,9 @@ impl VirtualMachine { self.call_method(&obj, "__setattr__", vec![attr_name, attr_value]) } - pub fn del_attr(&self, obj: &PyObjectRef, attr_name: PyObjectRef) -> PyResult { - self.call_method(&obj, "__delattr__", vec![attr_name]) + pub fn del_attr(&self, obj: &PyObjectRef, attr_name: PyObjectRef) -> PyResult<()> { + self.call_method(&obj, "__delattr__", vec![attr_name])?; + Ok(()) } // get_method should be used for internal access to magic methods (by-passing