Merge pull request #1818 from palaviv/struct-fixes

Struct fixes
This commit is contained in:
Jeong YunWon
2020-03-21 01:14:21 +09:00
committed by GitHub
4 changed files with 193 additions and 62 deletions

View File

@@ -78,8 +78,6 @@ class StructTest(unittest.TestCase):
self.assertEqual(int(100 * dp), int(100 * d))
self.assertEqual(tp, t)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_new_features(self):
# Test some of the new features in detail
# (format, argument, big-endian result, little-endian result, asymmetric)
@@ -167,8 +165,6 @@ class StructTest(unittest.TestCase):
self.assertGreaterEqual(struct.calcsize('n'), struct.calcsize('i'))
self.assertGreaterEqual(struct.calcsize('n'), struct.calcsize('P'))
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_integers(self):
# Integer tests (bBhHiIlLqQnN).
import binascii
@@ -338,8 +334,6 @@ class StructTest(unittest.TestCase):
assertStructError(struct.pack, format, 0)
assertStructError(struct.unpack, format, b"")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_p_code(self):
# Test p ("Pascal string") code.
for code, input, expected, expectedback in [
@@ -391,8 +385,6 @@ class StructTest(unittest.TestCase):
big = math.ldexp(big, 127 - 24)
self.assertRaises(OverflowError, struct.pack, ">f", big)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_1530559(self):
for code, byteorder in iter_integer_formats():
format = byteorder + code
@@ -495,8 +487,6 @@ class StructTest(unittest.TestCase):
value, = struct.unpack('>I', data)
self.assertEqual(value, 0x12345678)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_bool(self):
class ExplodingBool(object):
def __bool__(self):

View File

@@ -1231,51 +1231,87 @@ where
}
}
#[derive(Debug)]
enum EscapeMode {
NORMAL,
HEX,
OCTET,
}
fn lex_byte(s: String) -> Result<Vec<u8>, LexicalErrorType> {
let mut res = vec![];
let mut escape = false; //flag if previous was \
let mut hex_on = false; // hex mode on or off
let mut hex_value = String::new();
let mut escape: Option<EscapeMode> = None;
let mut escape_buffer = String::new();
for c in s.chars() {
if hex_on {
if c.is_ascii_hexdigit() {
if hex_value.is_empty() {
hex_value.push(c);
continue;
let mut chars_iter = s.chars();
let mut next_char = chars_iter.next();
while let Some(c) = next_char {
match escape {
Some(EscapeMode::OCTET) => {
if let '0'..='7' = c {
escape_buffer.push(c);
next_char = chars_iter.next();
if escape_buffer.len() < 3 {
continue;
}
}
res.push(u8::from_str_radix(&escape_buffer, 8).unwrap());
escape = None;
escape_buffer.clear();
}
Some(EscapeMode::HEX) => {
if c.is_ascii_hexdigit() {
if escape_buffer.is_empty() {
escape_buffer.push(c);
} else {
escape_buffer.push(c);
res.push(u8::from_str_radix(&escape_buffer, 16).unwrap());
escape = None;
escape_buffer.clear();
}
next_char = chars_iter.next();
} else {
hex_value.push(c);
res.push(u8::from_str_radix(&hex_value, 16).unwrap());
hex_on = false;
hex_value.clear();
return Err(LexicalErrorType::StringError);
}
} else {
return Err(LexicalErrorType::StringError);
}
} else {
match (c, escape) {
('\\', true) => res.push(b'\\'),
('\\', false) => {
escape = true;
continue;
Some(EscapeMode::NORMAL) => {
match c {
'\\' => res.push(b'\\'),
'x' => {
escape = Some(EscapeMode::HEX);
next_char = chars_iter.next();
continue;
}
't' => res.push(b'\t'),
'n' => res.push(b'\n'),
'r' => res.push(b'\r'),
'0'..='7' => {
escape = Some(EscapeMode::OCTET);
continue;
}
x => {
res.push(b'\\');
res.push(x as u8);
}
}
('x', true) => hex_on = true,
('x', false) => res.push(b'x'),
('t', true) => res.push(b'\t'),
('t', false) => res.push(b't'),
('n', true) => res.push(b'\n'),
('n', false) => res.push(b'n'),
('r', true) => res.push(b'\r'),
('r', false) => res.push(b'r'),
(x, true) => {
res.push(b'\\');
res.push(x as u8);
}
(x, false) => res.push(x as u8),
escape = None;
next_char = chars_iter.next();
}
None => {
match c {
'\\' => escape = Some(EscapeMode::NORMAL),
x => res.push(x as u8),
}
next_char = chars_iter.next();
}
escape = false;
}
}
match escape {
Some(EscapeMode::OCTET) => res.push(u8::from_str_radix(&escape_buffer, 8).unwrap()),
Some(EscapeMode::HEX) => return Err(LexicalErrorType::StringError),
_ => (),
}
Ok(res)
}
@@ -1713,4 +1749,19 @@ mod tests {
]
)
}
#[test]
fn test_escape_octet() {
let source = r##"b'\43a\4\1234'"##;
let tokens = lex_source(source);
assert_eq!(
tokens,
vec![
Tok::Bytes {
value: b"#a\x04S4".to_vec()
},
Tok::Newline
]
)
}
}

View File

@@ -47,3 +47,27 @@ assert struct.calcsize("<L4B") == 8
assert struct.Struct('3B').pack(65, 66, 67) == bytes([65, 66, 67])
class Indexable(object):
def __init__(self, value):
self._value = value
def __index__(self):
return self._value
data = struct.pack('B', Indexable(65))
assert data == bytes([65])
data = struct.pack('5s', b"test1")
assert data == b"test1"
data = struct.pack('3s', b"test2")
assert data == b"tes"
data = struct.pack('7s', b"test3")
assert data == b"test3\0\0"
data = struct.pack('?', True)
assert data == b'\1'
data = struct.pack('?', [])
assert data == b'\0'

View File

@@ -12,14 +12,15 @@
use byteorder::{ReadBytesExt, WriteBytesExt};
use num_bigint::BigInt;
use num_traits::ToPrimitive;
use std::cmp;
use std::io::{Cursor, Read, Write};
use std::iter::Peekable;
use crate::exceptions::PyBaseExceptionRef;
use crate::function::Args;
use crate::obj::{
objbytes::PyBytesRef, objstr::PyString, objstr::PyStringRef, objtuple::PyTuple,
objtype::PyClassRef,
objbool::IntoPyBool, objbytes::PyBytesRef, objstr::PyString, objstr::PyStringRef,
objtuple::PyTuple, objtype::PyClassRef,
};
use crate::pyobject::{Either, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject};
use crate::VirtualMachine;
@@ -224,12 +225,38 @@ fn is_supported_format_character(c: char) -> bool {
}
}
fn get_int_or_index<T>(vm: &VirtualMachine, arg: &PyObjectRef) -> PyResult<T>
where
T: TryFromObject,
{
match vm.to_index(arg) {
Some(index) => Ok(T::try_from_object(vm, index?.into_object())?),
None => Err(new_struct_error(
vm,
"required argument is not an integer".to_owned(),
)),
}
}
macro_rules! make_pack_no_endianess {
($T:ty) => {
paste::item! {
fn [<pack_ $T>](vm: &VirtualMachine, arg: &PyObjectRef, data: &mut dyn Write) -> PyResult<()> {
let v = $T::try_from_object(vm, arg.clone())?;
data.[<write_$T>](v).unwrap();
data.[<write_$T>](get_int_or_index(vm, arg)?).unwrap();
Ok(())
}
}
};
}
macro_rules! make_pack_with_endianess_int {
($T:ty) => {
paste::item! {
fn [<pack_ $T>]<Endianness>(vm: &VirtualMachine, arg: &PyObjectRef, data: &mut dyn Write) -> PyResult<()>
where
Endianness: byteorder::ByteOrder,
{
data.[<write_$T>]::<Endianness>(get_int_or_index(vm, arg)?).unwrap();
Ok(())
}
}
@@ -253,17 +280,17 @@ macro_rules! make_pack_with_endianess {
make_pack_no_endianess!(i8);
make_pack_no_endianess!(u8);
make_pack_with_endianess!(i16);
make_pack_with_endianess!(u16);
make_pack_with_endianess!(i32);
make_pack_with_endianess!(u32);
make_pack_with_endianess!(i64);
make_pack_with_endianess!(u64);
make_pack_with_endianess_int!(i16);
make_pack_with_endianess_int!(u16);
make_pack_with_endianess_int!(i32);
make_pack_with_endianess_int!(u32);
make_pack_with_endianess_int!(i64);
make_pack_with_endianess_int!(u64);
make_pack_with_endianess!(f32);
make_pack_with_endianess!(f64);
fn pack_bool(vm: &VirtualMachine, arg: &PyObjectRef, data: &mut dyn Write) -> PyResult<()> {
let v = if bool::try_from_object(vm, arg.clone())? {
let v = if IntoPyBool::try_from_object(vm, arg.clone())?.to_bool() {
1
} else {
0
@@ -280,7 +307,7 @@ fn pack_isize<Endianness>(
where
Endianness: byteorder::ByteOrder,
{
let v = isize::try_from_object(vm, arg.clone())?;
let v: isize = get_int_or_index(vm, arg)?;
match std::mem::size_of::<isize>() {
8 => data.write_i64::<Endianness>(v as i64).unwrap(),
4 => data.write_i32::<Endianness>(v as i32).unwrap(),
@@ -297,7 +324,7 @@ fn pack_usize<Endianness>(
where
Endianness: byteorder::ByteOrder,
{
let v = isize::try_from_object(vm, arg.clone())?;
let v: usize = get_int_or_index(vm, arg)?;
match std::mem::size_of::<usize>() {
8 => data.write_u64::<Endianness>(v as u64).unwrap(),
4 => data.write_u32::<Endianness>(v as u32).unwrap(),
@@ -312,8 +339,29 @@ fn pack_string(
data: &mut dyn Write,
length: usize,
) -> PyResult<()> {
let v = PyBytesRef::try_from_object(vm, arg.clone())?;
match data.write_all(&v.get_value()[..length]) {
let mut v = PyBytesRef::try_from_object(vm, arg.clone())?
.get_value()
.to_vec();
v.resize(length, 0);
match data.write_all(&v) {
Ok(_) => Ok(()),
Err(e) => Err(new_struct_error(vm, format!("{:?}", e))),
}
}
fn pack_pascal(
vm: &VirtualMachine,
arg: &PyObjectRef,
data: &mut dyn Write,
length: usize,
) -> PyResult<()> {
let mut v = PyBytesRef::try_from_object(vm, arg.clone())?
.get_value()
.to_vec();
let string_length = cmp::min(cmp::min(v.len(), 255), length - 1);
data.write_u8(string_length as u8).unwrap();
v.resize(length - 1, 0);
match data.write_all(&v) {
Ok(_) => Ok(()),
Err(e) => Err(new_struct_error(vm, format!("{:?}", e))),
}
@@ -356,10 +404,14 @@ where
'N' | 'P' => pack_usize::<Endianness>,
'f' => pack_f32::<Endianness>,
'd' => pack_f64::<Endianness>,
's' | 'p' => {
's' => {
pack_string(vm, &args[0], data, code.repeat as usize)?;
return Ok(1);
}
'p' => {
pack_pascal(vm, &args[0], data, code.repeat as usize)?;
return Ok(1);
}
'x' => {
for _ in 0..code.repeat as usize {
data.write_u8(0).unwrap();
@@ -528,6 +580,16 @@ fn unpack_string(vm: &VirtualMachine, rdr: &mut dyn Read, length: u32) -> PyResu
Ok(vm.ctx.new_bytes(buf))
}
fn unpack_pascal(vm: &VirtualMachine, rdr: &mut dyn Read, length: u32) -> PyResult {
let mut handle = rdr.take(length as u64);
let mut buf: Vec<u8> = Vec::new();
handle.read_to_end(&mut buf).map_err(|_| {
new_struct_error(vm, format!("unpack requires a buffer of {} bytes", length,))
})?;
let string_length = buf[0] as usize;
Ok(vm.ctx.new_bytes(buf[1..=string_length].to_vec()))
}
fn struct_unpack(fmt: PyStringRef, buffer: PyBytesRef, vm: &VirtualMachine) -> PyResult<PyTuple> {
let fmt_str = fmt.as_str();
let format_spec = FormatSpec::parse(fmt_str).map_err(|e| new_struct_error(vm, e))?;
@@ -563,10 +625,14 @@ where
unpack_empty(vm, rdr, code.repeat);
return Ok(());
}
's' | 'p' => {
's' => {
items.push(unpack_string(vm, rdr, code.repeat)?);
return Ok(());
}
'p' => {
items.push(unpack_pascal(vm, rdr, code.repeat)?);
return Ok(());
}
c => {
panic!("Unsupported format code {:?}", c);
}