diff --git a/vm/src/objtype.rs b/vm/src/objtype.rs index 2185c26b1..c429e0e26 100644 --- a/vm/src/objtype.rs +++ b/vm/src/objtype.rs @@ -1,4 +1,5 @@ use super::objdict; +use super::objtype; // Required for arg_check! to use isinstance use super::pyobject::{ AttributeProtocol, IdProtocol, PyContext, PyFuncArgs, PyObject, PyObjectKind, PyObjectRef, PyResult, ToRust, TypeProtocol, @@ -28,7 +29,16 @@ pub fn init(context: &PyContext) { } fn type_mro(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - match _mro(args.args[0].clone()) { + println!("{:?}", args); + arg_check!( + vm, + args, + required = [ + (cls, Some(vm.ctx.type_type())), + (_typ, Some(vm.ctx.type_type())) + ] + ); + match _mro(cls.clone()) { Some(mro) => Ok(vm.context().new_tuple(mro)), None => Err(vm.new_type_error("Only classes have an MRO.".to_string())), } @@ -71,14 +81,29 @@ pub fn get_type_name(typ: &PyObjectRef) -> String { pub fn type_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { debug!("type.__new__{:?}", args); if args.args.len() == 2 { - Ok(args.args[1].typ()) + arg_check!( + vm, + args, + required = [(_typ, Some(vm.ctx.type_type())), (obj, None)] + ); + Ok(obj.typ()) } else if args.args.len() == 4 { - let typ = args.args[0].clone(); - let name = args.args[1].to_str().unwrap(); + arg_check!( + vm, + args, + required = [ + (typ, Some(vm.ctx.type_type())), + (name, Some(vm.ctx.str_type())), + // bases needs to be mutable, which arg_check! doesn't support, so we just check + // the type and extract it again below + // TODO: arg_check! should support specifying iterables + (_bases, None), + (dict, Some(vm.ctx.dict_type())) + ] + ); let mut bases = args.args[2].to_vec().unwrap(); bases.push(vm.context().object()); - let dict = args.args[3].clone(); - new(typ, &name, bases, dict) + new(typ.clone(), &name.to_str().unwrap(), bases, dict.clone()) } else { Err(vm.new_type_error(format!(": type_new: {:?}", args))) } @@ -185,9 +210,7 @@ pub fn new(typ: PyObjectRef, name: &str, bases: Vec, dict: PyObject } fn type_str(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - // TODO: fix macro: - // arg_check!(vm, args, required = [(obj, Some(vm.ctx.type_type()))]); - let obj = args.args[0].clone(); + arg_check!(vm, args, required = [(obj, Some(vm.ctx.type_type()))]); let type_name = get_type_name(&obj); Ok(vm.new_str(format!("", type_name))) }