diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index da384eea36..44b0d04d09 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -179,10 +179,12 @@ impl PySocket { #[pymethod] fn recv(&self, bufsize: usize, vm: &VirtualMachine) -> PyResult> { let mut buffer = vec![0u8; bufsize]; - match self.sock.borrow_mut().read_exact(&mut buffer) { - Ok(()) => Ok(buffer), - Err(err) => Err(convert_sock_error(vm, err)), - } + let n = self + .sock() + .recv(&mut buffer) + .map_err(|err| convert_sock_error(vm, err))?; + buffer.truncate(n); + Ok(buffer) } #[pymethod] @@ -196,10 +198,12 @@ impl PySocket { #[pymethod] fn recvfrom(&self, bufsize: usize, vm: &VirtualMachine) -> PyResult<(Vec, AddrTuple)> { let mut buffer = vec![0u8; bufsize]; - match self.sock().recv_from(&mut buffer) { - Ok((_, addr)) => Ok((buffer, get_addr_tuple(addr))), - Err(err) => Err(convert_sock_error(vm, err)), - } + let (n, addr) = self + .sock() + .recv_from(&mut buffer) + .map_err(|err| convert_sock_error(vm, err))?; + buffer.truncate(n); + Ok((buffer, get_addr_tuple(addr))) } #[pymethod] @@ -276,16 +280,73 @@ impl PySocket { } #[pymethod] - fn settimeout(&self, timeout: Option, vm: &VirtualMachine) -> PyResult<()> { + fn settimeout(&self, timeout: Option>, vm: &VirtualMachine) -> PyResult<()> { + let mut block = if timeout.is_none() { Some(true) } else { None }; + let timeout = timeout.and_then(|n| { + let dur = match n { + Either::A(f) => Duration::from_secs_f64(f), + Either::B(i) => Duration::from_secs(i), + }; + if dur == Duration::from_secs(0) { + block = Some(false); + None + } else { + Some(dur) + } + }); self.sock() - .set_read_timeout(timeout.map(Duration::from_secs_f64)) + .set_read_timeout(timeout) .map_err(|err| convert_sock_error(vm, err))?; self.sock() - .set_write_timeout(timeout.map(Duration::from_secs_f64)) + .set_write_timeout(timeout) .map_err(|err| convert_sock_error(vm, err))?; + if let Some(blocking) = block { + self.sock() + .set_nonblocking(!blocking) + .map_err(|err| convert_sock_error(vm, err))?; + } Ok(()) } + #[pymethod] + fn setsockopt( + &self, + level: i32, + name: i32, + value: Option>, + optlen: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult<()> { + let fd = sock_fileno(&self.sock()); + let ret = match (value, optlen) { + (Some(Either::A(b)), OptionalArg::Missing) => b.with_ref(|b| unsafe { + c::setsockopt(fd, level, name, b.as_ptr() as *const _, b.len() as _) + }), + (Some(Either::B(ref val)), OptionalArg::Missing) => unsafe { + c::setsockopt( + fd, + level, + name, + val as *const i32 as *const _, + std::mem::size_of::() as _, + ) + }, + (None, OptionalArg::Present(optlen)) => unsafe { + c::setsockopt(fd, level, name, std::ptr::null(), optlen as _) + }, + _ => { + return Err( + vm.new_type_error("expected the value arg xor the optlen arg".to_string()) + ); + } + }; + if ret < 0 { + Err(convert_sock_error(vm, io::Error::last_os_error())) + } else { + Ok(()) + } + } + #[pymethod] fn shutdown(&self, how: i32, vm: &VirtualMachine) -> PyResult<()> { let how = match how { @@ -570,6 +631,10 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "IPPROTO_IPIP" => ctx.new_int(c::IPPROTO_IP), "IPPROTO_IPV6" => ctx.new_int(c::IPPROTO_IPV6), "IPPROTO_NONE" => ctx.new_int(c::IPPROTO_NONE), + "SOL_SOCKET" => ctx.new_int(c::SOL_SOCKET), + "SO_REUSEADDR" => ctx.new_int(c::SO_REUSEADDR), + "TCP_NODELAY" => ctx.new_int(c::TCP_NODELAY), + "SO_BROADCAST" => ctx.new_int(c::SO_BROADCAST), "socket" => PySocket::make_class(ctx), "inet_aton" => ctx.new_rustfunc(socket_inet_aton), "inet_ntoa" => ctx.new_rustfunc(socket_inet_ntoa),