forked from Rust-related/RustPython
Support recursion in JIT-ed functions (#5473)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
use super::{JitCompileError, JitSig, JitType};
|
||||
use cranelift::codegen::ir::FuncRef;
|
||||
use cranelift::prelude::*;
|
||||
use num_traits::cast::ToPrimitive;
|
||||
use rustpython_compiler_core::bytecode::{
|
||||
@@ -6,8 +8,6 @@ use rustpython_compiler_core::bytecode::{
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{JitCompileError, JitSig, JitType};
|
||||
|
||||
#[repr(u16)]
|
||||
enum CustomTrapCode {
|
||||
/// Raised when shifting by a negative number
|
||||
@@ -27,6 +27,7 @@ enum JitValue {
|
||||
Bool(Value),
|
||||
None,
|
||||
Tuple(Vec<JitValue>),
|
||||
FuncRef(FuncRef),
|
||||
}
|
||||
|
||||
impl JitValue {
|
||||
@@ -43,14 +44,14 @@ impl JitValue {
|
||||
JitValue::Int(_) => Some(JitType::Int),
|
||||
JitValue::Float(_) => Some(JitType::Float),
|
||||
JitValue::Bool(_) => Some(JitType::Bool),
|
||||
JitValue::None | JitValue::Tuple(_) => None,
|
||||
JitValue::None | JitValue::Tuple(_) | JitValue::FuncRef(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn into_value(self) -> Option<Value> {
|
||||
match self {
|
||||
JitValue::Int(val) | JitValue::Float(val) | JitValue::Bool(val) => Some(val),
|
||||
JitValue::None | JitValue::Tuple(_) => None,
|
||||
JitValue::None | JitValue::Tuple(_) | JitValue::FuncRef(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -68,6 +69,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
|
||||
builder: &'a mut FunctionBuilder<'b>,
|
||||
num_variables: usize,
|
||||
arg_types: &[JitType],
|
||||
ret_type: Option<JitType>,
|
||||
entry_block: Block,
|
||||
) -> FunctionCompiler<'a, 'b> {
|
||||
let mut compiler = FunctionCompiler {
|
||||
@@ -77,7 +79,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
|
||||
label_to_block: HashMap::new(),
|
||||
sig: JitSig {
|
||||
args: arg_types.to_vec(),
|
||||
ret: None,
|
||||
ret: ret_type,
|
||||
},
|
||||
};
|
||||
let params = compiler.builder.func.dfg.block_params(entry_block).to_vec();
|
||||
@@ -132,7 +134,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
|
||||
}
|
||||
JitValue::Bool(val) => Ok(val),
|
||||
JitValue::None => Ok(self.builder.ins().iconst(types::I8, 0)),
|
||||
JitValue::Tuple(_) => Err(JitCompileError::NotSupported),
|
||||
JitValue::Tuple(_) | JitValue::FuncRef(_) => Err(JitCompileError::NotSupported),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,6 +148,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
|
||||
|
||||
pub fn compile<C: bytecode::Constant>(
|
||||
&mut self,
|
||||
func_ref: FuncRef,
|
||||
bytecode: &CodeObject<C>,
|
||||
) -> Result<(), JitCompileError> {
|
||||
// TODO: figure out if this is sufficient -- previously individual labels were associated
|
||||
@@ -177,7 +180,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
|
||||
continue;
|
||||
}
|
||||
|
||||
self.add_instruction(instruction, arg, &bytecode.constants)?;
|
||||
self.add_instruction(func_ref, bytecode, instruction, arg)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -229,9 +232,10 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
|
||||
|
||||
pub fn add_instruction<C: bytecode::Constant>(
|
||||
&mut self,
|
||||
func_ref: FuncRef,
|
||||
bytecode: &CodeObject<C>,
|
||||
instruction: Instruction,
|
||||
arg: OpArg,
|
||||
constants: &[C],
|
||||
) -> Result<(), JitCompileError> {
|
||||
match instruction {
|
||||
Instruction::ExtendedArg => Ok(()),
|
||||
@@ -282,7 +286,8 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
|
||||
self.store_variable(idx.get(arg), val)
|
||||
}
|
||||
Instruction::LoadConst { idx } => {
|
||||
let val = self.prepare_const(constants[idx.get(arg) as usize].borrow_constant())?;
|
||||
let val = self
|
||||
.prepare_const(bytecode.constants[idx.get(arg) as usize].borrow_constant())?;
|
||||
self.stack.push(val);
|
||||
Ok(())
|
||||
}
|
||||
@@ -311,7 +316,8 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
|
||||
self.return_value(val)
|
||||
}
|
||||
Instruction::ReturnConst { idx } => {
|
||||
let val = self.prepare_const(constants[idx.get(arg) as usize].borrow_constant())?;
|
||||
let val = self
|
||||
.prepare_const(bytecode.constants[idx.get(arg) as usize].borrow_constant())?;
|
||||
self.return_value(val)
|
||||
}
|
||||
Instruction::CompareOperation { op, .. } => {
|
||||
@@ -508,6 +514,36 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
|
||||
// TODO: block support
|
||||
Ok(())
|
||||
}
|
||||
Instruction::LoadGlobal(idx) => {
|
||||
let name = &bytecode.names[idx.get(arg) as usize];
|
||||
|
||||
if name.as_ref() != bytecode.obj_name.as_ref() {
|
||||
Err(JitCompileError::NotSupported)
|
||||
} else {
|
||||
self.stack.push(JitValue::FuncRef(func_ref));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Instruction::CallFunctionPositional { nargs } => {
|
||||
let nargs = nargs.get(arg);
|
||||
|
||||
let mut args = Vec::new();
|
||||
for _ in 0..nargs {
|
||||
let arg = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
|
||||
args.push(arg.into_value().unwrap());
|
||||
}
|
||||
|
||||
match self.stack.pop().ok_or(JitCompileError::BadBytecode)? {
|
||||
JitValue::FuncRef(reference) => {
|
||||
let call = self.builder.ins().call(reference, &args);
|
||||
let returns = self.builder.inst_results(call);
|
||||
self.stack.push(JitValue::Int(returns[0]));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(JitCompileError::BadBytecode),
|
||||
}
|
||||
}
|
||||
_ => Err(JitCompileError::NotSupported),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ impl Jit {
|
||||
&mut self,
|
||||
bytecode: &bytecode::CodeObject<C>,
|
||||
args: &[JitType],
|
||||
ret: Option<JitType>,
|
||||
) -> Result<(FuncId, JitSig), JitCompileError> {
|
||||
for arg in args {
|
||||
self.ctx
|
||||
@@ -58,22 +59,13 @@ impl Jit {
|
||||
.push(AbiParam::new(arg.to_cranelift()));
|
||||
}
|
||||
|
||||
let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
|
||||
let entry_block = builder.create_block();
|
||||
builder.append_block_params_for_function_params(entry_block);
|
||||
builder.switch_to_block(entry_block);
|
||||
|
||||
let sig = {
|
||||
let mut compiler =
|
||||
FunctionCompiler::new(&mut builder, bytecode.varnames.len(), args, entry_block);
|
||||
|
||||
compiler.compile(bytecode)?;
|
||||
|
||||
compiler.sig
|
||||
};
|
||||
|
||||
builder.seal_all_blocks();
|
||||
builder.finalize();
|
||||
if ret.is_some() {
|
||||
self.ctx
|
||||
.func
|
||||
.signature
|
||||
.returns
|
||||
.push(AbiParam::new(ret.clone().unwrap().to_cranelift()));
|
||||
}
|
||||
|
||||
let id = self.module.declare_function(
|
||||
&format!("jit_{}", bytecode.obj_name.as_ref()),
|
||||
@@ -81,6 +73,30 @@ impl Jit {
|
||||
&self.ctx.func.signature,
|
||||
)?;
|
||||
|
||||
let func_ref = self.module.declare_func_in_func(id, &mut self.ctx.func);
|
||||
|
||||
let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
|
||||
let entry_block = builder.create_block();
|
||||
builder.append_block_params_for_function_params(entry_block);
|
||||
builder.switch_to_block(entry_block);
|
||||
|
||||
let sig = {
|
||||
let mut compiler = FunctionCompiler::new(
|
||||
&mut builder,
|
||||
bytecode.varnames.len(),
|
||||
args,
|
||||
ret,
|
||||
entry_block,
|
||||
);
|
||||
|
||||
compiler.compile(func_ref, bytecode)?;
|
||||
|
||||
compiler.sig
|
||||
};
|
||||
|
||||
builder.seal_all_blocks();
|
||||
builder.finalize();
|
||||
|
||||
self.module.define_function(id, &mut self.ctx)?;
|
||||
|
||||
self.module.clear_context(&mut self.ctx);
|
||||
@@ -92,10 +108,11 @@ impl Jit {
|
||||
pub fn compile<C: bytecode::Constant>(
|
||||
bytecode: &bytecode::CodeObject<C>,
|
||||
args: &[JitType],
|
||||
ret: Option<JitType>,
|
||||
) -> Result<CompiledCode, JitCompileError> {
|
||||
let mut jit = Jit::new();
|
||||
|
||||
let (id, sig) = jit.build_function(bytecode, args)?;
|
||||
let (id, sig) = jit.build_function(bytecode, args, ret)?;
|
||||
|
||||
jit.module.finalize_definitions();
|
||||
|
||||
|
||||
@@ -27,7 +27,17 @@ impl Function {
|
||||
arg_types.push(arg_type);
|
||||
}
|
||||
|
||||
rustpython_jit::compile(&self.code, &arg_types).expect("Compile failure")
|
||||
let ret_type = match self.annotations.get("return") {
|
||||
Some(StackValue::String(annotation)) => match annotation.as_str() {
|
||||
"int" => Some(JitType::Int),
|
||||
"float" => Some(JitType::Float),
|
||||
"bool" => Some(JitType::Bool),
|
||||
_ => panic!("Unrecognised jit type"),
|
||||
},
|
||||
_ => None,
|
||||
};
|
||||
|
||||
rustpython_jit::compile(&self.code, &arg_types, ret_type).expect("Compile failure")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -113,3 +113,15 @@ fn test_unpack_tuple() {
|
||||
assert_eq!(unpack_tuple(0, 1), Ok(1));
|
||||
assert_eq!(unpack_tuple(1, 2), Ok(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recursive_fib() {
|
||||
let fib = jit_function! { fib(n: i64) -> i64 => r##"
|
||||
def fib(n: int) -> int:
|
||||
if n == 0 or n == 1:
|
||||
return 1
|
||||
return fib(n-1) + fib(n-2)
|
||||
"## };
|
||||
|
||||
assert_eq!(fib(10), Ok(89));
|
||||
}
|
||||
|
||||
@@ -506,7 +506,8 @@ impl PyFunction {
|
||||
zelf.jitted_code
|
||||
.get_or_try_init(|| {
|
||||
let arg_types = jitfunc::get_jit_arg_types(&zelf, vm)?;
|
||||
rustpython_jit::compile(&zelf.code.code, &arg_types)
|
||||
let ret_type = jitfunc::jit_ret_type(&zelf, vm)?;
|
||||
rustpython_jit::compile(&zelf.code.code, &arg_types, ret_type)
|
||||
.map_err(|err| jitfunc::new_jit_error(err.to_string(), vm))
|
||||
})
|
||||
.map(drop)
|
||||
|
||||
@@ -52,7 +52,7 @@ fn get_jit_arg_type(dict: &PyDictRef, name: &str, vm: &VirtualMachine) -> PyResu
|
||||
Ok(JitType::Bool)
|
||||
} else {
|
||||
Err(new_jit_error(
|
||||
"Jit requires argument to be either int or float".to_owned(),
|
||||
"Jit requires argument to be either int, float or bool".to_owned(),
|
||||
vm,
|
||||
))
|
||||
}
|
||||
@@ -106,6 +106,25 @@ pub fn get_jit_arg_types(func: &Py<PyFunction>, vm: &VirtualMachine) -> PyResult
|
||||
}
|
||||
}
|
||||
|
||||
pub fn jit_ret_type(func: &Py<PyFunction>, vm: &VirtualMachine) -> PyResult<Option<JitType>> {
|
||||
let func_obj: PyObjectRef = func.as_ref().to_owned();
|
||||
let annotations = func_obj.get_attr("__annotations__", vm)?;
|
||||
if vm.is_none(&annotations) {
|
||||
Err(new_jit_error(
|
||||
"Jitting function requires return type to have annotations".to_owned(),
|
||||
vm,
|
||||
))
|
||||
} else if let Ok(dict) = PyDictRef::try_from_object(vm, annotations) {
|
||||
if dict.contains_key("return", vm) {
|
||||
get_jit_arg_type(&dict, "return", vm).map_or(Ok(None), |t| Ok(Some(t)))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
Err(vm.new_type_error("Function annotations aren't a dict".to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_jit_value(vm: &VirtualMachine, obj: &PyObject) -> Result<AbiValue, ArgsError> {
|
||||
// This does exact type checks as subclasses of int/float can't be passed to jitted functions
|
||||
let cls = obj.class();
|
||||
|
||||
Reference in New Issue
Block a user