diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index 0df709bb13..5f87fbba5f 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -20,6 +20,7 @@ pub struct PyBaseException { cause: RefCell>, context: RefCell>, suppress_context: Cell, + args: RefCell, } pub type PyBaseExceptionRef = PyRef; @@ -44,10 +45,28 @@ impl PyBaseException { cause: RefCell::new(None), context: RefCell::new(None), suppress_context: Cell::new(false), + args: RefCell::new(vm.ctx.new_tuple(vec![])), } .into_ref_with_type(vm, cls) } + #[pymethod(name = "__init__")] + fn init(&self, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult<()> { + self.args.replace(vm.ctx.new_tuple(args.args.to_vec())); + Ok(()) + } + + #[pyproperty] + fn args(&self, _vm: &VirtualMachine) -> PyObjectRef { + self.args.borrow().clone() + } + + #[pyproperty(setter)] + fn set_args(&self, args: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.args.replace(args); + Ok(vm.get_none()) + } + #[pyproperty(name = "__traceback__")] fn get_traceback(&self, _vm: &VirtualMachine) -> Option { self.traceback.borrow().clone() @@ -103,14 +122,6 @@ impl PyBaseException { } } -fn exception_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - let exc_self = args.args[0].clone(); - let exc_args = vm.ctx.new_tuple(args.args[1..].to_vec()); - vm.set_attr(&exc_self, "args", exc_args)?; - - Ok(vm.get_none()) -} - /// Print exception chain pub fn print_exception(vm: &VirtualMachine, exc: &PyObjectRef) { let _ = write_exception(io::stdout(), vm, exc); @@ -463,10 +474,9 @@ impl ExceptionZoo { } fn import_error_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - // TODO: call super().__init__(*args) instead - exception_init(vm, args.clone())?; - let exc_self = args.args[0].clone(); + + vm.set_attr(&exc_self, "args", vm.ctx.new_tuple(args.args[1..].to_vec()))?; vm.set_attr( &exc_self, "name", @@ -494,9 +504,6 @@ fn import_error_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn init(context: &PyContext) { let base_exception_type = &context.exceptions.base_exception_type; PyBaseException::extend_class(context, base_exception_type); - extend_class!(context, base_exception_type, { - "__init__" => context.new_rustfunc(exception_init), - }); let exception_type = &context.exceptions.exception_type; extend_class!(context, exception_type, {