diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index f388aeac6..7c19968da 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -7,8 +7,7 @@ use num_traits::{Pow, Signed, ToPrimitive, Zero}; use crate::format::FormatSpec; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyobject::{ - FromPyObjectRef, IntoPyObject, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, - TypeProtocol, + IntoPyObject, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -388,7 +387,7 @@ fn int_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { None => Zero::zero(), }; Ok(PyInt::new(val) - .into_ref_with_type(vm, PyClassRef::from_pyobj(cls))? + .into_ref_with_type(vm, cls.clone().downcast().unwrap())? .into_object()) } diff --git a/vm/src/obj/objtype.rs b/vm/src/obj/objtype.rs index d2cabe1ff..81f04603b 100644 --- a/vm/src/obj/objtype.rs +++ b/vm/src/obj/objtype.rs @@ -4,8 +4,8 @@ use std::fmt; use crate::function::{Args, KwArgs, PyFuncArgs}; use crate::pyobject::{ - FromPyObjectRef, IdProtocol, PyAttributes, PyContext, PyObject, PyObjectRef, PyRef, PyResult, - PyValue, TypeProtocol, + IdProtocol, PyAttributes, PyContext, PyObject, PyObjectRef, PyRef, PyResult, PyValue, + TypeProtocol, }; use crate::vm::VirtualMachine; @@ -236,7 +236,7 @@ pub fn type_new_class( let mut bases: Vec = vm .extract_elements(bases)? .iter() - .map(|x| FromPyObjectRef::from_pyobj(x)) + .map(|x| x.clone().downcast().unwrap()) .collect(); bases.push(vm.ctx.object()); let name = objstr::get_value(name); @@ -385,7 +385,6 @@ pub fn new( #[cfg(test)] mod tests { - use super::FromPyObjectRef; use super::{linearise_mro, new}; use super::{HashMap, IdProtocol, PyClassRef, PyContext}; @@ -417,8 +416,8 @@ mod tests { ) .unwrap(); - let a: PyClassRef = FromPyObjectRef::from_pyobj(&a); - let b: PyClassRef = FromPyObjectRef::from_pyobj(&b); + let a: PyClassRef = a.downcast().unwrap(); + let b: PyClassRef = b.downcast().unwrap(); assert_eq!( map_ids(linearise_mro(vec![ diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 83b2deab9..41a1ca6f3 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -159,7 +159,7 @@ pub fn create_type(name: &str, type_type: &PyClassRef, base: &PyClassRef) -> PyC dict, ) .unwrap(); - FromPyObjectRef::from_pyobj(&new_type) + new_type.downcast().unwrap() } pub type PyNotImplementedRef = PyRef; @@ -205,7 +205,7 @@ fn init_type_hierarchy() -> (PyClassRef, PyClassRef) { dict: Some(RefCell::new(PyAttributes::new())), payload: PyClass { name: String::from("type"), - mro: vec![FromPyObjectRef::from_pyobj(&object_type)], + mro: vec![object_type.clone().downcast().unwrap()], }, } .into_ref(); @@ -216,8 +216,8 @@ fn init_type_hierarchy() -> (PyClassRef, PyClassRef) { ptr::write(&mut (*type_type_ptr).typ, type_type.clone()); ( - PyClassRef::from_pyobj(&type_type), - PyClassRef::from_pyobj(&object_type), + type_type.downcast().unwrap(), + object_type.downcast().unwrap(), ) } } @@ -583,7 +583,7 @@ impl PyContext { PyAttributes::new(), ) .unwrap(); - PyClassRef::from_pyobj(&typ) + typ.downcast().unwrap() } pub fn new_scope(&self) -> Scope { @@ -721,6 +721,21 @@ where pub payload: T, } +impl PyObject { + pub fn downcast(self: Rc) -> Option> { + if self.payload_is::() { + Some({ + PyRef { + obj: self, + _payload: PhantomData, + } + }) + } else { + None + } + } +} + /// A reference to a Python object. /// /// Note that a `PyRef` can only deref to a shared / immutable reference. @@ -860,16 +875,12 @@ impl IdProtocol for PyRef { } } -pub trait FromPyObjectRef { - fn from_pyobj(obj: &PyObjectRef) -> Self; -} - pub trait TypeProtocol { fn typ(&self) -> PyObjectRef { self.type_ref().clone() } fn type_pyref(&self) -> PyClassRef { - FromPyObjectRef::from_pyobj(self.type_ref()) + self.typ().downcast().unwrap() } fn type_ref(&self) -> &PyObjectRef; } @@ -1207,19 +1218,6 @@ impl PyObjectPayload for T { } } -impl FromPyObjectRef for PyRef { - fn from_pyobj(obj: &PyObjectRef) -> Self { - if obj.payload_is::() { - PyRef { - obj: obj.clone(), - _payload: PhantomData, - } - } else { - panic!("Error getting inner type: {:?}", obj.typ) - } - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/vm/src/stdlib/json.rs b/vm/src/stdlib/json.rs index e7e6d3b6f..e1cfbd93b 100644 --- a/vm/src/stdlib/json.rs +++ b/vm/src/stdlib/json.rs @@ -6,15 +6,13 @@ use serde::ser::{SerializeMap, SerializeSeq}; use serde_json; use crate::function::PyFuncArgs; -use crate::obj::objtype::PyClassRef; use crate::obj::{ objbool, objdict, objfloat, objint, objsequence, objstr::{self, PyString}, objtype, }; use crate::pyobject::{ - create_type, DictProtocol, FromPyObjectRef, IdProtocol, PyContext, PyObjectRef, PyResult, - TypeProtocol, + create_type, DictProtocol, IdProtocol, PyContext, PyObjectRef, PyResult, TypeProtocol, }; use crate::VirtualMachine; use num_traits::cast::ToPrimitive; @@ -208,7 +206,7 @@ pub fn de_pyobject(vm: &VirtualMachine, s: &str) -> PyResult { .unwrap() .get_item("JSONDecodeError") .unwrap(); - let json_decode_error = PyClassRef::from_pyobj(&json_decode_error); + let json_decode_error = json_decode_error.downcast().unwrap(); let exc = vm.new_exception(json_decode_error, format!("{}", err)); vm.ctx.set_attr(&exc, "lineno", vm.ctx.new_int(err.line())); vm.ctx.set_attr(&exc, "colno", vm.ctx.new_int(err.column())); diff --git a/vm/src/vm.rs b/vm/src/vm.rs index a0e0eaf78..4a3cc3abc 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -30,8 +30,8 @@ use crate::obj::objtuple::PyTuple; use crate::obj::objtype; use crate::obj::objtype::PyClassRef; use crate::pyobject::{ - DictProtocol, FromPyObjectRef, IdProtocol, PyContext, PyObjectRef, PyResult, TryFromObject, - TryIntoRef, TypeProtocol, + DictProtocol, IdProtocol, PyContext, PyObjectRef, PyResult, TryFromObject, TryIntoRef, + TypeProtocol, }; use crate::stdlib; use crate::sysmodule; @@ -112,7 +112,7 @@ impl VirtualMachine { let class = self .get_attribute(module.clone(), class) .unwrap_or_else(|_| panic!("module {} has no class {}", module, class)); - PyClassRef::from_pyobj(&class) + class.downcast().unwrap() } /// Create a new python string object.