diff --git a/vm/src/frame.rs b/vm/src/frame.rs index e257ea21f..0015557fe 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -106,20 +106,6 @@ impl Scope { } } - pub fn get(&self, name: &str) -> Option { - for dict in self.locals.iter() { - if let Some(value) = dict.get_item(name) { - return Some(value); - } - } - - if let Some(value) = self.globals.get_item(name) { - return Some(value); - } - - None - } - pub fn get_only_locals(&self) -> Option { self.locals.iter().next().cloned() } @@ -140,15 +126,31 @@ pub trait NameProtocol { fn load_name(&self, vm: &VirtualMachine, name: &str) -> Option; fn store_name(&self, vm: &VirtualMachine, name: &str, value: PyObjectRef); fn delete_name(&self, vm: &VirtualMachine, name: &str); + fn load_cell(&self, vm: &VirtualMachine, name: &str) -> Option; } impl NameProtocol for Scope { fn load_name(&self, vm: &VirtualMachine, name: &str) -> Option { - if let Some(value) = self.get(name) { - Some(value) - } else { - vm.builtins.get_item(name) + for dict in self.locals.iter() { + if let Some(value) = dict.get_item(name) { + return Some(value); + } } + + if let Some(value) = self.globals.get_item(name) { + return Some(value); + } + + vm.builtins.get_item(name) + } + + fn load_cell(&self, _vm: &VirtualMachine, name: &str) -> Option { + for dict in self.locals.iter().skip(1) { + if let Some(value) = dict.get_item(name) { + return Some(value); + } + } + None } fn store_name(&self, vm: &VirtualMachine, key: &str, value: PyObjectRef) { diff --git a/vm/src/obj/objsuper.rs b/vm/src/obj/objsuper.rs index 38e9cf0e7..2b2be0241 100644 --- a/vm/src/obj/objsuper.rs +++ b/vm/src/obj/objsuper.rs @@ -6,6 +6,7 @@ https://github.com/python/cpython/blob/50b48572d9a90c5bb36e2bef6179548ea927a35a/ */ +use crate::frame::NameProtocol; use crate::function::PyFuncArgs; use crate::obj::objstr; use crate::obj::objtype::PyClass; @@ -105,7 +106,7 @@ fn super_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { let py_type = if let Some(ty) = py_type { ty.clone() } else { - match vm.current_scope().get("__class__") { + match vm.current_scope().load_cell(vm, "__class__") { Some(obj) => obj.clone(), _ => { return Err(vm.new_type_error(