diff --git a/README.md b/README.md index b060a29e9..7e7b0bd70 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ $ pipenv install $ pipenv run pytest -v ``` -There also are some unit tests, you can run those will cargo: +There also are some unit tests, you can run those with cargo: ```shell $ cargo test --all diff --git a/parser/src/lexer.rs b/parser/src/lexer.rs index 0c3d88780..e49ef5148 100644 --- a/parser/src/lexer.rs +++ b/parser/src/lexer.rs @@ -719,16 +719,13 @@ where Some('+') => { let tok_start = self.get_pos(); self.next_char(); - match self.chr0 { - Some('=') => { - self.next_char(); - let tok_end = self.get_pos(); - return Some(Ok((tok_start, Tok::PlusEqual, tok_end))); - } - _ => { - let tok_end = self.get_pos(); - return Some(Ok((tok_start, Tok::Plus, tok_end))); - } + if let Some('=') = self.chr0 { + self.next_char(); + let tok_end = self.get_pos(); + return Some(Ok((tok_start, Tok::PlusEqual, tok_end))); + } else { + let tok_end = self.get_pos(); + return Some(Ok((tok_start, Tok::Plus, tok_end))); } } Some('*') => { @@ -792,61 +789,49 @@ where Some('%') => { let tok_start = self.get_pos(); self.next_char(); - match self.chr0 { - Some('=') => { - self.next_char(); - let tok_end = self.get_pos(); - return Some(Ok((tok_start, Tok::PercentEqual, tok_end))); - } - _ => { - let tok_end = self.get_pos(); - return Some(Ok((tok_start, Tok::Percent, tok_end))); - } + if let Some('=') = self.chr0 { + self.next_char(); + let tok_end = self.get_pos(); + return Some(Ok((tok_start, Tok::PercentEqual, tok_end))); + } else { + let tok_end = self.get_pos(); + return Some(Ok((tok_start, Tok::Percent, tok_end))); } } Some('|') => { let tok_start = self.get_pos(); self.next_char(); - match self.chr0 { - Some('=') => { - self.next_char(); - let tok_end = self.get_pos(); - return Some(Ok((tok_start, Tok::VbarEqual, tok_end))); - } - _ => { - let tok_end = self.get_pos(); - return Some(Ok((tok_start, Tok::Vbar, tok_end))); - } + if let Some('=') = self.chr0 { + self.next_char(); + let tok_end = self.get_pos(); + return Some(Ok((tok_start, Tok::VbarEqual, tok_end))); + } else { + let tok_end = self.get_pos(); + return Some(Ok((tok_start, Tok::Vbar, tok_end))); } } Some('^') => { let tok_start = self.get_pos(); self.next_char(); - match self.chr0 { - Some('=') => { - self.next_char(); - let tok_end = self.get_pos(); - return Some(Ok((tok_start, Tok::CircumflexEqual, tok_end))); - } - _ => { - let tok_end = self.get_pos(); - return Some(Ok((tok_start, Tok::CircumFlex, tok_end))); - } + if let Some('=') = self.chr0 { + self.next_char(); + let tok_end = self.get_pos(); + return Some(Ok((tok_start, Tok::CircumflexEqual, tok_end))); + } else { + let tok_end = self.get_pos(); + return Some(Ok((tok_start, Tok::CircumFlex, tok_end))); } } Some('&') => { let tok_start = self.get_pos(); self.next_char(); - match self.chr0 { - Some('=') => { - self.next_char(); - let tok_end = self.get_pos(); - return Some(Ok((tok_start, Tok::AmperEqual, tok_end))); - } - _ => { - let tok_end = self.get_pos(); - return Some(Ok((tok_start, Tok::Amper, tok_end))); - } + if let Some('=') = self.chr0 { + self.next_char(); + let tok_end = self.get_pos(); + return Some(Ok((tok_start, Tok::AmperEqual, tok_end))); + } else { + let tok_end = self.get_pos(); + return Some(Ok((tok_start, Tok::Amper, tok_end))); } } Some('-') => { @@ -872,16 +857,13 @@ where Some('@') => { let tok_start = self.get_pos(); self.next_char(); - match self.chr0 { - Some('=') => { - self.next_char(); - let tok_end = self.get_pos(); - return Some(Ok((tok_start, Tok::AtEqual, tok_end))); - } - _ => { - let tok_end = self.get_pos(); - return Some(Ok((tok_start, Tok::At, tok_end))); - } + if let Some('=') = self.chr0 { + self.next_char(); + let tok_end = self.get_pos(); + return Some(Ok((tok_start, Tok::AtEqual, tok_end))); + } else { + let tok_end = self.get_pos(); + return Some(Ok((tok_start, Tok::At, tok_end))); } } Some('!') => { diff --git a/tests/snippets/stdlib_socket.py b/tests/snippets/stdlib_socket.py index e0ea6d168..f984515ea 100644 --- a/tests/snippets/stdlib_socket.py +++ b/tests/snippets/stdlib_socket.py @@ -1,6 +1,10 @@ import socket from testutils import assertRaises +MESSAGE_A = b'aaaa' +MESSAGE_B= b'bbbbb' + +# TCP listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) listener.bind(("127.0.0.1", 0)) @@ -8,18 +12,15 @@ listener.listen(1) connector = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connector.connect(("127.0.0.1", listener.getsockname()[1])) -connection = listener.accept()[0] - -message_a = b'aaaa' -message_b = b'bbbbb' - -connector.send(message_a) -connection.send(message_b) -recv_a = connection.recv(len(message_a)) -recv_b = connector.recv(len(message_b)) -assert recv_a == message_a -assert recv_b == message_b +(connection, addr) = listener.accept() +assert addr == connector.getsockname() +connector.send(MESSAGE_A) +connection.send(MESSAGE_B) +recv_a = connection.recv(len(MESSAGE_A)) +recv_b = connector.recv(len(MESSAGE_B)) +assert recv_a == MESSAGE_A +assert recv_b == MESSAGE_B connection.close() connector.close() listener.close() @@ -35,3 +36,40 @@ with assertRaises(TypeError): s.bind((888, 8888)) s.close() + +# UDP +sock1 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) +sock1.bind(("127.0.0.1", 0)) + +sock2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + +sock2.sendto(MESSAGE_A, sock1.getsockname()) +(recv_a, addr1) = sock1.recvfrom(len(MESSAGE_A)) +assert recv_a == MESSAGE_A + +sock2.sendto(MESSAGE_B, sock1.getsockname()) +(recv_b, addr2) = sock1.recvfrom(len(MESSAGE_B)) +assert recv_b == MESSAGE_B +assert addr1[0] == addr2[0] +assert addr1[1] == addr2[1] + +sock2.close() + +sock3 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) +sock3.bind(("127.0.0.1", 0)) +sock3.sendto(MESSAGE_A, sock1.getsockname()) +(recv_a, addr) = sock1.recvfrom(len(MESSAGE_A)) +assert recv_a == MESSAGE_A +assert addr == sock3.getsockname() + +sock1.connect(("127.0.0.1", sock3.getsockname()[1])) +sock3.connect(("127.0.0.1", sock1.getsockname()[1])) + +sock1.send(MESSAGE_A) +sock3.send(MESSAGE_B) +recv_a = sock3.recv(len(MESSAGE_A)) +recv_b = sock1.recv(len(MESSAGE_B)) +assert recv_a == MESSAGE_A +assert recv_b == MESSAGE_B +sock1.close() +sock3.close() diff --git a/tests/snippets/test_exec.py b/tests/snippets/test_exec.py new file mode 100644 index 000000000..37ba33ff1 --- /dev/null +++ b/tests/snippets/test_exec.py @@ -0,0 +1,42 @@ +exec("def square(x):\n return x * x\n") +assert 16 == square(4) + +d = {} +exec("def square(x):\n return x * x\n", {}, d) +assert 16 == d['square'](4) + +exec("assert 2 == x", {}, {'x': 2}) +exec("assert 2 == x", {'x': 2}, {}) +exec("assert 4 == x", {'x': 2}, {'x': 4}) + +exec("assert max(1, 2) == 2", {}, {}) + +exec("assert max(1, 5, square(5)) == 25", None) + +# +# These doesn't work yet: +# +# Local environment shouldn't replace global environment: +# +# exec("assert max(1, 5, square(5)) == 25", None, {}) +# +# Closures aren't available if local scope is replaced: +# +# def g(): +# seven = "seven" +# def f(): +# try: +# exec("seven", None, {}) +# except NameError: +# pass +# else: +# raise NameError("seven shouldn't be in scope") +# f() +# g() + +try: + exec("", 1) +except TypeError: + pass +else: + raise TypeError("exec should fail unless globals is a dict or None") diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 0e48315e9..b146fcf91 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -18,7 +18,6 @@ use crate::frame::{Scope, ScopeRef}; use crate::pyobject::{ AttributeProtocol, IdProtocol, PyContext, PyFuncArgs, PyObjectRef, PyResult, TypeProtocol, }; -use std::rc::Rc; #[cfg(not(target_arch = "wasm32"))] use crate::stdlib::io::io_open; @@ -191,12 +190,11 @@ fn builtin_eval(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { vm, args, required = [(source, None)], - optional = [ - (_globals, Some(vm.ctx.dict_type())), - (locals, Some(vm.ctx.dict_type())) - ] + optional = [(globals, None), (locals, Some(vm.ctx.dict_type()))] ); + let scope = make_scope(vm, globals, locals)?; + // Determine code object: let code_obj = if objtype::isinstance(source, &vm.ctx.code_type()) { source.clone() @@ -215,8 +213,6 @@ fn builtin_eval(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { return Err(vm.new_type_error("code argument must be str or code object".to_string())); }; - let scope = make_scope(vm, locals); - // Run the source: vm.run_code_obj(code_obj.clone(), scope) } @@ -228,12 +224,11 @@ fn builtin_exec(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { vm, args, required = [(source, None)], - optional = [ - (_globals, Some(vm.ctx.dict_type())), - (locals, Some(vm.ctx.dict_type())) - ] + optional = [(globals, None), (locals, Some(vm.ctx.dict_type()))] ); + let scope = make_scope(vm, globals, locals)?; + // Determine code object: let code_obj = if objtype::isinstance(source, &vm.ctx.str_type()) { let mode = compile::Mode::Exec; @@ -252,26 +247,48 @@ fn builtin_exec(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { return Err(vm.new_type_error("source argument must be str or code object".to_string())); }; - let scope = make_scope(vm, locals); - // Run the code: vm.run_code_obj(code_obj, scope) } -fn make_scope(vm: &mut VirtualMachine, locals: Option<&PyObjectRef>) -> ScopeRef { - // handle optional global and locals - let locals = if let Some(locals) = locals { - locals.clone() - } else { - vm.new_dict() +fn make_scope( + vm: &mut VirtualMachine, + globals: Option<&PyObjectRef>, + locals: Option<&PyObjectRef>, +) -> PyResult { + let dict_type = vm.ctx.dict_type(); + let globals = match globals { + Some(arg) => { + if arg.is(&vm.get_none()) { + None + } else { + if vm.isinstance(arg, &dict_type)? { + Some(arg) + } else { + let arg_typ = arg.typ(); + let actual_type = vm.to_pystr(&arg_typ)?; + let expected_type_name = vm.to_pystr(&dict_type)?; + return Err(vm.new_type_error(format!( + "globals must be a {}, not {}", + expected_type_name, actual_type + ))); + } + } + } + None => None, }; - // TODO: handle optional globals - // Construct new scope: - Rc::new(Scope { - locals, - parent: None, - }) + let current_scope = vm.current_scope(); + let parent = match globals { + Some(dict) => Some(Scope::new(dict.clone(), Some(vm.get_builtin_scope()))), + None => current_scope.parent.clone(), + }; + let locals = match locals { + Some(dict) => dict.clone(), + None => current_scope.locals.clone(), + }; + + Ok(Scope::new(locals, parent)) } fn builtin_format(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -347,7 +364,7 @@ fn builtin_isinstance(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { required = [(obj, None), (typ, Some(vm.get_type()))] ); - let isinstance = objtype::real_isinstance(vm, obj, typ)?; + let isinstance = vm.isinstance(obj, typ)?; Ok(vm.new_bool(isinstance)) } @@ -358,7 +375,7 @@ fn builtin_issubclass(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { required = [(subclass, Some(vm.get_type())), (cls, Some(vm.get_type()))] ); - let issubclass = objtype::real_issubclass(vm, subclass, cls)?; + let issubclass = vm.issubclass(subclass, cls)?; Ok(vm.context().new_bool(issubclass)) } @@ -814,7 +831,7 @@ pub fn builtin_build_class_(vm: &mut VirtualMachine, mut args: PyFuncArgs) -> Py let mut metaclass = args.get_kwarg("metaclass", vm.get_type()); for base in bases.clone() { - if objtype::real_issubclass(vm, &base.typ(), &metaclass)? { + if objtype::issubclass(&base.typ(), &metaclass) { metaclass = base.typ(); } else if !objtype::issubclass(&metaclass, &base.typ()) { return Err(vm.new_type_error("metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases".to_string())); diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index 2ba30c066..df6b3bb5a 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -66,11 +66,7 @@ impl Dict { } } - pub fn contains( - &self, - vm: &mut VirtualMachine, - key: &PyObjectRef, - ) -> PyResult { + pub fn contains(&self, vm: &mut VirtualMachine, key: &PyObjectRef) -> PyResult { if let LookupResult::Existing(_index) = self.lookup(vm, key)? { Ok(true) } else { @@ -93,11 +89,7 @@ impl Dict { } /// Delete a key - pub fn delete( - &mut self, - vm: &mut VirtualMachine, - key: &PyObjectRef, - ) -> PyResult<()> { + pub fn delete(&mut self, vm: &mut VirtualMachine, key: &PyObjectRef) -> PyResult<()> { if let LookupResult::Existing(index) = self.lookup(vm, key)? { self.entries[index] = None; self.size -= 1; @@ -126,11 +118,7 @@ impl Dict { } /// Lookup the index for the given key. - fn lookup( - &self, - vm: &mut VirtualMachine, - key: &PyObjectRef, - ) -> PyResult { + fn lookup(&self, vm: &mut VirtualMachine, key: &PyObjectRef) -> PyResult { let hash_value = calc_hash(vm, key)?; let perturb = hash_value; let mut hash_index: usize = hash_value; diff --git a/vm/src/frame.rs b/vm/src/frame.rs index ec0c09923..c76552b76 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -35,6 +35,12 @@ pub struct Scope { } pub type ScopeRef = Rc; +impl Scope { + pub fn new(locals: PyObjectRef, parent: Option) -> ScopeRef { + Rc::new(Scope { locals, parent }) + } +} + #[derive(Clone, Debug)] struct Block { /// The type of block. diff --git a/vm/src/function.rs b/vm/src/function.rs new file mode 100644 index 000000000..131fce129 --- /dev/null +++ b/vm/src/function.rs @@ -0,0 +1,83 @@ +use std::marker::PhantomData; +use std::ops::Deref; + +use crate::obj::objtype; +use crate::pyobject::{ + IntoPyObject, PyContext, PyObject, PyObjectPayload, PyObjectPayload2, PyObjectRef, PyResult, + TryFromObject, TypeProtocol, +}; +use crate::vm::VirtualMachine; + +// TODO: Move PyFuncArgs, FromArgs, etc. here + +// TODO: `PyRef` probably actually belongs in the pyobject module. + +/// A reference to the payload of a built-in object. +/// +/// Note that a `PyRef` can only deref to a shared / immutable reference. +/// It is the payload type's responsibility to handle (possibly concurrent) +/// mutability with locks or concurrent data structures if required. +/// +/// A `PyRef` can be directly returned from a built-in function to handle +/// situations (such as when implementing in-place methods such as `__iadd__`) +/// where a reference to the same object must be returned. +pub struct PyRef { + // invariant: this obj must always have payload of type T + obj: PyObjectRef, + _payload: PhantomData, +} + +impl PyRef +where + T: PyObjectPayload2, +{ + pub fn new(ctx: &PyContext, payload: T) -> Self { + PyRef { + obj: PyObject::new( + PyObjectPayload::AnyRustValue { + value: Box::new(payload), + }, + T::required_type(ctx), + ), + _payload: PhantomData, + } + } +} + +impl Deref for PyRef +where + T: PyObjectPayload2, +{ + type Target = T; + + fn deref(&self) -> &T { + self.obj.payload().expect("unexpected payload for type") + } +} + +impl TryFromObject for PyRef +where + T: PyObjectPayload2, +{ + fn try_from_object(vm: &mut VirtualMachine, obj: PyObjectRef) -> PyResult { + if objtype::isinstance(&obj, &T::required_type(&vm.ctx)) { + Ok(PyRef { + obj, + _payload: PhantomData, + }) + } else { + let expected_type = vm.to_pystr(&T::required_type(&vm.ctx))?; + let actual_type = vm.to_pystr(&obj.typ())?; + Err(vm.new_type_error(format!( + "Expected type {}, not {}", + expected_type, actual_type, + ))) + } + } +} + +impl IntoPyObject for PyRef { + fn into_pyobject(self, _ctx: &PyContext) -> PyResult { + Ok(self.obj) + } +} diff --git a/vm/src/lib.rs b/vm/src/lib.rs index b32b91cb7..78696f9f3 100644 --- a/vm/src/lib.rs +++ b/vm/src/lib.rs @@ -41,6 +41,7 @@ pub mod eval; mod exceptions; pub mod format; pub mod frame; +pub mod function; pub mod import; pub mod obj; pub mod pyobject; diff --git a/vm/src/macros.rs b/vm/src/macros.rs index 2c48286d3..f91cd7c61 100644 --- a/vm/src/macros.rs +++ b/vm/src/macros.rs @@ -19,7 +19,7 @@ macro_rules! type_check { if let Some(expected_type) = $arg_type { let arg = &$args.args[$arg_count]; - if !$crate::obj::objtype::real_isinstance($vm, arg, &expected_type)? { + if !$crate::obj::objtype::isinstance(arg, &expected_type) { let arg_typ = arg.typ(); let expected_type_name = $vm.to_pystr(&expected_type)?; let actual_type = $vm.to_pystr(&arg_typ)?; @@ -124,3 +124,16 @@ macro_rules! py_module { } } } + +#[macro_export] +macro_rules! py_class { + ( $ctx:expr, $class_name:expr, $class_base:expr, { $($name:expr => $value:expr),* $(,)* }) => { + { + let py_class = $ctx.new_class($class_name, $class_base); + $( + $ctx.set_attr(&py_class, $name, $value); + )* + py_class + } + } +} diff --git a/vm/src/obj/objbool.rs b/vm/src/obj/objbool.rs index eddd744e5..84b048e3a 100644 --- a/vm/src/obj/objbool.rs +++ b/vm/src/obj/objbool.rs @@ -1,3 +1,5 @@ +use super::objfloat::PyFloat; +use super::objstr::PyString; use super::objtype; use crate::pyobject::{ IntoPyObject, PyContext, PyFuncArgs, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, @@ -12,12 +14,16 @@ impl IntoPyObject for bool { } pub fn boolval(vm: &mut VirtualMachine, obj: PyObjectRef) -> Result { + if let Some(s) = obj.payload::() { + return Ok(!s.value.is_empty()); + } + if let Some(value) = obj.payload::() { + return Ok(*value != PyFloat::from(0.0)); + } let result = match obj.payload { PyObjectPayload::Integer { ref value } => !value.is_zero(), - PyObjectPayload::Float { value } => value != 0.0, PyObjectPayload::Sequence { ref elements } => !elements.borrow().is_empty(), PyObjectPayload::Dict { ref elements } => !elements.borrow().is_empty(), - PyObjectPayload::String { ref value } => !value.is_empty(), PyObjectPayload::None { .. } => false, _ => { if let Ok(f) = vm.get_method(obj.clone(), "__bool__") { diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index 0cabf23fd..7e03d34d3 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -1,17 +1,47 @@ //! Implementation of the python bytearray object. use std::cell::RefCell; +use std::ops::{Deref, DerefMut}; -use crate::pyobject::{PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyResult, TypeProtocol}; +use crate::pyobject::{ + PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectPayload2, PyObjectRef, PyResult, + TypeProtocol, +}; use super::objint; -use super::objbytes::get_mut_value; -use super::objbytes::get_value; use super::objtype; use crate::vm::VirtualMachine; use num_traits::ToPrimitive; +#[derive(Debug)] +pub struct PyByteArray { + // TODO: shouldn't be public + pub value: RefCell>, +} + +impl PyByteArray { + pub fn new(data: Vec) -> Self { + PyByteArray { + value: RefCell::new(data), + } + } +} + +impl PyObjectPayload2 for PyByteArray { + fn required_type(ctx: &PyContext) -> PyObjectRef { + ctx.bytearray_type() + } +} + +pub fn get_value<'a>(obj: &'a PyObjectRef) -> impl Deref> + 'a { + obj.payload::().unwrap().value.borrow() +} + +pub fn get_mut_value<'a>(obj: &'a PyObjectRef) -> impl DerefMut> + 'a { + obj.payload::().unwrap().value.borrow_mut() +} + // Binary data support /// Fill bytearray class methods dictionary. @@ -143,8 +173,8 @@ fn bytearray_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { vec![] }; Ok(PyObject::new( - PyObjectPayload::Bytes { - value: RefCell::new(value), + PyObjectPayload::AnyRustValue { + value: Box::new(PyByteArray::new(value)), }, cls.clone(), )) @@ -290,13 +320,8 @@ fn bytearray_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { fn bytearray_clear(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(zelf, Some(vm.ctx.bytearray_type()))]); - match zelf.payload { - PyObjectPayload::Bytes { ref value } => { - value.borrow_mut().clear(); - Ok(vm.get_none()) - } - _ => panic!("Bytearray has incorrect payload."), - } + get_mut_value(zelf).clear(); + Ok(vm.get_none()) } fn bytearray_pop(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 6ab2f72af..d285b86e9 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -1,16 +1,41 @@ -use std::cell::{Cell, RefCell}; +use std::cell::Cell; use std::hash::{Hash, Hasher}; use std::ops::Deref; -use std::ops::DerefMut; use super::objint; use super::objtype; use crate::pyobject::{ - PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, + PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectPayload2, PyObjectRef, PyResult, + TypeProtocol, }; use crate::vm::VirtualMachine; use num_traits::ToPrimitive; +#[derive(Debug)] +pub struct PyBytes { + value: Vec, +} + +impl PyBytes { + pub fn new(data: Vec) -> Self { + PyBytes { value: data } + } +} + +impl Deref for PyBytes { + type Target = [u8]; + + fn deref(&self) -> &[u8] { + &self.value + } +} + +impl PyObjectPayload2 for PyBytes { + fn required_type(ctx: &PyContext) -> PyObjectRef { + ctx.bytes_type() + } +} + // Binary data support // Fill bytes class methods: @@ -71,8 +96,8 @@ fn bytes_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { }; Ok(PyObject::new( - PyObjectPayload::Bytes { - value: RefCell::new(value), + PyObjectPayload::AnyRustValue { + value: Box::new(PyBytes::new(value)), }, cls.clone(), )) @@ -170,19 +195,7 @@ fn bytes_hash(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } pub fn get_value<'a>(obj: &'a PyObjectRef) -> impl Deref> + 'a { - if let PyObjectPayload::Bytes { ref value } = obj.payload { - value.borrow() - } else { - panic!("Inner error getting bytearray {:?}", obj); - } -} - -pub fn get_mut_value<'a>(obj: &'a PyObjectRef) -> impl DerefMut> + 'a { - if let PyObjectPayload::Bytes { ref value } = obj.payload { - value.borrow_mut() - } else { - panic!("Inner error getting bytearray {:?}", obj); - } + &obj.payload::().unwrap().value } fn bytes_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { diff --git a/vm/src/obj/objcode.rs b/vm/src/obj/objcode.rs index c9f244466..53ea55f4c 100644 --- a/vm/src/obj/objcode.rs +++ b/vm/src/obj/objcode.rs @@ -55,10 +55,7 @@ fn code_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.new_str(repr)) } -fn member_code_obj( - vm: &mut VirtualMachine, - args: PyFuncArgs, -) -> PyResult { +fn member_code_obj(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, args, diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index c6be5e5f8..66205559c 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -2,12 +2,30 @@ use super::objfloat; use super::objint; use super::objtype; use crate::pyobject::{ - PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, + PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectPayload2, PyObjectRef, PyResult, + TypeProtocol, }; use crate::vm::VirtualMachine; use num_complex::Complex64; use num_traits::ToPrimitive; +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct PyComplex { + value: Complex64, +} + +impl PyObjectPayload2 for PyComplex { + fn required_type(ctx: &PyContext) -> PyObjectRef { + ctx.complex_type() + } +} + +impl From for PyComplex { + fn from(value: Complex64) -> Self { + PyComplex { value } + } +} + pub fn init(context: &PyContext) { let complex_type = &context.complex_type; @@ -45,11 +63,7 @@ pub fn init(context: &PyContext) { } pub fn get_value(obj: &PyObjectRef) -> Complex64 { - if let PyObjectPayload::Complex { value } = &obj.payload { - *value - } else { - panic!("Inner error getting complex"); - } + obj.payload::().unwrap().value } fn complex_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -77,7 +91,9 @@ fn complex_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let value = Complex64::new(real, imag); Ok(PyObject::new( - PyObjectPayload::Complex { value }, + PyObjectPayload::AnyRustValue { + value: Box::new(PyComplex { value }), + }, cls.clone(), )) } diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index b59e893e3..a2038bffb 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -3,12 +3,30 @@ use super::objint; use super::objstr; use super::objtype; use crate::pyobject::{ - PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, + PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectPayload2, PyObjectRef, PyResult, + TypeProtocol, }; use crate::vm::VirtualMachine; use num_bigint::ToBigInt; use num_traits::ToPrimitive; +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct PyFloat { + value: f64, +} + +impl PyObjectPayload2 for PyFloat { + fn required_type(ctx: &PyContext) -> PyObjectRef { + ctx.float_type() + } +} + +impl From for PyFloat { + fn from(value: f64) -> Self { + PyFloat { value } + } +} + fn float_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(float, Some(vm.ctx.float_type()))]); let v = get_value(float); @@ -50,16 +68,18 @@ fn float_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let type_name = objtype::get_type_name(&arg.typ()); return Err(vm.new_type_error(format!("can't convert {} to float", type_name))); }; - Ok(PyObject::new(PyObjectPayload::Float { value }, cls.clone())) + + Ok(PyObject::new( + PyObjectPayload::AnyRustValue { + value: Box::new(PyFloat { value }), + }, + cls.clone(), + )) } // Retrieve inner float value: pub fn get_value(obj: &PyObjectRef) -> f64 { - if let PyObjectPayload::Float { value } = &obj.payload { - *value - } else { - panic!("Inner error getting float: {}", obj); - } + obj.payload::().unwrap().value } pub fn make_float(vm: &mut VirtualMachine, obj: &PyObjectRef) -> PyResult { diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index 480eaf675..a79dfb6dd 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -4,7 +4,7 @@ use super::objtype; use crate::format::FormatSpec; use crate::pyobject::{ FromPyObjectRef, IntoPyObject, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, - PyResult, TypeProtocol, + PyResult, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; use num_bigint::{BigInt, ToBigInt}; @@ -31,6 +31,26 @@ impl IntoPyObject for usize { } } +impl TryFromObject for usize { + fn try_from_object(vm: &mut VirtualMachine, obj: PyObjectRef) -> PyResult { + // FIXME: don't use get_value + match get_value(&obj).to_usize() { + Some(value) => Ok(value), + None => Err(vm.new_overflow_error("Int value cannot fit into Rust usize".to_string())), + } + } +} + +impl TryFromObject for isize { + fn try_from_object(vm: &mut VirtualMachine, obj: PyObjectRef) -> PyResult { + // FIXME: don't use get_value + match get_value(&obj).to_isize() { + Some(value) => Ok(value), + None => Err(vm.new_overflow_error("Int value cannot fit into Rust isize".to_string())), + } + } +} + fn int_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(int, Some(vm.ctx.int_type()))]); let v = get_value(int); @@ -63,11 +83,7 @@ fn int_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } // Casting function: -pub fn to_int( - vm: &mut VirtualMachine, - obj: &PyObjectRef, - base: u32, -) -> PyResult { +pub fn to_int(vm: &mut VirtualMachine, obj: &PyObjectRef, base: u32) -> PyResult { let val = if objtype::isinstance(obj, &vm.ctx.int_type()) { get_value(obj) } else if objtype::isinstance(obj, &vm.ctx.float_type()) { diff --git a/vm/src/obj/objiter.rs b/vm/src/obj/objiter.rs index 2010da906..049ea8be3 100644 --- a/vm/src/obj/objiter.rs +++ b/vm/src/obj/objiter.rs @@ -2,13 +2,16 @@ * Various types to support iteration. */ -use super::objbool; use crate::pyobject::{ PyContext, PyFuncArgs, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, }; use crate::vm::VirtualMachine; -// use super::objstr; -use super::objtype; // Required for arg_check! to use isinstance + +use super::objbool; +use super::objbytearray::PyByteArray; +use super::objbytes::PyBytes; +use super::objrange::PyRange; +use super::objtype; /* * This helper function is called at multiple places. First, it is called @@ -129,38 +132,43 @@ fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { iterated_obj: ref iterated_obj_ref, } = iter.payload { - match iterated_obj_ref.payload { - PyObjectPayload::Sequence { ref elements } => { - if position.get() < elements.borrow().len() { - let obj_ref = elements.borrow()[position.get()].clone(); - position.set(position.get() + 1); - Ok(obj_ref) - } else { - Err(new_stop_iteration(vm)) - } + if let Some(range) = iterated_obj_ref.payload::() { + if let Some(int) = range.get(position.get()) { + position.set(position.get() + 1); + Ok(vm.ctx.new_int(int)) + } else { + Err(new_stop_iteration(vm)) } - - PyObjectPayload::Range { ref range } => { - if let Some(int) = range.get(position.get()) { - position.set(position.get() + 1); - Ok(vm.ctx.new_int(int)) - } else { - Err(new_stop_iteration(vm)) - } + } else if let Some(bytes) = iterated_obj_ref.payload::() { + if position.get() < bytes.len() { + let obj_ref = vm.ctx.new_int(bytes[position.get()]); + position.set(position.get() + 1); + Ok(obj_ref) + } else { + Err(new_stop_iteration(vm)) } - - PyObjectPayload::Bytes { ref value } => { - if position.get() < value.borrow().len() { - let obj_ref = vm.ctx.new_int(value.borrow()[position.get()]); - position.set(position.get() + 1); - Ok(obj_ref) - } else { - Err(new_stop_iteration(vm)) - } + } else if let Some(bytes) = iterated_obj_ref.payload::() { + if position.get() < bytes.value.borrow().len() { + let obj_ref = vm.ctx.new_int(bytes.value.borrow()[position.get()]); + position.set(position.get() + 1); + Ok(obj_ref) + } else { + Err(new_stop_iteration(vm)) } - - _ => { - panic!("NOT IMPL"); + } else { + match iterated_obj_ref.payload { + PyObjectPayload::Sequence { ref elements } => { + if position.get() < elements.borrow().len() { + let obj_ref = elements.borrow()[position.get()].clone(); + position.set(position.get() + 1); + Ok(obj_ref) + } else { + Err(new_stop_iteration(vm)) + } + } + _ => { + panic!("NOT IMPL"); + } } } } else { diff --git a/vm/src/obj/objrange.rs b/vm/src/obj/objrange.rs index d80b6aa01..41a4c6964 100644 --- a/vm/src/obj/objrange.rs +++ b/vm/src/obj/objrange.rs @@ -4,7 +4,7 @@ use std::ops::Mul; use super::objint; use super::objtype; use crate::pyobject::{ - FromPyObject, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, + PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectPayload2, PyObjectRef, PyResult, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -13,7 +13,7 @@ use num_integer::Integer; use num_traits::{One, Signed, ToPrimitive, Zero}; #[derive(Debug, Clone)] -pub struct RangeType { +pub struct PyRange { // Unfortunately Rust's built in range type doesn't support things like indexing // or ranges where start > end so we need to roll our own. pub start: BigInt, @@ -21,19 +21,13 @@ pub struct RangeType { pub step: BigInt, } -type PyRange = RangeType; - -impl FromPyObject for PyRange { - fn typ(ctx: &PyContext) -> Option { - Some(ctx.range_type()) - } - - fn from_pyobject(obj: PyObjectRef) -> PyResult { - Ok(get_value(&obj)) +impl PyObjectPayload2 for PyRange { + fn required_type(ctx: &PyContext) -> PyObjectRef { + ctx.range_type() } } -impl RangeType { +impl PyRange { #[inline] pub fn try_len(&self) -> Option { match self.step.sign() { @@ -129,12 +123,12 @@ impl RangeType { }; match self.step.sign() { - Sign::Plus => RangeType { + Sign::Plus => PyRange { start, end: &self.start - 1, step: -&self.step, }, - Sign::Minus => RangeType { + Sign::Minus => PyRange { start, end: &self.start + 1, step: -&self.step, @@ -152,12 +146,8 @@ impl RangeType { } } -pub fn get_value(obj: &PyObjectRef) -> RangeType { - if let PyObjectPayload::Range { range } = &obj.payload { - range.clone() - } else { - panic!("Inner error getting range {:?}", obj); - } +pub fn get_value(obj: &PyObjectRef) -> PyRange { + obj.payload::().unwrap().clone() } pub fn init(context: &PyContext) { @@ -236,8 +226,8 @@ fn range_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Err(vm.new_value_error("range with 0 step size".to_string())) } else { Ok(PyObject::new( - PyObjectPayload::Range { - range: RangeType { start, end, step }, + PyObjectPayload::AnyRustValue { + value: Box::new(PyRange { start, end, step }), }, cls.clone(), )) @@ -264,7 +254,12 @@ fn range_reversed(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(PyObject::new( PyObjectPayload::Iterator { position: Cell::new(0), - iterated_obj: PyObject::new(PyObjectPayload::Range { range }, vm.ctx.range_type()), + iterated_obj: PyObject::new( + PyObjectPayload::AnyRustValue { + value: Box::new(range), + }, + vm.ctx.range_type(), + ), }, vm.ctx.iter_type(), )) @@ -329,12 +324,12 @@ fn range_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { }; Ok(PyObject::new( - PyObjectPayload::Range { - range: RangeType { + PyObjectPayload::AnyRustValue { + value: Box::new(PyRange { start: new_start, end: new_end, step: new_step, - }, + }), }, vm.ctx.range_type(), )) @@ -360,12 +355,22 @@ fn range_bool(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_bool(len > 0)) } -fn range_contains(vm: &mut VirtualMachine, zelf: PyRange, needle: PyObjectRef) -> bool { - if objtype::isinstance(&needle, &vm.ctx.int_type()) { - zelf.contains(&objint::get_value(&needle)) +fn range_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(zelf, Some(vm.ctx.range_type())), (needle, None)] + ); + + let range = get_value(zelf); + + let result = if objtype::isinstance(needle, &vm.ctx.int_type()) { + range.contains(&objint::get_value(needle)) } else { false - } + }; + + Ok(vm.ctx.new_bool(result)) } fn range_index(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 42f715286..2283a0e8d 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -2,8 +2,10 @@ use super::objint; use super::objsequence::PySliceableSequence; use super::objtype; use crate::format::{FormatParseError, FormatPart, FormatString}; +use crate::function::PyRef; use crate::pyobject::{ - FromPyObject, PyContext, PyFuncArgs, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, + OptArg, PyContext, PyFuncArgs, PyIterable, PyObjectPayload, PyObjectPayload2, PyObjectRef, + PyResult, TypeProtocol, }; use crate::vm::VirtualMachine; use num_traits::ToPrimitive; @@ -16,13 +18,71 @@ extern crate unicode_segmentation; use self::unicode_segmentation::UnicodeSegmentation; -impl FromPyObject for String { - fn typ(ctx: &PyContext) -> Option { - Some(ctx.str_type()) +#[derive(Clone, Debug)] +pub struct PyString { + // TODO: shouldn't be public + pub value: String, +} + +impl PyString { + pub fn endswith( + zelf: PyRef, + suffix: PyRef, + start: OptArg, + end: OptArg, + _vm: &mut VirtualMachine, + ) -> bool { + let start = start.unwrap_or(0); + let end = end.unwrap_or(zelf.value.len()); + zelf.value[start..end].ends_with(&suffix.value) } - fn from_pyobject(obj: PyObjectRef) -> PyResult { - Ok(get_value(&obj)) + pub fn startswith( + zelf: PyRef, + prefix: PyRef, + start: OptArg, + end: OptArg, + _vm: &mut VirtualMachine, + ) -> bool { + let start = start.unwrap_or(0); + let end = end.unwrap_or(zelf.value.len()); + zelf.value[start..end].starts_with(&prefix.value) + } + + fn upper(zelf: PyRef, _vm: &mut VirtualMachine) -> PyString { + PyString { + value: zelf.value.to_uppercase(), + } + } + + fn lower(zelf: PyRef, _vm: &mut VirtualMachine) -> PyString { + PyString { + value: zelf.value.to_lowercase(), + } + } + + fn join( + zelf: PyRef, + iterable: PyIterable>, + vm: &mut VirtualMachine, + ) -> PyResult { + let mut joined = String::new(); + + for (idx, elem) in iterable.iter(vm)?.enumerate() { + let elem = elem?; + if idx != 0 { + joined.push_str(&zelf.value); + } + joined.push_str(&elem.value) + } + + Ok(PyString { value: joined }) + } +} + +impl PyObjectPayload2 for PyString { + fn required_type(ctx: &PyContext) -> PyObjectRef { + ctx.str_type() } } @@ -47,9 +107,9 @@ pub fn init(context: &PyContext) { context.set_attr(&str_type, "__str__", context.new_rustfunc(str_str)); context.set_attr(&str_type, "__repr__", context.new_rustfunc(str_repr)); context.set_attr(&str_type, "format", context.new_rustfunc(str_format)); - context.set_attr(&str_type, "lower", context.new_rustfunc(str_lower)); + context.set_attr(&str_type, "lower", context.new_rustfunc(PyString::lower)); context.set_attr(&str_type, "casefold", context.new_rustfunc(str_casefold)); - context.set_attr(&str_type, "upper", context.new_rustfunc(str_upper)); + context.set_attr(&str_type, "upper", context.new_rustfunc(PyString::upper)); context.set_attr( &str_type, "capitalize", @@ -60,11 +120,15 @@ pub fn init(context: &PyContext) { context.set_attr(&str_type, "strip", context.new_rustfunc(str_strip)); context.set_attr(&str_type, "lstrip", context.new_rustfunc(str_lstrip)); context.set_attr(&str_type, "rstrip", context.new_rustfunc(str_rstrip)); - context.set_attr(&str_type, "endswith", context.new_rustfunc(str_endswith)); + context.set_attr( + &str_type, + "endswith", + context.new_rustfunc(PyString::endswith), + ); context.set_attr( &str_type, "startswith", - context.new_rustfunc(str_startswith), + context.new_rustfunc(PyString::startswith), ); context.set_attr(&str_type, "isalnum", context.new_rustfunc(str_isalnum)); context.set_attr(&str_type, "isnumeric", context.new_rustfunc(str_isnumeric)); @@ -84,7 +148,7 @@ pub fn init(context: &PyContext) { "splitlines", context.new_rustfunc(str_splitlines), ); - context.set_attr(&str_type, "join", context.new_rustfunc(str_join)); + context.set_attr(&str_type, "join", context.new_rustfunc(PyString::join)); context.set_attr(&str_type, "find", context.new_rustfunc(str_find)); context.set_attr(&str_type, "rfind", context.new_rustfunc(str_rfind)); context.set_attr(&str_type, "index", context.new_rustfunc(str_index)); @@ -113,19 +177,11 @@ pub fn init(context: &PyContext) { } pub fn get_value(obj: &PyObjectRef) -> String { - if let PyObjectPayload::String { value } = &obj.payload { - value.to_string() - } else { - panic!("Inner error getting str"); - } + obj.payload::().unwrap().value.clone() } pub fn borrow_value(obj: &PyObjectRef) -> &str { - if let PyObjectPayload::String { value } = &obj.payload { - value.as_str() - } else { - panic!("Inner error getting str"); - } + &obj.payload::().unwrap().value } fn str_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -387,18 +443,6 @@ fn str_mul(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } } -fn str_upper(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(s, Some(vm.ctx.str_type()))]); - let value = get_value(&s).to_uppercase(); - Ok(vm.ctx.new_str(value)) -} - -fn str_lower(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!(vm, args, required = [(s, Some(vm.ctx.str_type()))]); - let value = get_value(&s).to_lowercase(); - Ok(vm.ctx.new_str(value)) -} - fn str_capitalize(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(s, Some(vm.ctx.str_type()))]); let value = get_value(&s); @@ -477,10 +521,6 @@ fn str_rstrip(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_str(value)) } -fn str_endswith(_vm: &mut VirtualMachine, zelf: String, suffix: String) -> bool { - zelf.ends_with(&suffix) -} - fn str_isidentifier(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(s, Some(vm.ctx.str_type()))]); let value = get_value(&s); @@ -560,22 +600,6 @@ fn str_zfill(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_str(new_str)) } -fn str_join(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(s, Some(vm.ctx.str_type())), (iterable, None)] - ); - let value = get_value(&s); - let elements: Vec = vm - .extract_elements(iterable)? - .iter() - .map(|w| get_value(&w)) - .collect(); - let joined = elements.join(&value); - Ok(vm.ctx.new_str(joined)) -} - fn str_count(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, @@ -865,17 +889,6 @@ fn str_center(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_str(new_str)) } -fn str_startswith(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(s, Some(vm.ctx.str_type())), (pat, Some(vm.ctx.str_type()))] - ); - let value = get_value(&s); - let pat = get_value(&pat); - Ok(vm.ctx.new_bool(value.starts_with(pat.as_str()))) -} - fn str_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, diff --git a/vm/src/obj/objsuper.rs b/vm/src/obj/objsuper.rs index b3d218430..79c516d4c 100644 --- a/vm/src/obj/objsuper.rs +++ b/vm/src/obj/objsuper.rs @@ -72,9 +72,7 @@ fn super_init(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { }; // Check obj type: - if !(objtype::real_isinstance(vm, &py_obj, &py_type)? - || objtype::real_issubclass(vm, &py_obj, &py_type)?) - { + if !(objtype::isinstance(&py_obj, &py_type) || objtype::issubclass(&py_obj, &py_type)) { return Err(vm.new_type_error( "super(type, obj): obj must be an instance or subtype of type".to_string(), )); diff --git a/vm/src/obj/objtype.rs b/vm/src/obj/objtype.rs index 785ca1b87..c800b7f52 100644 --- a/vm/src/obj/objtype.rs +++ b/vm/src/obj/objtype.rs @@ -1,4 +1,3 @@ -use super::objbool; use super::objdict; use super::objstr; use super::objtype; // Required for arg_check! to use isinstance @@ -9,7 +8,6 @@ use crate::pyobject::{ use crate::vm::VirtualMachine; use std::cell::RefCell; use std::collections::HashMap; -use std::rc::Rc; /* * The magical type type @@ -108,23 +106,6 @@ pub fn isinstance(obj: &PyObjectRef, cls: &PyObjectRef) -> bool { mro.into_iter().any(|c| c.is(&cls)) } -/// Determines if `obj` is an instance of `cls`, either directly, indirectly or virtually via the -/// __instancecheck__ magic method. -pub fn real_isinstance( - vm: &mut VirtualMachine, - obj: &PyObjectRef, - cls: &PyObjectRef, -) -> PyResult { - // cpython first does an exact check on the type, although documentation doesn't state that - // https://github.com/python/cpython/blob/a24107b04c1277e3c1105f98aff5bfa3a98b33a0/Objects/abstract.c#L2408 - if Rc::ptr_eq(&obj.typ(), cls) { - Ok(true) - } else { - let ret = vm.call_method(cls, "__instancecheck__", vec![obj.clone()])?; - objbool::boolval(vm, ret) - } -} - fn type_instance_check(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, @@ -142,17 +123,6 @@ pub fn issubclass(subclass: &PyObjectRef, cls: &PyObjectRef) -> bool { mro.into_iter().any(|c| c.is(&cls)) } -/// Determines if `subclass` is a subclass of `cls`, either directly, indirectly or virtually via -/// the __subclasscheck__ magic method. -pub fn real_issubclass( - vm: &mut VirtualMachine, - subclass: &PyObjectRef, - cls: &PyObjectRef, -) -> PyResult { - let ret = vm.call_method(cls, "__subclasscheck__", vec![subclass.clone()])?; - objbool::boolval(vm, ret) -} - fn type_subclass_check(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 91fac9168..e581b1fae 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -1,3 +1,15 @@ +use std::cell::{Cell, RefCell}; +use std::collections::HashMap; +use std::fmt; +use std::iter; +use std::ops::RangeInclusive; +use std::rc::{Rc, Weak}; + +use num_bigint::BigInt; +use num_bigint::ToBigInt; +use num_complex::Complex64; +use num_traits::{One, Zero}; + use crate::bytecode; use crate::exceptions; use crate::frame::{Frame, Scope, ScopeRef}; @@ -5,11 +17,11 @@ use crate::obj::objbool; use crate::obj::objbytearray; use crate::obj::objbytes; use crate::obj::objcode; -use crate::obj::objcomplex; +use crate::obj::objcomplex::{self, PyComplex}; use crate::obj::objdict; use crate::obj::objenumerate; use crate::obj::objfilter; -use crate::obj::objfloat; +use crate::obj::objfloat::{self, PyFloat}; use crate::obj::objframe; use crate::obj::objfunction; use crate::obj::objgenerator; @@ -30,14 +42,6 @@ use crate::obj::objtuple; use crate::obj::objtype; use crate::obj::objzip; use crate::vm::VirtualMachine; -use num_bigint::BigInt; -use num_bigint::ToBigInt; -use num_complex::Complex64; -use num_traits::{One, Zero}; -use std::cell::{Cell, RefCell}; -use std::collections::HashMap; -use std::fmt; -use std::rc::{Rc, Weak}; /* Python objects and references. @@ -457,22 +461,37 @@ impl PyContext { ) } - pub fn new_float(&self, i: f64) -> PyObjectRef { - PyObject::new(PyObjectPayload::Float { value: i }, self.float_type()) + pub fn new_float(&self, value: f64) -> PyObjectRef { + PyObject::new( + PyObjectPayload::AnyRustValue { + value: Box::new(PyFloat::from(value)), + }, + self.float_type(), + ) } - pub fn new_complex(&self, i: Complex64) -> PyObjectRef { - PyObject::new(PyObjectPayload::Complex { value: i }, self.complex_type()) + pub fn new_complex(&self, value: Complex64) -> PyObjectRef { + PyObject::new( + PyObjectPayload::AnyRustValue { + value: Box::new(PyComplex::from(value)), + }, + self.complex_type(), + ) } pub fn new_str(&self, s: String) -> PyObjectRef { - PyObject::new(PyObjectPayload::String { value: s }, self.str_type()) + PyObject::new( + PyObjectPayload::AnyRustValue { + value: Box::new(objstr::PyString { value: s }), + }, + self.str_type(), + ) } pub fn new_bytes(&self, data: Vec) -> PyObjectRef { PyObject::new( - PyObjectPayload::Bytes { - value: RefCell::new(data), + PyObjectPayload::AnyRustValue { + value: Box::new(objbytes::PyBytes::new(data)), }, self.bytes_type(), ) @@ -480,8 +499,8 @@ impl PyContext { pub fn new_bytearray(&self, data: Vec) -> PyObjectRef { PyObject::new( - PyObjectPayload::Bytes { - value: RefCell::new(data), + PyObjectPayload::AnyRustValue { + value: Box::new(objbytearray::PyByteArray::new(data)), }, self.bytearray_type(), ) @@ -552,28 +571,18 @@ impl PyContext { ) } - pub fn new_rustfunc(&self, factory: F) -> PyObjectRef + pub fn new_rustfunc(&self, f: F) -> PyObjectRef where - F: PyNativeFuncFactory, + F: IntoPyNativeFunc, { PyObject::new( PyObjectPayload::RustFunction { - function: factory.create(self), + function: f.into_func(), }, self.builtin_function_or_method_type(), ) } - pub fn new_rustfunc_from_box( - &self, - function: Box PyResult>, - ) -> PyObjectRef { - PyObject::new( - PyObjectPayload::RustFunction { function }, - self.builtin_function_or_method_type(), - ) - } - pub fn new_frame(&self, code: PyObjectRef, scope: ScopeRef) -> PyObjectRef { PyObject::new( PyObjectPayload::Frame { @@ -925,7 +934,7 @@ impl PyFuncArgs { ) -> Result, PyObjectRef> { match self.get_optional_kwarg(key) { Some(kwarg) => { - if objtype::real_isinstance(vm, &kwarg, &ty)? { + if objtype::isinstance(&kwarg, &ty) { Ok(Some(kwarg)) } else { let expected_ty_name = vm.to_pystr(&ty)?; @@ -939,24 +948,304 @@ impl PyFuncArgs { None => Ok(None), } } -} -pub trait FromPyObject: Sized { - fn typ(ctx: &PyContext) -> Option; - - fn from_pyobject(obj: PyObjectRef) -> PyResult; -} - -impl FromPyObject for PyObjectRef { - fn typ(_ctx: &PyContext) -> Option { - None + /// Serializes these arguments into an iterator starting with the positional + /// arguments followed by keyword arguments. + fn into_iter(self) -> impl Iterator { + self.args.into_iter().map(PyArg::Positional).chain( + self.kwargs + .into_iter() + .map(|(name, value)| PyArg::Keyword(name, value)), + ) } - fn from_pyobject(obj: PyObjectRef) -> PyResult { + /// Binds these arguments to their respective values. + /// + /// If there is an insufficient number of arguments, there are leftover + /// arguments after performing the binding, or if an argument is not of + /// the expected type, a TypeError is raised. + /// + /// If the given `FromArgs` includes any conversions, exceptions raised + /// during the conversion will halt the binding and return the error. + fn bind(self, vm: &mut VirtualMachine) -> PyResult { + let given_args = self.args.len(); + let mut args = self.into_iter().peekable(); + let bound = match T::from_args(vm, &mut args) { + Ok(args) => args, + Err(ArgumentError::TooFewArgs) => { + return Err(vm.new_type_error(format!( + "Expected at least {} arguments ({} given)", + T::arity().start(), + given_args, + ))); + } + Err(ArgumentError::Exception(ex)) => { + return Err(ex); + } + }; + + match args.next() { + None => Ok(bound), + Some(PyArg::Positional(_)) => Err(vm.new_type_error(format!( + "Expected at most {} arguments ({} given)", + T::arity().end(), + given_args, + ))), + Some(PyArg::Keyword(name, _)) => { + Err(vm.new_type_error(format!("Unexpected keyword argument {}", name))) + } + } + } +} + +/// Implemented by any type that can be accepted as a parameter to a built-in +/// function. +/// +pub trait FromArgs: Sized { + /// The range of positional arguments permitted by the function signature. + /// + /// Returns an empty range if not applicable. + fn arity() -> RangeInclusive { + 0..=0 + } + + /// Extracts this item from the next argument(s). + fn from_args( + vm: &mut VirtualMachine, + args: &mut iter::Peekable, + ) -> Result + where + I: Iterator; +} + +/// An iterable Python object. +/// +/// `PyIterable` implements `FromArgs` so that a built-in function can accept +/// an object that is required to conform to the Python iterator protocol. +/// +/// PyIterable can optionally perform type checking and conversions on iterated +/// objects using a generic type parameter that implements `TryFromObject`. +pub struct PyIterable { + method: PyObjectRef, + _item: std::marker::PhantomData, +} + +impl PyIterable { + /// Returns an iterator over this sequence of objects. + /// + /// This operation may fail if an exception is raised while invoking the + /// `__iter__` method of the iterable object. + pub fn iter<'a>(&self, vm: &'a mut VirtualMachine) -> PyResult> { + let iter_obj = vm.invoke( + self.method.clone(), + PyFuncArgs { + args: vec![], + kwargs: vec![], + }, + )?; + + Ok(PyIterator { + vm, + obj: iter_obj, + _item: std::marker::PhantomData, + }) + } +} + +pub struct PyIterator<'a, T> { + vm: &'a mut VirtualMachine, + obj: PyObjectRef, + _item: std::marker::PhantomData, +} + +impl<'a, T> Iterator for PyIterator<'a, T> +where + T: TryFromObject, +{ + type Item = PyResult; + + fn next(&mut self) -> Option { + match self.vm.call_method(&self.obj, "__next__", vec![]) { + Ok(value) => Some(T::try_from_object(self.vm, value)), + Err(err) => { + let stop_ex = self.vm.ctx.exceptions.stop_iteration.clone(); + if objtype::isinstance(&err, &stop_ex) { + None + } else { + Some(Err(err)) + } + } + } + } +} + +impl TryFromObject for PyIterable +where + T: TryFromObject, +{ + fn try_from_object(vm: &mut VirtualMachine, obj: PyObjectRef) -> PyResult { + Ok(PyIterable { + method: vm.get_method(obj, "__iter__")?, + _item: std::marker::PhantomData, + }) + } +} + +impl TryFromObject for PyObjectRef { + fn try_from_object(_vm: &mut VirtualMachine, obj: PyObjectRef) -> PyResult { Ok(obj) } } +/// A map of keyword arguments to their values. +/// +/// A built-in function with a `KwArgs` parameter is analagous to a Python +/// function with `*kwargs`. All remaining keyword arguments are extracted +/// (and hence the function will permit an arbitrary number of them). +/// +/// `KwArgs` optionally accepts a generic type parameter to allow type checks +/// or conversions of each argument. +pub struct KwArgs(HashMap); + +impl FromArgs for KwArgs +where + T: TryFromObject, +{ + fn from_args( + vm: &mut VirtualMachine, + args: &mut iter::Peekable, + ) -> Result + where + I: Iterator, + { + let mut kwargs = HashMap::new(); + while let Some(PyArg::Keyword(name, value)) = args.next() { + kwargs.insert(name, T::try_from_object(vm, value)?); + } + Ok(KwArgs(kwargs)) + } +} + +/// A list of positional argument values. +/// +/// A built-in function with a `Args` parameter is analagous to a Python +/// function with `*args`. All remaining positional arguments are extracted +/// (and hence the function will permit an arbitrary number of them). +/// +/// `Args` optionally accepts a generic type parameter to allow type checks +/// or conversions of each argument. +pub struct Args(Vec); + +impl FromArgs for Args +where + T: TryFromObject, +{ + fn from_args( + vm: &mut VirtualMachine, + args: &mut iter::Peekable, + ) -> Result + where + I: Iterator, + { + let mut varargs = Vec::new(); + while let Some(PyArg::Positional(value)) = args.next() { + varargs.push(T::try_from_object(vm, value)?); + } + Ok(Args(varargs)) + } +} + +impl FromArgs for T +where + T: TryFromObject, +{ + fn arity() -> RangeInclusive { + 1..=1 + } + + fn from_args( + vm: &mut VirtualMachine, + args: &mut iter::Peekable, + ) -> Result + where + I: Iterator, + { + if let Some(PyArg::Positional(value)) = args.next() { + Ok(T::try_from_object(vm, value)?) + } else { + Err(ArgumentError::TooFewArgs) + } + } +} + +pub struct OptArg(Option); + +impl std::ops::Deref for OptArg { + type Target = Option; + + fn deref(&self) -> &Option { + &self.0 + } +} + +impl FromArgs for OptArg +where + T: TryFromObject, +{ + fn arity() -> RangeInclusive { + 0..=1 + } + + fn from_args( + vm: &mut VirtualMachine, + args: &mut iter::Peekable, + ) -> Result + where + I: Iterator, + { + Ok(OptArg(if let Some(PyArg::Positional(_)) = args.peek() { + let value = if let Some(PyArg::Positional(value)) = args.next() { + value + } else { + unreachable!() + }; + Some(T::try_from_object(vm, value)?) + } else { + None + })) + } +} + +pub enum PyArg { + Positional(PyObjectRef), + Keyword(String, PyObjectRef), +} + +pub enum ArgumentError { + TooFewArgs, + Exception(PyObjectRef), +} + +impl From for ArgumentError { + fn from(ex: PyObjectRef) -> Self { + ArgumentError::Exception(ex) + } +} + +/// Implemented by any type that can be created from a Python object. +/// +/// Any type that implements `TryFromObject` is automatically `FromArgs`, and +/// so can be accepted as a argument to a built-in function. +pub trait TryFromObject: Sized { + /// Attempt to convert a Python object to a value of this type. + fn try_from_object(vm: &mut VirtualMachine, obj: PyObjectRef) -> PyResult; +} + +/// Implemented by any type that can be returned from a built-in Python function. +/// +/// `IntoPyObject` has a blanket implementation for any built-in object payload, +/// and should be implemented by many primitive Rust types, allowing a built-in +/// function to simply return a `bool` or a `usize` for example. pub trait IntoPyObject { fn into_pyobject(self, ctx: &PyContext) -> PyResult; } @@ -967,206 +1256,170 @@ impl IntoPyObject for PyObjectRef { } } -impl IntoPyObject for PyResult { - fn into_pyobject(self, _ctx: &PyContext) -> PyResult { - self +impl IntoPyObject for PyResult +where + T: IntoPyObject, +{ + fn into_pyobject(self, ctx: &PyContext) -> PyResult { + self.and_then(|res| T::into_pyobject(res, ctx)) } } -pub trait FromPyFuncArgs: Sized { - fn required_params(ctx: &PyContext) -> Vec; - - fn from_py_func_args(args: &mut PyFuncArgs) -> PyResult; +// This allows a built-in function to not return a value, mapping to +// Python's behavior of returning `None` in this situation. +impl IntoPyObject for () { + fn into_pyobject(self, ctx: &PyContext) -> PyResult { + Ok(ctx.none()) + } } +// TODO: Allow a built-in function to return an `Option`, i.e.: +// +// impl IntoPyObject for Option +// +// Option::None should map to a Python `None`. + +// Allows a built-in function to return any built-in object payload without +// explicitly implementing `IntoPyObject`. +impl IntoPyObject for T +where + T: PyObjectPayload2 + Sized, +{ + fn into_pyobject(self, ctx: &PyContext) -> PyResult { + Ok(PyObject::new( + PyObjectPayload::AnyRustValue { + value: Box::new(self), + }, + T::required_type(ctx), + )) + } +} + +// A tuple of types that each implement `FromArgs` represents a sequence of +// arguments that can be bound and passed to a built-in function. +// +// Technically, a tuple can contain tuples, which can contain tuples, and so on, +// so this actually represents a tree of values to be bound from arguments, but +// in practice this is only used for the top-level parameters. macro_rules! tuple_from_py_func_args { ($($T:ident),+) => { - impl<$($T),+> FromPyFuncArgs for ($($T,)+) + impl<$($T),+> FromArgs for ($($T,)+) where - $($T: FromPyFuncArgs),+ + $($T: FromArgs),+ { - fn required_params(ctx: &PyContext) -> Vec { - vec![$($T::required_params(ctx),)+].into_iter().flatten().collect() + fn arity() -> RangeInclusive { + let mut min = 0; + let mut max = 0; + $( + let (start, end) = $T::arity().into_inner(); + min += start; + max += end; + )+ + min..=max } - fn from_py_func_args(args: &mut PyFuncArgs) -> PyResult { - Ok(($($T::from_py_func_args(args)?,)+)) + fn from_args( + vm: &mut VirtualMachine, + args: &mut iter::Peekable + ) -> Result + where + I: Iterator + { + Ok(($($T::from_args(vm, args)?,)+)) } } }; } +// Implement `FromArgs` for up to 5-tuples, allowing built-in functions to bind +// up to 5 top-level parameters (note that `Args`, `KwArgs`, nested tuples, etc. +// count as 1, so this should actually be more than enough). tuple_from_py_func_args!(A); tuple_from_py_func_args!(A, B); tuple_from_py_func_args!(A, B, C); tuple_from_py_func_args!(A, B, C, D); tuple_from_py_func_args!(A, B, C, D, E); -impl FromPyFuncArgs for T -where - T: FromPyObject, -{ - fn required_params(ctx: &PyContext) -> Vec { - vec![Parameter { - kind: PositionalOnly, - typ: T::typ(ctx), - }] - } +/// A built-in Python function. +pub type PyNativeFunc = Box PyResult + 'static>; - fn from_py_func_args(args: &mut PyFuncArgs) -> PyResult { - Self::from_pyobject(args.shift()) - } +/// Implemented by types that are or can generate built-in functions. +/// +/// For example, any function that: +/// +/// - Accepts a sequence of types that implement `FromArgs`, followed by a +/// `&mut VirtualMachine` +/// - Returns some type that implements `IntoPyObject` +/// +/// will generate a `PyNativeFunc` that performs the appropriate type and arity +/// checking, any requested conversions, and then if successful call the function +/// with the bound values. +/// +/// A bare `PyNativeFunc` also implements this trait, allowing the above to be +/// done manually, for rare situations that don't fit into this model. +pub trait IntoPyNativeFunc { + fn into_func(self) -> PyNativeFunc; } -pub type PyNativeFunc = Box PyResult>; - -pub trait PyNativeFuncFactory { - fn create(self, ctx: &PyContext) -> PyNativeFunc; -} - -impl PyNativeFuncFactory for F +impl IntoPyNativeFunc for F where F: Fn(&mut VirtualMachine, PyFuncArgs) -> PyResult + 'static, { - fn create(self, _ctx: &PyContext) -> PyNativeFunc { + fn into_func(self) -> PyNativeFunc { Box::new(self) } } -macro_rules! tuple_py_native_func_factory { - ($($T:ident),+) => { - impl PyNativeFuncFactory<($($T,)+), R> for F +impl IntoPyNativeFunc for PyNativeFunc { + fn into_func(self) -> PyNativeFunc { + self + } +} + +// This is the "magic" that allows rust functions of varying signatures to +// generate native python functions. +// +// Note that this could be done without a macro - it is simply to avoid repetition. +macro_rules! into_py_native_func_tuple { + ($(($n:tt, $T:ident)),+) => { + impl IntoPyNativeFunc<($($T,)+), R> for F where - F: Fn(&mut VirtualMachine, $($T),+) -> R + 'static, - $($T: FromPyFuncArgs,)+ + F: Fn($($T,)+ &mut VirtualMachine) -> R + 'static, + $($T: FromArgs,)+ + ($($T,)+): FromArgs, R: IntoPyObject, { - fn create(self, ctx: &PyContext) -> PyNativeFunc { - let parameters = vec![$($T::required_params(ctx)),+] - .into_iter() - .flatten() - .collect(); - let signature = Signature::new(parameters); + fn into_func(self) -> PyNativeFunc { + Box::new(move |vm, args| { + let ($($n,)+) = args.bind::<($($T,)+)>(vm)?; - Box::new(move |vm, mut args| { - signature.check(vm, &args)?; - - (self)(vm, $($T::from_py_func_args(&mut args)?,)+) - .into_pyobject(&vm.ctx) + (self)($($n,)+ vm).into_pyobject(&vm.ctx) }) } } }; } -tuple_py_native_func_factory!(A); -tuple_py_native_func_factory!(A, B); -tuple_py_native_func_factory!(A, B, C); -tuple_py_native_func_factory!(A, B, C, D); -tuple_py_native_func_factory!(A, B, C, D, E); - -#[derive(Debug)] -pub struct Signature { - positional_params: Vec, - keyword_params: HashMap, -} - -impl Signature { - fn new(params: Vec) -> Self { - let mut positional_params = Vec::new(); - let mut keyword_params = HashMap::new(); - for param in params { - match param.kind { - PositionalOnly => { - positional_params.push(param); - } - KeywordOnly { ref name } => { - keyword_params.insert(name.clone(), param); - } - } - } - - Self { - positional_params, - keyword_params, - } - } - - fn arg_type(&self, pos: usize) -> Option<&PyObjectRef> { - self.positional_params[pos].typ.as_ref() - } - - #[allow(unused)] - fn kwarg_type(&self, name: &str) -> Option<&PyObjectRef> { - self.keyword_params[name].typ.as_ref() - } - - fn check(&self, vm: &mut VirtualMachine, args: &PyFuncArgs) -> PyResult<()> { - // TODO: check arity - - for (pos, arg) in args.args.iter().enumerate() { - if let Some(expected_type) = self.arg_type(pos) { - if !objtype::real_isinstance(vm, arg, expected_type)? { - let arg_typ = arg.typ(); - let expected_type_name = vm.to_pystr(&expected_type)?; - let actual_type = vm.to_pystr(&arg_typ)?; - return Err(vm.new_type_error(format!( - "argument of type {} is required for parameter {} (got: {})", - expected_type_name, - pos + 1, - actual_type - ))); - } - } - } - - Ok(()) - } -} - -#[derive(Debug)] -pub struct Parameter { - typ: Option, - kind: ParameterKind, -} - -#[derive(Debug)] -pub enum ParameterKind { - PositionalOnly, - KeywordOnly { name: String }, -} - -use self::ParameterKind::*; +into_py_native_func_tuple!((a, A)); +into_py_native_func_tuple!((a, A), (b, B)); +into_py_native_func_tuple!((a, A), (b, B), (c, C)); +into_py_native_func_tuple!((a, A), (b, B), (c, C), (d, D)); +into_py_native_func_tuple!((a, A), (b, B), (c, C), (d, D), (e, E)); /// Rather than determining the type of a python object, this enum is more /// a holder for the rust payload of a python object. It is more a carrier /// of rust data for a particular python object. Determine the python type /// by using for example the `.typ()` method on a python object. pub enum PyObjectPayload { - String { - value: String, - }, Integer { value: BigInt, }, - Float { - value: f64, - }, - Complex { - value: Complex64, - }, - Bytes { - value: RefCell>, - }, Sequence { elements: RefCell>, }, Dict { elements: RefCell, }, - Set { - elements: RefCell>, - }, Iterator { position: Cell, iterated_obj: PyObjectRef, @@ -1191,9 +1444,6 @@ pub enum PyObjectPayload { stop: Option, step: Option, }, - Range { - range: objrange::RangeType, - }, MemoryView { obj: PyObjectRef, }, @@ -1226,6 +1476,9 @@ pub enum PyObjectPayload { dict: RefCell, mro: Vec, }, + Set { + elements: RefCell>, + }, WeakRef { referent: PyObjectWeakRef, }, @@ -1233,7 +1486,7 @@ pub enum PyObjectPayload { dict: RefCell, }, RustFunction { - function: Box PyResult>, + function: PyNativeFunc, }, AnyRustValue { value: Box, @@ -1243,17 +1496,12 @@ pub enum PyObjectPayload { impl fmt::Debug for PyObjectPayload { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - PyObjectPayload::String { ref value } => write!(f, "str \"{}\"", value), PyObjectPayload::Integer { ref value } => write!(f, "int {}", value), - PyObjectPayload::Float { ref value } => write!(f, "float {}", value), - PyObjectPayload::Complex { ref value } => write!(f, "complex {}", value), - PyObjectPayload::Bytes { ref value } => write!(f, "bytes/bytearray {:?}", value), PyObjectPayload::MemoryView { ref obj } => write!(f, "bytes/bytearray {:?}", obj), PyObjectPayload::Sequence { .. } => write!(f, "list or tuple"), PyObjectPayload::Dict { .. } => write!(f, "dict"), PyObjectPayload::Set { .. } => write!(f, "set"), PyObjectPayload::WeakRef { .. } => write!(f, "weakref"), - PyObjectPayload::Range { .. } => write!(f, "range"), PyObjectPayload::Iterator { .. } => write!(f, "iterator"), PyObjectPayload::EnumerateIterator { .. } => write!(f, "enumerate"), PyObjectPayload::FilterIterator { .. } => write!(f, "filter"), @@ -1296,6 +1544,20 @@ impl PyObject { pub fn into_ref(self) -> PyObjectRef { Rc::new(self) } + + pub fn payload(&self) -> Option<&T> { + if let PyObjectPayload::AnyRustValue { ref value } = self.payload { + value.downcast_ref() + } else { + None + } + } +} + +// The intention is for this to replace `PyObjectPayload` once everything is +// converted to use `PyObjectPayload::AnyRustvalue`. +pub trait PyObjectPayload2: std::any::Any + fmt::Debug { + fn required_type(ctx: &PyContext) -> PyObjectRef; } #[cfg(test)] diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 077f907d8..12b85e660 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -15,13 +15,13 @@ use num_traits::ToPrimitive; //custom imports use super::os; +use crate::obj::objbytearray::PyByteArray; use crate::obj::objbytes; use crate::obj::objint; use crate::obj::objstr; use crate::pyobject::{ - AttributeProtocol, BufferProtocol, PyContext, PyFuncArgs, PyObjectPayload, PyObjectRef, - PyResult, TypeProtocol, + AttributeProtocol, BufferProtocol, PyContext, PyFuncArgs, PyObjectRef, PyResult, TypeProtocol, }; use crate::import; @@ -86,8 +86,8 @@ fn buffered_reader_read(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { .map_err(|_| vm.new_value_error("IO Error".to_string()))?; //Copy bytes from the buffer vector into the results vector - if let PyObjectPayload::Bytes { ref value } = buffer.payload { - result.extend(value.borrow().iter().cloned()); + if let Some(bytes) = buffer.payload::() { + result.extend_from_slice(&bytes.value.borrow()); }; let len = vm.get_method(buffer.clone(), &"__len__".to_string()); @@ -169,10 +169,10 @@ fn file_io_readinto(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let handle = os::rust_file(raw_fd); let mut f = handle.take(length); - if let PyObjectPayload::Bytes { ref value } = obj.payload { + if let Some(bytes) = obj.payload::() { //TODO: Implement for MemoryView - let mut value_mut = value.borrow_mut(); + let mut value_mut = bytes.value.borrow_mut(); value_mut.clear(); match f.read_to_end(&mut value_mut) { Ok(_) => {} @@ -200,9 +200,9 @@ fn file_io_write(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { //to support windows - i.e. raw file_handles let mut handle = os::rust_file(raw_fd); - match obj.payload { - PyObjectPayload::Bytes { ref value } => { - let value_mut = value.borrow(); + match obj.payload::() { + Some(bytes) => { + let value_mut = bytes.value.borrow(); match handle.write(&value_mut[..]) { Ok(len) => { //reset raw fd on the FileIO object @@ -215,7 +215,7 @@ fn file_io_write(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Err(_) => Err(vm.new_value_error("Error Writing Bytes to Handle".to_string())), } } - _ => Err(vm.new_value_error("Expected Bytes Object".to_string())), + None => Err(vm.new_value_error("Expected Bytes Object".to_string())), } } @@ -346,101 +346,67 @@ pub fn io_open(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } pub fn mk_module(ctx: &PyContext) -> PyObjectRef { - let py_mod = ctx.new_module(&"io".to_string(), ctx.new_scope(None)); - - ctx.set_attr(&py_mod, "open", ctx.new_rustfunc(io_open)); - //IOBase the abstract base class of the IO Module let io_base = ctx.new_class("IOBase", ctx.object()); - ctx.set_attr(&py_mod, "IOBase", io_base.clone()); // IOBase Subclasses let raw_io_base = ctx.new_class("RawIOBase", ctx.object()); - ctx.set_attr(&py_mod, "RawIOBase", raw_io_base.clone()); - let buffered_io_base = { - let buffered_io_base = ctx.new_class("BufferedIOBase", io_base.clone()); - ctx.set_attr( - &buffered_io_base, - "__init__", - ctx.new_rustfunc(buffered_io_base_init), - ); - buffered_io_base - }; - ctx.set_attr(&py_mod, "BufferedIOBase", buffered_io_base.clone()); + let buffered_io_base = py_class!(ctx, "BufferedIOBase", io_base.clone(), { + "__init__" => ctx.new_rustfunc(buffered_io_base_init) + }); //TextIO Base has no public constructor - let text_io_base = { - let text_io_base = ctx.new_class("TextIOBase", io_base.clone()); - ctx.set_attr(&text_io_base, "read", ctx.new_rustfunc(text_io_base_read)); - text_io_base - }; - ctx.set_attr(&py_mod, "TextIOBase", text_io_base.clone()); + let text_io_base = py_class!(ctx, "TextIOBase", io_base.clone(), { + "read" => ctx.new_rustfunc(text_io_base_read) + }); // RawBaseIO Subclasses - let file_io = { - let file_io = ctx.new_class("FileIO", raw_io_base.clone()); - ctx.set_attr(&file_io, "__init__", ctx.new_rustfunc(file_io_init)); - ctx.set_attr(&file_io, "name", ctx.str_type()); - ctx.set_attr(&file_io, "read", ctx.new_rustfunc(file_io_read)); - ctx.set_attr(&file_io, "readinto", ctx.new_rustfunc(file_io_readinto)); - ctx.set_attr(&file_io, "write", ctx.new_rustfunc(file_io_write)); - file_io - }; - ctx.set_attr(&py_mod, "FileIO", file_io.clone()); + let file_io = py_class!(ctx, "FileIO", raw_io_base.clone(), { + "__init__" => ctx.new_rustfunc(file_io_init), + "name" => ctx.str_type(), + "read" => ctx.new_rustfunc(file_io_read), + "readinto" => ctx.new_rustfunc(file_io_readinto), + "write" => ctx.new_rustfunc(file_io_write) + }); // BufferedIOBase Subclasses - let buffered_reader = { - let buffered_reader = ctx.new_class("BufferedReader", buffered_io_base.clone()); - ctx.set_attr( - &buffered_reader, - "read", - ctx.new_rustfunc(buffered_reader_read), - ); - buffered_reader - }; - ctx.set_attr(&py_mod, "BufferedReader", buffered_reader.clone()); + let buffered_reader = py_class!(ctx, "BufferedReader", buffered_io_base.clone(), { + "read" => ctx.new_rustfunc(buffered_reader_read) + }); - let buffered_writer = { - let buffered_writer = ctx.new_class("BufferedWriter", buffered_io_base.clone()); - ctx.set_attr( - &buffered_writer, - "write", - ctx.new_rustfunc(buffered_writer_write), - ); - buffered_writer - }; - ctx.set_attr(&py_mod, "BufferedWriter", buffered_writer.clone()); + let buffered_writer = py_class!(ctx, "BufferedWriter", buffered_io_base.clone(), { + "write" => ctx.new_rustfunc(buffered_writer_write) + }); //TextIOBase Subclass - let text_io_wrapper = { - let text_io_wrapper = ctx.new_class("TextIOWrapper", text_io_base.clone()); - ctx.set_attr( - &text_io_wrapper, - "__init__", - ctx.new_rustfunc(text_io_wrapper_init), - ); - text_io_wrapper - }; - ctx.set_attr(&py_mod, "TextIOWrapper", text_io_wrapper.clone()); + let text_io_wrapper = py_class!(ctx, "TextIOWrapper", text_io_base.clone(), { + "__init__" => ctx.new_rustfunc(text_io_wrapper_init) + }); //StringIO: in-memory text - let string_io = { - let string_io = ctx.new_class("StringIO", text_io_base.clone()); - ctx.set_attr(&string_io, "__init__", ctx.new_rustfunc(string_io_init)); - ctx.set_attr(&string_io, "getvalue", ctx.new_rustfunc(string_io_getvalue)); - string_io - }; - ctx.set_attr(&py_mod, "StringIO", string_io); + let string_io = py_class!(ctx, "StringIO", text_io_base.clone(), { + "__init__" => ctx.new_rustfunc(string_io_init), + "getvalue" => ctx.new_rustfunc(string_io_getvalue) + }); //BytesIO: in-memory bytes - let bytes_io = { - let bytes_io = ctx.new_class("BytesIO", buffered_io_base.clone()); - ctx.set_attr(&bytes_io, "__init__", ctx.new_rustfunc(bytes_io_init)); - ctx.set_attr(&bytes_io, "getvalue", ctx.new_rustfunc(bytes_io_getvalue)); - bytes_io - }; - ctx.set_attr(&py_mod, "BytesIO", bytes_io); + let bytes_io = py_class!(ctx, "BytesIO", buffered_io_base.clone(), { + "__init__" => ctx.new_rustfunc(bytes_io_init), + "getvalue" => ctx.new_rustfunc(bytes_io_getvalue) + }); - py_mod + py_module!(ctx, "io", { + "open" => ctx.new_rustfunc(io_open), + "IOBase" => io_base.clone(), + "RawIOBase" => raw_io_base.clone(), + "BufferedIOBase" => buffered_io_base.clone(), + "TextIOBase" => text_io_base.clone(), + "FileIO" => file_io.clone(), + "BufferedReader" => buffered_reader.clone(), + "BufferedWriter" => buffered_writer.clone(), + "TextIOWrapper" => text_io_wrapper.clone(), + "StringIO" => string_io, + "BytesIO" => bytes_io, + }) } diff --git a/vm/src/stdlib/json.rs b/vm/src/stdlib/json.rs index fef284000..7403c8d8c 100644 --- a/vm/src/stdlib/json.rs +++ b/vm/src/stdlib/json.rs @@ -5,7 +5,11 @@ use serde::de::{DeserializeSeed, Visitor}; use serde::ser::{SerializeMap, SerializeSeq}; use serde_json; -use crate::obj::{objbool, objdict, objfloat, objint, objsequence, objstr, objtype}; +use crate::obj::{ + objbool, objdict, objfloat, objint, objsequence, + objstr::{self, PyString}, + objtype, +}; use crate::pyobject::{ create_type, DictProtocol, PyContext, PyFuncArgs, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, @@ -167,8 +171,8 @@ impl<'de> Visitor<'de> for PyObjectDeserializer<'de> { // than wrapping the given object up and then unwrapping it to determine whether or // not it is a string while let Some((key_obj, value)) = access.next_entry_seed(self.clone(), self.clone())? { - let key = match key_obj.payload { - PyObjectPayload::String { ref value } => value.clone(), + let key: String = match key_obj.payload::() { + Some(PyString { ref value }) => value.clone(), _ => unimplemented!("map keys must be strings"), }; dict.set_item(&self.vm.ctx, &key, value); diff --git a/vm/src/stdlib/pystruct.rs b/vm/src/stdlib/pystruct.rs index 5599e61a4..c96bba736 100644 --- a/vm/src/stdlib/pystruct.rs +++ b/vm/src/stdlib/pystruct.rs @@ -46,31 +46,19 @@ fn get_int(vm: &mut VirtualMachine, arg: &PyObjectRef) -> PyResult { objint::to_int(vm, arg, 10) } -fn pack_i8( - vm: &mut VirtualMachine, - arg: &PyObjectRef, - data: &mut Write, -) -> PyResult<()> { +fn pack_i8(vm: &mut VirtualMachine, arg: &PyObjectRef, data: &mut Write) -> PyResult<()> { let v = get_int(vm, arg)?.to_i8().unwrap(); data.write_i8(v).unwrap(); Ok(()) } -fn pack_u8( - vm: &mut VirtualMachine, - arg: &PyObjectRef, - data: &mut Write, -) -> PyResult<()> { +fn pack_u8(vm: &mut VirtualMachine, arg: &PyObjectRef, data: &mut Write) -> PyResult<()> { let v = get_int(vm, arg)?.to_u8().unwrap(); data.write_u8(v).unwrap(); Ok(()) } -fn pack_bool( - vm: &mut VirtualMachine, - arg: &PyObjectRef, - data: &mut Write, -) -> PyResult<()> { +fn pack_bool(vm: &mut VirtualMachine, arg: &PyObjectRef, data: &mut Write) -> PyResult<()> { if objtype::isinstance(&arg, &vm.ctx.bool_type()) { let v = if objbool::get_value(arg) { 1 } else { 0 }; data.write_u8(v).unwrap(); @@ -80,71 +68,43 @@ fn pack_bool( } } -fn pack_i16( - vm: &mut VirtualMachine, - arg: &PyObjectRef, - data: &mut Write, -) -> PyResult<()> { +fn pack_i16(vm: &mut VirtualMachine, arg: &PyObjectRef, data: &mut Write) -> PyResult<()> { let v = get_int(vm, arg)?.to_i16().unwrap(); data.write_i16::(v).unwrap(); Ok(()) } -fn pack_u16( - vm: &mut VirtualMachine, - arg: &PyObjectRef, - data: &mut Write, -) -> PyResult<()> { +fn pack_u16(vm: &mut VirtualMachine, arg: &PyObjectRef, data: &mut Write) -> PyResult<()> { let v = get_int(vm, arg)?.to_u16().unwrap(); data.write_u16::(v).unwrap(); Ok(()) } -fn pack_i32( - vm: &mut VirtualMachine, - arg: &PyObjectRef, - data: &mut Write, -) -> PyResult<()> { +fn pack_i32(vm: &mut VirtualMachine, arg: &PyObjectRef, data: &mut Write) -> PyResult<()> { let v = get_int(vm, arg)?.to_i32().unwrap(); data.write_i32::(v).unwrap(); Ok(()) } -fn pack_u32( - vm: &mut VirtualMachine, - arg: &PyObjectRef, - data: &mut Write, -) -> PyResult<()> { +fn pack_u32(vm: &mut VirtualMachine, arg: &PyObjectRef, data: &mut Write) -> PyResult<()> { let v = get_int(vm, arg)?.to_u32().unwrap(); data.write_u32::(v).unwrap(); Ok(()) } -fn pack_i64( - vm: &mut VirtualMachine, - arg: &PyObjectRef, - data: &mut Write, -) -> PyResult<()> { +fn pack_i64(vm: &mut VirtualMachine, arg: &PyObjectRef, data: &mut Write) -> PyResult<()> { let v = get_int(vm, arg)?.to_i64().unwrap(); data.write_i64::(v).unwrap(); Ok(()) } -fn pack_u64( - vm: &mut VirtualMachine, - arg: &PyObjectRef, - data: &mut Write, -) -> PyResult<()> { +fn pack_u64(vm: &mut VirtualMachine, arg: &PyObjectRef, data: &mut Write) -> PyResult<()> { let v = get_int(vm, arg)?.to_u64().unwrap(); data.write_u64::(v).unwrap(); Ok(()) } -fn pack_f32( - vm: &mut VirtualMachine, - arg: &PyObjectRef, - data: &mut Write, -) -> PyResult<()> { +fn pack_f32(vm: &mut VirtualMachine, arg: &PyObjectRef, data: &mut Write) -> PyResult<()> { if objtype::isinstance(&arg, &vm.ctx.float_type()) { let v = objfloat::get_value(arg) as f32; data.write_f32::(v).unwrap(); @@ -154,11 +114,7 @@ fn pack_f32( } } -fn pack_f64( - vm: &mut VirtualMachine, - arg: &PyObjectRef, - data: &mut Write, -) -> PyResult<()> { +fn pack_f64(vm: &mut VirtualMachine, arg: &PyObjectRef, data: &mut Write) -> PyResult<()> { if objtype::isinstance(&arg, &vm.ctx.float_type()) { let v = objfloat::get_value(arg) as f64; data.write_f64::(v).unwrap(); diff --git a/vm/src/stdlib/re.rs b/vm/src/stdlib/re.rs index 0c0cef934..b262ff76a 100644 --- a/vm/src/stdlib/re.rs +++ b/vm/src/stdlib/re.rs @@ -50,13 +50,11 @@ fn re_search(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } pub fn mk_module(ctx: &PyContext) -> PyObjectRef { - let py_mod = ctx.new_module("re", ctx.new_scope(None)); + let match_type = py_class!(ctx, "Match", ctx.object(), {}); - let match_type = ctx.new_class("Match", ctx.object()); - ctx.set_attr(&py_mod, "Match", match_type); - - ctx.set_attr(&py_mod, "match", ctx.new_rustfunc(re_match)); - ctx.set_attr(&py_mod, "search", ctx.new_rustfunc(re_search)); - - py_mod + py_module!(ctx, "re", { + "Match" => match_type, + "match" => ctx.new_rustfunc(re_match), + "search" => ctx.new_rustfunc(re_search) + }) } diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index 30a1bbdc5..4afe7a372 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -2,7 +2,7 @@ use std::cell::RefCell; use std::io; use std::io::Read; use std::io::Write; -use std::net::{SocketAddr, TcpListener, TcpStream}; +use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket}; use std::ops::DerefMut; use crate::obj::objbytes; @@ -53,7 +53,7 @@ impl SocketKind { enum Connection { TcpListener(TcpListener), TcpStream(TcpStream), - // UdpSocket(UdpSocket), + UdpSocket(UdpSocket), } impl Connection { @@ -67,6 +67,21 @@ impl Connection { fn local_addr(&self) -> io::Result { match self { Connection::TcpListener(con) => con.local_addr(), + Connection::UdpSocket(con) => con.local_addr(), + Connection::TcpStream(con) => con.local_addr(), + } + } + + fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + match self { + Connection::UdpSocket(con) => con.recv_from(buf), + _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), + } + } + + fn send_to(&self, buf: &[u8], addr: A) -> io::Result { + match self { + Connection::UdpSocket(con) => con.send_to(buf, addr), _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), } } @@ -76,6 +91,7 @@ impl Read for Connection { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self { Connection::TcpStream(con) => con.read(buf), + Connection::UdpSocket(con) => con.recv(buf), _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), } } @@ -85,6 +101,7 @@ impl Write for Connection { fn write(&mut self, buf: &[u8]) -> io::Result { match self { Connection::TcpStream(con) => con.write(buf), + Connection::UdpSocket(con) => con.send(buf), _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), } } @@ -153,12 +170,27 @@ fn socket_connect(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let mut socket = get_socket(zelf); - if let Ok(stream) = TcpStream::connect(address_string) { - socket.con = Some(Connection::TcpStream(stream)); - Ok(vm.get_none()) - } else { - // TODO: Socket error - Err(vm.new_type_error("socket failed".to_string())) + match socket.socket_kind { + SocketKind::Stream => { + if let Ok(stream) = TcpStream::connect(address_string) { + socket.con = Some(Connection::TcpStream(stream)); + Ok(vm.get_none()) + } else { + // TODO: Socket error + Err(vm.new_type_error("socket failed".to_string())) + } + } + SocketKind::Dgram => { + if let Some(Connection::UdpSocket(con)) = &socket.con { + match con.connect(address_string) { + Ok(_) => Ok(vm.get_none()), + // TODO: Socket error + Err(_) => Err(vm.new_type_error("socket failed".to_string())), + } + } else { + Err(vm.new_type_error("".to_string())) + } + } } } @@ -173,12 +205,25 @@ fn socket_bind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { let mut socket = get_socket(zelf); - if let Ok(stream) = TcpListener::bind(address_string) { - socket.con = Some(Connection::TcpListener(stream)); - Ok(vm.get_none()) - } else { - // TODO: Socket error - Err(vm.new_type_error("socket failed".to_string())) + match socket.socket_kind { + SocketKind::Stream => { + if let Ok(stream) = TcpListener::bind(address_string) { + socket.con = Some(Connection::TcpListener(stream)); + Ok(vm.get_none()) + } else { + // TODO: Socket error + Err(vm.new_type_error("socket failed".to_string())) + } + } + SocketKind::Dgram => { + if let Ok(dgram) = UdpSocket::bind(address_string) { + socket.con = Some(Connection::UdpSocket(dgram)); + Ok(vm.get_none()) + } else { + // TODO: Socket error + Err(vm.new_type_error("socket failed".to_string())) + } + } } } @@ -225,8 +270,8 @@ fn socket_accept(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { None => return Err(vm.new_type_error("".to_string())), }; - let tcp_stream = match ret { - Ok((socket, _addr)) => socket, + let (tcp_stream, addr) = match ret { + Ok((socket, addr)) => (socket, addr), _ => return Err(vm.new_type_error("".to_string())), }; @@ -243,12 +288,9 @@ fn socket_accept(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { zelf.typ(), ); - let elements = RefCell::new(vec![sock_obj, vm.get_none()]); + let addr_tuple = get_addr_tuple(vm, addr)?; - Ok(PyObject::new( - PyObjectPayload::Sequence { elements }, - vm.ctx.tuple_type(), - )) + Ok(vm.ctx.new_tuple(vec![sock_obj, addr_tuple])) } fn socket_recv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -267,6 +309,31 @@ fn socket_recv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_bytes(buffer)) } +fn socket_recvfrom(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(zelf, None), (bufsize, Some(vm.ctx.int_type()))] + ); + + let mut socket = get_socket(zelf); + + let mut buffer = vec![0u8; objint::get_value(bufsize).to_usize().unwrap()]; + let ret = match socket.con { + Some(ref mut v) => v.recv_from(&mut buffer), + None => return Err(vm.new_type_error("".to_string())), + }; + + let addr = match ret { + Ok((_size, addr)) => addr, + _ => return Err(vm.new_type_error("".to_string())), + }; + + let addr_tuple = get_addr_tuple(vm, addr)?; + + Ok(vm.ctx.new_tuple(vec![vm.ctx.new_bytes(buffer), addr_tuple])) +} + fn socket_send(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, @@ -282,6 +349,46 @@ fn socket_send(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.get_none()) } +fn socket_sendto(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [ + (zelf, None), + (bytes, Some(vm.ctx.bytes_type())), + (address, Some(vm.ctx.tuple_type())) + ] + ); + let address_string = get_address_string(vm, address)?; + + let mut socket = get_socket(zelf); + + match socket.socket_kind { + SocketKind::Dgram => { + match socket.con { + Some(ref mut v) => { + if let Ok(_) = v.send_to(&objbytes::get_value(&bytes), address_string) { + Ok(vm.get_none()) + } else { + Err(vm.new_type_error("socket failed".to_string())) + } + } + None => { + // Doing implicit bind + if let Ok(dgram) = UdpSocket::bind("0.0.0.0:0") { + if let Ok(_) = dgram.send_to(&objbytes::get_value(&bytes), address_string) { + socket.con = Some(Connection::UdpSocket(dgram)); + return Ok(vm.get_none()); + } + } + Err(vm.new_type_error("socket failed".to_string())) + } + } + } + _ => Err(vm.new_not_implemented_error("".to_string())), + } +} + fn socket_close(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(zelf, None)]); @@ -300,45 +407,37 @@ fn socket_getsockname(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { }; match addr { - Ok(addr) => { - let port = vm.ctx.new_int(addr.port()); - let ip = vm.ctx.new_str(addr.ip().to_string()); - let elements = RefCell::new(vec![ip, port]); - - Ok(PyObject::new( - PyObjectPayload::Sequence { elements }, - vm.ctx.tuple_type(), - )) - } + Ok(addr) => get_addr_tuple(vm, addr), _ => Err(vm.new_type_error("".to_string())), } } -pub fn mk_module(ctx: &PyContext) -> PyObjectRef { - let py_mod = ctx.new_module(&"socket".to_string(), ctx.new_scope(None)); +fn get_addr_tuple(vm: &mut VirtualMachine, addr: SocketAddr) -> PyResult { + let port = vm.ctx.new_int(addr.port()); + let ip = vm.ctx.new_str(addr.ip().to_string()); - ctx.set_attr(&py_mod, "AF_INET", ctx.new_int(AddressFamily::Inet as i32)); - - ctx.set_attr( - &py_mod, - "SOCK_STREAM", - ctx.new_int(SocketKind::Stream as i32), - ); - - let socket = { - let socket = ctx.new_class("socket", ctx.object()); - ctx.set_attr(&socket, "__new__", ctx.new_rustfunc(socket_new)); - ctx.set_attr(&socket, "connect", ctx.new_rustfunc(socket_connect)); - ctx.set_attr(&socket, "recv", ctx.new_rustfunc(socket_recv)); - ctx.set_attr(&socket, "send", ctx.new_rustfunc(socket_send)); - ctx.set_attr(&socket, "bind", ctx.new_rustfunc(socket_bind)); - ctx.set_attr(&socket, "accept", ctx.new_rustfunc(socket_accept)); - ctx.set_attr(&socket, "listen", ctx.new_rustfunc(socket_listen)); - ctx.set_attr(&socket, "close", ctx.new_rustfunc(socket_close)); - ctx.set_attr(&socket, "getsockname", ctx.new_rustfunc(socket_getsockname)); - socket - }; - ctx.set_attr(&py_mod, "socket", socket.clone()); - - py_mod + Ok(vm.ctx.new_tuple(vec![ip, port])) +} + +pub fn mk_module(ctx: &PyContext) -> PyObjectRef { + let socket = py_class!(ctx, "socket", ctx.object(), { + "__new__" => ctx.new_rustfunc(socket_new), + "connect" => ctx.new_rustfunc(socket_connect), + "recv" => ctx.new_rustfunc(socket_recv), + "send" => ctx.new_rustfunc(socket_send), + "bind" => ctx.new_rustfunc(socket_bind), + "accept" => ctx.new_rustfunc(socket_accept), + "listen" => ctx.new_rustfunc(socket_listen), + "close" => ctx.new_rustfunc(socket_close), + "getsockname" => ctx.new_rustfunc(socket_getsockname), + "sendto" => ctx.new_rustfunc(socket_sendto), + "recvfrom" => ctx.new_rustfunc(socket_recvfrom), + }); + + py_module!(ctx, "socket", { + "AF_INET" => ctx.new_int(AddressFamily::Inet as i32), + "SOCK_STREAM" => ctx.new_int(SocketKind::Stream as i32), + "SOCK_DGRAM" => ctx.new_int(SocketKind::Dgram as i32), + "socket" => socket.clone(), + }) } diff --git a/vm/src/vm.rs b/vm/src/vm.rs index 9d91abf9e..60a260479 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -90,6 +90,12 @@ impl VirtualMachine { result } + pub fn current_scope(&self) -> &ScopeRef { + let current_frame = &self.frames[self.frames.len() - 1]; + let frame = objframe::get_value(current_frame); + &frame.scope + } + /// Create a new python string object. pub fn new_str(&self, s: String) -> PyObjectRef { self.ctx.new_str(s) @@ -218,7 +224,7 @@ impl VirtualMachine { &self.ctx } - pub fn get_builtin_scope(&mut self) -> ScopeRef { + pub fn get_builtin_scope(&self) -> ScopeRef { let a2 = &*self.builtins; match a2.payload { PyObjectPayload::Module { ref scope, .. } => scope.clone(), @@ -242,6 +248,26 @@ impl VirtualMachine { self.call_method(obj, "__repr__", vec![]) } + /// Determines if `obj` is an instance of `cls`, either directly, indirectly or virtually via + /// the __instancecheck__ magic method. + pub fn isinstance(&mut self, obj: &PyObjectRef, cls: &PyObjectRef) -> PyResult { + // cpython first does an exact check on the type, although documentation doesn't state that + // https://github.com/python/cpython/blob/a24107b04c1277e3c1105f98aff5bfa3a98b33a0/Objects/abstract.c#L2408 + if Rc::ptr_eq(&obj.typ(), cls) { + Ok(true) + } else { + let ret = self.call_method(cls, "__instancecheck__", vec![obj.clone()])?; + objbool::boolval(self, ret) + } + } + + /// Determines if `subclass` is a subclass of `cls`, either directly, indirectly or virtually + /// via the __subclasscheck__ magic method. + pub fn issubclass(&mut self, subclass: &PyObjectRef, cls: &PyObjectRef) -> PyResult { + let ret = self.call_method(cls, "__subclasscheck__", vec![subclass.clone()])?; + objbool::boolval(self, ret) + } + pub fn call_get_descriptor(&mut self, attr: PyObjectRef, obj: PyObjectRef) -> PyResult { let attr_class = attr.typ(); if let Some(descriptor) = attr_class.get_attr("__get__") { diff --git a/wasm/lib/src/vm_class.rs b/wasm/lib/src/vm_class.rs index cae5f0c2d..38c27892f 100644 --- a/wasm/lib/src/vm_class.rs +++ b/wasm/lib/src/vm_class.rs @@ -301,7 +301,7 @@ impl WASMVirtualMachine { }; scope .locals - .set_item(&vm.ctx, "print", vm.ctx.new_rustfunc_from_box(print_fn)); + .set_item(&vm.ctx, "print", vm.ctx.new_rustfunc(print_fn)); Ok(()) }, )?