diff --git a/jit/src/instructions.rs b/jit/src/instructions.rs index 60eb723d9..f1b1d1f95 100644 --- a/jit/src/instructions.rs +++ b/jit/src/instructions.rs @@ -61,10 +61,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { let builder = &mut self.builder; let local = self.variables.entry(name).or_insert_with(|| { let var = Variable::new(len); - let local = Local { - var, - ty: val.ty, - }; + let local = Local { var, ty: val.ty }; builder.declare_var(var, val.ty.to_cranelift()); local }); diff --git a/vm/src/obj/objfunction.rs b/vm/src/obj/objfunction.rs index 48b61ce29..1066149d4 100644 --- a/vm/src/obj/objfunction.rs +++ b/vm/src/obj/objfunction.rs @@ -1,9 +1,8 @@ +#[cfg(feature = "jit")] +mod jitfunc; + use super::objcode::PyCodeRef; use super::objdict::PyDictRef; -#[cfg(feature = "jit")] -use super::objfloat; -#[cfg(feature = "jit")] -use super::objint; use super::objstr::PyStringRef; use super::objtuple::PyTupleRef; use super::objtype::PyClassRef; @@ -13,34 +12,20 @@ use crate::function::{OptionalArg, PyFuncArgs}; use crate::obj::objasyncgenerator::PyAsyncGen; use crate::obj::objcoroutine::PyCoroutine; use crate::obj::objgenerator::PyGenerator; +#[cfg(feature = "jit")] +use crate::pyobject::IntoPyObject; use crate::pyobject::{ BorrowValue, IdProtocol, ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, }; -#[cfg(feature = "jit")] -use crate::pyobject::{IntoPyObject, TryFromObject}; use crate::scope::Scope; use crate::slots::{SlotCall, SlotDescriptor}; use crate::VirtualMachine; use itertools::Itertools; #[cfg(feature = "jit")] -use num_traits::ToPrimitive; -#[cfg(feature = "jit")] -use rustpython_bytecode::bytecode::CodeFlags; -#[cfg(feature = "jit")] use rustpython_common::cell::OnceCell; #[cfg(feature = "jit")] -use rustpython_jit::{AbiValue, Args, CompiledCode, JitType}; - -#[cfg(feature = "jit")] -impl IntoPyObject for AbiValue { - fn into_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { - match self { - AbiValue::Int(i) => i.into_pyobject(vm), - AbiValue::Float(f) => f.into_pyobject(vm), - } - } -} +use rustpython_jit::CompiledCode; pub type PyFunctionRef = PyRef; @@ -253,7 +238,7 @@ impl PyFunction { ) -> PyResult { #[cfg(feature = "jit")] if let Some(jitted_code) = self.jitted_code.get() { - if let Some(args) = self.get_jit_args(&func_args, jitted_code, vm) { + if let Some(args) = jitfunc::get_jit_args(self, &func_args, jitted_code, vm) { return Ok(jitted_code.invoke(&args).into_pyobject(vm)); } } @@ -285,145 +270,6 @@ impl PyFunction { pub fn invoke(&self, func_args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { self.invoke_with_scope(func_args, &self.scope, vm) } - - #[cfg(feature = "jit")] - fn get_jit_arg_type(dict: &PyDictRef, name: &str, vm: &VirtualMachine) -> PyResult { - if let Some(value) = dict.get_item_option(name, vm)? { - if value.is(&vm.ctx.types.int_type) { - Ok(JitType::Int) - } else if value.is(&vm.ctx.types.float_type) { - Ok(JitType::Float) - } else { - Err(vm.new_runtime_error( - "Jit requires argument to be either int or float".to_owned(), - )) - } - } else { - Err(vm.new_runtime_error(format!("argument {} needs annotation", name))) - } - } - - #[cfg(feature = "jit")] - fn get_jit_arg_types(zelf: &PyRef, vm: &VirtualMachine) -> PyResult> { - if zelf - .code - .flags - .intersects(CodeFlags::HAS_VARARGS | CodeFlags::HAS_VARKEYWORDS) - { - return Err(vm.new_runtime_error( - "Can't jit functions with variable number of arguments".to_owned(), - )); - } - - if zelf.code.arg_names.is_empty() && zelf.code.kwonlyarg_names.is_empty() { - return Ok(Vec::new()); - } - - let annotations = vm.get_attribute(zelf.clone().into_object(), "__annotations__")?; - if vm.is_none(&annotations) { - Err(vm.new_runtime_error( - "Jitting function requires arguments to have annotations".to_owned(), - )) - } else if let Ok(dict) = PyDictRef::try_from_object(vm, annotations) { - let mut arg_types = Vec::new(); - - for arg in &zelf.code.arg_names { - arg_types.push(Self::get_jit_arg_type(&dict, arg, vm)?); - } - - for arg in &zelf.code.kwonlyarg_names { - arg_types.push(Self::get_jit_arg_type(&dict, arg, vm)?); - } - - Ok(arg_types) - } else { - Err(vm.new_type_error("Function annotations aren't a dict".to_owned())) - } - } - - #[cfg(feature = "jit")] - fn get_jit_value(vm: &VirtualMachine, obj: &PyObjectRef) -> Option { - // This does exact type checks as subclasses of int/float can't be passed to jitted functions - let cls = obj.lease_class(); - if cls.is(&vm.ctx.types.int_type) { - objint::get_value(&obj).to_i64().map(AbiValue::Int) - } else if cls.is(&vm.ctx.types.float_type) { - Some(AbiValue::Float(objfloat::get_value(&obj))) - } else { - None - } - } - - /// Like `fill_locals_from_args` but to populate arguments for calling a jit function. - /// This also doesn't do full error handling but instead return None if anything is wrong. In - /// that case it falls back to the executing the bytecode version which will call - /// `fill_locals_from_args` which will raise the actual exception if needed. - #[cfg(feature = "jit")] - fn get_jit_args<'a>( - &self, - func_args: &PyFuncArgs, - jitted_code: &'a CompiledCode, - vm: &VirtualMachine, - ) -> Option> { - let mut jit_args = jitted_code.args_builder(); - let nargs = func_args.args.len(); - - if nargs > self.code.arg_names.len() || nargs < self.code.posonlyarg_count { - return None; - } - - // Add positional arguments - for i in 0..nargs { - jit_args.set(i, Self::get_jit_value(vm, &func_args.args[i])?); - } - - // Handle keyword arguments - for (name, value) in &func_args.kwargs { - if let Some(arg_idx) = self.code.arg_names.iter().position(|arg| arg == name) { - if jit_args.is_set(arg_idx) { - return None; - } - jit_args.set(arg_idx, Self::get_jit_value(vm, &value)?); - } else if let Some(kwarg_idx) = - self.code.kwonlyarg_names.iter().position(|arg| arg == name) - { - let arg_idx = kwarg_idx + self.code.arg_names.len(); - if jit_args.is_set(arg_idx) { - return None; - } - jit_args.set(arg_idx, Self::get_jit_value(vm, &value)?); - } else { - return None; - } - } - - // fill in positional defaults - if let Some(defaults) = &self.defaults { - let defaults = defaults.borrow_value(); - for (i, default) in defaults.iter().enumerate() { - let arg_idx = i + self.code.arg_names.len() - defaults.len(); - if !jit_args.is_set(arg_idx) { - jit_args.set(arg_idx, Self::get_jit_value(vm, default)?); - } - } - } - - // fill in keyword only defaults - if let Some(kw_only_defaults) = &self.kw_only_defaults { - for (i, name) in self.code.kwonlyarg_names.iter().enumerate() { - let arg_idx = i + self.code.arg_names.len(); - if !jit_args.is_set(arg_idx) { - let default = kw_only_defaults - .get_item(name.as_str(), vm) - .ok() - .and_then(|obj| Self::get_jit_value(vm, &obj))?; - jit_args.set(arg_idx, default); - } - } - } - - jit_args.into_args() - } } impl PyValue for PyFunction { @@ -465,7 +311,7 @@ impl PyFunction { fn jit(zelf: PyRef, vm: &VirtualMachine) -> PyResult<()> { zelf.jitted_code .get_or_try_init(|| { - let arg_types = PyFunction::get_jit_arg_types(&zelf, vm)?; + let arg_types = jitfunc::get_jit_arg_types(&zelf, vm)?; rustpython_jit::compile(&zelf.code.code, &arg_types) .map_err(|err| vm.new_runtime_error(err.to_string())) }) diff --git a/vm/src/obj/objfunction/jitfunc.rs b/vm/src/obj/objfunction/jitfunc.rs new file mode 100644 index 000000000..0657f56fc --- /dev/null +++ b/vm/src/obj/objfunction/jitfunc.rs @@ -0,0 +1,154 @@ +use crate::function::PyFuncArgs; +use crate::obj::objdict::PyDictRef; +use crate::obj::objfunction::{PyFunction, PyFunctionRef}; +use crate::obj::{objfloat, objint}; +use crate::pyobject::{ + BorrowValue, IdProtocol, IntoPyObject, ItemProtocol, PyObjectRef, PyResult, TryFromObject, + TypeProtocol, +}; +use crate::VirtualMachine; +use num_traits::ToPrimitive; +use rustpython_bytecode::bytecode::CodeFlags; +use rustpython_jit::{AbiValue, Args, CompiledCode, JitType}; + +impl IntoPyObject for AbiValue { + fn into_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { + match self { + AbiValue::Int(i) => i.into_pyobject(vm), + AbiValue::Float(f) => f.into_pyobject(vm), + } + } +} + +fn get_jit_arg_type(dict: &PyDictRef, name: &str, vm: &VirtualMachine) -> PyResult { + if let Some(value) = dict.get_item_option(name, vm)? { + if value.is(&vm.ctx.types.int_type) { + Ok(JitType::Int) + } else if value.is(&vm.ctx.types.float_type) { + Ok(JitType::Float) + } else { + Err(vm.new_runtime_error("Jit requires argument to be either int or float".to_owned())) + } + } else { + Err(vm.new_runtime_error(format!("argument {} needs annotation", name))) + } +} + +pub fn get_jit_arg_types(func: &PyFunctionRef, vm: &VirtualMachine) -> PyResult> { + if func + .code + .flags + .intersects(CodeFlags::HAS_VARARGS | CodeFlags::HAS_VARKEYWORDS) + { + return Err(vm.new_runtime_error( + "Can't jit functions with variable number of arguments".to_owned(), + )); + } + + if func.code.arg_names.is_empty() && func.code.kwonlyarg_names.is_empty() { + return Ok(Vec::new()); + } + + let annotations = vm.get_attribute(func.clone().into_object(), "__annotations__")?; + if vm.is_none(&annotations) { + Err(vm.new_runtime_error( + "Jitting function requires arguments to have annotations".to_owned(), + )) + } else if let Ok(dict) = PyDictRef::try_from_object(vm, annotations) { + let mut arg_types = Vec::new(); + + for arg in &func.code.arg_names { + arg_types.push(get_jit_arg_type(&dict, arg, vm)?); + } + + for arg in &func.code.kwonlyarg_names { + arg_types.push(get_jit_arg_type(&dict, arg, vm)?); + } + + Ok(arg_types) + } else { + Err(vm.new_type_error("Function annotations aren't a dict".to_owned())) + } +} + +fn get_jit_value(vm: &VirtualMachine, obj: &PyObjectRef) -> Option { + // This does exact type checks as subclasses of int/float can't be passed to jitted functions + let cls = obj.lease_class(); + if cls.is(&vm.ctx.types.int_type) { + objint::get_value(&obj).to_i64().map(AbiValue::Int) + } else if cls.is(&vm.ctx.types.float_type) { + Some(AbiValue::Float(objfloat::get_value(&obj))) + } else { + None + } +} + +/// Like `fill_locals_from_args` but to populate arguments for calling a jit function. +/// This also doesn't do full error handling but instead return None if anything is wrong. In +/// that case it falls back to the executing the bytecode version which will call +/// `fill_locals_from_args` which will raise the actual exception if needed. +#[cfg(feature = "jit")] +pub(crate) fn get_jit_args<'a>( + func: &PyFunction, + func_args: &PyFuncArgs, + jitted_code: &'a CompiledCode, + vm: &VirtualMachine, +) -> Option> { + let mut jit_args = jitted_code.args_builder(); + let nargs = func_args.args.len(); + + if nargs > func.code.arg_names.len() || nargs < func.code.posonlyarg_count { + return None; + } + + // Add positional arguments + for i in 0..nargs { + jit_args.set(i, get_jit_value(vm, &func_args.args[i])?); + } + + // Handle keyword arguments + for (name, value) in &func_args.kwargs { + if let Some(arg_idx) = func.code.arg_names.iter().position(|arg| arg == name) { + if jit_args.is_set(arg_idx) { + return None; + } + jit_args.set(arg_idx, get_jit_value(vm, &value)?); + } else if let Some(kwarg_idx) = func.code.kwonlyarg_names.iter().position(|arg| arg == name) + { + let arg_idx = kwarg_idx + func.code.arg_names.len(); + if jit_args.is_set(arg_idx) { + return None; + } + jit_args.set(arg_idx, get_jit_value(vm, &value)?); + } else { + return None; + } + } + + // fill in positional defaults + if let Some(defaults) = &func.defaults { + let defaults = defaults.borrow_value(); + for (i, default) in defaults.iter().enumerate() { + let arg_idx = i + func.code.arg_names.len() - defaults.len(); + if !jit_args.is_set(arg_idx) { + jit_args.set(arg_idx, get_jit_value(vm, default)?); + } + } + } + + // fill in keyword only defaults + if let Some(kw_only_defaults) = &func.kw_only_defaults { + for (i, name) in func.code.kwonlyarg_names.iter().enumerate() { + let arg_idx = i + func.code.arg_names.len(); + if !jit_args.is_set(arg_idx) { + let default = kw_only_defaults + .get_item(name.as_str(), vm) + .ok() + .and_then(|obj| get_jit_value(vm, &obj))?; + jit_args.set(arg_idx, default); + } + } + } + + jit_args.into_args() +}