Reimplement bytes fromhex

This commit is contained in:
Kangzhi Shi
2020-10-11 15:14:25 +02:00
parent b2c0a69386
commit 21a9e05abc
2 changed files with 40 additions and 27 deletions

View File

@@ -377,8 +377,6 @@ class BaseBytesTest:
self.assertNotIn(f(b"dab"), b)
self.assertNotIn(f(b"abd"), b)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_fromhex(self):
self.assertRaises(TypeError, self.type2test.fromhex)
self.assertRaises(TypeError, self.type2test.fromhex, 1)

View File

@@ -557,35 +557,46 @@ impl PyBytesInner {
}
pub fn fromhex(string: &str, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
// first check for invalid character
for (i, c) in string.char_indices() {
if !c.is_digit(16) && !c.is_whitespace() {
return Err(vm.new_value_error(format!(
"non-hexadecimal number found in fromhex() arg at position {}",
i
)));
let mut iter = string.bytes().enumerate();
let mut bytes: Vec<u8> = Vec::with_capacity(string.len() / 2);
let i = loop {
let (i, b) = match iter.next() {
Some(val) => val,
None => {
return Ok(bytes);
}
};
if is_py_ascii_whitespace(b) {
continue;
}
}
// strip white spaces
let stripped = string.split_whitespace().collect::<String>();
let top = match b {
b'0'..=b'9' => b - b'0',
b'a'..=b'f' => 10 + b - b'a',
b'A'..=b'F' => 10 + b - b'A',
_ => break i,
};
// Hex is evaluated on 2 digits
if stripped.len() % 2 != 0 {
return Err(vm.new_value_error(format!(
"non-hexadecimal number found in fromhex() arg at position {}",
stripped.len() - 1
)));
}
let (i, b) = match iter.next() {
Some(val) => val,
None => break i,
};
// parse even string
Ok(stripped
.chars()
.collect::<Vec<char>>()
.chunks(2)
.map(|x| x.to_vec().iter().collect::<String>())
.map(|x| u8::from_str_radix(&x, 16).unwrap())
.collect::<Vec<u8>>())
let bot = match b {
b'0'..=b'9' => b - b'0',
b'a'..=b'f' => 10 + b - b'a',
b'A'..=b'F' => 10 + b - b'A',
_ => break i,
};
bytes.push((top << 4) + bot);
};
Err(vm.new_value_error(format!(
"non-hexadecimal number found in fromhex() arg at position {}",
i
)))
}
#[inline]
@@ -1330,3 +1341,7 @@ pub fn bytes_to_hex(
Ok(hex_impl_no_sep(bytes))
}
}
const fn is_py_ascii_whitespace(b: u8) -> bool {
matches!(b, b'\t' | b'\n' | b'\x0C' | b'\r' | b' ' | b'\x0B')
}