diff --git a/tests/snippets/import_name.py b/tests/snippets/import_name.py new file mode 100644 index 000000000..265fb9bf1 --- /dev/null +++ b/tests/snippets/import_name.py @@ -0,0 +1,2 @@ +def import_func(): + assert __name__ == "import_name" diff --git a/tests/snippets/name.py b/tests/snippets/name.py new file mode 100644 index 000000000..97f9367ec --- /dev/null +++ b/tests/snippets/name.py @@ -0,0 +1,9 @@ +#when name.py is run __name__ should equal to __main__ +assert __name__ == "__main__" + +from import_name import import_func + +#__name__ should be set to import_func +import_func() + +assert __name__ == "__main__" diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 42c16179c..9bff82526 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -339,7 +339,7 @@ pub fn builtin_print(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(v) => objstr::get_value(&v), Err(err) => return Err(err), }; - print!("{} ", s); + print!("{}", s); } println!(); io::stdout().flush().unwrap(); @@ -395,6 +395,11 @@ fn builtin_setattr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn make_module(ctx: &PyContext) -> PyObjectRef { // scope[String::from("print")] = print; let mut dict = HashMap::new(); + //set __name__ fixes: https://github.com/RustPython/RustPython/issues/146 + dict.insert( + String::from("__name__"), + ctx.new_str(String::from("__main__")), + ); dict.insert(String::from("abs"), ctx.new_rustfunc(builtin_abs)); dict.insert(String::from("all"), ctx.new_rustfunc(builtin_all)); dict.insert(String::from("any"), ctx.new_rustfunc(builtin_any)); diff --git a/vm/src/import.rs b/vm/src/import.rs index 534959fe3..e70cdb662 100644 --- a/vm/src/import.rs +++ b/vm/src/import.rs @@ -45,7 +45,7 @@ fn import_uncached_module(vm: &mut VirtualMachine, module: &str) -> PyResult { let builtins = vm.get_builtin_scope(); let scope = vm.ctx.new_scope(Some(builtins)); - + scope.set_item(&"__name__".to_string(), vm.new_str(module.to_string())); match vm.run_code_obj(code_obj, scope.clone()) { Ok(_) => {} Err(value) => return Err(value), @@ -83,7 +83,8 @@ fn find_source(vm: &VirtualMachine, name: &str) -> io::Result { .filter_map(|item| match item.borrow().kind { PyObjectKind::String { ref value } => Some(PathBuf::from(value)), _ => None, - }).collect(), + }) + .collect(), _ => panic!("sys.path unexpectedly not a list"), };