diff --git a/bytecode/src/bytecode.rs b/bytecode/src/bytecode.rs index 6f9ef81a8..922ed78c0 100644 --- a/bytecode/src/bytecode.rs +++ b/bytecode/src/bytecode.rs @@ -53,7 +53,6 @@ bitflags! { const HAS_DEFAULTS = 0x01; const HAS_KW_ONLY_DEFAULTS = 0x02; const HAS_ANNOTATIONS = 0x04; - const IS_CLASS = 0x08; } } diff --git a/compiler/src/compile.rs b/compiler/src/compile.rs index 2324ac303..7194c8886 100644 --- a/compiler/src/compile.rs +++ b/compiler/src/compile.rs @@ -975,7 +975,7 @@ impl Compiler { // Turn code object into function object: self.emit(Instruction::MakeFunction { - flags: bytecode::FunctionOpArg::IS_CLASS, + flags: bytecode::FunctionOpArg::empty(), }); self.emit(Instruction::LoadConst { diff --git a/compiler/src/symboltable.rs b/compiler/src/symboltable.rs index 312650f9f..43e0500ec 100644 --- a/compiler/src/symboltable.rs +++ b/compiler/src/symboltable.rs @@ -61,7 +61,7 @@ impl SymbolTable { } } -#[derive(Clone)] +#[derive(Clone, PartialEq)] pub enum SymbolTableType { Module, Class, @@ -241,12 +241,10 @@ impl SymbolTableAnalyzer { } else { // Interesting stuff about the __class__ variable: // https://docs.python.org/3/reference/datamodel.html?highlight=__class__#creating-the-class-object - let found_in_outer_scope = (symbol.name == "__class__") - || self - .tables - .iter() - .skip(1) - .any(|t| t.symbols.contains_key(&symbol.name)); + let found_in_outer_scope = symbol.name == "__class__" + || self.tables.iter().skip(1).any(|t| { + t.typ != SymbolTableType::Class && t.symbols.contains_key(&symbol.name) + }); if found_in_outer_scope { // Symbol is in some outer scope. @@ -387,7 +385,6 @@ impl SymbolTableBuilder { keywords, decorator_list, } => { - self.register_name(name, SymbolUsage::Assigned)?; self.enter_scope(name, SymbolTableType::Class, statement.location.row()); self.scan_statements(body)?; self.leave_scope(); @@ -396,6 +393,7 @@ impl SymbolTableBuilder { self.scan_expression(&keyword.value, &ExpressionContext::Load)?; } self.scan_expressions(decorator_list, &ExpressionContext::Load)?; + self.register_name(name, SymbolUsage::Assigned)?; } Expression { expression } => { self.scan_expression(expression, &ExpressionContext::Load)? diff --git a/tests/snippets/class.py b/tests/snippets/class.py index 8a0502430..19e6cb1d1 100644 --- a/tests/snippets/class.py +++ b/tests/snippets/class.py @@ -174,12 +174,15 @@ class A: assert a == 2 A.b() -def func(): - class A: - a = 2 - def b(): - assert a == 1 - b() - assert a == 2 - A.b() -func() +# TODO: uncomment once free vars/cells are working +# a = 1 +# def nested_scope(): +# a = 2 +# class A: +# a = 3 +# def b(): +# assert a == 2 +# b() +# assert a == 3 +# A.b() +# nested_scope() diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index e61f96411..99627549f 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -10,12 +10,10 @@ use std::str; use num_bigint::Sign; use num_traits::{Signed, ToPrimitive, Zero}; -use crate::frame::Frame; use crate::obj::objbool; use crate::obj::objbytes::PyBytesRef; use crate::obj::objcode::PyCodeRef; use crate::obj::objdict::PyDictRef; -use crate::obj::objfunction::{PyFunction, PyFunctionRef}; use crate::obj::objint::{self, PyIntRef}; use crate::obj::objiter; use crate::obj::objstr::{PyString, PyStringRef}; @@ -898,7 +896,7 @@ pub fn make_module(vm: &VirtualMachine, module: PyObjectRef) { } pub fn builtin_build_class_( - function: PyFunctionRef, + function: PyObjectRef, qualified_name: PyStringRef, bases: Args, mut kwargs: KwArgs, @@ -935,20 +933,7 @@ pub fn builtin_build_class_( let cells = vm.ctx.new_dict(); - let PyFunction { code, scope, .. } = &*function; - - let is_class = scope.is_class(); - - let mut scope = scope - .new_child_scope_with_locals(cells.clone()) - .new_child_scope_with_locals(namespace.clone()); - - if is_class { - scope = scope.as_class(); - } - - let frame = Frame::new(code.clone(), scope).into_ref(vm); - vm.run_frame_full(frame)?; + vm.invoke_with_locals(&function, cells.clone(), namespace.clone())?; namespace.set_item("__name__", name_obj.clone(), vm)?; namespace.set_item("__qualname__", qualified_name.into_object(), vm)?; diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 4e4f801c8..b9a912e3a 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -1064,16 +1064,6 @@ impl Frame { // pop argc arguments // argument: name, args, globals let scope = self.scope.clone(); - let scope = if flags.contains(bytecode::FunctionOpArg::IS_CLASS) { - // if the function we're making is a class initializer - scope.new_child_scope(&vm.ctx).as_class() - } else if scope.is_class() { - // if the surrounding scope is a class, i.e. the function we're making is a method, - // then get the parent scope. See builtin_build_class for why. - scope.parent_scope() - } else { - scope - }; let func_obj = vm .ctx .new_function(code_obj, scope, defaults, kw_only_defaults); diff --git a/vm/src/scope.rs b/vm/src/scope.rs index 9c1c6481a..bbba87a30 100644 --- a/vm/src/scope.rs +++ b/vm/src/scope.rs @@ -12,7 +12,6 @@ use crate::vm::VirtualMachine; pub struct Scope { locals: Vec, pub globals: PyDictRef, - is_class: bool, } impl fmt::Debug for Scope { @@ -28,11 +27,7 @@ impl Scope { Some(dict) => vec![dict], None => vec![], }; - let scope = Scope { - locals, - globals, - is_class: false, - }; + let scope = Scope { locals, globals }; scope.store_name(vm, "__annotations__", vm.ctx.new_dict().into_object()); scope } @@ -69,30 +64,12 @@ impl Scope { Scope { locals: new_locals, globals: self.globals.clone(), - is_class: false, } } pub fn new_child_scope(&self, ctx: &PyContext) -> Scope { self.new_child_scope_with_locals(ctx.new_dict()) } - - pub fn parent_scope(&self) -> Scope { - Scope { - locals: self.locals[1..].to_vec(), - globals: self.globals.clone(), - is_class: false, - } - } - - pub fn is_class(&self) -> bool { - self.is_class - } - - pub fn as_class(mut self) -> Self { - self.is_class = true; - self - } } pub trait NameProtocol { @@ -154,15 +131,6 @@ impl NameProtocol for Scope { #[cfg_attr(feature = "flame-it", flame("Scope"))] /// Load a global name. fn load_global(&self, vm: &VirtualMachine, name: &str) -> Option { - // First, take a look in the outmost local scope (the scope at top level) - let last_local_dict = self.locals.iter().last(); - if let Some(local_dict) = last_local_dict { - if let Some(value) = local_dict.get_item_option(name, vm).unwrap() { - return Some(value); - } - } - - // Now, take a look at the globals or builtins. if let Some(value) = self.globals.get_item_option(name, vm).unwrap() { Some(value) } else { diff --git a/vm/src/vm.rs b/vm/src/vm.rs index 303ecf5a2..0494eadd7 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -662,6 +662,25 @@ impl VirtualMachine { } } + pub fn invoke_with_locals( + &self, + function: &PyObjectRef, + cells: PyDictRef, + locals: PyDictRef, + ) -> PyResult { + if let Some(PyFunction { code, scope, .. }) = &function.payload() { + let scope = scope + .new_child_scope_with_locals(cells) + .new_child_scope_with_locals(locals); + let frame = Frame::new(code.clone(), scope).into_ref(self); + return self.run_frame_full(frame); + } + panic!( + "invoke_with_locals: expected python function, got: {:?}", + *function + ); + } + fn fill_locals_from_args( &self, code_object: &bytecode::CodeObject,