Fix TextIOWrapper.reconfigure coder

This commit is contained in:
Jeong YunWon
2024-07-28 18:50:49 +09:00
committed by Jeong, YunWon
parent aa5eba9723
commit 87b84a83cc

View File

@@ -2229,32 +2229,7 @@ mod _io {
let has_read1 = vm.get_attribute_opt(buffer.clone(), "read1")?.is_some();
let seekable = vm.call_method(&buffer, "seekable", ())?.try_to_bool(vm)?;
let codec = vm.state.codec_registry.lookup(encoding.as_str(), vm)?;
let encoder = if vm.call_method(&buffer, "writable", ())?.try_to_bool(vm)? {
let incremental_encoder =
codec.get_incremental_encoder(Some(errors.clone()), vm)?;
let encoding_name = vm.get_attribute_opt(incremental_encoder.clone(), "name")?;
let encodefunc = encoding_name.and_then(|name| {
let name = name.payload::<PyStr>()?;
match name.as_str() {
"utf-8" => Some(textio_encode_utf8 as EncodeFunc),
_ => None,
}
});
Some((incremental_encoder, encodefunc))
} else {
None
};
let decoder = if vm.call_method(&buffer, "readable", ())?.try_to_bool(vm)? {
let incremental_decoder =
codec.get_incremental_decoder(Some(errors.clone()), vm)?;
// TODO: wrap in IncrementalNewlineDecoder if newlines == Universal | Passthrough
Some(incremental_decoder)
} else {
None
};
let (encoder, decoder) = Self::find_coder(&buffer, encoding.as_str(), &errors, vm)?;
*data = Some(TextIOData {
buffer,
@@ -2296,16 +2271,59 @@ mod _io {
PyThreadMutexGuard::try_map(lock, |x| x.as_mut())
.map_err(|_| vm.new_value_error("I/O operation on uninitialized object".to_owned()))
}
#[allow(clippy::type_complexity)]
fn find_coder(
buffer: &PyObject,
encoding: &str,
errors: &Py<PyStr>,
vm: &VirtualMachine,
) -> PyResult<(
Option<(PyObjectRef, Option<EncodeFunc>)>,
Option<PyObjectRef>,
)> {
let codec = vm.state.codec_registry.lookup(encoding, vm)?;
let encoder = if vm.call_method(buffer, "writable", ())?.try_to_bool(vm)? {
let incremental_encoder =
codec.get_incremental_encoder(Some(errors.to_owned()), vm)?;
let encoding_name = vm.get_attribute_opt(incremental_encoder.clone(), "name")?;
let encodefunc = encoding_name.and_then(|name| {
let name = name.payload::<PyStr>()?;
match name.as_str() {
"utf-8" => Some(textio_encode_utf8 as EncodeFunc),
_ => None,
}
});
Some((incremental_encoder, encodefunc))
} else {
None
};
let decoder = if vm.call_method(buffer, "readable", ())?.try_to_bool(vm)? {
let incremental_decoder =
codec.get_incremental_decoder(Some(errors.to_owned()), vm)?;
// TODO: wrap in IncrementalNewlineDecoder if newlines == Universal | Passthrough
Some(incremental_decoder)
} else {
None
};
Ok((encoder, decoder))
}
}
#[pyclass(with(Constructor, Initializer), flags(BASETYPE))]
impl TextIOWrapper {
#[pymethod]
fn reconfigure(&self, args: TextIOWrapperArgs) {
fn reconfigure(&self, args: TextIOWrapperArgs, vm: &VirtualMachine) -> PyResult<()> {
let mut data = self.data.lock().unwrap();
if let Some(data) = data.as_mut() {
if let Some(encoding) = args.encoding {
let (encoder, decoder) =
Self::find_coder(&data.buffer, encoding.as_str(), &data.errors, vm)?;
data.encoding = encoding;
data.encoder = encoder;
data.decoder = decoder;
}
if let Some(errors) = args.errors {
data.errors = errors;
@@ -2320,6 +2338,7 @@ mod _io {
data.write_through = write_through;
}
}
Ok(())
}
#[pymethod]
fn seekable(&self, vm: &VirtualMachine) -> PyResult {