diff --git a/compiler/src/compile.rs b/compiler/src/compile.rs index 1daf80e50..23b46f1b9 100644 --- a/compiler/src/compile.rs +++ b/compiler/src/compile.rs @@ -188,7 +188,19 @@ impl Compiler { ) -> Result<(), CompileError> { let size_before = self.output_stack.len(); self.symbol_table_stack.push(symbol_table); - self.compile_statements(&program.statements)?; + + let (statements, doc) = get_doc(&program.statements); + if let Some(value) = doc { + self.emit(Instruction::LoadConst { + value: bytecode::Constant::String { value }, + }); + self.emit(Instruction::StoreName { + name: "__doc__".to_owned(), + scope: bytecode::NameScope::Global, + }); + } + self.compile_statements(statements)?; + assert_eq!(self.output_stack.len(), size_before); // Emit None at end: @@ -1897,30 +1909,36 @@ impl Compiler { } fn compile_string(&mut self, string: &ast::StringGroup) -> Result<(), CompileError> { - match string { - ast::StringGroup::Joined { values } => { - for value in values { - self.compile_string(value)?; + if let Some(value) = try_get_constant_string(string) { + self.emit(Instruction::LoadConst { + value: bytecode::Constant::String { value }, + }); + } else { + match string { + ast::StringGroup::Joined { values } => { + for value in values { + self.compile_string(value)?; + } + self.emit(Instruction::BuildString { size: values.len() }) + } + ast::StringGroup::Constant { value } => { + self.emit(Instruction::LoadConst { + value: bytecode::Constant::String { + value: value.to_string(), + }, + }); + } + ast::StringGroup::FormattedValue { + value, + conversion, + spec, + } => { + self.compile_expression(value)?; + self.emit(Instruction::FormatValue { + conversion: conversion.map(compile_conversion_flag), + spec: spec.clone(), + }); } - self.emit(Instruction::BuildString { size: values.len() }) - } - ast::StringGroup::Constant { value } => { - self.emit(Instruction::LoadConst { - value: bytecode::Constant::String { - value: value.to_string(), - }, - }); - } - ast::StringGroup::FormattedValue { - value, - conversion, - spec, - } => { - self.compile_expression(value)?; - self.emit(Instruction::FormatValue { - conversion: conversion.map(compile_conversion_flag), - spec: spec.clone(), - }); } } Ok(()) @@ -2000,13 +2018,11 @@ impl Compiler { } fn get_doc(body: &[ast::Statement]) -> (&[ast::Statement], Option) { - if let Some(val) = body.get(0) { + if let Some((val, body_rest)) = body.split_first() { if let ast::StatementType::Expression { ref expression } = val.node { if let ast::ExpressionType::String { value } = &expression.node { - if let ast::StringGroup::Constant { ref value } = value { - if let Some((_, body_rest)) = body.split_first() { - return (body_rest, Some(value.to_string())); - } + if let Some(value) = try_get_constant_string(value) { + return (body_rest, Some(value.to_string())); } } } @@ -2014,6 +2030,27 @@ fn get_doc(body: &[ast::Statement]) -> (&[ast::Statement], Option) { (body, None) } +fn try_get_constant_string(string: &ast::StringGroup) -> Option { + fn get_constant_string_inner(out_string: &mut String, string: &ast::StringGroup) -> bool { + match string { + ast::StringGroup::Constant { value } => { + out_string.push_str(&value); + true + } + ast::StringGroup::Joined { values } => values + .iter() + .all(|value| get_constant_string_inner(out_string, value)), + ast::StringGroup::FormattedValue { .. } => false, + } + } + let mut out_string = String::new(); + if get_constant_string_inner(&mut out_string, string) { + Some(out_string) + } else { + None + } +} + fn compile_location(location: &ast::Location) -> bytecode::Location { bytecode::Location::new(location.row(), location.column()) } diff --git a/src/main.rs b/src/main.rs index dcd55aa1d..5e2d7c60c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -312,7 +312,7 @@ fn run_rustpython(vm: &VirtualMachine, matches: &ArgMatches) -> PyResult<()> { } let scope = vm.new_scope_with_builtins(); - let main_module = vm.ctx.new_module("__main__", scope.globals.clone()); + let main_module = vm.new_module("__main__", scope.globals.clone()); vm.get_attribute(vm.sys_module.clone(), "modules")? .set_item("__main__", main_module, vm)?; diff --git a/vm/src/function.rs b/vm/src/function.rs index d199860b1..9874309fc 100644 --- a/vm/src/function.rs +++ b/vm/src/function.rs @@ -393,6 +393,18 @@ impl OptionalArg { } } +pub type OptionalOption = OptionalArg>; + +impl OptionalOption { + #[inline] + pub fn flat_option(self) -> Option { + match self { + Present(Some(value)) => Some(value), + _ => None, + } + } +} + impl FromArgs for OptionalArg where T: TryFromObject, diff --git a/vm/src/import.rs b/vm/src/import.rs index 188418c54..267adfc86 100644 --- a/vm/src/import.rs +++ b/vm/src/import.rs @@ -87,7 +87,7 @@ pub fn import_codeobj( if set_file_attr { attrs.set_item("__file__", vm.new_str(code_obj.source_path.to_owned()), vm)?; } - let module = vm.ctx.new_module(module_name, attrs.clone()); + let module = vm.new_module(module_name, attrs.clone()); // Store module in cache to prevent infinite loop with mutual importing libs: let sys_modules = vm.get_attribute(vm.sys_module.clone(), "modules")?; diff --git a/vm/src/macros.rs b/vm/src/macros.rs index 5ae975d8e..ea10abfdc 100644 --- a/vm/src/macros.rs +++ b/vm/src/macros.rs @@ -117,7 +117,7 @@ macro_rules! no_kwargs { #[macro_export] macro_rules! py_module { ( $vm:expr, $module_name:expr, { $($name:expr => $value:expr),* $(,)* }) => {{ - let module = $vm.ctx.new_module($module_name, $vm.ctx.new_dict()); + let module = $vm.new_module($module_name, $vm.ctx.new_dict()); $vm.set_attr(&module, "__name__", $vm.ctx.new_str($module_name.to_string())).unwrap(); $( $vm.set_attr(&module, $name, $value).unwrap(); diff --git a/vm/src/obj/objmodule.rs b/vm/src/obj/objmodule.rs index d5df5664b..902418498 100644 --- a/vm/src/obj/objmodule.rs +++ b/vm/src/obj/objmodule.rs @@ -1,6 +1,8 @@ +use crate::function::OptionalOption; +use crate::obj::objdict::PyDictRef; use crate::obj::objstr::PyStringRef; use crate::obj::objtype::PyClassRef; -use crate::pyobject::{PyContext, PyRef, PyResult, PyValue}; +use crate::pyobject::{ItemProtocol, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; #[derive(Debug)] @@ -17,13 +19,47 @@ impl PyValue for PyModule { } } +pub fn init_module_dict( + vm: &VirtualMachine, + module_dict: &PyDictRef, + name: PyObjectRef, + doc: PyObjectRef, +) { + module_dict + .set_item("__name__", name, vm) + .expect("Failed to set __name__ on module"); + module_dict + .set_item("__doc__", doc, vm) + .expect("Failed to set __doc__ on module"); + module_dict + .set_item("__package__", vm.get_none(), vm) + .expect("Failed to set __package__ on module"); + module_dict + .set_item("__loader__", vm.get_none(), vm) + .expect("Failed to set __loader__ on module"); + module_dict + .set_item("__spec__", vm.get_none(), vm) + .expect("Failed to set __spec__ on module"); +} + impl PyModuleRef { - fn new(cls: PyClassRef, name: PyStringRef, vm: &VirtualMachine) -> PyResult { + fn new( + cls: PyClassRef, + name: PyStringRef, + doc: OptionalOption, + vm: &VirtualMachine, + ) -> PyResult { let zelf = PyModule { name: name.as_str().to_owned(), } .into_ref_with_type(vm, cls)?; - vm.set_attr(zelf.as_object(), "__name__", name)?; + init_module_dict( + vm, + zelf.as_object().dict.as_ref().unwrap(), + name.into_object(), + doc.flat_option() + .map_or_else(|| vm.get_none(), PyRef::into_object), + ); Ok(zelf) } diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index ffa863e4a..05049528c 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -430,16 +430,6 @@ impl PyContext { objtype::new(self.type_type(), name, vec![base], PyAttributes::new()).unwrap() } - pub fn new_module(&self, name: &str, dict: PyDictRef) -> PyObjectRef { - PyObject::new( - PyModule { - name: name.to_string(), - }, - self.types.module_type.clone(), - Some(dict), - ) - } - pub fn new_namespace(&self) -> PyObjectRef { PyObject::new(PyNamespace, self.namespace_type(), Some(self.new_dict())) } diff --git a/vm/src/vm.rs b/vm/src/vm.rs index 486550b3a..ce9382eaf 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -25,6 +25,7 @@ use crate::obj::objfunction::{PyFunction, PyMethod}; use crate::obj::objgenerator::PyGenerator; use crate::obj::objint::PyInt; use crate::obj::objiter; +use crate::obj::objmodule::{self, PyModule}; use crate::obj::objsequence; use crate::obj::objstr::{PyString, PyStringRef}; use crate::obj::objtuple::PyTupleRef; @@ -32,8 +33,8 @@ use crate::obj::objtype; use crate::obj::objtype::PyClassRef; use crate::pyhash; use crate::pyobject::{ - IdProtocol, ItemProtocol, PyContext, PyObjectRef, PyResult, PyValue, TryFromObject, TryIntoRef, - TypeProtocol, + IdProtocol, ItemProtocol, PyContext, PyObject, PyObjectRef, PyResult, PyValue, TryFromObject, + TryIntoRef, TypeProtocol, }; use crate::scope::Scope; use crate::stdlib; @@ -142,9 +143,23 @@ impl VirtualMachine { flame_guard!("init VirtualMachine"); let ctx = PyContext::new(); + // make a new module without access to the vm; doesn't + // set __spec__, __loader__, etc. attributes + let new_module = |name: &str, dict| { + PyObject::new( + PyModule { + name: name.to_owned(), + }, + ctx.types.module_type.clone(), + Some(dict), + ) + }; + // Hard-core modules: - let builtins = ctx.new_module("builtins", ctx.new_dict()); - let sysmod = ctx.new_module("sys", ctx.new_dict()); + let builtins_dict = ctx.new_dict(); + let builtins = new_module("builtins", builtins_dict.clone()); + let sysmod_dict = ctx.new_dict(); + let sysmod = new_module("sys", sysmod_dict.clone()); let stdlib_inits = RefCell::new(stdlib::get_module_inits()); let frozen = RefCell::new(frozen::get_module_inits()); @@ -168,6 +183,19 @@ impl VirtualMachine { signal_handlers: Default::default(), }; + objmodule::init_module_dict( + &vm, + &builtins_dict, + vm.new_str("builtins".to_owned()), + vm.get_none(), + ); + objmodule::init_module_dict( + &vm, + &sysmod_dict, + vm.new_str("sys".to_owned()), + vm.get_none(), + ); + builtins::make_module(&vm, builtins.clone()); sysmodule::make_module(&vm, sysmod, builtins); vm @@ -254,6 +282,17 @@ impl VirtualMachine { self.ctx.new_bool(b) } + pub fn new_module(&self, name: &str, dict: PyDictRef) -> PyObjectRef { + objmodule::init_module_dict(self, &dict, self.new_str(name.to_owned()), self.get_none()); + PyObject::new( + PyModule { + name: name.to_owned(), + }, + self.ctx.types.module_type.clone(), + Some(dict), + ) + } + #[cfg_attr(feature = "flame-it", flame("VirtualMachine"))] fn new_exception_obj(&self, exc_type: PyClassRef, args: Vec) -> PyResult { // TODO: add repr of args into logging? diff --git a/wasm/lib/src/vm_class.rs b/wasm/lib/src/vm_class.rs index 5015c3925..1f2584f0a 100644 --- a/wasm/lib/src/vm_class.rs +++ b/wasm/lib/src/vm_class.rs @@ -256,7 +256,7 @@ impl WASMVirtualMachine { let mod_name = name.clone(); let stdlib_init_fn = move |vm: &VirtualMachine| { - let module = vm.ctx.new_module(&name, vm.ctx.new_dict()); + let module = vm.new_module(&name, vm.ctx.new_dict()); for (key, value) in module_items.clone() { vm.set_attr(&module, key, value).unwrap(); }