Support recursion in JIT-ed functions (#5473)

This commit is contained in:
Shubham Patil
2025-01-13 11:25:27 +05:30
committed by GitHub
parent 76c699b4ba
commit 53db70e784
6 changed files with 125 additions and 30 deletions

View File

@@ -1,3 +1,5 @@
use super::{JitCompileError, JitSig, JitType};
use cranelift::codegen::ir::FuncRef;
use cranelift::prelude::*;
use num_traits::cast::ToPrimitive;
use rustpython_compiler_core::bytecode::{
@@ -6,8 +8,6 @@ use rustpython_compiler_core::bytecode::{
};
use std::collections::HashMap;
use super::{JitCompileError, JitSig, JitType};
#[repr(u16)]
enum CustomTrapCode {
/// Raised when shifting by a negative number
@@ -27,6 +27,7 @@ enum JitValue {
Bool(Value),
None,
Tuple(Vec<JitValue>),
FuncRef(FuncRef),
}
impl JitValue {
@@ -43,14 +44,14 @@ impl JitValue {
JitValue::Int(_) => Some(JitType::Int),
JitValue::Float(_) => Some(JitType::Float),
JitValue::Bool(_) => Some(JitType::Bool),
JitValue::None | JitValue::Tuple(_) => None,
JitValue::None | JitValue::Tuple(_) | JitValue::FuncRef(_) => None,
}
}
fn into_value(self) -> Option<Value> {
match self {
JitValue::Int(val) | JitValue::Float(val) | JitValue::Bool(val) => Some(val),
JitValue::None | JitValue::Tuple(_) => None,
JitValue::None | JitValue::Tuple(_) | JitValue::FuncRef(_) => None,
}
}
}
@@ -68,6 +69,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
builder: &'a mut FunctionBuilder<'b>,
num_variables: usize,
arg_types: &[JitType],
ret_type: Option<JitType>,
entry_block: Block,
) -> FunctionCompiler<'a, 'b> {
let mut compiler = FunctionCompiler {
@@ -77,7 +79,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
label_to_block: HashMap::new(),
sig: JitSig {
args: arg_types.to_vec(),
ret: None,
ret: ret_type,
},
};
let params = compiler.builder.func.dfg.block_params(entry_block).to_vec();
@@ -132,7 +134,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
}
JitValue::Bool(val) => Ok(val),
JitValue::None => Ok(self.builder.ins().iconst(types::I8, 0)),
JitValue::Tuple(_) => Err(JitCompileError::NotSupported),
JitValue::Tuple(_) | JitValue::FuncRef(_) => Err(JitCompileError::NotSupported),
}
}
@@ -146,6 +148,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
pub fn compile<C: bytecode::Constant>(
&mut self,
func_ref: FuncRef,
bytecode: &CodeObject<C>,
) -> Result<(), JitCompileError> {
// TODO: figure out if this is sufficient -- previously individual labels were associated
@@ -177,7 +180,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
continue;
}
self.add_instruction(instruction, arg, &bytecode.constants)?;
self.add_instruction(func_ref, bytecode, instruction, arg)?;
}
Ok(())
@@ -229,9 +232,10 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
pub fn add_instruction<C: bytecode::Constant>(
&mut self,
func_ref: FuncRef,
bytecode: &CodeObject<C>,
instruction: Instruction,
arg: OpArg,
constants: &[C],
) -> Result<(), JitCompileError> {
match instruction {
Instruction::ExtendedArg => Ok(()),
@@ -282,7 +286,8 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
self.store_variable(idx.get(arg), val)
}
Instruction::LoadConst { idx } => {
let val = self.prepare_const(constants[idx.get(arg) as usize].borrow_constant())?;
let val = self
.prepare_const(bytecode.constants[idx.get(arg) as usize].borrow_constant())?;
self.stack.push(val);
Ok(())
}
@@ -311,7 +316,8 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
self.return_value(val)
}
Instruction::ReturnConst { idx } => {
let val = self.prepare_const(constants[idx.get(arg) as usize].borrow_constant())?;
let val = self
.prepare_const(bytecode.constants[idx.get(arg) as usize].borrow_constant())?;
self.return_value(val)
}
Instruction::CompareOperation { op, .. } => {
@@ -508,6 +514,36 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
// TODO: block support
Ok(())
}
Instruction::LoadGlobal(idx) => {
let name = &bytecode.names[idx.get(arg) as usize];
if name.as_ref() != bytecode.obj_name.as_ref() {
Err(JitCompileError::NotSupported)
} else {
self.stack.push(JitValue::FuncRef(func_ref));
Ok(())
}
}
Instruction::CallFunctionPositional { nargs } => {
let nargs = nargs.get(arg);
let mut args = Vec::new();
for _ in 0..nargs {
let arg = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
args.push(arg.into_value().unwrap());
}
match self.stack.pop().ok_or(JitCompileError::BadBytecode)? {
JitValue::FuncRef(reference) => {
let call = self.builder.ins().call(reference, &args);
let returns = self.builder.inst_results(call);
self.stack.push(JitValue::Int(returns[0]));
Ok(())
}
_ => Err(JitCompileError::BadBytecode),
}
}
_ => Err(JitCompileError::NotSupported),
}
}

View File

@@ -49,6 +49,7 @@ impl Jit {
&mut self,
bytecode: &bytecode::CodeObject<C>,
args: &[JitType],
ret: Option<JitType>,
) -> Result<(FuncId, JitSig), JitCompileError> {
for arg in args {
self.ctx
@@ -58,22 +59,13 @@ impl Jit {
.push(AbiParam::new(arg.to_cranelift()));
}
let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
let entry_block = builder.create_block();
builder.append_block_params_for_function_params(entry_block);
builder.switch_to_block(entry_block);
let sig = {
let mut compiler =
FunctionCompiler::new(&mut builder, bytecode.varnames.len(), args, entry_block);
compiler.compile(bytecode)?;
compiler.sig
};
builder.seal_all_blocks();
builder.finalize();
if ret.is_some() {
self.ctx
.func
.signature
.returns
.push(AbiParam::new(ret.clone().unwrap().to_cranelift()));
}
let id = self.module.declare_function(
&format!("jit_{}", bytecode.obj_name.as_ref()),
@@ -81,6 +73,30 @@ impl Jit {
&self.ctx.func.signature,
)?;
let func_ref = self.module.declare_func_in_func(id, &mut self.ctx.func);
let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
let entry_block = builder.create_block();
builder.append_block_params_for_function_params(entry_block);
builder.switch_to_block(entry_block);
let sig = {
let mut compiler = FunctionCompiler::new(
&mut builder,
bytecode.varnames.len(),
args,
ret,
entry_block,
);
compiler.compile(func_ref, bytecode)?;
compiler.sig
};
builder.seal_all_blocks();
builder.finalize();
self.module.define_function(id, &mut self.ctx)?;
self.module.clear_context(&mut self.ctx);
@@ -92,10 +108,11 @@ impl Jit {
pub fn compile<C: bytecode::Constant>(
bytecode: &bytecode::CodeObject<C>,
args: &[JitType],
ret: Option<JitType>,
) -> Result<CompiledCode, JitCompileError> {
let mut jit = Jit::new();
let (id, sig) = jit.build_function(bytecode, args)?;
let (id, sig) = jit.build_function(bytecode, args, ret)?;
jit.module.finalize_definitions();

View File

@@ -27,7 +27,17 @@ impl Function {
arg_types.push(arg_type);
}
rustpython_jit::compile(&self.code, &arg_types).expect("Compile failure")
let ret_type = match self.annotations.get("return") {
Some(StackValue::String(annotation)) => match annotation.as_str() {
"int" => Some(JitType::Int),
"float" => Some(JitType::Float),
"bool" => Some(JitType::Bool),
_ => panic!("Unrecognised jit type"),
},
_ => None,
};
rustpython_jit::compile(&self.code, &arg_types, ret_type).expect("Compile failure")
}
}

View File

@@ -113,3 +113,15 @@ fn test_unpack_tuple() {
assert_eq!(unpack_tuple(0, 1), Ok(1));
assert_eq!(unpack_tuple(1, 2), Ok(2));
}
#[test]
fn test_recursive_fib() {
let fib = jit_function! { fib(n: i64) -> i64 => r##"
def fib(n: int) -> int:
if n == 0 or n == 1:
return 1
return fib(n-1) + fib(n-2)
"## };
assert_eq!(fib(10), Ok(89));
}

View File

@@ -506,7 +506,8 @@ impl PyFunction {
zelf.jitted_code
.get_or_try_init(|| {
let arg_types = jitfunc::get_jit_arg_types(&zelf, vm)?;
rustpython_jit::compile(&zelf.code.code, &arg_types)
let ret_type = jitfunc::jit_ret_type(&zelf, vm)?;
rustpython_jit::compile(&zelf.code.code, &arg_types, ret_type)
.map_err(|err| jitfunc::new_jit_error(err.to_string(), vm))
})
.map(drop)

View File

@@ -52,7 +52,7 @@ fn get_jit_arg_type(dict: &PyDictRef, name: &str, vm: &VirtualMachine) -> PyResu
Ok(JitType::Bool)
} else {
Err(new_jit_error(
"Jit requires argument to be either int or float".to_owned(),
"Jit requires argument to be either int, float or bool".to_owned(),
vm,
))
}
@@ -106,6 +106,25 @@ pub fn get_jit_arg_types(func: &Py<PyFunction>, vm: &VirtualMachine) -> PyResult
}
}
pub fn jit_ret_type(func: &Py<PyFunction>, vm: &VirtualMachine) -> PyResult<Option<JitType>> {
let func_obj: PyObjectRef = func.as_ref().to_owned();
let annotations = func_obj.get_attr("__annotations__", vm)?;
if vm.is_none(&annotations) {
Err(new_jit_error(
"Jitting function requires return type to have annotations".to_owned(),
vm,
))
} else if let Ok(dict) = PyDictRef::try_from_object(vm, annotations) {
if dict.contains_key("return", vm) {
get_jit_arg_type(&dict, "return", vm).map_or(Ok(None), |t| Ok(Some(t)))
} else {
Ok(None)
}
} else {
Err(vm.new_type_error("Function annotations aren't a dict".to_owned()))
}
}
fn get_jit_value(vm: &VirtualMachine, obj: &PyObject) -> Result<AbiValue, ArgsError> {
// This does exact type checks as subclasses of int/float can't be passed to jitted functions
let cls = obj.class();