diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 6b66633ad..99627549f 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -16,7 +16,7 @@ use crate::obj::objcode::PyCodeRef; use crate::obj::objdict::PyDictRef; use crate::obj::objint::{self, PyIntRef}; use crate::obj::objiter; -use crate::obj::objstr::{self, PyString, PyStringRef}; +use crate::obj::objstr::{PyString, PyStringRef}; use crate::obj::objtype::{self, PyClassRef}; #[cfg(feature = "rustpython-compiler")] use rustpython_compiler::compile; @@ -154,8 +154,7 @@ fn builtin_dir(obj: OptionalArg, vm: &VirtualMachine) -> PyResult { Ok(sorted) } -fn builtin_divmod(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(a, None), (b, None)]); +fn builtin_divmod(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult { vm.call_or_reflection( a.clone(), b.clone(), @@ -165,92 +164,63 @@ fn builtin_divmod(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { ) } +#[cfg(feature = "rustpython-compiler")] +#[derive(FromArgs)] +struct ScopeArgs { + #[pyarg(positional_or_keyword, default = "None")] + globals: Option, + // TODO: support any mapping for `locals` + #[pyarg(positional_or_keyword, default = "None")] + locals: Option, +} + /// Implements `eval`. /// See also: https://docs.python.org/3/library/functions.html#eval #[cfg(feature = "rustpython-compiler")] -fn builtin_eval(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - // TODO: support any mapping for `locals` - arg_check!( - vm, - args, - required = [(source, None)], - optional = [(globals, None), (locals, Some(vm.ctx.dict_type()))] - ); - - let scope = make_scope(vm, globals, locals)?; - - // Determine code object: - let code_obj = if let Ok(code_obj) = PyCodeRef::try_from_object(vm, source.clone()) { - code_obj - } else if objtype::isinstance(source, &vm.ctx.str_type()) { - let mode = compile::Mode::Eval; - let source = objstr::get_value(source); - vm.compile(&source, mode, "".to_string()) - .map_err(|err| vm.new_syntax_error(&err))? - } else { - return Err(vm.new_type_error("code argument must be str or code object".to_string())); - }; - - // Run the source: - vm.run_code_obj(code_obj, scope) +fn builtin_eval( + source: Either, + scope: ScopeArgs, + vm: &VirtualMachine, +) -> PyResult { + run_code(vm, source, scope, compile::Mode::Eval) } /// Implements `exec` /// https://docs.python.org/3/library/functions.html#exec #[cfg(feature = "rustpython-compiler")] -fn builtin_exec(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(source, None)], - optional = [(globals, None), (locals, None)] - ); +fn builtin_exec( + source: Either, + scope: ScopeArgs, + vm: &VirtualMachine, +) -> PyResult { + run_code(vm, source, scope, compile::Mode::Exec) +} - let scope = make_scope(vm, globals, locals)?; +fn run_code( + vm: &VirtualMachine, + source: Either, + scope: ScopeArgs, + mode: compile::Mode, +) -> PyResult { + let scope = make_scope(vm, scope)?; // Determine code object: - let code_obj = if objtype::isinstance(source, &vm.ctx.str_type()) { - let mode = compile::Mode::Exec; - let source = objstr::get_value(source); - vm.compile(&source, mode, "".to_string()) - .map_err(|err| vm.new_syntax_error(&err))? - } else if let Ok(code_obj) = PyCodeRef::try_from_object(vm, source.clone()) { - code_obj - } else { - return Err(vm.new_type_error("source argument must be str or code object".to_string())); + let code_obj = match source { + Either::A(string) => vm + .compile(string.as_str(), mode, "".to_string()) + .map_err(|err| vm.new_syntax_error(&err))?, + Either::B(code_obj) => code_obj, }; // Run the code: vm.run_code_obj(code_obj, scope) } -fn make_scope( - vm: &VirtualMachine, - globals: Option<&PyObjectRef>, - locals: Option<&PyObjectRef>, -) -> PyResult { - let dict_type = vm.ctx.dict_type(); - let globals = match globals { - Some(arg) => { - if arg.is(&vm.get_none()) { - None - } else if vm.isinstance(arg, &dict_type)? { - Some(arg) - } else { - let arg_typ = arg.class(); - let actual_type = vm.to_pystr(&arg_typ)?; - let expected_type_name = vm.to_pystr(&dict_type)?; - return Err(vm.new_type_error(format!( - "globals must be a {}, not {}", - expected_type_name, actual_type - ))); - } - } - None => None, - }; +fn make_scope(vm: &VirtualMachine, scope: ScopeArgs) -> PyResult { + let globals = scope.globals; let current_scope = vm.current_scope(); - let locals = match locals { - Some(dict) => dict.clone().downcast().ok(), + let locals = match scope.locals { + Some(dict) => Some(dict), None => { if globals.is_some() { None @@ -261,7 +231,6 @@ fn make_scope( }; let globals = match globals { Some(dict) => { - let dict: PyDictRef = dict.clone().downcast().unwrap(); if !dict.contains_key("__builtins__", vm) { let builtins_dict = vm.builtins.dict.as_ref().unwrap().as_object(); dict.set_item("__builtins__", builtins_dict.clone(), vm) @@ -332,17 +301,14 @@ fn builtin_hasattr(obj: PyObjectRef, attr: PyStringRef, vm: &VirtualMachine) -> } } -fn builtin_hash(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(obj, None)]); - vm._hash(obj).and_then(|v| Ok(vm.new_int(v))) +fn builtin_hash(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { + vm._hash(&obj).and_then(|v| Ok(vm.new_int(v))) } // builtin_help -fn builtin_hex(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(number, Some(vm.ctx.int_type()))]); - - let n = objint::get_value(number); +fn builtin_hex(number: PyIntRef, vm: &VirtualMachine) -> PyResult { + let n = number.as_bigint(); let s = if n.is_negative() { format!("-0x{:x}", n.abs()) } else { @@ -352,9 +318,7 @@ fn builtin_hex(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.new_str(s)) } -fn builtin_id(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(obj, None)]); - +fn builtin_id(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { Ok(vm.context().new_int(obj.get_id())) } @@ -392,13 +356,11 @@ fn builtin_issubclass( ) } -fn builtin_iter(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(iter_target, None)]); - objiter::get_iter(vm, iter_target) +fn builtin_iter(iter_target: PyObjectRef, vm: &VirtualMachine) -> PyResult { + objiter::get_iter(vm, &iter_target) } -fn builtin_len(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(obj, None)]); +fn builtin_len(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { let method = vm.get_method_or_type_error(obj.clone(), "__len__", || { format!("object of type '{}' has no len()", obj.class().name) })?; @@ -506,21 +468,18 @@ fn builtin_min(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(x) } -fn builtin_next(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(iterator, None)], - optional = [(default_value, None)] - ); - - match vm.call_method(iterator, "__next__", vec![]) { +fn builtin_next( + iterator: PyObjectRef, + default_value: OptionalArg, + vm: &VirtualMachine, +) -> PyResult { + match vm.call_method(&iterator, "__next__", vec![]) { Ok(value) => Ok(value), Err(value) => { if objtype::isinstance(&value, &vm.ctx.exceptions.stop_iteration) { match default_value { - None => Err(value), - Some(value) => Ok(value.clone()), + OptionalArg::Missing => Err(value), + OptionalArg::Present(value) => Ok(value.clone()), } } else { Err(value) @@ -529,10 +488,8 @@ fn builtin_next(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { } } -fn builtin_oct(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(number, Some(vm.ctx.int_type()))]); - - let n = objint::get_value(number); +fn builtin_oct(number: PyIntRef, vm: &VirtualMachine) -> PyResult { + let n = number.as_bigint(); let s = if n.is_negative() { format!("-0o{:o}", n.abs()) } else { @@ -542,75 +499,69 @@ fn builtin_oct(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.new_str(s)) } -fn builtin_ord(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(string, None)]); - if objtype::isinstance(string, &vm.ctx.str_type()) { - let string = objstr::borrow_value(string); - let string_len = string.chars().count(); - if string_len != 1 { - return Err(vm.new_type_error(format!( - "ord() expected a character, but string of length {} found", - string_len - ))); +fn builtin_ord(string: Either, vm: &VirtualMachine) -> PyResult { + match string { + Either::A(bytes) => { + let bytes_len = bytes.elements.len(); + if bytes_len != 1 { + return Err(vm.new_type_error(format!( + "ord() expected a character, but string of length {} found", + bytes_len + ))); + } + Ok(vm.context().new_int(bytes.elements[0])) } - match string.chars().next() { - Some(character) => Ok(vm.context().new_int(character as i32)), - None => Err(vm.new_type_error( - "ord() could not guess the integer representing this character".to_string(), - )), + Either::B(string) => { + let string = string.as_str(); + let string_len = string.chars().count(); + if string_len != 1 { + return Err(vm.new_type_error(format!( + "ord() expected a character, but string of length {} found", + string_len + ))); + } + match string.chars().next() { + Some(character) => Ok(vm.context().new_int(character as i32)), + None => Err(vm.new_type_error( + "ord() could not guess the integer representing this character".to_string(), + )), + } } - } else if objtype::isinstance(string, &vm.ctx.bytearray_type()) - || objtype::isinstance(string, &vm.ctx.bytes_type()) - { - let inner = PyByteInner::try_from_object(vm, string.clone()).unwrap(); - let bytes_len = inner.elements.len(); - if bytes_len != 1 { - return Err(vm.new_type_error(format!( - "ord() expected a character, but string of length {} found", - bytes_len - ))); - } - Ok(vm.context().new_int(inner.elements[0])) - } else { - Err(vm.new_type_error(format!( - "ord() expected a string, bytes or bytearray, but found {}", - string.class().name - ))) } } -fn builtin_pow(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(x, None), (y, None)], - optional = [(mod_value, Some(vm.ctx.int_type()))] - ); - +fn builtin_pow( + x: PyObjectRef, + y: PyObjectRef, + mod_value: OptionalArg, + vm: &VirtualMachine, +) -> PyResult { match mod_value { - None => vm.call_or_reflection(x.clone(), y.clone(), "__pow__", "__rpow__", |vm, x, y| { - Err(vm.new_unsupported_operand_error(x, y, "pow")) - }), - Some(m) => { + OptionalArg::Missing => { + vm.call_or_reflection(x.clone(), y.clone(), "__pow__", "__rpow__", |vm, x, y| { + Err(vm.new_unsupported_operand_error(x, y, "pow")) + }) + } + OptionalArg::Present(m) => { // Check if the 3rd argument is defined and perform modulus on the result - if !(objtype::isinstance(x, &vm.ctx.int_type()) - && objtype::isinstance(y, &vm.ctx.int_type())) + if !(objtype::isinstance(&x, &vm.ctx.int_type()) + && objtype::isinstance(&y, &vm.ctx.int_type())) { return Err(vm.new_type_error( "pow() 3rd argument not allowed unless all arguments are integers".to_string(), )); } - let y = objint::get_value(y); + let y = objint::get_value(&y); if y.sign() == Sign::Minus { return Err(vm.new_value_error( "pow() 2nd argument cannot be negative when 3rd argument specified".to_string(), )); } - let m = objint::get_value(m); + let m = m.as_bigint(); if m.is_zero() { return Err(vm.new_value_error("pow() 3rd argument cannot be 0".to_string())); } - let x = objint::get_value(x); + let x = objint::get_value(&x); Ok(vm.new_int(x.modpow(&y, &m))) } } @@ -714,9 +665,7 @@ fn builtin_repr(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult vm.to_repr(&obj) } -fn builtin_reversed(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(obj, None)]); - +fn builtin_reversed(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { if let Some(reversed_method) = vm.get_method(obj.clone(), "__reversed__") { vm.invoke(&reversed_method?, PyFuncArgs::default()) } else { @@ -733,31 +682,28 @@ fn builtin_reversed(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { } } -fn builtin_round(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(number, Some(vm.ctx.object()))], - optional = [(ndigits, None)] - ); - if let Some(ndigits) = ndigits { - if objtype::isinstance(ndigits, &vm.ctx.int_type()) { - let ndigits = vm.call_method(ndigits, "__int__", vec![])?; - let rounded = vm.call_method(number, "__round__", vec![ndigits])?; - Ok(rounded) - } else if vm.ctx.none().is(ndigits) { - let rounded = &vm.call_method(number, "__round__", vec![])?; +fn builtin_round( + number: PyObjectRef, + ndigits: OptionalArg>, + vm: &VirtualMachine, +) -> PyResult { + match ndigits { + OptionalArg::Present(ndigits) => match ndigits { + Some(int) => { + let ndigits = vm.call_method(int.as_object(), "__int__", vec![])?; + let rounded = vm.call_method(&number, "__round__", vec![ndigits])?; + Ok(rounded) + } + None => { + let rounded = &vm.call_method(&number, "__round__", vec![])?; + Ok(vm.ctx.new_int(objint::get_value(rounded).clone())) + } + }, + OptionalArg::Missing => { + // without a parameter, the result type is coerced to int + let rounded = &vm.call_method(&number, "__round__", vec![])?; Ok(vm.ctx.new_int(objint::get_value(rounded).clone())) - } else { - Err(vm.new_type_error(format!( - "'{}' object cannot be interpreted as an integer", - ndigits.class().name - ))) } - } else { - // without a parameter, the result type is coerced to int - let rounded = &vm.call_method(number, "__round__", vec![])?; - Ok(vm.ctx.new_int(objint::get_value(rounded).clone())) } }