diff --git a/vm/src/stdlib/ast.rs b/vm/src/stdlib/ast.rs index eeac00984..b1980d754 100644 --- a/vm/src/stdlib/ast.rs +++ b/vm/src/stdlib/ast.rs @@ -5,7 +5,6 @@ use crate::{ builtins::{self, PyStrRef, PyTypeRef}, - function::FuncArgs, IdProtocol, ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyResult, PyValue, StaticType, TryFromObject, TypeProtocol, VirtualMachine, }; @@ -22,6 +21,56 @@ use rustpython_parser::parser; mod gen; +#[pymodule] +mod _ast { + use crate::{ + builtins::PyStrRef, function::FuncArgs, PyObjectRef, PyResult, PyValue, TypeProtocol, + VirtualMachine, + }; + #[pyattr] + #[pyclass(module = "_ast", name = "AST")] + #[derive(Debug, PyValue)] + pub(crate) struct AstNode; + + #[pyimpl(flags(BASETYPE, HAS_DICT))] + impl AstNode { + #[pymethod(magic)] + fn init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + let obj: PyObjectRef = zelf.clone_class().into(); + let fields = obj.get_attr("_fields", vm)?; + let fields = vm.extract_elements::(&fields)?; + let numargs = args.args.len(); + if numargs > fields.len() { + return Err(vm.new_type_error(format!( + "{} constructor takes at most {} positional argument{}", + zelf.class().name(), + fields.len(), + if fields.len() == 1 { "" } else { "s" }, + ))); + } + for (name, arg) in fields.iter().zip(args.args) { + zelf.set_attr(name.clone(), arg, vm)?; + } + for (key, value) in args.kwargs { + if let Some(pos) = fields.iter().position(|f| f.as_str() == key) { + if pos < numargs { + return Err(vm.new_type_error(format!( + "{} got multiple values for argument '{}'", + zelf.class().name(), + key + ))); + } + } + zelf.set_attr(key, value, vm)?; + } + Ok(()) + } + } + + #[pyattr(name = "PyCF_ONLY_AST")] + use super::PY_COMPILE_FLAG_AST_ONLY; +} + fn get_node_field(vm: &VirtualMachine, obj: &PyObjectRef, field: &str, typ: &str) -> PyResult { vm.get_attribute_opt(obj.clone(), field)?.ok_or_else(|| { vm.new_type_error(format!("required field \"{}\" missing from {}", field, typ)) @@ -38,48 +87,6 @@ fn get_node_field_opt( .filter(|obj| !vm.is_none(obj))) } -#[pyclass(module = "_ast", name = "AST")] -#[derive(Debug, PyValue)] -pub(crate) struct AstNode; - -#[pyimpl(flags(BASETYPE, HAS_DICT))] -impl AstNode { - #[pymethod(magic)] - fn init(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { - let obj: PyObjectRef = zelf.clone_class().into(); - let fields = obj.get_attr("_fields", vm)?; - let fields = vm.extract_elements::(&fields)?; - let numargs = args.args.len(); - if numargs > fields.len() { - return Err(vm.new_type_error(format!( - "{} constructor takes at most {} positional argument{}", - zelf.class().name(), - fields.len(), - if fields.len() == 1 { "" } else { "s" }, - ))); - } - for (name, arg) in fields.iter().zip(args.args) { - zelf.set_attr(name.clone(), arg, vm)?; - } - for (key, value) in args.kwargs { - if let Some(pos) = fields.iter().position(|f| f.as_str() == key) { - if pos < numargs { - return Err(vm.new_type_error(format!( - "{} got multiple values for argument '{}'", - zelf.class().name(), - key - ))); - } - } - zelf.set_attr(key, value, vm)?; - } - Ok(()) - } -} - -const MODULE_NAME: &str = "_ast"; -pub const PY_COMPILE_FLAG_AST_ONLY: i32 = 0x0400; - trait Node: Sized { fn ast_to_object(self, vm: &VirtualMachine) -> PyObjectRef; fn ast_from_object(vm: &VirtualMachine, object: PyObjectRef) -> PyResult; @@ -276,15 +283,13 @@ pub(crate) fn compile( Ok(vm.new_code_object(code).into()) } -pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { - let ctx = &vm.ctx; +// Required crate visibility for inclusion by gen.rs +pub(crate) use _ast::AstNode; +// Used by builtins::compile() +pub const PY_COMPILE_FLAG_AST_ONLY: i32 = 0x0400; - let ast_base = AstNode::make_class(ctx); - let module = py_module!(vm, MODULE_NAME, { - // TODO: There's got to be a better way! - "AST" => ast_base, - "PyCF_ONLY_AST" => ctx.new_int(PY_COMPILE_FLAG_AST_ONLY), - }); +pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { + let module = _ast::make_module(vm); gen::extend_module_nodes(vm, &module); module }