diff --git a/vm/src/obj/mod.rs b/vm/src/obj/mod.rs index 95aacba41..6cbb5fa66 100644 --- a/vm/src/obj/mod.rs +++ b/vm/src/obj/mod.rs @@ -18,6 +18,7 @@ pub mod objiter; pub mod objlist; pub mod objmap; pub mod objmemory; +pub mod objmodule; pub mod objnone; pub mod objobject; pub mod objproperty; diff --git a/vm/src/obj/objmodule.rs b/vm/src/obj/objmodule.rs new file mode 100644 index 000000000..ef6cf2d6d --- /dev/null +++ b/vm/src/obj/objmodule.rs @@ -0,0 +1,30 @@ +use crate::frame::ScopeRef; +use crate::pyobject::{ + DictProtocol, PyContext, PyFuncArgs, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, +}; +use crate::vm::VirtualMachine; + +fn module_dir(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(obj, Some(vm.ctx.module_type()))]); + let scope = get_scope(obj); + let keys = scope + .locals + .get_key_value_pairs() + .iter() + .map(|(k, _v)| k.clone()) + .collect(); + Ok(vm.ctx.new_list(keys)) +} + +pub fn init(context: &PyContext) { + let module_type = &context.module_type; + context.set_attr(&module_type, "__dir__", context.new_rustfunc(module_dir)); +} + +fn get_scope(obj: &PyObjectRef) -> &ScopeRef { + if let PyObjectPayload::Module { ref scope, .. } = &obj.payload { + &scope + } else { + panic!("Can't get scope from non-module.") + } +} diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 6e0ab2701..00bce8458 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -31,6 +31,7 @@ use crate::obj::objiter; use crate::obj::objlist; use crate::obj::objmap; use crate::obj::objmemory; +use crate::obj::objmodule; use crate::obj::objnone; use crate::obj::objobject; use crate::obj::objproperty; @@ -317,6 +318,7 @@ impl PyContext { objcode::init(&context); objframe::init(&context); objnone::init(&context); + objmodule::init(&context); exceptions::init(&context); context } @@ -357,6 +359,10 @@ impl PyContext { self.list_type.clone() } + pub fn module_type(&self) -> PyObjectRef { + self.module_type.clone() + } + pub fn set_type(&self) -> PyObjectRef { self.set_type.clone() }