diff --git a/parser/src/ast.rs b/parser/src/ast.rs index e607c4e4b..749003d9a 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -357,6 +357,17 @@ pub enum Number { Complex { real: f64, imag: f64 }, } +/// Transforms a value prior to formatting it. +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum ConversionFlag { + /// Converts by calling `str()`. + Str, + /// Converts by calling `ascii()`. + Ascii, + /// Converts by calling `repr()`. + Repr, +} + #[derive(Debug, PartialEq)] pub enum StringGroup { Constant { @@ -364,6 +375,7 @@ pub enum StringGroup { }, FormattedValue { value: Box, + conversion: Option, spec: String, }, Joined { diff --git a/parser/src/fstring.rs b/parser/src/fstring.rs new file mode 100644 index 000000000..e6e5586b0 --- /dev/null +++ b/parser/src/fstring.rs @@ -0,0 +1,234 @@ +use std::iter; +use std::mem; +use std::str; + +use lalrpop_util::ParseError as LalrpopError; + +use crate::ast::{ConversionFlag, StringGroup}; +use crate::lexer::{LexicalError, Location, Tok}; +use crate::parser::parse_expression; + +use self::FStringError::*; +use self::StringGroup::*; + +// TODO: consolidate these with ParseError +#[derive(Debug, PartialEq)] +pub enum FStringError { + UnclosedLbrace, + UnopenedRbrace, + InvalidExpression, + InvalidConversionFlag, + EmptyExpression, + MismatchedDelimiter, +} + +impl From for LalrpopError { + fn from(_err: FStringError) -> Self { + lalrpop_util::ParseError::User { + error: LexicalError::StringError, + } + } +} + +struct FStringParser<'a> { + chars: iter::Peekable>, +} + +impl<'a> FStringParser<'a> { + fn new(source: &'a str) -> Self { + Self { + chars: source.chars().peekable(), + } + } + + fn parse_formatted_value(&mut self) -> Result { + let mut expression = String::new(); + let mut spec = String::new(); + let mut delims = Vec::new(); + let mut conversion = None; + + while let Some(ch) = self.chars.next() { + match ch { + '!' if delims.is_empty() => { + conversion = Some(match self.chars.next() { + Some('s') => ConversionFlag::Str, + Some('a') => ConversionFlag::Ascii, + Some('r') => ConversionFlag::Repr, + Some(_) => { + return Err(InvalidConversionFlag); + } + None => { + break; + } + }) + } + ':' if delims.is_empty() => { + while let Some(&next) = self.chars.peek() { + if next != '}' { + spec.push(next); + self.chars.next(); + } else { + break; + } + } + } + '(' | '{' | '[' => { + expression.push(ch); + delims.push(ch); + } + ')' => { + if delims.pop() != Some('(') { + return Err(MismatchedDelimiter); + } + expression.push(ch); + } + ']' => { + if delims.pop() != Some('[') { + return Err(MismatchedDelimiter); + } + expression.push(ch); + } + '}' if !delims.is_empty() => { + if delims.pop() != Some('{') { + return Err(MismatchedDelimiter); + } + expression.push(ch); + } + '}' => { + if expression.is_empty() { + return Err(EmptyExpression); + } + return Ok(FormattedValue { + value: Box::new( + parse_expression(expression.trim()).map_err(|_| InvalidExpression)?, + ), + conversion, + spec, + }); + } + '"' | '\'' => { + expression.push(ch); + while let Some(next) = self.chars.next() { + expression.push(next); + if next == ch { + break; + } + } + } + _ => { + expression.push(ch); + } + } + } + + return Err(UnclosedLbrace); + } + + fn parse(mut self) -> Result { + let mut content = String::new(); + let mut values = vec![]; + + while let Some(ch) = self.chars.next() { + match ch { + '{' => { + if let Some('{') = self.chars.peek() { + self.chars.next(); + content.push('{'); + } else { + if !content.is_empty() { + values.push(Constant { + value: mem::replace(&mut content, String::new()), + }); + } + + values.push(self.parse_formatted_value()?); + } + } + '}' => { + if let Some('}') = self.chars.peek() { + self.chars.next(); + content.push('}'); + } else { + return Err(UnopenedRbrace); + } + } + _ => { + content.push(ch); + } + } + } + + if !content.is_empty() { + values.push(Constant { value: content }) + } + + Ok(match values.len() { + 0 => Constant { + value: String::new(), + }, + 1 => values.into_iter().next().unwrap(), + _ => Joined { values }, + }) + } +} + +pub fn parse_fstring(source: &str) -> Result { + FStringParser::new(source).parse() +} + +#[cfg(test)] +mod tests { + use crate::ast; + + use super::*; + + fn mk_ident(name: &str) -> ast::Expression { + ast::Expression::Identifier { + name: name.to_owned(), + } + } + + #[test] + fn test_parse_fstring() { + let source = String::from("{a}{ b }{{foo}}"); + let parse_ast = parse_fstring(&source).unwrap(); + + assert_eq!( + parse_ast, + Joined { + values: vec![ + FormattedValue { + value: Box::new(mk_ident("a")), + conversion: None, + spec: String::new(), + }, + FormattedValue { + value: Box::new(mk_ident("b")), + conversion: None, + spec: String::new(), + }, + Constant { + value: "{foo}".to_owned() + } + ] + } + ); + } + + #[test] + fn test_parse_empty_fstring() { + assert_eq!( + parse_fstring(""), + Ok(Constant { + value: String::new(), + }), + ); + } + + #[test] + fn test_parse_invalid_fstring() { + assert_eq!(parse_fstring("{"), Err(UnclosedLbrace)); + assert_eq!(parse_fstring("}"), Err(UnopenedRbrace)); + assert_eq!(parse_fstring("{class}"), Err(InvalidExpression)); + } +} diff --git a/parser/src/lib.rs b/parser/src/lib.rs index f7f369968..b10c3d551 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -3,6 +3,7 @@ extern crate log; pub mod ast; pub mod error; +mod fstring; pub mod lexer; pub mod parser; #[cfg_attr(rustfmt, rustfmt_skip)] diff --git a/parser/src/parser.rs b/parser/src/parser.rs index fb8f2f0e5..2a39a41f8 100644 --- a/parser/src/parser.rs +++ b/parser/src/parser.rs @@ -65,180 +65,12 @@ pub fn parse_expression(source: &str) -> Result { do_lalr_parsing!(source, Expression, StartExpression) } -// TODO: consolidate these with ParseError -#[derive(Debug, PartialEq)] -pub enum FStringError { - UnclosedLbrace, - UnopenedRbrace, - InvalidExpression, -} - -impl From - for lalrpop_util::ParseError -{ - fn from(_err: FStringError) -> Self { - lalrpop_util::ParseError::User { - error: lexer::LexicalError::StringError, - } - } -} - -enum ParseState { - Text { - content: String, - }, - FormattedValue { - expression: String, - spec: Option, - depth: usize, - }, -} - -pub fn parse_fstring(source: &str) -> Result { - use self::ParseState::*; - - let mut values = vec![]; - let mut state = ParseState::Text { - content: String::new(), - }; - - let mut chars = source.chars().peekable(); - while let Some(ch) = chars.next() { - state = match state { - Text { mut content } => match ch { - '{' => { - if let Some('{') = chars.peek() { - chars.next(); - content.push('{'); - Text { content } - } else { - if !content.is_empty() { - values.push(ast::StringGroup::Constant { value: content }); - } - - FormattedValue { - expression: String::new(), - spec: None, - depth: 0, - } - } - } - '}' => { - if let Some('}') = chars.peek() { - chars.next(); - content.push('}'); - Text { content } - } else { - return Err(FStringError::UnopenedRbrace); - } - } - _ => { - content.push(ch); - Text { content } - } - }, - - FormattedValue { - mut expression, - mut spec, - depth, - } => match ch { - ':' if depth == 0 => FormattedValue { - expression, - spec: Some(String::new()), - depth, - }, - '{' => { - if let Some('{') = chars.peek() { - expression.push_str("{{"); - chars.next(); - FormattedValue { - expression, - spec, - depth, - } - } else { - expression.push('{'); - FormattedValue { - expression, - spec, - depth: depth + 1, - } - } - } - '}' => { - if let Some('}') = chars.peek() { - expression.push_str("}}"); - chars.next(); - FormattedValue { - expression, - spec, - depth, - } - } else if depth > 0 { - expression.push('}'); - FormattedValue { - expression, - spec, - depth: depth - 1, - } - } else { - values.push(ast::StringGroup::FormattedValue { - value: Box::new(match parse_expression(expression.trim()) { - Ok(expr) => expr, - Err(_) => return Err(FStringError::InvalidExpression), - }), - spec: spec.unwrap_or_default(), - }); - Text { - content: String::new(), - } - } - } - _ => { - if let Some(spec) = spec.as_mut() { - spec.push(ch) - } else { - expression.push(ch); - } - FormattedValue { - expression, - spec, - depth, - } - } - }, - }; - } - - match state { - Text { content } => { - if !content.is_empty() { - values.push(ast::StringGroup::Constant { value: content }) - } - } - FormattedValue { .. } => { - return Err(FStringError::UnclosedLbrace); - } - } - - Ok(match values.len() { - 0 => ast::StringGroup::Constant { - value: String::new(), - }, - 1 => values.into_iter().next().unwrap(), - _ => ast::StringGroup::Joined { values }, - }) -} - #[cfg(test)] mod tests { use super::ast; use super::parse_expression; - use super::parse_fstring; use super::parse_program; use super::parse_statement; - use super::FStringError; use num_bigint::BigInt; #[test] @@ -630,55 +462,4 @@ mod tests { } ); } - - fn mk_ident(name: &str) -> ast::Expression { - ast::Expression::Identifier { - name: name.to_owned(), - } - } - - #[test] - fn test_parse_fstring() { - let source = String::from("{a}{ b }{{foo}}"); - let parse_ast = parse_fstring(&source).unwrap(); - - assert_eq!( - parse_ast, - ast::StringGroup::Joined { - values: vec![ - ast::StringGroup::FormattedValue { - value: Box::new(mk_ident("a")), - spec: String::new(), - }, - ast::StringGroup::FormattedValue { - value: Box::new(mk_ident("b")), - spec: String::new(), - }, - ast::StringGroup::Constant { - value: "{foo}".to_owned() - } - ] - } - ); - } - - #[test] - fn test_parse_empty_fstring() { - assert_eq!( - parse_fstring(""), - Ok(ast::StringGroup::Constant { - value: String::new(), - }), - ); - } - - #[test] - fn test_parse_invalid_fstring() { - assert_eq!(parse_fstring("{"), Err(FStringError::UnclosedLbrace)); - assert_eq!(parse_fstring("}"), Err(FStringError::UnopenedRbrace)); - assert_eq!( - parse_fstring("{class}"), - Err(FStringError::InvalidExpression) - ); - } } diff --git a/parser/src/python.lalrpop b/parser/src/python.lalrpop index 8ee407f9a..2748877dd 100644 --- a/parser/src/python.lalrpop +++ b/parser/src/python.lalrpop @@ -4,10 +4,12 @@ // See also: https://greentreesnakes.readthedocs.io/en/latest/nodes.html#keyword #![allow(unknown_lints,clippy)] -use super::ast; -use super::lexer; -use super::parser; use std::iter::FromIterator; + +use crate::ast; +use crate::fstring::parse_fstring; +use crate::lexer; + use num_bigint::BigInt; grammar; @@ -1008,7 +1010,7 @@ StringGroup: ast::StringGroup = { let mut values = vec![]; for (value, is_fstring) in s { values.push(if is_fstring { - parser::parse_fstring(&value)? + parse_fstring(&value)? } else { ast::StringGroup::Constant { value } }) diff --git a/tests/snippets/fstrings.py b/tests/snippets/fstrings.py index 2ee45742f..c76967acb 100644 --- a/tests/snippets/fstrings.py +++ b/tests/snippets/fstrings.py @@ -11,7 +11,31 @@ assert f"{f'{{}}'}" == '{}' # don't include escaped braces in nested f-strings assert f'{f"{{"}' == '{' assert f'{f"}}"}' == '}' assert f'{foo}' f"{foo}" 'foo' == 'barbarfoo' -#assert f'{"!:"}' == '!:' -#assert f"{1 != 2}" == 'True' +assert f'{"!:"}' == '!:' assert fr'x={4*10}\n' == 'x=40\\n' assert f'{16:0>+#10x}' == '00000+0x10' +assert f"{{{(lambda x: f'hello, {x}')('world}')}" == '{hello, world}' + +# Normally `!` cannot appear outside of delimiters in the expression but +# cpython makes an exception for `!=`, so we should too. + +# assert f'{1 != 2}' == 'True' + + +# conversion flags + +class Value: + def __format__(self, spec): + return "foo" + + def __repr__(self): + return "bar" + + def __str__(self): + return "baz" + +v = Value() + +assert f'{v}' == 'foo' +assert f'{v!r}' == 'bar' +assert f'{v!s}' == 'baz' diff --git a/tests/snippets/stdlib_socket.py b/tests/snippets/stdlib_socket.py new file mode 100644 index 000000000..5419d802f --- /dev/null +++ b/tests/snippets/stdlib_socket.py @@ -0,0 +1,20 @@ +import socket + +listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +listener.bind(("127.0.0.1", 8080)) +listener.listen(1) + +connector = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +connector.connect(("127.0.0.1", 8080)) +connection = listener.accept()[0] + +message_a = b'aaaa' +message_b = b'bbbbb' + +connector.send(message_a) +connector.close() +recv_a = connection.recv(10) + +connection.close() +listener.close() + diff --git a/vm/src/bytecode.rs b/vm/src/bytecode.rs index 1e0b85f13..b7cefe0b3 100644 --- a/vm/src/bytecode.rs +++ b/vm/src/bytecode.rs @@ -169,6 +169,7 @@ pub enum Instruction { }, Unpack, FormatValue { + conversion: Option, spec: String, }, } @@ -361,7 +362,10 @@ impl Instruction { UnpackSequence { size } => w!(UnpackSequence, size), UnpackEx { before, after } => w!(UnpackEx, before, after), Unpack => w!(Unpack), - FormatValue { spec } => w!(FormatValue, spec), + FormatValue { + conversion: _, + spec, + } => w!(FormatValue, spec), // TODO: write conversion } } } diff --git a/vm/src/compile.rs b/vm/src/compile.rs index 866b0dad7..a562c1d3d 100644 --- a/vm/src/compile.rs +++ b/vm/src/compile.rs @@ -1352,9 +1352,16 @@ impl Compiler { }, }); } - ast::StringGroup::FormattedValue { value, spec } => { + ast::StringGroup::FormattedValue { + value, + conversion, + spec, + } => { self.compile_expression(value)?; - self.emit(Instruction::FormatValue { spec: spec.clone() }); + self.emit(Instruction::FormatValue { + conversion: *conversion, + spec: spec.clone(), + }); } } Ok(()) diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 5de7f6809..f97be6894 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -654,8 +654,15 @@ impl Frame { } Ok(None) } - bytecode::Instruction::FormatValue { spec } => { - let value = self.pop_value(); + bytecode::Instruction::FormatValue { conversion, spec } => { + use ast::ConversionFlag::*; + let value = match conversion { + Some(Str) => vm.to_str(&self.pop_value())?, + Some(Repr) => vm.to_repr(&self.pop_value())?, + Some(Ascii) => self.pop_value(), // TODO + None => self.pop_value(), + }; + let spec = vm.new_str(spec.clone()); let formatted = vm.call_method(&value, "__format__", vec![spec])?; self.push_value(formatted); diff --git a/vm/src/obj/objbool.rs b/vm/src/obj/objbool.rs index c5c9cb43b..41681a476 100644 --- a/vm/src/obj/objbool.rs +++ b/vm/src/obj/objbool.rs @@ -1,10 +1,16 @@ use super::objtype; use crate::pyobject::{ - PyContext, PyFuncArgs, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, + IntoPyObject, PyContext, PyFuncArgs, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, }; use crate::vm::VirtualMachine; use num_traits::Zero; +impl IntoPyObject for bool { + fn into_pyobject(self, ctx: &PyContext) -> PyResult { + Ok(ctx.new_bool(self)) + } +} + pub fn boolval(vm: &mut VirtualMachine, obj: PyObjectRef) -> Result { let result = match obj.borrow().payload { PyObjectPayload::Integer { ref value } => !value.is_zero(), diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index de53f5082..c6bb93b9a 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -3,8 +3,8 @@ use super::objstr; use super::objtype; use crate::format::FormatSpec; use crate::pyobject::{ - FromPyObjectRef, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, - TypeProtocol, + FromPyObjectRef, IntoPyObject, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, + PyResult, TypeProtocol, }; use crate::vm::VirtualMachine; use num_bigint::{BigInt, ToBigInt}; @@ -15,6 +15,22 @@ use std::hash::{Hash, Hasher}; // This proxy allows for easy switching between types. type IntType = BigInt; +pub type PyInt = BigInt; + +impl IntoPyObject for PyInt { + fn into_pyobject(self, ctx: &PyContext) -> PyResult { + Ok(ctx.new_int(self)) + } +} + +// TODO: macro to impl for all primitive ints + +impl IntoPyObject for usize { + fn into_pyobject(self, ctx: &PyContext) -> PyResult { + Ok(ctx.new_int(self)) + } +} + fn int_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(int, Some(vm.ctx.int_type()))]); let v = get_value(int); diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index c5075ccbd..e2ac5de45 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -1,7 +1,8 @@ use super::objint; use super::objtype; use crate::pyobject::{ - PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, + FromPyObject, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, + TypeProtocol, }; use crate::vm::VirtualMachine; use num_bigint::{BigInt, Sign}; @@ -18,6 +19,18 @@ pub struct RangeType { pub step: BigInt, } +type PyRange = RangeType; + +impl FromPyObject for PyRange { + fn typ(ctx: &PyContext) -> Option { + Some(ctx.range_type()) + } + + fn from_pyobject(obj: PyObjectRef) -> PyResult { + Ok(get_value(&obj)) + } +} + impl RangeType { #[inline] pub fn try_len(&self) -> Option { @@ -345,22 +358,12 @@ fn range_bool(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_bool(len > 0)) } -fn range_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(zelf, Some(vm.ctx.range_type())), (needle, None)] - ); - - let range = get_value(zelf); - - let result = if objtype::isinstance(needle, &vm.ctx.int_type()) { - range.contains(&objint::get_value(needle)) +fn range_contains(vm: &mut VirtualMachine, zelf: PyRange, needle: PyObjectRef) -> bool { + if objtype::isinstance(&needle, &vm.ctx.int_type()) { + zelf.contains(&objint::get_value(&needle)) } else { false - }; - - Ok(vm.ctx.new_bool(result)) + } } fn range_index(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 9c2051e6f..04394f7c5 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -3,7 +3,8 @@ use super::objsequence::PySliceableSequence; use super::objtype; use crate::format::{FormatParseError, FormatPart, FormatString}; use crate::pyobject::{ - PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, + FromPyObject, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, + TypeProtocol, }; use crate::vm::VirtualMachine; use num_traits::ToPrimitive; @@ -17,6 +18,16 @@ extern crate unicode_segmentation; use self::unicode_segmentation::UnicodeSegmentation; +impl FromPyObject for String { + fn typ(ctx: &PyContext) -> Option { + Some(ctx.str_type()) + } + + fn from_pyobject(obj: PyObjectRef) -> PyResult { + Ok(get_value(&obj)) + } +} + pub fn init(context: &PyContext) { let str_type = &context.str_type; context.set_attr(&str_type, "__add__", context.new_rustfunc(str_add)); @@ -474,15 +485,8 @@ fn str_rstrip(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_str(value)) } -fn str_endswith(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(s, Some(vm.ctx.str_type())), (pat, Some(vm.ctx.str_type()))] - ); - let value = get_value(&s); - let pat = get_value(&pat); - Ok(vm.ctx.new_bool(value.ends_with(pat.as_str()))) +fn str_endswith(_vm: &mut VirtualMachine, zelf: String, suffix: String) -> bool { + zelf.ends_with(&suffix) } fn str_isidentifier(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index dd1f839ff..073a5d646 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -29,6 +29,7 @@ use crate::obj::objsuper; use crate::obj::objtuple; use crate::obj::objtype; use crate::obj::objzip; +use crate::stdlib::socket::Socket; use crate::vm::VirtualMachine; use num_bigint::BigInt; use num_bigint::ToBigInt; @@ -72,7 +73,7 @@ pub type PyObjectWeakRef = Weak>; /// Use this type for function which return a python object or and exception. /// Both the python object and the python exception are `PyObjectRef` types /// since exceptions are also python objects. -pub type PyResult = Result; // A valid value, or an exception +pub type PyResult = Result; // A valid value, or an exception /// For attributes we do not use a dict, but a hashmap. This is probably /// faster, unordered, and only supports strings as keys. @@ -553,13 +554,13 @@ impl PyContext { ) } - pub fn new_rustfunc PyResult>( - &self, - function: F, - ) -> PyObjectRef { + pub fn new_rustfunc(&self, factory: F) -> PyObjectRef + where + F: PyNativeFuncFactory, + { PyObject::new( PyObjectPayload::RustFunction { - function: Box::new(function), + function: factory.create(self), }, self.builtin_function_or_method_type(), ) @@ -945,6 +946,203 @@ impl PyFuncArgs { } } +pub trait FromPyObject: Sized { + fn typ(ctx: &PyContext) -> Option; + + fn from_pyobject(obj: PyObjectRef) -> PyResult; +} + +impl FromPyObject for PyObjectRef { + fn typ(_ctx: &PyContext) -> Option { + None + } + + fn from_pyobject(obj: PyObjectRef) -> PyResult { + Ok(obj) + } +} + +pub trait IntoPyObject { + fn into_pyobject(self, ctx: &PyContext) -> PyResult; +} + +impl IntoPyObject for PyObjectRef { + fn into_pyobject(self, _ctx: &PyContext) -> PyResult { + Ok(self) + } +} + +impl IntoPyObject for PyResult { + fn into_pyobject(self, _ctx: &PyContext) -> PyResult { + self + } +} + +pub trait FromPyFuncArgs: Sized { + fn required_params(ctx: &PyContext) -> Vec; + + fn from_py_func_args(args: &mut PyFuncArgs) -> PyResult; +} + +macro_rules! tuple_from_py_func_args { + ($($T:ident),+) => { + impl<$($T),+> FromPyFuncArgs for ($($T,)+) + where + $($T: FromPyFuncArgs),+ + { + fn required_params(ctx: &PyContext) -> Vec { + vec![$($T::required_params(ctx),)+].into_iter().flatten().collect() + } + + fn from_py_func_args(args: &mut PyFuncArgs) -> PyResult { + Ok(($($T::from_py_func_args(args)?,)+)) + } + } + }; +} + +tuple_from_py_func_args!(A); +tuple_from_py_func_args!(A, B); +tuple_from_py_func_args!(A, B, C); +tuple_from_py_func_args!(A, B, C, D); +tuple_from_py_func_args!(A, B, C, D, E); + +impl FromPyFuncArgs for T +where + T: FromPyObject, +{ + fn required_params(ctx: &PyContext) -> Vec { + vec![Parameter { + kind: PositionalOnly, + typ: T::typ(ctx), + }] + } + + fn from_py_func_args(args: &mut PyFuncArgs) -> PyResult { + Self::from_pyobject(args.shift()) + } +} + +pub type PyNativeFunc = Box PyResult>; + +pub trait PyNativeFuncFactory { + fn create(self, ctx: &PyContext) -> PyNativeFunc; +} + +impl PyNativeFuncFactory for F +where + F: Fn(&mut VirtualMachine, PyFuncArgs) -> PyResult + 'static, +{ + fn create(self, _ctx: &PyContext) -> PyNativeFunc { + Box::new(self) + } +} + +macro_rules! tuple_py_native_func_factory { + ($($T:ident),+) => { + impl PyNativeFuncFactory<($($T,)+), R> for F + where + F: Fn(&mut VirtualMachine, $($T),+) -> R + 'static, + $($T: FromPyFuncArgs,)+ + R: IntoPyObject, + { + fn create(self, ctx: &PyContext) -> PyNativeFunc { + let parameters = vec![$($T::required_params(ctx)),+] + .into_iter() + .flatten() + .collect(); + let signature = Signature::new(parameters); + + Box::new(move |vm, mut args| { + signature.check(vm, &mut args)?; + + (self)(vm, $($T::from_py_func_args(&mut args)?,)+) + .into_pyobject(&vm.ctx) + }) + } + } + }; +} + +tuple_py_native_func_factory!(A); +tuple_py_native_func_factory!(A, B); +tuple_py_native_func_factory!(A, B, C); +tuple_py_native_func_factory!(A, B, C, D); +tuple_py_native_func_factory!(A, B, C, D, E); + +#[derive(Debug)] +pub struct Signature { + positional_params: Vec, + keyword_params: HashMap, +} + +impl Signature { + fn new(params: Vec) -> Self { + let mut positional_params = Vec::new(); + let mut keyword_params = HashMap::new(); + for param in params { + match param.kind { + PositionalOnly => { + positional_params.push(param); + } + KeywordOnly { ref name } => { + keyword_params.insert(name.clone(), param); + } + } + } + + Self { + positional_params, + keyword_params, + } + } + + fn arg_type(&self, pos: usize) -> Option<&PyObjectRef> { + self.positional_params[pos].typ.as_ref() + } + + #[allow(unused)] + fn kwarg_type(&self, name: &str) -> Option<&PyObjectRef> { + self.keyword_params[name].typ.as_ref() + } + + fn check(&self, vm: &mut VirtualMachine, args: &PyFuncArgs) -> PyResult<()> { + // TODO: check arity + + for (pos, arg) in args.args.iter().enumerate() { + if let Some(expected_type) = self.arg_type(pos) { + if !objtype::isinstance(arg, expected_type) { + let arg_typ = arg.typ(); + let expected_type_name = vm.to_pystr(&expected_type)?; + let actual_type = vm.to_pystr(&arg_typ)?; + return Err(vm.new_type_error(format!( + "argument of type {} is required for parameter {} (got: {})", + expected_type_name, + pos + 1, + actual_type + ))); + } + } + } + + Ok(()) + } +} + +#[derive(Debug)] +pub struct Parameter { + typ: Option, + kind: ParameterKind, +} + +#[derive(Debug)] +pub enum ParameterKind { + PositionalOnly, + KeywordOnly { name: String }, +} + +use self::ParameterKind::*; + /// Rather than determining the type of a python object, this enum is more /// a holder for the rust payload of a python object. It is more a carrier /// of rust data for a particular python object. Determine the python type @@ -1045,6 +1243,9 @@ pub enum PyObjectPayload { RustFunction { function: Box PyResult>, }, + Socket { + socket: Socket, + }, } impl fmt::Debug for PyObjectPayload { @@ -1082,6 +1283,7 @@ impl fmt::Debug for PyObjectPayload { PyObjectPayload::Instance { .. } => write!(f, "instance"), PyObjectPayload::RustFunction { .. } => write!(f, "rust function"), PyObjectPayload::Frame { .. } => write!(f, "frame"), + PyObjectPayload::Socket { .. } => write!(f, "socket"), } } } diff --git a/vm/src/stdlib/mod.rs b/vm/src/stdlib/mod.rs index 8c542df99..0dabba9b1 100644 --- a/vm/src/stdlib/mod.rs +++ b/vm/src/stdlib/mod.rs @@ -6,6 +6,7 @@ mod math; mod pystruct; mod random; mod re; +pub mod socket; mod string; mod time_module; mod tokenize; @@ -25,24 +26,25 @@ pub type StdlibInitFunc = Box PyObjectRef>; pub fn get_module_inits() -> HashMap { let mut modules = HashMap::new(); modules.insert("ast".to_string(), Box::new(ast::mk_module) as StdlibInitFunc); - modules.insert("dis".to_string(), Box::new(dis::mk_module) as StdlibInitFunc); - modules.insert("json".to_string(), Box::new(json::mk_module) as StdlibInitFunc); - modules.insert("keyword".to_string(), Box::new(keyword::mk_module) as StdlibInitFunc); - modules.insert("math".to_string(), Box::new(math::mk_module) as StdlibInitFunc); - modules.insert("re".to_string(), Box::new(re::mk_module) as StdlibInitFunc); - modules.insert("random".to_string(), Box::new(random::mk_module) as StdlibInitFunc); - modules.insert("string".to_string(), Box::new(string::mk_module) as StdlibInitFunc); - modules.insert("struct".to_string(), Box::new(pystruct::mk_module) as StdlibInitFunc); - modules.insert("time".to_string(), Box::new(time_module::mk_module) as StdlibInitFunc); - modules.insert( "tokenize".to_string(), Box::new(tokenize::mk_module) as StdlibInitFunc); - modules.insert("types".to_string(), Box::new(types::mk_module) as StdlibInitFunc); - modules.insert("_weakref".to_string(), Box::new(weakref::mk_module) as StdlibInitFunc); + modules.insert("dis".to_string(), Box::new(dis::mk_module)); + modules.insert("json".to_string(), Box::new(json::mk_module)); + modules.insert("keyword".to_string(), Box::new(keyword::mk_module)); + modules.insert("math".to_string(), Box::new(math::mk_module)); + modules.insert("re".to_string(), Box::new(re::mk_module)); + modules.insert("random".to_string(), Box::new(random::mk_module)); + modules.insert("string".to_string(), Box::new(string::mk_module)); + modules.insert("struct".to_string(), Box::new(pystruct::mk_module)); + modules.insert("time".to_string(), Box::new(time_module::mk_module)); + modules.insert("tokenize".to_string(), Box::new(tokenize::mk_module)); + modules.insert("types".to_string(), Box::new(types::mk_module)); + modules.insert("_weakref".to_string(), Box::new(weakref::mk_module)); // disable some modules on WASM #[cfg(not(target_arch = "wasm32"))] { - modules.insert("io".to_string(), Box::new(io::mk_module) as StdlibInitFunc); - modules.insert("os".to_string(), Box::new(os::mk_module) as StdlibInitFunc); + modules.insert("io".to_string(), Box::new(io::mk_module)); + modules.insert("os".to_string(), Box::new(os::mk_module)); + modules.insert("socket".to_string(), Box::new(socket::mk_module)); } modules diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs new file mode 100644 index 000000000..f1634d645 --- /dev/null +++ b/vm/src/stdlib/socket.rs @@ -0,0 +1,315 @@ +use crate::obj::objbytes; +use crate::obj::objint; +use crate::obj::objsequence::get_elements; +use crate::obj::objstr; +use crate::pyobject::{ + PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, +}; +use crate::vm::VirtualMachine; + +use num_traits::ToPrimitive; +use std::io; +use std::io::Read; +use std::io::Write; +use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket}; + +#[derive(Copy, Clone)] +enum AddressFamily { + AfUnix = 1, + AfInet = 2, + AfInet6 = 3, +} + +impl AddressFamily { + fn from_i32(value: i32) -> AddressFamily { + match value { + 1 => AddressFamily::AfUnix, + 2 => AddressFamily::AfInet, + 3 => AddressFamily::AfInet6, + _ => panic!("Unknown value: {}", value), + } + } +} + +#[derive(Copy, Clone)] +enum SocketKind { + SockStream = 1, + SockDgram = 2, +} + +impl SocketKind { + fn from_i32(value: i32) -> SocketKind { + match value { + 1 => SocketKind::SockStream, + 2 => SocketKind::SockDgram, + _ => panic!("Unknown value: {}", value), + } + } +} + +enum Connection { + TcpListener(TcpListener), + TcpStream(TcpStream), + UdpSocket(UdpSocket), +} + +impl Connection { + fn accept(&mut self) -> io::Result<(TcpStream, SocketAddr)> { + match self { + Connection::TcpListener(con) => con.accept(), + _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), + } + } +} + +impl Read for Connection { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self { + Connection::TcpStream(con) => con.read(buf), + _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), + } + } +} + +impl Write for Connection { + fn write(&mut self, buf: &[u8]) -> io::Result { + match self { + Connection::TcpStream(con) => con.write(buf), + _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), + } + } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +pub struct Socket { + address_family: AddressFamily, + sk: SocketKind, + con: Option, +} + +impl Socket { + fn new(address_family: AddressFamily, sk: SocketKind) -> Socket { + Socket { + address_family, + sk: sk, + con: None, + } + } +} + +fn socket_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [ + (cls, None), + (family_int, Some(vm.ctx.int_type())), + (kind_int, Some(vm.ctx.int_type())) + ] + ); + + let address_family = AddressFamily::from_i32(objint::get_value(family_int).to_i32().unwrap()); + let kind = SocketKind::from_i32(objint::get_value(kind_int).to_i32().unwrap()); + + let socket = Socket::new(address_family, kind); + + Ok(PyObject::new( + PyObjectPayload::Socket { socket }, + cls.clone(), + )) +} + +fn socket_connect(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(zelf, None), (address, Some(vm.ctx.tuple_type()))] + ); + + let elements = get_elements(address); + let host = objstr::get_value(&elements[0]); + let port = objint::get_value(&elements[1]); + + let address_string = format!("{}:{}", host, port.to_string()); + + let mut mut_obj = zelf.borrow_mut(); + + match mut_obj.payload { + PyObjectPayload::Socket { ref mut socket } => { + if let Ok(stream) = TcpStream::connect(address_string) { + socket.con = Some(Connection::TcpStream(stream)); + Ok(vm.get_none()) + } else { + // TODO: Socket error + Err(vm.new_type_error("socket failed".to_string())) + } + } + _ => Err(vm.new_type_error("".to_string())), + } +} + +fn socket_bind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(zelf, None), (address, Some(vm.ctx.tuple_type()))] + ); + + let elements = get_elements(address); + let host = objstr::get_value(&elements[0]); + let port = objint::get_value(&elements[1]); + + let address_string = format!("{}:{}", host, port.to_string()); + + let mut mut_obj = zelf.borrow_mut(); + + match mut_obj.payload { + PyObjectPayload::Socket { ref mut socket } => { + if let Ok(stream) = TcpListener::bind(address_string) { + socket.con = Some(Connection::TcpListener(stream)); + Ok(vm.get_none()) + } else { + // TODO: Socket error + Err(vm.new_type_error("socket failed".to_string())) + } + } + _ => Err(vm.new_type_error("".to_string())), + } +} + +fn socket_listen(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + Ok(vm.get_none()) +} + +fn socket_accept(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(zelf, None)]); + + let mut mut_obj = zelf.borrow_mut(); + + match mut_obj.payload { + PyObjectPayload::Socket { ref mut socket } => { + let ret = match socket.con { + Some(ref mut v) => v.accept(), + None => return Err(vm.new_type_error("".to_string())), + }; + + let tcp_stream = match ret { + Ok((socket, _addr)) => socket, + _ => return Err(vm.new_type_error("".to_string())), + }; + + let socket = Socket { + address_family: socket.address_family.clone(), + sk: socket.sk.clone(), + con: Some(Connection::TcpStream(tcp_stream)), + }; + + let sock_obj = PyObject::new(PyObjectPayload::Socket { socket }, mut_obj.typ()); + + let elements = vec![sock_obj, vm.get_none()]; + + Ok(PyObject::new( + PyObjectPayload::Sequence { elements }, + vm.ctx.tuple_type(), + )) + } + _ => Err(vm.new_type_error("".to_string())), + } +} + +fn socket_recv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(zelf, None), (bufsize, Some(vm.ctx.int_type()))] + ); + let mut mut_obj = zelf.borrow_mut(); + + match mut_obj.payload { + PyObjectPayload::Socket { ref mut socket } => { + let mut buffer = Vec::new(); + let _temp = match socket.con { + Some(ref mut v) => v.read_to_end(&mut buffer).unwrap(), + None => 0, + }; + Ok(PyObject::new( + PyObjectPayload::Bytes { value: buffer }, + vm.ctx.bytes_type(), + )) + } + _ => Err(vm.new_type_error("".to_string())), + } +} + +fn socket_send(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(zelf, None), (bytes, Some(vm.ctx.bytes_type()))] + ); + let mut mut_obj = zelf.borrow_mut(); + + match mut_obj.payload { + PyObjectPayload::Socket { ref mut socket } => { + match socket.con { + Some(ref mut v) => v.write(&objbytes::get_value(&bytes)).unwrap(), + None => return Err(vm.new_type_error("".to_string())), + }; + Ok(vm.get_none()) + } + _ => Err(vm.new_type_error("".to_string())), + } +} + +fn socket_close(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(zelf, None)]); + let mut mut_obj = zelf.borrow_mut(); + + match mut_obj.payload { + PyObjectPayload::Socket { ref mut socket } => match socket.address_family { + AddressFamily::AfInet => match socket.sk { + SocketKind::SockStream => { + socket.con = None; + Ok(vm.get_none()) + } + _ => Err(vm.new_type_error("".to_string())), + }, + _ => Err(vm.new_type_error("".to_string())), + }, + _ => Err(vm.new_type_error("".to_string())), + } +} + +pub fn mk_module(ctx: &PyContext) -> PyObjectRef { + let py_mod = ctx.new_module(&"socket".to_string(), ctx.new_scope(None)); + + ctx.set_attr( + &py_mod, + "AF_INET", + ctx.new_int(AddressFamily::AfInet as i32), + ); + + ctx.set_attr( + &py_mod, + "SOCK_STREAM", + ctx.new_int(SocketKind::SockStream as i32), + ); + + let socket = { + let socket = ctx.new_class("socket", ctx.object()); + ctx.set_attr(&socket, "__new__", ctx.new_rustfunc(socket_new)); + ctx.set_attr(&socket, "connect", ctx.new_rustfunc(socket_connect)); + ctx.set_attr(&socket, "recv", ctx.new_rustfunc(socket_recv)); + ctx.set_attr(&socket, "send", ctx.new_rustfunc(socket_send)); + ctx.set_attr(&socket, "bind", ctx.new_rustfunc(socket_bind)); + ctx.set_attr(&socket, "accept", ctx.new_rustfunc(socket_accept)); + ctx.set_attr(&socket, "listen", ctx.new_rustfunc(socket_listen)); + ctx.set_attr(&socket, "close", ctx.new_rustfunc(socket_close)); + socket + }; + ctx.set_attr(&py_mod, "socket", socket.clone()); + + py_mod +} diff --git a/wasm/demo/src/index.ejs b/wasm/demo/src/index.ejs index 2d06c3200..0cb2def1d 100644 --- a/wasm/demo/src/index.ejs +++ b/wasm/demo/src/index.ejs @@ -20,7 +20,7 @@