Merge pull request #2409 from RustPython/coolreader18/fix-http

Fix BufferedReader over a SocketIO to properly read the right amount
This commit is contained in:
Noah
2021-02-03 11:29:01 -06:00
committed by GitHub
8 changed files with 331 additions and 126 deletions

4
Cargo.lock generated
View File

@@ -793,9 +793,9 @@ dependencies = [
[[package]]
name = "flate2"
version = "1.0.19"
version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7411863d55df97a419aa64cb4d2f167103ea9d767e2c54a1868b7ac3f6b47129"
checksum = "cd3aec53de10fe96d7d8c565eb17f2c687bb5518a2ec453b5b1252964526abe0"
dependencies = [
"cfg-if 1.0.0",
"crc32fast",

View File

@@ -129,7 +129,7 @@ num_cpus = "1"
[target.'cfg(not(any(target_arch = "wasm32", target_os = "redox")))'.dependencies]
dns-lookup = "1.0"
flate2 = { version = "1.0", features = ["zlib"], default-features = false }
flate2 = { version = "1.0.20", features = ["zlib"], default-features = false }
libz-sys = "1.0"
[target.'cfg(windows)'.dependencies]

View File

@@ -95,7 +95,13 @@ impl Buffer for RcBuffer {
pub trait Buffer: Debug + PyThreadingConstraint {
fn get_options(&self) -> &BufferOptions;
/// Get the full inner buffer of this memory. You probably want [`as_contiguous()`], as
/// `obj_bytes` doesn't take into account the range a memoryview might operate on, among other
/// footguns.
fn obj_bytes(&self) -> BorrowedValue<[u8]>;
/// Get the full inner buffer of this memory, mutably. You probably want
/// [`as_contiguous_mut()`], as `obj_bytes` doesn't take into account the range a memoryview
/// might operate on, among other footguns.
fn obj_bytes_mut(&self) -> BorrowedValueMut<[u8]>;
fn release(&self);

View File

@@ -30,25 +30,6 @@ impl Write for PyWriter<'_> {
}
}
pub fn write_all(
mut buf: &[u8],
mut write: impl FnMut(&[u8]) -> io::Result<usize>,
) -> io::Result<()> {
while !buf.is_empty() {
match write(buf) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write whole buffer",
))
}
Ok(n) => buf = &buf[n..],
Err(e) => return Err(e),
}
}
Ok(())
}
pub fn file_readline(obj: &PyObjectRef, size: Option<usize>, vm: &VirtualMachine) -> PyResult {
let args = size.map_or_else(Vec::new, |size| vec![vm.ctx.new_int(size)]);
let ret = vm.call_method(obj, "readline", args)?;

View File

@@ -903,7 +903,7 @@ mod _io {
// raw file is non-blocking
if remaining > self.buffer.len() {
// can't buffer everything, buffer what we can and error
let buf = rcbuf.obj_bytes();
let buf = rcbuf.as_contiguous().unwrap();
let buffer_len = self.buffer.len();
self.buffer.copy_from_slice(&buf[written..][..buffer_len]);
self.raw_pos = 0;
@@ -926,7 +926,7 @@ mod _io {
self.reset_read();
}
if remaining > 0 {
let buf = rcbuf.obj_bytes();
let buf = rcbuf.as_contiguous().unwrap();
self.buffer[..remaining].copy_from_slice(&buf[written..][..remaining]);
written += remaining;
}
@@ -1025,13 +1025,13 @@ mod _io {
0
};
let buf_end = self.buffer.len();
let res = self.raw_read(Either::A(None), start..buf_end, vm);
if let Ok(Some(n)) = &res {
let new_start = (start + *n) as Offset;
let res = self.raw_read(Either::A(None), start..buf_end, vm)?;
if let Some(n) = res.filter(|n| *n > 0) {
let new_start = (start + n) as Offset;
self.read_end = new_start;
self.raw_pos = new_start;
}
res
Ok(res)
}
fn raw_read(
@@ -1094,7 +1094,7 @@ mod _io {
n, len
)));
}
if self.abs_pos != -1 {
if n > 0 && self.abs_pos != -1 {
self.abs_pos += n as Offset
}
Ok(Some(n as usize))
@@ -1194,10 +1194,10 @@ mod _io {
let n = self.readahead();
let buf_len;
{
let mut b = buf.obj_bytes_mut();
let mut b = buf.as_contiguous_mut().unwrap();
buf_len = b.len();
if n > 0 {
if n as usize > b.len() {
if n as usize >= b.len() {
b.copy_from_slice(&self.buffer[self.pos as usize..][..buf_len]);
self.pos += buf_len as Offset;
return Ok(Some(buf_len));
@@ -1224,7 +1224,7 @@ mod _io {
let n = self.fill_buffer(vm)?;
if let Some(n) = n.filter(|&n| n > 0) {
let n = std::cmp::min(n, remaining);
rcbuf.obj_bytes_mut()[written..][..n]
rcbuf.as_contiguous_mut().unwrap()[written..][..n]
.copy_from_slice(&self.buffer[self.pos as usize..][..n]);
self.pos += n as Offset;
written += n;
@@ -1233,7 +1233,7 @@ mod _io {
}
n
} else {
Some(0)
break;
};
let n = match n {
Some(0) => break,

View File

@@ -1,5 +1,6 @@
use crate::pyobject::{PyObjectRef, PyResult, TryFromObject};
use crate::vm::VirtualMachine;
use std::io;
pub(crate) fn make_module(vm: &VirtualMachine) -> PyObjectRef {
#[cfg(windows)]
@@ -12,18 +13,18 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyObjectRef {
#[cfg(unix)]
mod platform {
pub(super) use libc::{fd_set, select, timeval, FD_ISSET, FD_SET, FD_SETSIZE, FD_ZERO};
pub(super) use std::os::unix::io::RawFd;
pub use libc::{fd_set, select, timeval, FD_ISSET, FD_SET, FD_SETSIZE, FD_ZERO};
pub use std::os::unix::io::RawFd;
}
#[allow(non_snake_case)]
#[cfg(windows)]
mod platform {
pub(super) use winapi::um::winsock2::{fd_set, select, timeval, FD_SETSIZE, SOCKET as RawFd};
pub use winapi::um::winsock2::{fd_set, select, timeval, FD_SETSIZE, SOCKET as RawFd};
// from winsock2.h: https://gist.github.com/piscisaureus/906386#file-winsock2-h-L128-L141
pub(super) unsafe fn FD_SET(fd: RawFd, set: *mut fd_set) {
pub unsafe fn FD_SET(fd: RawFd, set: *mut fd_set) {
let mut i = 0;
for idx in 0..(*set).fd_count as usize {
i = idx;
@@ -39,17 +40,18 @@ mod platform {
}
}
pub(super) unsafe fn FD_ZERO(set: *mut fd_set) {
pub unsafe fn FD_ZERO(set: *mut fd_set) {
(*set).fd_count = 0;
}
pub(super) unsafe fn FD_ISSET(fd: RawFd, set: *mut fd_set) -> bool {
pub unsafe fn FD_ISSET(fd: RawFd, set: *mut fd_set) -> bool {
use winapi::um::winsock2::__WSAFDIsSet;
__WSAFDIsSet(fd as _, set) != 0
}
}
use platform::{timeval, RawFd};
pub use platform::timeval;
use platform::RawFd;
struct Selectable {
obj: PyObjectRef,
@@ -68,8 +70,8 @@ impl TryFromObject for Selectable {
}
}
#[repr(C)]
struct FdSet(platform::fd_set);
#[repr(transparent)]
pub struct FdSet(platform::fd_set);
impl FdSet {
pub fn new() -> FdSet {
@@ -103,6 +105,33 @@ impl FdSet {
}
}
pub fn select(
nfds: libc::c_int,
readfds: &mut FdSet,
writefds: &mut FdSet,
errfds: &mut FdSet,
timeout: Option<&mut timeval>,
) -> io::Result<i32> {
let timeout = match timeout {
Some(tv) => tv as *mut timeval,
None => std::ptr::null_mut(),
};
let ret = unsafe {
platform::select(
nfds,
&mut readfds.0,
&mut writefds.0,
&mut errfds.0,
timeout,
)
};
if ret < 0 {
Err(io::Error::last_os_error())
} else {
Ok(ret)
}
}
fn sec_to_timeval(sec: f64) -> timeval {
timeval {
tv_sec: sec.trunc() as _,
@@ -162,19 +191,14 @@ mod decl {
.max()
.map_or(0, |n| n + 1) as i32;
let (select_res, err) = loop {
loop {
let mut tv = timeout.map(sec_to_timeval);
let timeout_ptr = match tv {
Some(ref mut tv) => tv as *mut _,
None => std::ptr::null_mut(),
};
let res =
unsafe { super::platform::select(nfds, &mut r.0, &mut w.0, &mut x.0, timeout_ptr) };
let res = super::select(nfds, &mut r, &mut w, &mut x, tv.as_mut());
let err = std::io::Error::last_os_error();
if res >= 0 || err.kind() != std::io::ErrorKind::Interrupted {
break (res, err);
match res {
Ok(_) => break,
Err(err) if err.kind() == io::ErrorKind::Interrupted => {}
Err(err) => return Err(err.into_pyexception(vm)),
}
vm.check_signals()?;
@@ -185,14 +209,10 @@ mod decl {
r.clear();
w.clear();
x.clear();
break (0, err);
break;
}
// retry select() if we haven't reached the deadline yet
}
};
if select_res < 0 {
return Err(err.into_pyexception(vm));
}
let set2list = |list: Vec<Selectable>, mut set: FdSet| {

View File

@@ -6,7 +6,7 @@ use socket2::{Domain, Protocol, Socket, Type as SocketType};
use std::convert::TryFrom;
use std::io;
use std::net::{Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr, ToSocketAddrs};
use std::time::Duration;
use std::time::{Duration, Instant};
use crate::builtins::bytes::PyBytesRef;
use crate::builtins::pystr::{PyStr, PyStrRef};
@@ -20,7 +20,7 @@ use crate::pyobject::{
BorrowValue, Either, IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue,
StaticType, TryFromObject,
};
use crate::{py_io, VirtualMachine};
use crate::VirtualMachine;
#[cfg(unix)]
type RawSocket = std::os::unix::io::RawFd;
@@ -59,6 +59,7 @@ pub struct PySocket {
kind: AtomicCell<i32>,
family: AtomicCell<i32>,
proto: AtomicCell<i32>,
timeout: AtomicCell<f64>,
sock: PyRwLock<Socket>,
}
@@ -86,6 +87,7 @@ impl PySocket {
kind: AtomicCell::default(),
family: AtomicCell::default(),
proto: AtomicCell::default(),
timeout: AtomicCell::new(-1.0),
sock: PyRwLock::new(invalid_sock()),
}
.into_ref_with_type(vm, cls)
@@ -128,15 +130,94 @@ impl PySocket {
Ok(())
}
fn sock_op<F, R>(&self, vm: &VirtualMachine, select: SelectKind, f: F) -> PyResult<R>
where
F: FnMut() -> io::Result<R>,
{
let timeout = self.timeout.load();
let timeout = if timeout > 0.0 {
Some(Duration::from_secs_f64(timeout))
} else {
None
};
self.sock_op_timeout(vm, select, timeout, f)
}
fn sock_op_timeout<F, R>(
&self,
vm: &VirtualMachine,
select: SelectKind,
timeout: Option<Duration>,
mut f: F,
) -> PyResult<R>
where
F: FnMut() -> io::Result<R>,
{
let deadline = timeout.map(Deadline::new);
loop {
if deadline.is_some() || matches!(select, SelectKind::Connect) {
let interval = deadline.as_ref().map(|d| d.time_until(vm)).transpose()?;
let res = sock_select(&self.sock(), select, interval);
match res {
Ok(true) => return Err(timeout_error(vm)),
Err(e) if e.kind() == io::ErrorKind::Interrupted => {
vm.check_signals()?;
continue;
}
Err(e) => return Err(convert_sock_error(vm, e)),
Ok(false) => {} // no timeout, continue as normal
}
}
let err = loop {
// loop on interrupt
match f() {
Ok(x) => return Ok(x),
Err(e) if e.kind() == io::ErrorKind::Interrupted => vm.check_signals()?,
Err(e) => break e,
}
};
if timeout.is_some() && err.kind() == io::ErrorKind::WouldBlock {
continue;
}
return Err(convert_sock_error(vm, err));
}
}
#[pymethod]
fn connect(&self, address: Address, vm: &VirtualMachine) -> PyResult<()> {
let sock_addr = get_addr(vm, address, Some(self.family.load()))?;
let res = if let Some(duration) = self.sock().read_timeout().unwrap() {
self.sock().connect_timeout(&sock_addr, duration)
} else {
self.sock().connect(&sock_addr)
let err = match self.sock().connect(&sock_addr) {
Ok(()) => return Ok(()),
Err(e) => e,
};
res.map_err(|err| convert_sock_error(vm, err))
let wait_connect = if err.kind() == io::ErrorKind::Interrupted {
vm.check_signals()?;
self.timeout.load() != 0.0
} else {
self.timeout.load() > 0.0 && err.raw_os_error() == Some(libc::EINPROGRESS)
};
if wait_connect {
// basically, connect() is async, and it registers an "error" on the socket when it's
// done connecting. SelectKind::Connect fills the errorfds fd_set, so if we wake up
// from poll and the error is EISCONN then we know that the connect is done
self.sock_op(vm, SelectKind::Connect, || {
let sock = self.sock();
let err = sock.take_error()?;
match err {
Some(e) if e.raw_os_error() == Some(libc::EISCONN) => Ok(()),
Some(e) => Err(e),
// TODO: is this accurate?
None => Ok(()),
}
})
} else {
Err(convert_sock_error(vm, err))
}
}
#[pymethod]
@@ -158,11 +239,7 @@ impl PySocket {
#[pymethod]
fn _accept(&self, vm: &VirtualMachine) -> PyResult<(RawSocket, AddrTuple)> {
let (sock, addr) = self
.sock()
.accept()
.map_err(|err| convert_sock_error(vm, err))?;
let (sock, addr) = self.sock_op(vm, SelectKind::Read, || self.sock().accept())?;
let fd = into_sock_fileno(sock);
Ok((fd, get_addr_tuple(addr)))
}
@@ -176,10 +253,10 @@ impl PySocket {
) -> PyResult<Vec<u8>> {
let flags = flags.unwrap_or(0);
let mut buffer = vec![0u8; bufsize];
let n = self
.sock()
.recv_with_flags(&mut buffer, flags)
.map_err(|err| convert_sock_error(vm, err))?;
let sock = self.sock();
let n = self.sock_op(vm, SelectKind::Read, || {
sock.recv_with_flags(&mut buffer, flags)
})?;
buffer.truncate(n);
Ok(buffer)
}
@@ -192,8 +269,10 @@ impl PySocket {
vm: &VirtualMachine,
) -> PyResult<usize> {
let flags = flags.unwrap_or(0);
buf.with_ref(|buf| self.sock().recv_with_flags(buf, flags))
.map_err(|err| convert_sock_error(vm, err))
let sock = self.sock();
self.sock_op(vm, SelectKind::Read, || {
buf.with_ref(|buf| sock.recv_with_flags(buf, flags))
})
}
#[pymethod]
@@ -205,10 +284,9 @@ impl PySocket {
) -> PyResult<(Vec<u8>, AddrTuple)> {
let flags = flags.unwrap_or(0);
let mut buffer = vec![0u8; bufsize];
let (n, addr) = self
.sock()
.recv_from_with_flags(&mut buffer, flags)
.map_err(|err| convert_sock_error(vm, err))?;
let (n, addr) = self.sock_op(vm, SelectKind::Read, || {
self.sock().recv_from_with_flags(&mut buffer, flags)
})?;
buffer.truncate(n);
Ok((buffer, get_addr_tuple(addr)))
}
@@ -221,9 +299,9 @@ impl PySocket {
vm: &VirtualMachine,
) -> PyResult<usize> {
let flags = flags.unwrap_or(0);
bytes
.with_ref(|b| self.sock().send_with_flags(b, flags))
.map_err(|err| convert_sock_error(vm, err))
self.sock_op(vm, SelectKind::Write, || {
bytes.with_ref(|b| self.sock().send_with_flags(b, flags))
})
}
#[pymethod]
@@ -234,10 +312,31 @@ impl PySocket {
vm: &VirtualMachine,
) -> PyResult<()> {
let flags = flags.unwrap_or(0);
let sock = self.sock();
bytes
.with_ref(|buf| py_io::write_all(buf, |b| sock.send_with_flags(b, flags)))
.map_err(|err| convert_sock_error(vm, err))
let timeout = self.timeout.load();
let timeout = if timeout > 0.0 {
Some(Duration::from_secs_f64(timeout))
} else {
None
};
let deadline = timeout.map(Deadline::new);
let buf_len = bytes.len();
let mut buf_offset = 0;
// now we have like 3 layers of interrupt loop :)
while buf_offset < buf_len {
let interval = deadline.as_ref().map(|d| d.time_until(vm)).transpose()?;
self.sock_op_timeout(vm, SelectKind::Write, interval, || {
bytes.with_ref(|b| {
let subbuf = &b[buf_offset..];
buf_offset += self.sock().send_with_flags(subbuf, flags)?;
Ok(())
})
})?;
vm.check_signals()?;
}
Ok(())
}
#[pymethod]
@@ -247,13 +346,12 @@ impl PySocket {
address: Address,
flags: OptionalArg<i32>,
vm: &VirtualMachine,
) -> PyResult<()> {
) -> PyResult<usize> {
let flags = flags.unwrap_or(0);
let addr = get_addr(vm, address, Some(self.family.load()))?;
bytes
.with_ref(|b| self.sock().send_to_with_flags(b, &addr, flags))
.map_err(|err| convert_sock_error(vm, err))?;
Ok(())
self.sock_op(vm, SelectKind::Write, || {
bytes.with_ref(|b| self.sock().send_to_with_flags(b, &addr, flags))
})
}
#[pymethod]
@@ -290,24 +388,26 @@ impl PySocket {
}
#[pymethod]
fn gettimeout(&self, vm: &VirtualMachine) -> PyResult<Option<f64>> {
let dur = self
.sock()
.read_timeout()
.map_err(|err| convert_sock_error(vm, err))?;
Ok(dur.map(|d| d.as_secs_f64()))
fn gettimeout(&self) -> Option<f64> {
let timeout = self.timeout.load();
if timeout >= 0.0 {
Some(timeout)
} else {
None
}
}
#[pymethod]
fn setblocking(&self, block: bool, vm: &VirtualMachine) -> PyResult<()> {
self.timeout.store(if block { -1.0 } else { 0.0 });
self.sock()
.set_nonblocking(!block)
.map_err(|err| convert_sock_error(vm, err))
}
#[pymethod]
fn getblocking(&self, vm: &VirtualMachine) -> PyResult<bool> {
Ok(self.gettimeout(vm)?.map_or(false, |t| t == 0.0))
fn getblocking(&self) -> bool {
self.timeout.load() != 0.0
}
#[pymethod]
@@ -316,22 +416,14 @@ impl PySocket {
// timeout is 0: non-blocking, no timeout
// otherwise: timeout is timeout, don't change blocking
let (block, timeout) = match timeout {
None => (Some(true), None),
Some(d) if d == Duration::from_secs(0) => (Some(false), None),
Some(d) => (None, Some(d)),
None => (true, -1.0),
Some(d) if d == Duration::from_secs(0) => (false, 0.0),
Some(d) => (true, d.as_secs_f64()),
};
self.timeout.store(timeout);
self.sock()
.set_read_timeout(timeout)
.map_err(|err| convert_sock_error(vm, err))?;
self.sock()
.set_write_timeout(timeout)
.map_err(|err| convert_sock_error(vm, err))?;
if let Some(blocking) = block {
self.sock()
.set_nonblocking(!blocking)
.map_err(|err| convert_sock_error(vm, err))?;
}
Ok(())
.set_nonblocking(!block)
.map_err(|err| convert_sock_error(vm, err))
}
#[pymethod]
@@ -550,6 +642,72 @@ fn _socket_getservbyname(
Ok(vm.ctx.new_int(u16::from_be(port as u16)))
}
#[derive(Copy, Clone)]
enum SelectKind {
Read,
Write,
Connect,
}
/// returns true if timed out
fn sock_select(sock: &Socket, kind: SelectKind, interval: Option<Duration>) -> io::Result<bool> {
let fd = sock_fileno(sock);
#[cfg(unix)]
{
let mut pollfd = libc::pollfd {
fd,
events: match kind {
SelectKind::Read => libc::POLLIN,
SelectKind::Write => libc::POLLOUT,
SelectKind::Connect => libc::POLLOUT | libc::POLLERR,
},
revents: 0,
};
let timeout = match interval {
Some(d) => d.as_millis() as _,
None => -1,
};
let ret = unsafe { libc::poll(&mut pollfd, 1, timeout) };
if ret < 0 {
Err(io::Error::last_os_error())
} else {
Ok(ret == 0)
}
}
#[cfg(windows)]
{
use crate::stdlib::select;
let mut reads = select::FdSet::new();
let mut writes = select::FdSet::new();
let mut errs = select::FdSet::new();
let fd = fd as usize;
match kind {
SelectKind::Read => reads.insert(fd),
SelectKind::Write => writes.insert(fd),
SelectKind::Connect => {
writes.insert(fd);
errs.insert(fd);
}
}
let mut interval = interval.map(|dur| select::timeval {
tv_sec: dur.as_secs() as _,
tv_usec: dur.subsec_micros() as _,
});
select::select(
fd as i32 + 1,
&mut reads,
&mut writes,
&mut errs,
interval.as_mut(),
)
.map(|ret| ret == 0)
}
}
#[derive(FromArgs)]
struct GAIOptions {
#[pyarg(positional)]
@@ -817,6 +975,10 @@ fn convert_sock_error(vm: &VirtualMachine, err: io::Error) -> PyBaseExceptionRef
}
}
fn timeout_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_exception_msg(TIMEOUT_ERROR.get().unwrap().clone(), "timed out".to_owned())
}
fn get_ipv6_addr_str(ipv6: Ipv6Addr) -> String {
match ipv6.to_ipv4() {
Some(v4) if matches!(v4.octets(), [0, 0, _, _]) => format!("::{:x}", u32::from(v4)),
@@ -824,6 +986,24 @@ fn get_ipv6_addr_str(ipv6: Ipv6Addr) -> String {
}
}
struct Deadline {
deadline: Instant,
}
impl Deadline {
fn new(timeout: Duration) -> Self {
Self {
deadline: Instant::now() + timeout,
}
}
fn time_until(&self, vm: &VirtualMachine) -> PyResult<Duration> {
self.deadline
.checked_duration_since(Instant::now())
// past the deadline already
.ok_or_else(|| timeout_error(vm))
}
}
rustpython_common::static_cell! {
static TIMEOUT_ERROR: PyTypeRef;
static GAI_ERROR: PyTypeRef;

View File

@@ -89,10 +89,18 @@ mod decl {
Ok(vm.ctx.new_bytes(encoded_bytes))
}
// TODO: validate wbits value here
fn header_from_wbits(wbits: OptionalArg<i8>) -> (bool, u8) {
fn header_from_wbits(
wbits: OptionalArg<i8>,
vm: &VirtualMachine,
) -> PyResult<(Option<bool>, u8)> {
let wbits = wbits.unwrap_or(MAX_WBITS as i8);
(wbits > 0, wbits.abs() as u8)
let header = wbits > 0;
let wbits = wbits.abs() as u8;
match wbits {
9..=15 => Ok((Some(header), wbits)),
25..=31 => Ok((None, wbits - 16)),
_ => Err(vm.new_value_error("Invalid initialization option".to_owned())),
}
}
fn _decompress(
@@ -172,10 +180,13 @@ mod decl {
vm: &VirtualMachine,
) -> PyResult<Vec<u8>> {
data.with_ref(|data| {
let (header, wbits) = header_from_wbits(wbits);
let (header, wbits) = header_from_wbits(wbits, vm)?;
let bufsize = bufsize.unwrap_or(DEF_BUF_SIZE);
let mut d = Decompress::new_with_window_bits(header, wbits);
let mut d = match header {
Some(header) => Decompress::new_with_window_bits(header, wbits),
None => Decompress::new_gzip(wbits),
};
_decompress(data, &mut d, bufsize, None, vm).and_then(|(buf, stream_end)| {
if stream_end {
Ok(buf)
@@ -187,18 +198,21 @@ mod decl {
}
#[pyfunction]
fn decompressobj(args: DecopmressobjArgs, vm: &VirtualMachine) -> PyDecompress {
let (header, wbits) = header_from_wbits(args.wbits);
let mut decompress = Decompress::new_with_window_bits(header, wbits);
fn decompressobj(args: DecopmressobjArgs, vm: &VirtualMachine) -> PyResult<PyDecompress> {
let (header, wbits) = header_from_wbits(args.wbits, vm)?;
let mut decompress = match header {
Some(header) => Decompress::new_with_window_bits(header, wbits),
None => Decompress::new_gzip(wbits),
};
if let OptionalArg::Present(dict) = args.zdict {
dict.with_ref(|d| decompress.set_dictionary(d).unwrap());
}
PyDecompress {
Ok(PyDecompress {
decompress: PyMutex::new(decompress),
eof: AtomicCell::new(false),
unused_data: PyMutex::new(PyBytes::from(vec![]).into_ref(vm)),
unconsumed_tail: PyMutex::new(PyBytes::from(vec![]).into_ref(vm)),
}
})
}
#[pyattr]
#[pyclass(name = "Decompress")]
@@ -350,7 +364,7 @@ mod decl {
_zdict: OptionalArg<PyBytesLike>,
vm: &VirtualMachine,
) -> PyResult<PyCompress> {
let (header, wbits) = header_from_wbits(wbits);
let (header, wbits) = header_from_wbits(wbits, vm)?;
let level = level.unwrap_or(-1);
let level = match level {
@@ -358,7 +372,11 @@ mod decl {
n @ 0..=9 => n as u32,
_ => return Err(vm.new_value_error("invalid initialization option".to_owned())),
};
let compress = Compress::new_with_window_bits(Compression::new(level), header, wbits);
let level = Compression::new(level);
let compress = match header {
Some(header) => Compress::new_with_window_bits(level, header, wbits),
None => Compress::new_gzip(level, wbits),
};
Ok(PyCompress {
inner: PyMutex::new(CompressInner {
compress,