diff --git a/Lib/test/test_marshal.py b/Lib/test/test_marshal.py new file mode 100644 index 000000000..a0642258c --- /dev/null +++ b/Lib/test/test_marshal.py @@ -0,0 +1,91 @@ +import unittest +import marshal + +class MarshalTests(unittest.TestCase): + """ + Testing each data type is done with two tests + Test dumps data == expected_bytes + Test load(dumped data) == data + """ + + def dump_then_load(self, data): + return marshal.loads(marshal.dumps(data)) + + def test_dumps_int(self): + self.assertEqual(marshal.dumps(0), b'i0\x00') + self.assertEqual(marshal.dumps(-1), b'i-\x01') + self.assertEqual(marshal.dumps(1), b'i+\x01') + self.assertEqual(marshal.dumps(100000000), b'i+\x00\xe1\xf5\x05') + + def test_dump_and_load_int(self): + self.assertEqual(self.dump_then_load(0), 0) + self.assertEqual(self.dump_then_load(-1), -1) + self.assertEqual(self.dump_then_load(1), 1) + self.assertEqual(self.dump_then_load(100000000), 100000000) + + def test_dumps_float(self): + self.assertEqual(marshal.dumps(0.0), b'f\x00\x00\x00\x00\x00\x00\x00\x00') + self.assertEqual(marshal.dumps(-10.0), b'f\x00\x00\x00\x00\x00\x00$\xc0') + self.assertEqual(marshal.dumps(10.0), b'f\x00\x00\x00\x00\x00\x00$@') + + def test_dump_and_load_int(self): + self.assertEqual(self.dump_then_load(0.0), 0.0) + self.assertEqual(self.dump_then_load(-10.0), -10.0) + self.assertEqual(self.dump_then_load(10), 10) + + def test_dumps_str(self): + self.assertEqual(marshal.dumps(""), b's') + self.assertEqual(marshal.dumps("Hello, World"), b'sHello, World') + + def test_dump_and_load_str(self): + self.assertEqual(self.dump_then_load(""), "") + self.assertEqual(self.dump_then_load("Hello, World"), "Hello, World") + + def test_dumps_list(self): + # Lists have to print the length of every element + # so when marshelling and unmarshelling we know how many bytes to search + # all usize values are converted to u32 to handle different architecture sizes. + self.assertEqual(marshal.dumps([]), b'[\x00\x00\x00\x00') + self.assertEqual( + marshal.dumps([1, "hello", 1.0]), + b'[\x03\x00\x00\x00\x03\x00\x00\x00i+\x01\x06\x00\x00\x00shello\t\x00\x00\x00f\x00\x00\x00\x00\x00\x00\xf0?', + ) + self.assertEqual( + marshal.dumps([[0], ['a','b']]), + b'[\x02\x00\x00\x00\x0c\x00\x00\x00[\x01\x00\x00\x00\x03\x00\x00\x00i0\x00\x11\x00\x00\x00[\x02\x00\x00\x00\x02\x00\x00\x00sa\x02\x00\x00\x00sb', + ) + + def test_dump_and_load_list(self): + self.assertEqual(self.dump_then_load([]), []) + self.assertEqual(self.dump_then_load([1, "hello", 1.0]), [1, "hello", 1.0]) + self.assertEqual(self.dump_then_load([[0], ['a','b']]),[[0], ['a','b']]) + + def test_dumps_tuple(self): + self.assertEqual(marshal.dumps(()), b'(\x00\x00\x00\x00') + self.assertEqual( + marshal.dumps((1, "hello", 1.0)), + b'(\x03\x00\x00\x00\x03\x00\x00\x00i+\x01\x06\x00\x00\x00shello\t\x00\x00\x00f\x00\x00\x00\x00\x00\x00\xf0?' + ) + + def test_dump_and_load_tuple(self): + self.assertEqual(self.dump_then_load(()), ()) + self.assertEqual(self.dump_then_load((1, "hello", 1.0)), (1, "hello", 1.0)) + + def test_dumps_dict(self): + self.assertEqual(marshal.dumps({}), b',[\x00\x00\x00\x00') + self.assertEqual( + marshal.dumps({'a':1, 1:'a'}), + b',[\x02\x00\x00\x00\x12\x00\x00\x00(\x02\x00\x00\x00\x02\x00\x00\x00sa\x03\x00\x00\x00i+\x01\x12\x00\x00\x00(\x02\x00\x00\x00\x03\x00\x00\x00i+\x01\x02\x00\x00\x00sa' + ) + self.assertEqual( + marshal.dumps({'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}), + b',[\x02\x00\x00\x00+\x00\x00\x00(\x02\x00\x00\x00\x02\x00\x00\x00sa\x1c\x00\x00\x00,[\x01\x00\x00\x00\x12\x00\x00\x00(\x02\x00\x00\x00\x02\x00\x00\x00sb\x03\x00\x00\x00i+\x02<\x00\x00\x00(\x02\x00\x00\x00\x02\x00\x00\x00sc-\x00\x00\x00[\x04\x00\x00\x00\t\x00\x00\x00f\x00\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00f\x00\x00\x00\x00\x00\x00\x10@\x03\x00\x00\x00i+\x06\x03\x00\x00\x00i+\t' + ) + + def test_dump_and_load_dict(self): + self.assertEqual(self.dump_then_load({}), {}) + self.assertEqual(self.dump_then_load({'a':1, 1:'a'}), {'a':1, 1:'a'}) + self.assertEqual(self.dump_then_load({'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}), {'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index d5d1c6e09..b589e121a 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -69,6 +69,10 @@ impl PyDict { &self.entries } + pub(crate) fn from_entries(entries: DictContentType) -> Self { + Self { entries } + } + #[pyslot] fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { PyDict::default().into_pyresult_with_type(vm, cls) diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index e29b95b1a..e3784fb4a 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -114,6 +114,28 @@ struct DictEntry { value: T, } +impl DictEntry { + pub(crate) fn as_tuple(&self) -> (PyObjectRef, T) { + (self.key.clone(), self.value.clone()) + } +} + +impl Dict { + pub(crate) fn as_kvpairs(&self) -> Vec<(PyObjectRef, T)> { + let entries = &self.inner.read().entries; + entries + .into_iter() + .filter_map(|entry| { + if let Some(dict_entry) = entry { + Some(dict_entry.as_tuple()) + } else { + None + } + }) + .collect() + } +} + #[derive(Debug, PartialEq)] pub struct DictSize { indices_size: usize, diff --git a/vm/src/protocol/buffer.rs b/vm/src/protocol/buffer.rs index c69b4efc2..6ca3ed380 100644 --- a/vm/src/protocol/buffer.rs +++ b/vm/src/protocol/buffer.rs @@ -10,7 +10,7 @@ use crate::{ }, sliceable::wrap_index, types::{Constructor, Unconstructible}, - PyObject, PyObjectPayload, PyObjectRef, PyObjectView, PyObjectWrap, PyRef, PyResult, + PyObject, PyObjectPayload, PyObjectRef, PyObjectView, PyObjectWrap, PyRef, PyResult, PyValue, TryFromBorrowedObject, TypeProtocol, VirtualMachine, }; use std::{borrow::Cow, fmt::Debug, ops::Range}; @@ -63,6 +63,15 @@ impl PyBuffer { .then(|| unsafe { self.contiguous_mut_unchecked() }) } + pub fn from_byte_vector(bytes: Vec, vm: &VirtualMachine) -> Self { + let bytes_len = bytes.len(); + PyBuffer::new( + PyValue::into_object(VecBuffer::from(bytes), vm), + BufferDescriptor::simple(bytes_len, true), + &VEC_BUFFER_METHODS, + ) + } + /// # Safety /// assume the buffer is contiguous pub unsafe fn contiguous_unchecked(&self) -> BorrowedValue<[u8]> { diff --git a/vm/src/stdlib/marshal.rs b/vm/src/stdlib/marshal.rs index f7ae60eb1..211d1713a 100644 --- a/vm/src/stdlib/marshal.rs +++ b/vm/src/stdlib/marshal.rs @@ -2,23 +2,113 @@ pub(crate) use decl::make_module; #[pymodule(name = "marshal")] mod decl { + /// TODO add support for Booleans, Sets, etc + use ascii::AsciiStr; + use num_bigint::{BigInt, Sign}; + use std::ops::Deref; + use std::slice::Iter; + use crate::{ - builtins::{PyBytes, PyCode}, + builtins::{ + dict::DictContentType, PyBytes, PyCode, PyDict, PyFloat, PyInt, PyList, PyStr, PyTuple, + }, bytecode, - function::ArgBytesLike, + common::borrow::BorrowedValue, + function::{ArgBytesLike, IntoPyObject}, + protocol::PyBuffer, PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; + const STR_BYTE: u8 = b's'; + const INT_BYTE: u8 = b'i'; + const FLOAT_BYTE: u8 = b'f'; + const LIST_BYTE: u8 = b'['; + const TUPLE_BYTE: u8 = b'('; + const DICT_BYTE: u8 = b','; + + /// Safely convert usize to 4 le bytes + fn size_to_bytes(x: usize, vm: &VirtualMachine) -> PyResult<[u8; 4]> { + // For marshalling we want to convert lengths to bytes. To save space + // we limit the size to u32 to keep marshalling smaller. + match u32::try_from(x) { + Ok(n) => Ok(n.to_le_bytes()), + Err(_) => { + Err(vm.new_value_error("Size exceeds 2^32 capacity for marshalling.".to_owned())) + } + } + } + + /// Dumps a iterator of objects into binary vector. + fn dump_list(pyobjs: Iter, vm: &VirtualMachine) -> PyResult> { + let mut byte_list = size_to_bytes(pyobjs.len(), vm)?.to_vec(); + // For each element, dump into binary, then add its length and value. + for element in pyobjs { + let element_bytes: PyBytes = dumps(element.clone(), vm)?; + byte_list.extend(size_to_bytes(element_bytes.len(), vm)?); + byte_list.extend_from_slice(element_bytes.deref()) + } + Ok(byte_list) + } + #[pyfunction] fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult { let r = match_class!(match value { - co @ PyCode => { - PyBytes::from(co.code.map_clone_bag(&bytecode::BasicBag).to_bytes()) + pyint @ PyInt => { + let (sign, uint_bytes) = pyint.as_bigint().to_bytes_le(); + let sign_byte = match sign { + Sign::Minus => b'-', + Sign::NoSign => b'0', + Sign::Plus => b'+', + }; + // Return as [TYPE, SIGN, uint bytes] + PyBytes::from([vec![INT_BYTE, sign_byte], uint_bytes].concat()) } - _ => + pyfloat @ PyFloat => { + let mut float_bytes = pyfloat.to_f64().to_le_bytes().to_vec(); + float_bytes.insert(0, FLOAT_BYTE); + PyBytes::from(float_bytes) + } + pystr @ PyStr => { + let mut str_bytes = pystr.as_str().as_bytes().to_vec(); + str_bytes.insert(0, STR_BYTE); + PyBytes::from(str_bytes) + } + pylist @ PyList => { + let pylist_items = pylist.borrow_vec(); + let mut list_bytes = dump_list(pylist_items.iter(), vm)?; + list_bytes.insert(0, LIST_BYTE); + PyBytes::from(list_bytes) + } + pytuple @ PyTuple => { + let mut tuple_bytes = dump_list(pytuple.as_slice().iter(), vm)?; + tuple_bytes.insert(0, TUPLE_BYTE); + PyBytes::from(tuple_bytes) + } + pydict @ PyDict => { + let key_value_pairs = pydict._as_dict_inner().clone().as_kvpairs(); + // Converts list of tuples to PyObjectRefs of tuples + let elements: Vec = key_value_pairs + .into_iter() + .map(|(k, v)| { + PyTuple::new_ref(vec![k, v], &vm.ctx).into_pyobject(vm) + }) + .collect(); + // Converts list of tuples to list, dump into binary + let mut dict_bytes = dump_list(elements.iter(), vm)?; + dict_bytes.insert(0, LIST_BYTE); + dict_bytes.insert(0, DICT_BYTE); + PyBytes::from(dict_bytes) + } + co @ PyCode => { + // Code is default, doesn't have prefix. + let code_bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes(); + PyBytes::from(code_bytes) + } + _ => { return Err(vm.new_not_implemented_error( - "TODO: not implemented yet or marshal unsupported type".to_owned() - )), + "TODO: not implemented yet or marshal unsupported type".to_owned(), + )); + } }); Ok(r) } @@ -30,25 +120,159 @@ mod decl { Ok(()) } - #[pyfunction] - fn loads(code_bytes: ArgBytesLike, vm: &VirtualMachine) -> PyResult { - let buf = &*code_bytes.borrow_buf(); - let code = bytecode::CodeObject::from_bytes(buf).map_err(|e| match e { - bytecode::CodeDeserializeError::Eof => vm.new_exception_msg( - vm.ctx.exceptions.eof_error.clone(), - "end of file while deserializing bytecode".to_owned(), - ), - _ => vm.new_value_error("Couldn't deserialize python bytecode".to_owned()), - })?; - Ok(PyCode { - code: vm.map_codeobj(code), - }) + /// Read the next 4 bytes of a slice, convert to u32. + /// Side effect: increasing position pointer by 4. + fn eat_u32(bytes: &[u8], position: &mut usize, vm: &VirtualMachine) -> PyResult { + let length_as_u32 = + u32::from_le_bytes(match bytes[*position..(*position + 4)].try_into() { + Ok(length_as_u32) => length_as_u32, + Err(_) => { + return Err( + vm.new_buffer_error("Could not read u32 size from byte array".to_owned()) + ) + } + }); + *position += 4; + Ok(length_as_u32) + } + + /// Reads next element from a python list. First by getting element size + /// then by building a pybuffer and "loading" the pyobject. + /// Moves the position pointer past the element. + fn next_element_of_list( + buf: &BorrowedValue<[u8]>, + position: &mut usize, + vm: &VirtualMachine, + ) -> PyResult { + // Read size of the current element from buffer. + let element_length = eat_u32(buf, position, vm)? as usize; + // Create pybuffer consisting of the data in the next element. + let pybuffer = + PyBuffer::from_byte_vector(buf[*position..(*position + element_length)].to_vec(), vm); + // Move position pointer past element. + *position += element_length; + // Return marshalled element. + loads(pybuffer, vm) + } + + /// Reads a list (or tuple) from a buffer. + fn read_list(buf: &BorrowedValue<[u8]>, vm: &VirtualMachine) -> PyResult> { + let mut position = 1; + let expected_array_len = eat_u32(buf, &mut position, vm)? as usize; + // Read each element in list, incrementing position pointer to reflect position in the buffer. + let mut elements: Vec = Vec::new(); + while position < buf.len() { + elements.push(next_element_of_list(buf, &mut position, vm)?); + } + debug_assert!(expected_array_len == elements.len()); + debug_assert!(buf.len() == position); + Ok(elements) + } + + /// Builds a PyDict from iterator of tuple objects + pub fn from_tuples(iterable: Iter, vm: &VirtualMachine) -> PyResult { + let dict = DictContentType::default(); + for elem in iterable { + let items = match_class!(match elem.clone() { + pytuple @ PyTuple => pytuple.as_slice().to_vec(), + _ => + return Err(vm.new_value_error( + "Couldn't unmarshal key:value pair of dictionary".to_owned() + )), + }); + // Marshalled tuples are always in format key:value. + dict.insert( + vm, + items.get(0).unwrap().clone(), + items.get(1).unwrap().clone(), + )?; + } + Ok(PyDict::from_entries(dict)) } #[pyfunction] - fn load(f: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn loads(pybuffer: PyBuffer, vm: &VirtualMachine) -> PyResult { + let buf = &pybuffer.as_contiguous().ok_or_else(|| { + vm.new_buffer_error("Buffer provided to marshal.loads() is not contiguous".to_owned()) + })?; + match buf[0] { + INT_BYTE => { + let sign = match buf[1] { + b'-' => Sign::Minus, + b'0' => Sign::NoSign, + b'+' => Sign::Plus, + _ => { + return Err(vm.new_value_error( + "Unknown sign byte when trying to unmarshal integer".to_owned(), + )) + } + }; + let pyint = BigInt::from_bytes_le(sign, &buf[2..buf.len()]); + Ok(pyint.into_pyobject(vm)) + } + FLOAT_BYTE => { + let number = f64::from_le_bytes(match buf[1..buf.len()].try_into() { + Ok(byte_array) => byte_array, + Err(e) => { + return Err(vm.new_value_error(format!( + "Expected float, could not load from bytes. {}", + e + ))) + } + }); + let pyfloat = PyFloat::from(number); + Ok(pyfloat.into_pyobject(vm)) + } + STR_BYTE => { + let pystr = PyStr::from(match AsciiStr::from_ascii(&buf[1..buf.len()]) { + Ok(ascii_str) => ascii_str, + Err(e) => { + return Err( + vm.new_value_error(format!("Cannot unmarshal bytes to string, {}", e)) + ) + } + }); + Ok(pystr.into_pyobject(vm)) + } + LIST_BYTE => { + let elements = read_list(buf, vm)?; + Ok(elements.into_pyobject(vm)) + } + TUPLE_BYTE => { + let elements = read_list(buf, vm)?; + let pytuple = PyTuple::new_ref(elements, &vm.ctx).into_pyobject(vm); + Ok(pytuple) + } + DICT_BYTE => { + let pybuffer = PyBuffer::from_byte_vector(buf[1..buf.len()].to_vec(), vm); + let pydict = match_class!(match loads(pybuffer, vm)? { + pylist @ PyList => from_tuples(pylist.borrow_vec().iter(), vm)?, + _ => + return Err(vm.new_value_error("Couldn't unmarshal dicitionary.".to_owned())), + }); + Ok(pydict.into_pyobject(vm)) + } + _ => { + // If prefix is not identifiable, assume CodeObject, error out if it doesn't match. + let code = bytecode::CodeObject::from_bytes(&buf).map_err(|e| match e { + bytecode::CodeDeserializeError::Eof => vm.new_exception_msg( + vm.ctx.exceptions.eof_error.clone(), + "End of file while deserializing bytecode".to_owned(), + ), + _ => vm.new_value_error("Couldn't deserialize python bytecode".to_owned()), + })?; + Ok(PyCode { + code: vm.map_codeobj(code), + } + .into_pyobject(vm)) + } + } + } + + #[pyfunction] + fn load(f: PyObjectRef, vm: &VirtualMachine) -> PyResult { let read_res = vm.call_method(&f, "read", ())?; let bytes = ArgBytesLike::try_from_object(vm, read_res)?; - loads(bytes, vm) + loads(PyBuffer::from(bytes), vm) } }