Fix PyCode constructor/replace (#6193)

* Fix PyCode constructor

* Reuse MarshalError
This commit is contained in:
Jeong, YunWon
2025-10-22 21:09:42 +09:00
committed by GitHub
parent 0fb7d0fae2
commit f22aed2614
6 changed files with 296 additions and 99 deletions

View File

@@ -222,8 +222,6 @@ class CodeTest(unittest.TestCase):
obj = List([1, 2, 3])
self.assertEqual(obj[0], "Foreign getitem: 1")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_constructor(self):
def func(): pass
co = func.__code__
@@ -255,8 +253,6 @@ class CodeTest(unittest.TestCase):
CodeTest.test_qualname.__qualname__
)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_replace(self):
def func():
x = 1
@@ -297,8 +293,6 @@ class CodeTest(unittest.TestCase):
self.assertEqual(new_code.co_varnames, code2.co_varnames)
self.assertEqual(new_code.co_nlocals, code2.co_nlocals)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_nlocals_mismatch(self):
def func():
x = 1

View File

@@ -65,8 +65,6 @@ class FunctionPropertiesTest(FuncAttrsTest):
return 3
self.assertNotEqual(self.b, duplicate)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_copying___code__(self):
def test(): pass
self.assertEqual(test(), None)

View File

@@ -165,6 +165,19 @@ impl<'a> ReadBorrowed<'a> for &'a [u8] {
}
}
/// Parses bytecode bytes into CodeUnit instructions.
/// Each instruction is 2 bytes: opcode and argument.
pub fn parse_instructions_from_bytes(bytes: &[u8]) -> Result<Box<[CodeUnit]>> {
bytes
.chunks_exact(2)
.map(|cu| {
let op = Instruction::try_from(cu[0])?;
let arg = OpArgByte(cu[1]);
Ok(CodeUnit { op, arg })
})
.collect()
}
pub struct Cursor<B> {
pub data: B,
pub position: usize,
@@ -185,14 +198,7 @@ pub fn deserialize_code<R: Read, Bag: ConstantBag>(
) -> Result<CodeObject<Bag::Constant>> {
let len = rdr.read_u32()?;
let instructions = rdr.read_slice(len * 2)?;
let instructions = instructions
.chunks_exact(2)
.map(|cu| {
let op = Instruction::try_from(cu[0])?;
let arg = OpArgByte(cu[1]);
Ok(CodeUnit { op, arg })
})
.collect::<Result<Box<[CodeUnit]>>>()?;
let instructions = parse_instructions_from_bytes(instructions)?;
let len = rdr.read_u32()?;
let locations = (0..len)

View File

@@ -2,21 +2,24 @@
*/
use super::{PyStrRef, PyTupleRef, PyType, PyTypeRef};
use super::{PyBytesRef, PyStrRef, PyTupleRef, PyType, PyTypeRef};
use crate::{
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
builtins::PyStrInterned,
bytecode::{self, AsBag, BorrowedConstant, CodeFlags, Constant, ConstantBag},
bytecode::{self, AsBag, BorrowedConstant, CodeFlags, CodeUnit, Constant, ConstantBag},
class::{PyClassImpl, StaticType},
convert::ToPyObject,
frozen,
function::{FuncArgs, OptionalArg},
types::Representable,
function::OptionalArg,
types::{Constructor, Representable},
};
use malachite_bigint::BigInt;
use num_traits::Zero;
use rustpython_compiler_core::OneIndexed;
use rustpython_compiler_core::bytecode::PyCodeLocationInfoKind;
use rustpython_compiler_core::{
OneIndexed,
bytecode::PyCodeLocationInfoKind,
marshal::{MarshalError, parse_instructions_from_bytes},
};
use std::{borrow::Borrow, fmt, ops::Deref};
/// State for iterating through code address ranges
@@ -367,13 +370,158 @@ impl Representable for PyCode {
}
}
#[pyclass(with(Representable))]
impl PyCode {
#[pyslot]
fn slot_new(_cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult {
Err(vm.new_type_error("Cannot directly create code object"))
}
// Arguments for code object constructor
#[derive(FromArgs)]
pub struct PyCodeNewArgs {
argcount: u32,
posonlyargcount: u32,
kwonlyargcount: u32,
nlocals: u32,
stacksize: u32,
flags: u16,
co_code: PyBytesRef,
consts: PyTupleRef,
names: PyTupleRef,
varnames: PyTupleRef,
filename: PyStrRef,
name: PyStrRef,
qualname: PyStrRef,
firstlineno: i32,
linetable: PyBytesRef,
exceptiontable: PyBytesRef,
freevars: PyTupleRef,
cellvars: PyTupleRef,
}
impl Constructor for PyCode {
type Args = PyCodeNewArgs;
fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult {
// Convert names tuple to vector of interned strings
let names: Box<[&'static PyStrInterned]> = args
.names
.iter()
.map(|obj| {
let s = obj.downcast_ref::<super::pystr::PyStr>().ok_or_else(|| {
vm.new_type_error("names must be tuple of strings".to_owned())
})?;
Ok(vm.ctx.intern_str(s.as_str()))
})
.collect::<PyResult<Vec<_>>>()?
.into_boxed_slice();
let varnames: Box<[&'static PyStrInterned]> = args
.varnames
.iter()
.map(|obj| {
let s = obj.downcast_ref::<super::pystr::PyStr>().ok_or_else(|| {
vm.new_type_error("varnames must be tuple of strings".to_owned())
})?;
Ok(vm.ctx.intern_str(s.as_str()))
})
.collect::<PyResult<Vec<_>>>()?
.into_boxed_slice();
let cellvars: Box<[&'static PyStrInterned]> = args
.cellvars
.iter()
.map(|obj| {
let s = obj.downcast_ref::<super::pystr::PyStr>().ok_or_else(|| {
vm.new_type_error("cellvars must be tuple of strings".to_owned())
})?;
Ok(vm.ctx.intern_str(s.as_str()))
})
.collect::<PyResult<Vec<_>>>()?
.into_boxed_slice();
let freevars: Box<[&'static PyStrInterned]> = args
.freevars
.iter()
.map(|obj| {
let s = obj.downcast_ref::<super::pystr::PyStr>().ok_or_else(|| {
vm.new_type_error("freevars must be tuple of strings".to_owned())
})?;
Ok(vm.ctx.intern_str(s.as_str()))
})
.collect::<PyResult<Vec<_>>>()?
.into_boxed_slice();
// Check nlocals matches varnames length
if args.nlocals as usize != varnames.len() {
return Err(vm.new_value_error(format!(
"nlocals ({}) != len(varnames) ({})",
args.nlocals,
varnames.len()
)));
}
// Parse and validate bytecode from bytes
let bytecode_bytes = args.co_code.as_bytes();
let instructions = parse_bytecode(bytecode_bytes)
.map_err(|e| vm.new_value_error(format!("invalid bytecode: {}", e)))?;
// Convert constants
let constants: Box<[Literal]> = args
.consts
.iter()
.map(|obj| {
// Convert PyObject to Literal constant
// For now, just wrap it
Literal(obj.clone())
})
.collect::<Vec<_>>()
.into_boxed_slice();
// Create locations
let row = if args.firstlineno > 0 {
OneIndexed::new(args.firstlineno as usize).unwrap_or(OneIndexed::MIN)
} else {
OneIndexed::MIN
};
let locations: Box<[rustpython_compiler_core::SourceLocation]> = vec![
rustpython_compiler_core::SourceLocation {
line: row,
character_offset: OneIndexed::from_zero_indexed(0),
};
instructions.len()
]
.into_boxed_slice();
// Build the CodeObject
let code = CodeObject {
instructions,
locations,
flags: CodeFlags::from_bits_truncate(args.flags),
posonlyarg_count: args.posonlyargcount,
arg_count: args.argcount,
kwonlyarg_count: args.kwonlyargcount,
source_path: vm.ctx.intern_str(args.filename.as_str()),
first_line_number: if args.firstlineno > 0 {
OneIndexed::new(args.firstlineno as usize)
} else {
None
},
max_stackdepth: args.stacksize,
obj_name: vm.ctx.intern_str(args.name.as_str()),
qualname: vm.ctx.intern_str(args.qualname.as_str()),
cell2arg: None, // TODO: reuse `fn cell2arg`
constants,
names,
varnames,
cellvars,
freevars,
linetable: args.linetable.as_bytes().to_vec().into_boxed_slice(),
exceptiontable: args.exceptiontable.as_bytes().to_vec().into_boxed_slice(),
};
Ok(PyCode::new(code)
.into_ref_with_type(vm, cls)?
.to_pyobject(vm))
}
}
#[pyclass(with(Representable, Constructor))]
impl PyCode {
#[pygetset]
const fn co_posonlyargcount(&self) -> usize {
self.code.posonlyarg_count as usize
@@ -397,9 +545,7 @@ impl PyCode {
#[pygetset]
pub fn co_cellvars(&self, vm: &VirtualMachine) -> PyTupleRef {
let cellvars = self
.code
.cellvars
.deref()
.iter()
.map(|name| name.to_pyobject(vm))
.collect();
@@ -408,7 +554,7 @@ impl PyCode {
#[pygetset]
fn co_nlocals(&self) -> usize {
self.varnames.len()
self.code.varnames.len()
}
#[pygetset]
@@ -690,42 +836,62 @@ impl PyCode {
#[pymethod]
pub fn replace(&self, args: ReplaceArgs, vm: &VirtualMachine) -> PyResult<Self> {
let posonlyarg_count = match args.co_posonlyargcount {
let ReplaceArgs {
co_posonlyargcount,
co_argcount,
co_kwonlyargcount,
co_filename,
co_firstlineno,
co_consts,
co_name,
co_names,
co_flags,
co_varnames,
co_nlocals,
co_stacksize,
co_code,
co_linetable,
co_exceptiontable,
co_freevars,
co_cellvars,
co_qualname,
} = args;
let posonlyarg_count = match co_posonlyargcount {
OptionalArg::Present(posonlyarg_count) => posonlyarg_count,
OptionalArg::Missing => self.code.posonlyarg_count,
};
let arg_count = match args.co_argcount {
let arg_count = match co_argcount {
OptionalArg::Present(arg_count) => arg_count,
OptionalArg::Missing => self.code.arg_count,
};
let source_path = match args.co_filename {
let source_path = match co_filename {
OptionalArg::Present(source_path) => source_path,
OptionalArg::Missing => self.code.source_path.to_owned(),
};
let first_line_number = match args.co_firstlineno {
let first_line_number = match co_firstlineno {
OptionalArg::Present(first_line_number) => OneIndexed::new(first_line_number as _),
OptionalArg::Missing => self.code.first_line_number,
};
let kwonlyarg_count = match args.co_kwonlyargcount {
let kwonlyarg_count = match co_kwonlyargcount {
OptionalArg::Present(kwonlyarg_count) => kwonlyarg_count,
OptionalArg::Missing => self.code.kwonlyarg_count,
};
let constants = match args.co_consts {
let constants = match co_consts {
OptionalArg::Present(constants) => constants,
OptionalArg::Missing => self.code.constants.iter().map(|x| x.0.clone()).collect(),
};
let obj_name = match args.co_name {
let obj_name = match co_name {
OptionalArg::Present(obj_name) => obj_name,
OptionalArg::Missing => self.code.obj_name.to_owned(),
};
let names = match args.co_names {
let names = match co_names {
OptionalArg::Present(names) => names,
OptionalArg::Missing => self
.code
@@ -736,37 +902,36 @@ impl PyCode {
.collect(),
};
let flags = match args.co_flags {
let flags = match co_flags {
OptionalArg::Present(flags) => flags,
OptionalArg::Missing => self.code.flags.bits(),
};
let varnames = match args.co_varnames {
let varnames = match co_varnames {
OptionalArg::Present(varnames) => varnames,
OptionalArg::Missing => self.code.varnames.iter().map(|s| s.to_object()).collect(),
};
let qualname = match args.co_qualname {
let qualname = match co_qualname {
OptionalArg::Present(qualname) => qualname,
OptionalArg::Missing => self.code.qualname.to_owned(),
};
let max_stackdepth = match args.co_stacksize {
let max_stackdepth = match co_stacksize {
OptionalArg::Present(stacksize) => stacksize,
OptionalArg::Missing => self.code.max_stackdepth,
};
let instructions = match args.co_code {
OptionalArg::Present(_code_bytes) => {
// Convert bytes back to instructions
// For now, keep the original instructions
// TODO: Properly parse bytecode from bytes
self.code.instructions.clone()
let instructions = match co_code {
OptionalArg::Present(code_bytes) => {
// Parse and validate bytecode from bytes
parse_bytecode(code_bytes.as_bytes())
.map_err(|e| vm.new_value_error(format!("invalid bytecode: {}", e)))?
}
OptionalArg::Missing => self.code.instructions.clone(),
};
let cellvars = match args.co_cellvars {
let cellvars = match co_cellvars {
OptionalArg::Present(cellvars) => cellvars
.into_iter()
.map(|o| o.as_interned_str(vm).unwrap())
@@ -774,7 +939,7 @@ impl PyCode {
OptionalArg::Missing => self.code.cellvars.clone(),
};
let freevars = match args.co_freevars {
let freevars = match co_freevars {
OptionalArg::Present(freevars) => freevars
.into_iter()
.map(|o| o.as_interned_str(vm).unwrap())
@@ -783,7 +948,7 @@ impl PyCode {
};
// Validate co_nlocals if provided
if let OptionalArg::Present(nlocals) = args.co_nlocals
if let OptionalArg::Present(nlocals) = co_nlocals
&& nlocals as usize != varnames.len()
{
return Err(vm.new_value_error(format!(
@@ -794,48 +959,50 @@ impl PyCode {
}
// Handle linetable and exceptiontable
let linetable = match args.co_linetable {
let linetable = match co_linetable {
OptionalArg::Present(linetable) => linetable.as_bytes().to_vec().into_boxed_slice(),
OptionalArg::Missing => self.code.linetable.clone(),
};
let exceptiontable = match args.co_exceptiontable {
let exceptiontable = match co_exceptiontable {
OptionalArg::Present(exceptiontable) => {
exceptiontable.as_bytes().to_vec().into_boxed_slice()
}
OptionalArg::Missing => self.code.exceptiontable.clone(),
};
Ok(Self {
code: CodeObject {
flags: CodeFlags::from_bits_truncate(flags),
posonlyarg_count,
arg_count,
kwonlyarg_count,
source_path: source_path.as_object().as_interned_str(vm).unwrap(),
first_line_number,
obj_name: obj_name.as_object().as_interned_str(vm).unwrap(),
qualname: qualname.as_object().as_interned_str(vm).unwrap(),
let new_code = CodeObject {
flags: CodeFlags::from_bits_truncate(flags),
posonlyarg_count,
arg_count,
kwonlyarg_count,
source_path: source_path.as_object().as_interned_str(vm).unwrap(),
first_line_number,
obj_name: obj_name.as_object().as_interned_str(vm).unwrap(),
qualname: qualname.as_object().as_interned_str(vm).unwrap(),
max_stackdepth,
instructions,
locations: self.code.locations.clone(),
constants: constants.into_iter().map(Literal).collect(),
names: names
.into_iter()
.map(|o| o.as_interned_str(vm).unwrap())
.collect(),
varnames: varnames
.into_iter()
.map(|o| o.as_interned_str(vm).unwrap())
.collect(),
cellvars,
freevars,
cell2arg: self.code.cell2arg.clone(),
linetable,
exceptiontable,
},
})
max_stackdepth,
instructions,
// FIXME: invalid locations. Actually locations is a duplication of linetable.
// It can be removed once we move every other code to use linetable only.
locations: self.code.locations.clone(),
constants: constants.into_iter().map(Literal).collect(),
names: names
.into_iter()
.map(|o| o.as_interned_str(vm).unwrap())
.collect(),
varnames: varnames
.into_iter()
.map(|o| o.as_interned_str(vm).unwrap())
.collect(),
cellvars,
freevars,
cell2arg: self.code.cell2arg.clone(),
linetable,
exceptiontable,
};
Ok(PyCode::new(new_code))
}
#[pymethod]
@@ -866,6 +1033,19 @@ impl ToPyObject for bytecode::CodeObject {
}
}
/// Validates and parses bytecode bytes into CodeUnit instructions.
/// Returns MarshalError if bytecode is invalid (odd length or contains invalid opcodes).
/// Note: Returning MarshalError is not necessary at this point because this is not a part of marshalling API.
/// However, we (temporarily) reuse MarshalError for simplicity.
fn parse_bytecode(bytecode_bytes: &[u8]) -> Result<Box<[CodeUnit]>, MarshalError> {
// Bytecode must have even length (each instruction is 2 bytes)
if !bytecode_bytes.len().is_multiple_of(2) {
return Err(MarshalError::InvalidBytecode);
}
parse_instructions_from_bytes(bytecode_bytes)
}
// Helper struct for reading linetable
struct LineTableReader<'a> {
data: &'a [u8],

View File

@@ -28,7 +28,7 @@ use rustpython_jit::CompiledCode;
#[pyclass(module = false, name = "function", traverse = "manual")]
#[derive(Debug)]
pub struct PyFunction {
code: PyRef<PyCode>,
code: PyMutex<PyRef<PyCode>>,
globals: PyDictRef,
builtins: PyObjectRef,
closure: Option<PyRef<PyTuple<PyCellRef>>>,
@@ -73,7 +73,7 @@ impl PyFunction {
let qualname = vm.ctx.new_str(code.qualname.as_str());
let func = Self {
code: code.clone(),
code: PyMutex::new(code.clone()),
globals,
builtins,
closure: None,
@@ -96,7 +96,7 @@ impl PyFunction {
func_args: FuncArgs,
vm: &VirtualMachine,
) -> PyResult<()> {
let code = &*self.code;
let code = &*self.code.lock();
let nargs = func_args.args.len();
let n_expected_args = code.arg_count as usize;
let total_args = code.arg_count as usize + code.kwonlyarg_count as usize;
@@ -392,14 +392,15 @@ impl Py<PyFunction> {
Err(err) => info!(
"jit: function `{}` is falling back to being interpreted because of the \
error: {}",
self.code.obj_name, err
self.code.lock().obj_name,
err
),
}
}
let code = &self.code;
let code = self.code.lock().clone();
let locals = if self.code.flags.contains(bytecode::CodeFlags::NEW_LOCALS) {
let locals = if code.flags.contains(bytecode::CodeFlags::NEW_LOCALS) {
ArgMapping::from_dict_exact(vm.ctx.new_dict())
} else if let Some(locals) = locals {
locals
@@ -451,7 +452,18 @@ impl PyPayload for PyFunction {
impl PyFunction {
#[pygetset]
fn __code__(&self) -> PyRef<PyCode> {
self.code.clone()
self.code.lock().clone()
}
#[pygetset(setter)]
fn set___code__(&self, code: PyRef<PyCode>) {
*self.code.lock() = code;
// TODO: jit support
// #[cfg(feature = "jit")]
// {
// // If available, clear cached compiled code.
// let _ = self.jitted_code.take();
// }
}
#[pygetset]
@@ -595,7 +607,8 @@ impl PyFunction {
.get_or_try_init(|| {
let arg_types = jit::get_jit_arg_types(&zelf, vm)?;
let ret_type = jit::jit_ret_type(&zelf, vm)?;
rustpython_jit::compile(&zelf.code.code, &arg_types, ret_type)
let code = zelf.code.lock();
rustpython_jit::compile(&code.code, &arg_types, ret_type)
.map_err(|err| jit::new_jit_error(err.to_string(), vm))
})
.map(drop)

View File

@@ -65,10 +65,10 @@ fn get_jit_arg_type(dict: &PyDictRef, name: &str, vm: &VirtualMachine) -> PyResu
}
pub fn get_jit_arg_types(func: &Py<PyFunction>, vm: &VirtualMachine) -> PyResult<Vec<JitType>> {
let arg_names = func.code.arg_names();
let code = func.code.lock();
let arg_names = code.arg_names();
if func
.code
if code
.flags
.intersects(CodeFlags::HAS_VARARGS | CodeFlags::HAS_VARKEYWORDS)
{
@@ -157,9 +157,13 @@ pub(crate) fn get_jit_args<'a>(
) -> Result<Args<'a>, ArgsError> {
let mut jit_args = jitted_code.args_builder();
let nargs = func_args.args.len();
let arg_names = func.code.arg_names();
if nargs > func.code.arg_count as usize || nargs < func.code.posonlyarg_count as usize {
let code = func.code.lock();
let arg_names = code.arg_names();
let arg_count = code.arg_count;
let posonlyarg_count = code.posonlyarg_count;
if nargs > arg_count as usize || nargs < posonlyarg_count as usize {
return Err(ArgsError::WrongNumberOfArgs);
}
@@ -178,7 +182,7 @@ pub(crate) fn get_jit_args<'a>(
}
jit_args.set(arg_idx, get_jit_value(vm, value)?)?;
} else if let Some(kwarg_idx) = arg_pos(arg_names.kwonlyargs, name) {
let arg_idx = kwarg_idx + func.code.arg_count as usize;
let arg_idx = kwarg_idx + arg_count as usize;
if jit_args.is_set(arg_idx) {
return Err(ArgsError::ArgPassedMultipleTimes);
}
@@ -193,7 +197,7 @@ pub(crate) fn get_jit_args<'a>(
// fill in positional defaults
if let Some(defaults) = defaults {
for (i, default) in defaults.iter().enumerate() {
let arg_idx = i + func.code.arg_count as usize - defaults.len();
let arg_idx = i + arg_count as usize - defaults.len();
if !jit_args.is_set(arg_idx) {
jit_args.set(arg_idx, get_jit_value(vm, default)?)?;
}
@@ -203,7 +207,7 @@ pub(crate) fn get_jit_args<'a>(
// fill in keyword only defaults
if let Some(kw_only_defaults) = kwdefaults {
for (i, name) in arg_names.kwonlyargs.iter().enumerate() {
let arg_idx = i + func.code.arg_count as usize;
let arg_idx = i + arg_count as usize;
if !jit_args.is_set(arg_idx) {
let default = kw_only_defaults
.get_item(&**name, vm)
@@ -214,5 +218,7 @@ pub(crate) fn get_jit_args<'a>(
}
}
drop(code);
jit_args.into_args().ok_or(ArgsError::NotAllArgsPassed)
}