diff --git a/vm/src/obj/objmappingproxy.rs b/vm/src/obj/objmappingproxy.rs index b38e9d1ef..692fde2b3 100644 --- a/vm/src/obj/objmappingproxy.rs +++ b/vm/src/obj/objmappingproxy.rs @@ -1,6 +1,7 @@ use super::objstr::PyStringRef; use super::objtype::{self, PyClassRef}; -use crate::pyobject::{PyClassImpl, PyContext, PyRef, PyResult, PyValue}; +use crate::function::OptionalArg; +use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; #[pyclass] @@ -23,6 +24,14 @@ impl PyMappingProxy { PyMappingProxy { class } } + #[pymethod] + fn get(&self, key: PyStringRef, default: OptionalArg, vm: &VirtualMachine) -> PyObjectRef { + let default = default.into_option(); + objtype::class_get_attr(&self.class, key.as_str()) + .or(default) + .unwrap_or_else(|| vm.get_none()) + } + #[pymethod(name = "__getitem__")] pub fn getitem(&self, key: PyStringRef, vm: &VirtualMachine) -> PyResult { if let Some(value) = objtype::class_get_attr(&self.class, key.as_str()) { diff --git a/vm/src/obj/objtype.rs b/vm/src/obj/objtype.rs index 48e50bcc9..367ee6d3f 100644 --- a/vm/src/obj/objtype.rs +++ b/vm/src/obj/objtype.rs @@ -2,7 +2,7 @@ use std::cell::RefCell; use std::collections::HashMap; use std::fmt; -use crate::function::{Args, KwArgs, PyFuncArgs}; +use crate::function::PyFuncArgs; use crate::pyobject::{ IdProtocol, PyAttributes, PyContext, PyIterable, PyObject, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, @@ -203,6 +203,11 @@ impl PyClassRef { } } +fn type_mro(cls: PyClassRef, vm: &VirtualMachine) -> PyObjectRef { + vm.ctx + .new_list(cls.mro.iter().map(|x| x.clone().into_object()).collect()) +} + /* * The magical type type */ @@ -213,13 +218,14 @@ pub fn init(ctx: &PyContext) { type(name, bases, dict) -> a new type"; extend_class!(&ctx, &ctx.types.type_type, { + "mro" => ctx.new_rustfunc(type_mro), "__call__" => ctx.new_rustfunc(type_call), "__dict__" => PropertyBuilder::new(ctx) .add_getter(type_dict) .add_setter(type_dict_setter) .create(), - "__new__" => ctx.new_rustfunc(type_new), + "__new__" => ctx.new_classmethod(type_new), "__mro__" => PropertyBuilder::new(ctx) .add_getter(PyClassRef::mro) @@ -260,13 +266,30 @@ pub fn issubclass(subclass: &PyClassRef, cls: &PyClassRef) -> bool { subclass.is(cls) || mro.iter().any(|c| c.is(cls.as_object())) } -pub fn type_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { +pub fn type_new( + zelf: PyClassRef, + cls: PyClassRef, + args: PyFuncArgs, + vm: &VirtualMachine, +) -> PyResult { vm_trace!("type.__new__ {:?}", args); - if args.args.len() == 2 { - Ok(args.args[1].class().into_object()) - } else if args.args.len() == 4 { - let (typ, name, bases, dict) = args.bind(vm)?; - type_new_class(vm, typ, name, bases, dict).map(PyRef::into_object) + + if !issubclass(&cls, &zelf) { + return Err(vm.new_type_error(format!( + "{zelf}.__new__({cls}): {cls} is not a subtype of {zelf}", + zelf = zelf.name, + cls = cls.name, + ))); + } + + // let new = class_get_super_attr(&zelf, "__new__").expect("Couldn't find __new__"); + + // vm.invoke(&new, args.insert(cls.into_object())); + if args.args.len() == 1 { + Ok(args.args[0].class().into_object()) + } else if args.args.len() == 3 { + let (name, bases, dict) = args.bind(vm)?; + type_new_class(vm, cls, name, bases, dict).map(PyRef::into_object) } else { Err(vm.new_type_error("type() takes 1 or 3 arguments".to_string())) } @@ -284,15 +307,21 @@ pub fn type_new_class( new(typ.clone(), name.as_str(), bases, dict.to_attributes()) } -pub fn type_call(class: PyClassRef, args: Args, kwargs: KwArgs, vm: &VirtualMachine) -> PyResult { +pub fn type_call(class: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { vm_trace!("type_call: {:?}", class); let new = class_get_attr(&class, "__new__").expect("All types should have a __new__."); - let new_wrapped = vm.call_get_descriptor(new, class.into_object())?; - let obj = vm.invoke(&new_wrapped, (&args, &kwargs))?; + let new_wrapped = vm.call_get_descriptor(new, class.clone().into_object())?; + // TODO: don't do this, init __new__ based on tp_new + let new_args = if class.is(&vm.ctx.types.type_type) { + args.insert(class.into_object()) + } else { + args.clone() + }; + let obj = vm.invoke(&new_wrapped, new_args)?; if let Some(init_method_or_err) = vm.get_method(obj.clone(), "__init__") { let init_method = init_method_or_err?; - let res = vm.invoke(&init_method, (&args, &kwargs))?; + let res = vm.invoke(&init_method, args)?; if !res.is(&vm.get_none()) { return Err(vm.new_type_error("__init__ must return None".to_string())); } @@ -314,15 +343,19 @@ fn type_dict_setter(_instance: PyClassRef, _value: PyObjectRef, vm: &VirtualMach pub fn class_get_attr(class: &PyClassRef, attr_name: &str) -> Option { flame_guard!(format!("class_get_attr({:?})", attr_name)); - if let Some(item) = class.attributes.borrow().get(attr_name).cloned() { - return Some(item); - } - for class in &class.mro { - if let Some(item) = class.attributes.borrow().get(attr_name).cloned() { - return Some(item); - } - } - None + class + .attributes + .borrow() + .get(attr_name) + .cloned() + .or_else(|| class_get_super_attr(class, attr_name)) +} + +pub fn class_get_super_attr(class: &PyClassRef, attr_name: &str) -> Option { + class + .mro + .iter() + .find_map(|class| class.attributes.borrow().get(attr_name).cloned()) } // This is the internal has_attr implementation for fast lookup on a class.