diff --git a/vm/src/codecs.rs b/vm/src/codecs.rs index 0bf906d42..e6ad84e10 100644 --- a/vm/src/codecs.rs +++ b/vm/src/codecs.rs @@ -163,6 +163,10 @@ impl CodecsRegistry { "namereplace", ctx.new_function("namereplace_errors", namereplace_errors), ), + ( + "surrogatepass", + ctx.new_function("surrogatepass_errors", surrogatepass_errors), + ), ]; let errors = std::array::IntoIter::new(errors) .map(|(name, f)| (name.to_owned(), f)) @@ -459,3 +463,184 @@ fn namereplace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(String Err(bad_err_type(err, vm)) } } + +enum StandardEncoding { + Utf8, + Utf16Be, + Utf16Le, + Utf32Be, + Utf32Le, + Unknown, +} + +fn get_standard_encoding(encoding: &str) -> (usize, StandardEncoding) { + if let Some(encoding) = encoding.to_lowercase().strip_prefix("utf") { + let mut byte_length: usize = 0; + let mut standard_encoding = StandardEncoding::Unknown; + let encoding = match encoding.strip_prefix(|c| c == '-' || c == '_') { + Some(x) => x, + None => encoding, + }; + if encoding == "8" { + byte_length = 3; + standard_encoding = StandardEncoding::Utf8; + } else if let Some(encoding) = encoding.strip_prefix("16") { + byte_length = 2; + if encoding.is_empty() { + if cfg!(target_endian = "little") { + standard_encoding = StandardEncoding::Utf16Le; + } else if cfg!(target_endian = "big") { + standard_encoding = StandardEncoding::Utf16Be; + } + match standard_encoding { + StandardEncoding::Unknown => (), + _ => return (byte_length, standard_encoding), + } + } + let encoding = match encoding.strip_prefix(|c| c == '-' || c == '_') { + Some(x) => x, + None => encoding, + }; + standard_encoding = match encoding { + "be" => StandardEncoding::Utf16Be, + "le" => StandardEncoding::Utf16Le, + _ => StandardEncoding::Unknown, + } + } else if let Some(encoding) = encoding.strip_prefix("32") { + byte_length = 4; + if encoding.is_empty() { + if cfg!(target_endian = "little") { + standard_encoding = StandardEncoding::Utf32Le; + } else if cfg!(target_endian = "big") { + standard_encoding = StandardEncoding::Utf32Be; + } + match standard_encoding { + StandardEncoding::Unknown => (), + _ => return (byte_length, standard_encoding), + } + } + let encoding = match encoding.strip_prefix(|c| c == '-' || c == '_') { + Some(x) => x, + None => encoding, + }; + standard_encoding = match encoding { + "be" => StandardEncoding::Utf32Be, + "le" => StandardEncoding::Utf32Le, + _ => StandardEncoding::Unknown, + } + } + return (byte_length, standard_encoding); + } else if encoding.to_lowercase() == "CP_UTF8" { + return (3, StandardEncoding::Utf8); + } + (0, StandardEncoding::Unknown) +} + +fn surrogatepass_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(String, usize)> { + if err.isinstance(&vm.ctx.exceptions.unicode_encode_error) { + let range = extract_unicode_error_range(&err, vm)?; + let s = PyStrRef::try_from_object(vm, vm.get_attribute(err.clone(), "object")?)?; + let s_encoding = PyStrRef::try_from_object(vm, vm.get_attribute(err.clone(), "encoding")?)?; + let (_, standard_encoding) = get_standard_encoding(s_encoding.as_str()); + if let StandardEncoding::Unknown = standard_encoding { + // Not supported, fail with original exception + return Err(err.downcast().unwrap()); + } + let s_after_start = + crate::common::str::try_get_chars(s.as_str(), range.start..).unwrap_or(""); + let num_chars = range.len(); + let mut out = String::with_capacity(num_chars * 4); + for c in s_after_start.chars().take(num_chars).map(|x| x as u32) { + use std::fmt::Write; + if !(0xd800..=0xdfff).contains(&c) { + // Not a surrogate, fail with original exception + return Err(err.downcast().unwrap()); + } + match standard_encoding { + StandardEncoding::Utf8 => { + write!(out, "\\x{:x?}", (0xe0 | (c >> 12))).unwrap(); + write!(out, "\\x{:x?}", (0x80 | ((c >> 6) & 0x3f))).unwrap(); + write!(out, "\\x{:x?}", (0x80 | (c & 0x3f))).unwrap(); + } + StandardEncoding::Utf16Le => { + write!(out, "\\x{:x?}", c).unwrap(); + write!(out, "\\x{:x?}", (c >> 8)).unwrap(); + } + StandardEncoding::Utf16Be => { + write!(out, "\\x{:x?}", (c >> 8)).unwrap(); + write!(out, "\\x{:x?}", c).unwrap(); + } + StandardEncoding::Utf32Le => { + write!(out, "\\x{:x?}", c).unwrap(); + write!(out, "\\x{:x?}", (c >> 8)).unwrap(); + write!(out, "\\x{:x?}", (c >> 16)).unwrap(); + write!(out, "\\x{:x?}", (c >> 24)).unwrap(); + } + StandardEncoding::Utf32Be => { + write!(out, "\\x{:x?}", (c >> 24)).unwrap(); + write!(out, "\\x{:x?}", (c >> 16)).unwrap(); + write!(out, "\\x{:x?}", (c >> 8)).unwrap(); + write!(out, "\\x{:x?}", c).unwrap(); + } + StandardEncoding::Unknown => unreachable!(), // NOTE: RUSTPYTHON, should've bailed out earlier + } + } + Ok((out, range.end)) + } else if is_decode_err(&err, vm) { + let range = extract_unicode_error_range(&err, vm)?; + let s = PyStrRef::try_from_object(vm, vm.get_attribute(err.clone(), "object")?)?; + let s_encoding = PyStrRef::try_from_object(vm, vm.get_attribute(err.clone(), "encoding")?)?; + let (byte_length, standard_encoding) = get_standard_encoding(s_encoding.as_str()); + if let StandardEncoding::Unknown = standard_encoding { + // Not supported, fail with original exception + return Err(err.downcast().unwrap()); + } + let mut c: u32 = 0; + // Try decoding a single surrogate character. If there are more, + // let the codec call us again. + let s_after_start = crate::common::str::try_get_chars(s.as_str(), range.start..) + .unwrap_or("") + .as_bytes(); + if s_after_start.len() - range.start >= byte_length { + match standard_encoding { + StandardEncoding::Utf8 => { + if (s_after_start[0] as u32 & 0xf0) == 0xe0 + && (s_after_start[1] as u32 & 0xc0) == 0x80 + && (s_after_start[2] as u32 & 0xc0) == 0x80 + { + // it's a three-byte code + c = ((s_after_start[0] as u32 & 0x0f) << 12) + + ((s_after_start[1] as u32 & 0x3f) << 6) + + (s_after_start[2] as u32 & 0x3f); + } + } + StandardEncoding::Utf16Le => { + c = (s_after_start[1] as u32) << 8 | s_after_start[0] as u32; + } + StandardEncoding::Utf16Be => { + c = (s_after_start[0] as u32) << 8 | s_after_start[1] as u32; + } + StandardEncoding::Utf32Le => { + c = ((s_after_start[3] as u32) << 24) + | ((s_after_start[2] as u32) << 16) + | ((s_after_start[1] as u32) << 8) + | s_after_start[0] as u32; + } + StandardEncoding::Utf32Be => { + c = ((s_after_start[0] as u32) << 24) + | ((s_after_start[1] as u32) << 16) + | ((s_after_start[2] as u32) << 8) + | s_after_start[3] as u32; + } + StandardEncoding::Unknown => unreachable!(), // NOTE: RUSTPYTHON, should've bailed out earlier + } + } + if !(0xd800..=0xdfff).contains(&c) { + // Not a surrogate, fail with original exception + return Err(err.downcast().unwrap()); + } + Ok((format!("\\x{:x?}", c), range.start + byte_length)) + } else { + Err(bad_err_type(err, vm)) + } +}