Merge pull request #4004 from youknowone/marshal

redesign marshal
This commit is contained in:
Jeong YunWon
2022-08-07 03:53:14 +09:00
committed by GitHub
4 changed files with 237 additions and 232 deletions

View File

@@ -93,14 +93,10 @@ class FloatTestCase(unittest.TestCase, HelperMixin):
n *= 123.4567
class StringTestCase(unittest.TestCase, HelperMixin):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_unicode(self):
for s in ["", "Andr\xe8 Previn", "abc", " "*10000]:
self.helper(marshal.loads(marshal.dumps(s)))
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_string(self):
for s in ["", "Andr\xe8 Previn", "abc", " "*10000]:
self.helper(s)
@@ -159,13 +155,9 @@ class ContainerTestCase(unittest.TestCase, HelperMixin):
'aunicode': "Andr\xe8 Previn"
}
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_dict(self):
self.helper(self.d)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_list(self):
self.helper(list(self.d.items()))
@@ -178,8 +170,6 @@ class ContainerTestCase(unittest.TestCase, HelperMixin):
class BufferTestCase(unittest.TestCase, HelperMixin):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_bytearray(self):
b = bytearray(b"abc")
self.helper(b)
@@ -298,8 +288,6 @@ class BugsTestCase(unittest.TestCase):
testString = 'abc' * size
marshal.dumps(testString)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_invalid_longs(self):
# Issue #7019: marshal.loads shouldn't produce unnormalized PyLongs
invalid_string = b'l\x02\x00\x00\x00\x00\x00\x00\x00'

View File

@@ -70,10 +70,6 @@ impl PyDict {
&self.entries
}
pub(crate) fn from_entries(entries: DictContentType) -> Self {
Self { entries }
}
// Used in update and ior.
fn merge_object(
dict: &DictContentType,

View File

@@ -99,22 +99,6 @@ struct DictEntry<T> {
}
static_assertions::assert_eq_size!(DictEntry<PyObjectRef>, Option<DictEntry<PyObjectRef>>);
impl<T: Clone> DictEntry<T> {
pub(crate) fn as_tuple(&self) -> (PyObjectRef, T) {
(self.key.clone(), self.value.clone())
}
}
impl<T: Clone> Dict<T> {
pub(crate) fn as_kvpairs(&self) -> Vec<(PyObjectRef, T)> {
let entries = &self.inner.read().entries;
entries
.iter()
.filter_map(|entry| entry.as_ref().map(|dict_entry| dict_entry.as_tuple()))
.collect()
}
}
#[derive(Debug, PartialEq)]
pub struct DictSize {
indices_size: usize,

View File

@@ -4,8 +4,8 @@ pub(crate) use decl::make_module;
mod decl {
use crate::{
builtins::{
dict::DictContentType, PyByteArray, PyBytes, PyCode, PyDict, PyFloat, PyFrozenSet,
PyInt, PyList, PySet, PyStr, PyTuple,
PyBaseExceptionRef, PyByteArray, PyBytes, PyCode, PyDict, PyFloat, PyFrozenSet, PyInt,
PyList, PySet, PyStr, PyTuple,
},
bytecode,
convert::ToPyObject,
@@ -16,121 +16,166 @@ mod decl {
};
/// TODO
/// PyBytes: Currently getting recursion error with match_class!
use ascii::AsciiStr;
use num_bigint::{BigInt, Sign};
use std::ops::Deref;
use std::slice::Iter;
use num_traits::Zero;
const STR_BYTE: u8 = b's';
const INT_BYTE: u8 = b'i';
const FLOAT_BYTE: u8 = b'f';
const BOOL_BYTE: u8 = b'b';
const LIST_BYTE: u8 = b'[';
const TUPLE_BYTE: u8 = b'(';
const DICT_BYTE: u8 = b',';
const SET_BYTE: u8 = b'~';
const FROZEN_SET_BYTE: u8 = b'<';
const BYTE_ARRAY: u8 = b'>';
#[repr(u8)]
enum Type {
// Null = b'0',
// None = b'N',
False = b'F',
True = b'T',
// StopIter = b'S',
// Ellipsis = b'.',
Int = b'i',
Float = b'g',
// Complex = b'y',
// Long = b'l', // i32
Bytes = b's', // = TYPE_STRING
// Interned = b't',
// Ref = b'r',
Tuple = b'(',
List = b'[',
Dict = b'{',
Code = b'c',
Str = b'u', // = TYPE_UNICODE
// Unknown = b'?',
Set = b'<',
FrozenSet = b'>',
// Ascii = b'a',
// AsciiInterned = b'A',
// SmallTuple = b')',
// ShortAscii = b'z',
// ShortAsciiInterned = b'Z',
}
// const FLAG_REF: u8 = b'\x80';
/// Safely convert usize to 4 le bytes
fn size_to_bytes(x: usize, vm: &VirtualMachine) -> PyResult<[u8; 4]> {
// For marshalling we want to convert lengths to bytes. To save space
// we limit the size to u32 to keep marshalling smaller.
match u32::try_from(x) {
Ok(n) => Ok(n.to_le_bytes()),
Err(_) => {
Err(vm.new_value_error("Size exceeds 2^32 capacity for marshaling.".to_owned()))
}
impl TryFrom<u8> for Type {
type Error = u8;
fn try_from(value: u8) -> Result<Self, u8> {
use Type::*;
Ok(match value {
// b'0' => Null,
// b'N' => None,
b'F' => False,
b'T' => True,
// b'S' => StopIter,
// b'.' => Ellipsis,
b'i' => Int,
b'g' => Float,
// b'y' => Complex,
// b'l' => Long,
b's' => Bytes,
// b't' => Interned,
// b'r' => Ref,
b'(' => Tuple,
b'[' => List,
b'{' => Dict,
b'c' => Code,
b'u' => Str,
// b'?' => Unknown,
b'<' => Set,
b'>' => FrozenSet,
// b'a' => Ascii,
// b'A' => AsciiInterned,
// b')' => SmallTuple,
// b'z' => ShortAscii,
// b'Z' => ShortAsciiInterned,
c => return Err(c),
})
}
}
/// Dumps a iterator of objects into binary vector.
fn dump_list(pyobjs: Iter<PyObjectRef>, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
let mut byte_list = size_to_bytes(pyobjs.len(), vm)?.to_vec();
fn too_short_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_exception_msg(
vm.ctx.exceptions.eof_error.to_owned(),
"marshal data too short".to_owned(),
)
}
/// Dumps a sequence of objects into binary vector.
fn dump_seq(
buf: &mut Vec<u8>,
iter: std::slice::Iter<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<()> {
write_size(buf, iter.len(), vm)?;
// For each element, dump into binary, then add its length and value.
for element in pyobjs {
let element_bytes: Vec<u8> = _dumps(element.clone(), vm)?;
byte_list.extend(size_to_bytes(element_bytes.len(), vm)?);
byte_list.extend(element_bytes)
for element in iter {
dump_obj(buf, element.clone(), vm)?;
}
Ok(byte_list)
Ok(())
}
/// Dumping helper function to turn a value into bytes.
fn _dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
let r = match_class!(match value {
fn dump_obj(buf: &mut Vec<u8>, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
match_class!(match value {
pyint @ PyInt => {
if pyint.class().is(vm.ctx.types.bool_type) {
let (_, mut bool_bytes) = pyint.as_bigint().to_bytes_le();
bool_bytes.push(BOOL_BYTE);
bool_bytes
} else {
let (sign, mut int_bytes) = pyint.as_bigint().to_bytes_le();
let sign_byte = match sign {
Sign::Minus => b'-',
Sign::NoSign => b'0',
Sign::Plus => b'+',
let typ = if pyint.as_bigint().is_zero() {
Type::False
} else {
Type::True
};
// Return as [TYPE, SIGN, uint bytes]
int_bytes.insert(0, sign_byte);
int_bytes.push(INT_BYTE);
int_bytes
buf.push(typ as u8);
} else {
buf.push(Type::Int as u8);
let (sign, int_bytes) = pyint.as_bigint().to_bytes_le();
let mut len = int_bytes.len() as i32;
if sign == Sign::Minus {
len = -len;
}
buf.extend(len.to_le_bytes());
buf.extend(int_bytes);
}
}
pyfloat @ PyFloat => {
let mut float_bytes = pyfloat.to_f64().to_le_bytes().to_vec();
float_bytes.push(FLOAT_BYTE);
float_bytes
buf.push(Type::Float as u8);
buf.extend(pyfloat.to_f64().to_le_bytes());
}
pystr @ PyStr => {
let mut str_bytes = pystr.as_str().as_bytes().to_vec();
str_bytes.push(STR_BYTE);
str_bytes
buf.push(Type::Str as u8);
write_size(buf, pystr.as_str().len(), vm)?;
buf.extend(pystr.as_str().as_bytes());
}
pylist @ PyList => {
buf.push(Type::List as u8);
let pylist_items = pylist.borrow_vec();
let mut list_bytes = dump_list(pylist_items.iter(), vm)?;
list_bytes.push(LIST_BYTE);
list_bytes
dump_seq(buf, pylist_items.iter(), vm)?;
}
pyset @ PySet => {
buf.push(Type::Set as u8);
let elements = pyset.elements();
let mut set_bytes = dump_list(elements.iter(), vm)?;
set_bytes.push(SET_BYTE);
set_bytes
dump_seq(buf, elements.iter(), vm)?;
}
pyfrozen @ PyFrozenSet => {
buf.push(Type::FrozenSet as u8);
let elements = pyfrozen.elements();
let mut fset_bytes = dump_list(elements.iter(), vm)?;
fset_bytes.push(FROZEN_SET_BYTE);
fset_bytes
dump_seq(buf, elements.iter(), vm)?;
}
pytuple @ PyTuple => {
let mut tuple_bytes = dump_list(pytuple.iter(), vm)?;
tuple_bytes.push(TUPLE_BYTE);
tuple_bytes
buf.push(Type::Tuple as u8);
dump_seq(buf, pytuple.iter(), vm)?;
}
pydict @ PyDict => {
let key_value_pairs = pydict._as_dict_inner().clone().as_kvpairs();
// Converts list of tuples to PyObjectRefs of tuples
let elements: Vec<PyObjectRef> = key_value_pairs
.into_iter()
.map(|(k, v)| PyTuple::new_ref(vec![k, v], &vm.ctx).to_pyobject(vm))
.collect();
// Converts list of tuples to list, dump into binary
let mut dict_bytes = dump_list(elements.iter(), vm)?;
dict_bytes.push(LIST_BYTE);
dict_bytes.push(DICT_BYTE);
dict_bytes
buf.push(Type::Dict as u8);
write_size(buf, pydict.len(), vm)?;
for (key, value) in pydict {
dump_obj(buf, key, vm)?;
dump_obj(buf, value, vm)?;
}
}
pybyte_array @ PyByteArray => {
let mut pybytes = pybyte_array.borrow_buf_mut();
pybytes.push(BYTE_ARRAY);
pybytes.deref().to_owned()
bytes @ PyByteArray => {
buf.push(Type::Bytes as u8);
let data = bytes.borrow_buf();
write_size(buf, data.len(), vm)?;
buf.extend(&*data);
}
co @ PyCode => {
// Code is default, doesn't have prefix.
co.code.map_clone_bag(&bytecode::BasicBag).to_bytes()
buf.push(Type::Code as u8);
let bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes();
write_size(buf, bytes.len(), vm)?;
buf.extend(bytes);
}
_ => {
return Err(vm.new_not_implemented_error(
@@ -138,12 +183,14 @@ mod decl {
));
}
});
Ok(r)
Ok(())
}
#[pyfunction]
fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyBytes> {
Ok(PyBytes::from(_dumps(value, vm)?))
let mut buf = Vec::new();
dump_obj(&mut buf, value, vm)?;
Ok(PyBytes::from(buf))
}
#[pyfunction]
@@ -153,161 +200,151 @@ mod decl {
Ok(())
}
/// Safely convert usize to 4 le bytes
fn write_size(buf: &mut Vec<u8>, x: usize, vm: &VirtualMachine) -> PyResult<()> {
// For marshalling we want to convert lengths to bytes. To save space
// we limit the size to u32 to keep marshalling smaller.
let n = u32::try_from(x).map_err(|_| {
vm.new_value_error("Size exceeds 2^32 capacity for marshaling.".to_owned())
})?;
buf.extend(n.to_le_bytes());
Ok(())
}
/// Read the next 4 bytes of a slice, read as u32, pass as usize.
/// Returns the rest of buffer with the value.
fn eat_length<'a>(bytes: &'a [u8], vm: &VirtualMachine) -> PyResult<(usize, &'a [u8])> {
let (u32_bytes, rest) = bytes.split_at(4);
let length = u32::from_le_bytes(u32_bytes.try_into().map_err(|_| {
vm.new_value_error("Could not read u32 size from byte array".to_owned())
})?);
fn read_size<'a>(buf: &'a [u8], vm: &VirtualMachine) -> PyResult<(usize, &'a [u8])> {
if buf.len() < 4 {
return Err(too_short_error(vm));
}
let (u32_bytes, rest) = buf.split_at(4);
let length = u32::from_le_bytes(u32_bytes.try_into().unwrap());
Ok((length as usize, rest))
}
/// Reads next element from a python list. First by getting element size
/// then by building a pybuffer and "loading" the pyobject.
/// Returns rest of buffer with object.
fn next_element_of_list<'a>(
buf: &'a [u8],
vm: &VirtualMachine,
) -> PyResult<(PyObjectRef, &'a [u8])> {
let (element_length, element_and_rest) = eat_length(buf, vm)?;
let (element_buff, rest) = element_and_rest.split_at(element_length);
let pybuffer = PyBuffer::from_byte_vector(element_buff.to_vec(), vm);
Ok((loads(pybuffer, vm)?, rest))
}
/// Reads a list (or tuple) from a buffer.
fn read_list(buf: &[u8], vm: &VirtualMachine) -> PyResult<Vec<PyObjectRef>> {
let (expected_array_len, mut buffer) = eat_length(buf, vm)?;
fn load_seq<'b>(buf: &'b [u8], vm: &VirtualMachine) -> PyResult<(Vec<PyObjectRef>, &'b [u8])> {
let (len, mut buf) = read_size(buf, vm)?;
let mut elements: Vec<PyObjectRef> = Vec::new();
while !buffer.is_empty() {
let (element, rest_of_buffer) = next_element_of_list(buffer, vm)?;
for _ in 0..len {
let (element, rest) = load_obj(buf, vm)?;
buf = rest;
elements.push(element);
buffer = rest_of_buffer;
}
debug_assert!(expected_array_len == elements.len());
Ok(elements)
}
/// Builds a PyDict from iterator of tuple objects
pub fn from_tuples(iterable: Iter<PyObjectRef>, vm: &VirtualMachine) -> PyResult<PyDict> {
let dict = DictContentType::default();
for elem in iterable {
let items = match_class!(match elem.clone() {
pytuple @ PyTuple => pytuple.to_vec(),
_ =>
return Err(vm.new_value_error(
"Couldn't unmarshal key:value pair of dictionary".to_owned()
)),
});
// Marshalled tuples are always in format key:value.
dict.insert(vm, &**items.get(0).unwrap(), items.get(1).unwrap().clone())?;
}
Ok(PyDict::from_entries(dict))
Ok((elements, buf))
}
#[pyfunction]
fn loads(pybuffer: PyBuffer, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
let full_buff = pybuffer.as_contiguous().ok_or_else(|| {
let buf = pybuffer.as_contiguous().ok_or_else(|| {
vm.new_buffer_error("Buffer provided to marshal.loads() is not contiguous".to_owned())
})?;
let (type_indicator, buf) = full_buff.split_last().ok_or_else(|| {
vm.new_exception_msg(
vm.ctx.exceptions.eof_error.to_owned(),
"EOF where object expected.".to_owned(),
)
})?;
match *type_indicator {
BOOL_BYTE => Ok((buf[0] != 0).to_pyobject(vm)),
INT_BYTE => {
let (sign_byte, uint_bytes) = buf
.split_first()
.ok_or_else(|| vm.new_value_error("EOF where object expected.".to_owned()))?;
let sign = match sign_byte {
b'-' => Sign::Minus,
b'0' => Sign::NoSign,
b'+' => Sign::Plus,
_ => {
return Err(vm.new_value_error(
"Unknown sign byte when trying to unmarshal integer".to_owned(),
))
}
let (obj, _) = load_obj(&buf, vm)?;
Ok(obj)
}
fn load_obj<'b>(buf: &'b [u8], vm: &VirtualMachine) -> PyResult<(PyObjectRef, &'b [u8])> {
let (type_indicator, buf) = buf.split_first().ok_or_else(|| too_short_error(vm))?;
let typ = Type::try_from(*type_indicator)
.map_err(|_| vm.new_value_error("bad marshal data (unknown type code)".to_owned()))?;
let (obj, buf) = match typ {
Type::True => ((true).to_pyobject(vm), buf),
Type::False => ((false).to_pyobject(vm), buf),
Type::Int => {
if buf.len() < 4 {
return Err(too_short_error(vm));
}
let (len_bytes, buf) = buf.split_at(4);
let len = i32::from_le_bytes(len_bytes.try_into().unwrap());
let (sign, len) = if len < 0 {
(Sign::Minus, (-len) as usize)
} else {
(Sign::Plus, len as usize)
};
let pyint = BigInt::from_bytes_le(sign, uint_bytes);
Ok(pyint.to_pyobject(vm))
if buf.len() < len {
return Err(too_short_error(vm));
}
let (bytes, buf) = buf.split_at(len);
let int = BigInt::from_bytes_le(sign, bytes);
(int.to_pyobject(vm), buf)
}
FLOAT_BYTE => {
let number = f64::from_le_bytes(match buf[..].try_into() {
Ok(byte_array) => byte_array,
Err(e) => {
return Err(vm.new_value_error(format!(
"Expected float, could not load from bytes. {}",
e
)))
}
});
let pyfloat = PyFloat::from(number);
Ok(pyfloat.to_pyobject(vm))
Type::Float => {
if buf.len() < 8 {
return Err(too_short_error(vm));
}
let (bytes, buf) = buf.split_at(8);
let number = f64::from_le_bytes(bytes.try_into().unwrap());
(vm.ctx.new_float(number).into(), buf)
}
STR_BYTE => {
let pystr = PyStr::from(match AsciiStr::from_ascii(buf) {
Ok(ascii_str) => ascii_str,
Err(e) => {
return Err(
vm.new_value_error(format!("Cannot unmarshal bytes to string, {}", e))
)
}
});
Ok(pystr.to_pyobject(vm))
Type::Str => {
let (len, buf) = read_size(buf, vm)?;
if buf.len() < len {
return Err(too_short_error(vm));
}
let (bytes, buf) = buf.split_at(len);
let s = String::from_utf8(bytes.to_vec())
.map_err(|_| vm.new_value_error("invalid utf8 data".to_owned()))?;
(s.to_pyobject(vm), buf)
}
LIST_BYTE => {
let elements = read_list(buf, vm)?;
Ok(elements.to_pyobject(vm))
Type::List => {
let (elements, buf) = load_seq(buf, vm)?;
(vm.ctx.new_list(elements).into(), buf)
}
SET_BYTE => {
let elements = read_list(buf, vm)?;
Type::Set => {
let (elements, buf) = load_seq(buf, vm)?;
let set = PySet::new_ref(&vm.ctx);
for element in elements {
set.add(element, vm)?;
}
Ok(set.to_pyobject(vm))
(set.to_pyobject(vm), buf)
}
FROZEN_SET_BYTE => {
let elements = read_list(buf, vm)?;
Type::FrozenSet => {
let (elements, buf) = load_seq(buf, vm)?;
let set = PyFrozenSet::from_iter(vm, elements.into_iter())?;
Ok(set.to_pyobject(vm))
(set.to_pyobject(vm), buf)
}
TUPLE_BYTE => {
let elements = read_list(buf, vm)?;
let pytuple = PyTuple::new_ref(elements, &vm.ctx).to_pyobject(vm);
Ok(pytuple)
Type::Tuple => {
let (elements, buf) = load_seq(buf, vm)?;
(vm.ctx.new_tuple(elements).into(), buf)
}
DICT_BYTE => {
let pybuffer = PyBuffer::from_byte_vector(buf[..].to_vec(), vm);
let pydict = match_class!(match loads(pybuffer, vm)? {
pylist @ PyList => from_tuples(pylist.borrow_vec().iter(), vm)?,
_ =>
return Err(vm.new_value_error("Couldn't unmarshal dicitionary.".to_owned())),
});
Ok(pydict.to_pyobject(vm))
Type::Dict => {
let (len, mut buf) = read_size(buf, vm)?;
let dict = vm.ctx.new_dict();
for _ in 0..len {
let (key, rest) = load_obj(buf, vm)?;
let (value, rest) = load_obj(rest, vm)?;
buf = rest;
dict.set_item(key.as_object(), value, vm)?;
}
(dict.into(), buf)
}
BYTE_ARRAY => {
Type::Bytes => {
// Following CPython, after marshaling, byte arrays are converted into bytes.
let byte_array = PyBytes::from(buf[..].to_vec());
Ok(byte_array.to_pyobject(vm))
let (len, buf) = read_size(buf, vm)?;
if buf.len() < len {
return Err(too_short_error(vm));
}
let (bytes, buf) = buf.split_at(len);
(vm.ctx.new_bytes(bytes.to_vec()).into(), buf)
}
_ => {
Type::Code => {
// If prefix is not identifiable, assume CodeObject, error out if it doesn't match.
let code = bytecode::CodeObject::from_bytes(&full_buff).map_err(|e| match e {
let (len, buf) = read_size(buf, vm)?;
if buf.len() < len {
return Err(too_short_error(vm));
}
let (bytes, buf) = buf.split_at(len);
let code = bytecode::CodeObject::from_bytes(bytes).map_err(|e| match e {
bytecode::CodeDeserializeError::Eof => vm.new_exception_msg(
vm.ctx.exceptions.eof_error.to_owned(),
"End of file while deserializing bytecode".to_owned(),
),
_ => vm.new_value_error("Couldn't deserialize python bytecode".to_owned()),
})?;
Ok(vm.ctx.new_code(code).into())
(vm.ctx.new_code(code).into(), buf)
}
}
};
Ok((obj, buf))
}
#[pyfunction]