From aadaf18219ccbc8ddf0182f1bddd7fd582ff336d Mon Sep 17 00:00:00 2001 From: Jake Armendariz Date: Wed, 16 Mar 2022 18:09:33 -0700 Subject: [PATCH] Marshaling sets, frozen sets, bytearr, and changes to testing --- extra_tests/snippets/test_marshal.py | 69 ++++++++++++++++++---------- vm/src/builtins/set.rs | 12 +++++ vm/src/stdlib/marshal.rs | 56 +++++++++++++++++++--- 3 files changed, 108 insertions(+), 29 deletions(-) diff --git a/extra_tests/snippets/test_marshal.py b/extra_tests/snippets/test_marshal.py index eee468333..2fea7f5c3 100644 --- a/extra_tests/snippets/test_marshal.py +++ b/extra_tests/snippets/test_marshal.py @@ -9,34 +9,57 @@ class MarshalTests(unittest.TestCase): def dump_then_load(self, data): return marshal.loads(marshal.dumps(data)) - def test_dump_and_load_int(self): - self.assertEqual(self.dump_then_load(0), 0) - self.assertEqual(self.dump_then_load(-1), -1) - self.assertEqual(self.dump_then_load(1), 1) - self.assertEqual(self.dump_then_load(100000000), 100000000) + def _test_marshal(self, data): + self.assertEqual(self.dump_then_load(data), data) - def test_dump_and_load_int(self): - self.assertEqual(self.dump_then_load(0.0), 0.0) - self.assertEqual(self.dump_then_load(-10.0), -10.0) - self.assertEqual(self.dump_then_load(10), 10) + def test_marshal_int(self): + self._test_marshal(0) + self._test_marshal(-1) + self._test_marshal(1) + self._test_marshal(100000000) - def test_dump_and_load_str(self): - self.assertEqual(self.dump_then_load(""), "") - self.assertEqual(self.dump_then_load("Hello, World"), "Hello, World") + def test_marshal_float(self): + self._test_marshal(0.0) + self._test_marshal(-10.0) + self._test_marshal(10.0) - def test_dump_and_load_list(self): - self.assertEqual(self.dump_then_load([]), []) - self.assertEqual(self.dump_then_load([1, "hello", 1.0]), [1, "hello", 1.0]) - self.assertEqual(self.dump_then_load([[0], ['a','b']]),[[0], ['a','b']]) + def test_marshal_str(self): + self._test_marshal("") + self._test_marshal("Hello, World") - def test_dump_and_load_tuple(self): - self.assertEqual(self.dump_then_load(()), ()) - self.assertEqual(self.dump_then_load((1, "hello", 1.0)), (1, "hello", 1.0)) + def test_marshal_list(self): + self._test_marshal([]) + self._test_marshal([1, "hello", 1.0]) + self._test_marshal([[0], ['a','b']]) - def test_dump_and_load_dict(self): - self.assertEqual(self.dump_then_load({}), {}) - self.assertEqual(self.dump_then_load({'a':1, 1:'a'}), {'a':1, 1:'a'}) - self.assertEqual(self.dump_then_load({'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}), {'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}) + def test_marshal_tuple(self): + self._test_marshal(()) + self._test_marshal((1, "hello", 1.0)) + + def test_marshal_dict(self): + self._test_marshal({}) + self._test_marshal({'a':1, 1:'a'}) + self._test_marshal({'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}) + + def test_marshal_set(self): + self._test_marshal(set()) + self._test_marshal({1, 2, 3}) + self._test_marshal({1, 'a', 'b'}) + + def test_marshal_frozen_set(self): + self._test_marshal(frozenset()) + self._test_marshal(frozenset({1, 2, 3})) + self._test_marshal(frozenset({1, 'a', 'b'})) + + def test_marshal_bytearray(self): + self.assertEqual( + self.dump_then_load(bytearray([])), + bytearray(b''), + ) + self.assertEqual( + self.dump_then_load(bytearray([1, 2])), + bytearray(b'\x01\x02'), + ) if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index e07145412..a4102ac45 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -32,6 +32,12 @@ pub struct PySet { pub(super) inner: PySetInner, } +impl PySet { + pub fn elements(zelf: PyRef) -> Vec { + zelf.inner.elements() + } +} + /// frozenset() -> empty frozenset object /// frozenset(iterable) -> frozenset object /// @@ -42,6 +48,12 @@ pub struct PyFrozenSet { inner: PySetInner, } +impl PyFrozenSet { + pub fn elements(&self) -> Vec { + self.inner.elements() + } +} + impl fmt::Debug for PySet { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // TODO: implement more detailed, non-recursive Debug formatter diff --git a/vm/src/stdlib/marshal.rs b/vm/src/stdlib/marshal.rs index d5a84f0c1..764f7242f 100644 --- a/vm/src/stdlib/marshal.rs +++ b/vm/src/stdlib/marshal.rs @@ -2,20 +2,24 @@ pub(crate) use decl::make_module; #[pymodule(name = "marshal")] mod decl { - /// TODO add support for Booleans, Sets, etc - use ascii::AsciiStr; - use num_bigint::{BigInt, Sign}; - use std::slice::Iter; - use crate::{ builtins::{ - dict::DictContentType, PyBytes, PyCode, PyDict, PyFloat, PyInt, PyList, PyStr, PyTuple, + dict::DictContentType, PyByteArray, PyBytes, PyCode, PyDict, PyFloat, PyFrozenSet, + PyInt, PyList, PySet, PyStr, PyTuple, }, bytecode, function::{ArgBytesLike, IntoPyObject}, protocol::PyBuffer, PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; + /// TODO + /// PyBool: Difficult because match_class assigns as an int + /// try_to_bool converts (0,1) to (false, true). + /// PyBytes: match_class! gets recursion limit reached error. + use ascii::AsciiStr; + use num_bigint::{BigInt, Sign}; + use std::ops::Deref; + use std::slice::Iter; const STR_BYTE: u8 = b's'; const INT_BYTE: u8 = b'i'; @@ -23,6 +27,9 @@ mod decl { const LIST_BYTE: u8 = b'['; const TUPLE_BYTE: u8 = b'('; const DICT_BYTE: u8 = b','; + const SET_BYTE: u8 = b'~'; + const FROZEN_SET_BYTE: u8 = b'<'; + const BYTE_ARRAY: u8 = b'>'; /// Safely convert usize to 4 le bytes fn size_to_bytes(x: usize, vm: &VirtualMachine) -> PyResult<[u8; 4]> { @@ -48,9 +55,11 @@ mod decl { Ok(byte_list) } + /// Dumping helper function to turn a value into bytes. fn _dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult> { let r = match_class!(match value { pyint @ PyInt => { + // Could not convert to boolean. Assume integer let (sign, mut int_bytes) = pyint.as_bigint().to_bytes_le(); let sign_byte = match sign { Sign::Minus => b'-', @@ -78,6 +87,18 @@ mod decl { list_bytes.push(LIST_BYTE); list_bytes } + pyset @ PySet => { + let items = PySet::elements(pyset); + let mut list_bytes = dump_list(items.iter(), vm)?; + list_bytes.push(SET_BYTE); + list_bytes + } + pyfrozen @ PyFrozenSet => { + let items = pyfrozen.elements(); + let mut list_bytes = dump_list(items.iter(), vm)?; + list_bytes.push(FROZEN_SET_BYTE); + list_bytes + } pytuple @ PyTuple => { let mut tuple_bytes = dump_list(pytuple.as_slice().iter(), vm)?; tuple_bytes.push(TUPLE_BYTE); @@ -96,6 +117,11 @@ mod decl { dict_bytes.push(DICT_BYTE); dict_bytes } + pybyte_array @ PyByteArray => { + let mut pybytes = pybyte_array.borrow_buf_mut(); + pybytes.push(BYTE_ARRAY); + pybytes.deref().clone() + } co @ PyCode => { // Code is default, doesn't have prefix. co.code.map_clone_bag(&bytecode::BasicBag).to_bytes() @@ -235,6 +261,19 @@ mod decl { let elements = read_list(buf, vm)?; Ok(elements.into_pyobject(vm)) } + SET_BYTE => { + let elements = read_list(buf, vm)?; + let set = PySet::new_ref(&vm.ctx); + for element in elements { + set.add(element, vm)?; + } + Ok(set.into_pyobject(vm)) + } + FROZEN_SET_BYTE => { + let elements = read_list(buf, vm)?; + let set = PyFrozenSet::from_iter(vm, elements.into_iter())?; + Ok(set.into_pyobject(vm)) + } TUPLE_BYTE => { let elements = read_list(buf, vm)?; let pytuple = PyTuple::new_ref(elements, &vm.ctx).into_pyobject(vm); @@ -249,6 +288,11 @@ mod decl { }); Ok(pydict.into_pyobject(vm)) } + BYTE_ARRAY => { + // Following CPython, after marshaling, byte arrays are converted into bytes. + let byte_array = PyBytes::from(buf[..].to_vec()); + Ok(byte_array.into_pyobject(vm)) + } _ => { // If prefix is not identifiable, assume CodeObject, error out if it doesn't match. let code = bytecode::CodeObject::from_bytes(&full_buff).map_err(|e| match e {