diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index cf5d46466..7839bae65 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -3,15 +3,24 @@ use crate::obj::objtraceback::PyTracebackRef; use crate::obj::objtuple::{PyTuple, PyTupleRef}; use crate::obj::objtype; use crate::obj::objtype::PyClassRef; -use crate::pyobject::{IdProtocol, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol}; +use crate::pyobject::{ + IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, +}; use crate::types::create_type; use crate::vm::VirtualMachine; use itertools::Itertools; +use std::cell::{Cell, RefCell}; use std::fs::File; use std::io::{self, BufRead, BufReader, Write}; +#[pyclass] #[derive(Debug)] -pub struct PyBaseException {} +pub struct PyBaseException { + traceback: RefCell>, + cause: RefCell>, + context: RefCell>, + suppress_context: Cell, +} pub type PyBaseExceptionRef = PyRef; impl PyValue for PyBaseException { @@ -22,19 +31,65 @@ impl PyValue for PyBaseException { } } -impl PyBaseExceptionRef { - fn new( +#[pyimpl] +impl PyBaseException { + #[pyslot(new)] + fn tp_new( cls: PyClassRef, _args: PyFuncArgs, vm: &VirtualMachine, ) -> PyResult { - let zelf = PyBaseException {}.into_ref_with_type(vm, cls)?; - let exc = zelf.clone().into_object(); - vm.set_attr(&exc, "__traceback__", vm.get_none())?; - vm.set_attr(&exc, "__cause__", vm.get_none())?; - vm.set_attr(&exc, "__context__", vm.get_none())?; - vm.set_attr(&exc, "__suppress_context__", vm.new_bool(false))?; - Ok(zelf) + PyBaseException { + traceback: RefCell::new(None), + cause: RefCell::new(None), + context: RefCell::new(None), + suppress_context: Cell::new(false), + } + .into_ref_with_type(vm, cls) + } + + #[pyproperty(name = "__traceback__")] + fn get_traceback(&self, _vm: &VirtualMachine) -> Option { + self.traceback.borrow().clone() + } + + #[pyproperty(name = "__traceback__", setter)] + fn set_traceback(&self, traceback: Option, vm: &VirtualMachine) -> PyResult { + self.traceback.replace(traceback); + Ok(vm.get_none()) + } + + #[pyproperty(name = "__cause__")] + fn get_cause(&self, _vm: &VirtualMachine) -> Option { + self.cause.borrow().clone() + } + + #[pyproperty(name = "__cause__", setter)] + fn set_cause(&self, cause: Option, vm: &VirtualMachine) -> PyResult { + self.cause.replace(cause); + Ok(vm.get_none()) + } + + #[pyproperty(name = "__context__")] + fn get_context(&self, _vm: &VirtualMachine) -> Option { + self.context.borrow().clone() + } + + #[pyproperty(name = "__context__", setter)] + fn set_context(&self, context: Option, vm: &VirtualMachine) -> PyResult { + self.context.replace(context); + Ok(vm.get_none()) + } + + #[pyproperty(name = "__suppress_context__")] + fn get_suppress_context(&self, _vm: &VirtualMachine) -> bool { + self.suppress_context.get() + } + + #[pyproperty(name = "__suppress_context__", setter)] + fn set_suppress_context(&self, suppress_context: bool, vm: &VirtualMachine) -> PyResult { + self.suppress_context.set(suppress_context); + Ok(vm.get_none()) } } @@ -441,8 +496,8 @@ fn import_error_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn init(context: &PyContext) { let base_exception_type = &context.exceptions.base_exception_type; + PyBaseException::extend_class(context, base_exception_type); extend_class!(context, base_exception_type, { - "__new__" => context.new_rustfunc(PyBaseExceptionRef::new), "__init__" => context.new_rustfunc(exception_init), "with_traceback" => context.new_rustfunc(exception_with_traceback) });