diff --git a/tests/snippets/stdlib_socket.py b/tests/snippets/stdlib_socket.py index d5727bb25..ddd8d6ccc 100644 --- a/tests/snippets/stdlib_socket.py +++ b/tests/snippets/stdlib_socket.py @@ -137,3 +137,13 @@ with assertRaises(OSError): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: pass + +with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as listener: + listener.bind(("127.0.0.1", 0)) + listener.listen(1) + connector = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + connector.connect(("127.0.0.1", listener.getsockname()[1])) + (connection, addr) = listener.accept() + connection.settimeout(1.0) + with assertRaises(OSError): + connection.recv(len(MESSAGE_A)) diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index c83b75656..e0fd0a505 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -3,6 +3,7 @@ use std::io; use std::io::Read; use std::io::Write; use std::net::{Ipv4Addr, SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket}; +use std::time::Duration; #[cfg(all(unix, not(target_os = "redox")))] use nix::unistd::sethostname; @@ -122,6 +123,27 @@ impl Connection { fn fileno(&self) -> i64 { unimplemented!(); } + + fn setblocking(&mut self, value: bool) -> io::Result<()> { + match self { + Connection::TcpListener(con) => con.set_nonblocking(!value), + Connection::UdpSocket(con) => con.set_nonblocking(!value), + Connection::TcpStream(con) => con.set_nonblocking(!value), + } + } + + fn settimeout(&mut self, duration: Duration) -> io::Result<()> { + match self { + // net + Connection::TcpListener(_con) => Ok(()), + Connection::UdpSocket(con) => con + .set_read_timeout(Some(duration)) + .and_then(|_| con.set_write_timeout(Some(duration))), + Connection::TcpStream(con) => con + .set_read_timeout(Some(duration)) + .and_then(|_| con.set_write_timeout(Some(duration))), + } + } } impl Read for Connection { @@ -152,6 +174,7 @@ pub struct Socket { address_family: AddressFamily, socket_kind: SocketKind, con: RefCell>, + timeout: RefCell>, } impl PyValue for Socket { @@ -166,6 +189,7 @@ impl Socket { address_family, socket_kind, con: RefCell::new(None), + timeout: RefCell::new(None), } } } @@ -194,13 +218,41 @@ impl SocketRef { let address_string = address.get_address_string(); match self.socket_kind { - SocketKind::Stream => match TcpStream::connect(address_string) { - Ok(stream) => { - self.con.borrow_mut().replace(Connection::TcpStream(stream)); - Ok(()) + SocketKind::Stream => { + let con = if let Some(duration) = self.timeout.borrow().as_ref() { + let sock_addr = match address_string.to_socket_addrs() { + Ok(mut sock_addrs) => { + if sock_addrs.len() == 0 { + let error_type = vm.class("socket", "gaierror"); + return Err(vm.new_exception( + error_type, + "nodename nor servname provided, or not known".to_string(), + )); + } else { + sock_addrs.next().unwrap() + } + } + Err(e) => { + let error_type = vm.class("socket", "gaierror"); + return Err(vm.new_exception(error_type, e.to_string())); + } + }; + TcpStream::connect_timeout(&sock_addr, *duration) + } else { + TcpStream::connect(address_string) + }; + match con { + Ok(stream) => { + self.con.borrow_mut().replace(Connection::TcpStream(stream)); + Ok(()) + } + Err(ref e) if e.kind() == io::ErrorKind::TimedOut => { + let socket_timeout = vm.class("socket", "timeout"); + Err(vm.new_exception(socket_timeout, "Timed out".to_string())) + } + Err(s) => Err(vm.new_os_error(s.to_string())), } - Err(s) => Err(vm.new_os_error(s.to_string())), - }, + } SocketKind::Dgram => { if let Some(Connection::UdpSocket(con)) = self.con.borrow().as_ref() { match con.connect(address_string) { @@ -254,6 +306,7 @@ impl SocketRef { address_family: self.address_family, socket_kind: self.socket_kind, con: RefCell::new(Some(Connection::TcpStream(tcp_stream))), + timeout: RefCell::new(None), } .into_ref(vm); @@ -267,6 +320,10 @@ impl SocketRef { match self.con.borrow_mut().as_mut() { Some(v) => match v.read_exact(&mut buffer) { Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::TimedOut => { + let socket_timeout = vm.class("socket", "timeout"); + return Err(vm.new_exception(socket_timeout, "Timed out".to_string())); + } Err(s) => return Err(vm.new_os_error(s.to_string())), }, None => return Err(vm.new_type_error("".to_string())), @@ -295,9 +352,13 @@ impl SocketRef { match self.con.borrow_mut().as_mut() { Some(v) => match v.write(&bytes) { Ok(_) => (), + Err(ref e) if e.kind() == io::ErrorKind::TimedOut => { + let socket_timeout = vm.class("socket", "timeout"); + return Err(vm.new_exception(socket_timeout, "Timed out".to_string())); + } Err(s) => return Err(vm.new_os_error(s.to_string())), }, - None => return Err(vm.new_type_error("".to_string())), + None => return Err(vm.new_type_error("Socket is not connected".to_string())), }; Ok(()) } @@ -352,6 +413,75 @@ impl SocketRef { Err(s) => Err(vm.new_os_error(s.to_string())), } } + + fn gettimeout(self, _vm: &VirtualMachine) -> PyResult> { + match self.timeout.borrow().as_ref() { + Some(duration) => Ok(Some(duration.as_secs() as f64)), + None => Ok(None), + } + } + + fn setblocking(self, block: Option, vm: &VirtualMachine) -> PyResult<()> { + match block { + Some(value) => { + if value { + self.timeout.replace(None); + } else { + self.timeout.borrow_mut().replace(Duration::from_secs(0)); + } + if let Some(conn) = self.con.borrow_mut().as_mut() { + return match conn.setblocking(value) { + Ok(_) => Ok(()), + Err(err) => Err(vm.new_os_error(err.to_string())), + }; + } else { + Ok(()) + } + } + None => { + // Avoid converting None to bool + Err(vm.new_type_error("an bool is required".to_string())) + } + } + } + + fn getblocking(self, _vm: &VirtualMachine) -> PyResult> { + match self.timeout.borrow().as_ref() { + Some(duration) => { + if duration.as_secs() != 0 { + Ok(Some(true)) + } else { + Ok(Some(false)) + } + } + None => Ok(Some(true)), + } + } + + fn settimeout(self, timeout: Option, vm: &VirtualMachine) -> PyResult<()> { + match timeout { + Some(timeout) => { + self.timeout + .borrow_mut() + .replace(Duration::from_secs(timeout as u64)); + + let block = timeout > 0.0; + + if let Some(conn) = self.con.borrow_mut().as_mut() { + conn.setblocking(block) + .and_then(|_| conn.settimeout(Duration::from_secs(timeout as u64))) + .map_err(|err| vm.new_os_error(err.to_string())) + .map(|_| ()) + } else { + Ok(()) + } + } + None => { + self.timeout.replace(None); + Ok(()) + } + } + } } struct Address { @@ -432,6 +562,8 @@ fn socket_htonl(host: PyIntRef, vm: &VirtualMachine) -> PyResult { pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; + let socket_timeout = ctx.new_class("socket.timeout", vm.ctx.exceptions.os_error.clone()); + let socket_gaierror = ctx.new_class("socket.gaierror", vm.ctx.exceptions.os_error.clone()); let socket = py_class!(ctx, "socket", ctx.object(), { "__new__" => ctx.new_rustfunc(SocketRef::new), @@ -448,9 +580,16 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "sendto" => ctx.new_rustfunc(SocketRef::sendto), "recvfrom" => ctx.new_rustfunc(SocketRef::recvfrom), "fileno" => ctx.new_rustfunc(SocketRef::fileno), + "getblocking" => ctx.new_rustfunc(SocketRef::getblocking), + "setblocking" => ctx.new_rustfunc(SocketRef::setblocking), + "gettimeout" => ctx.new_rustfunc(SocketRef::gettimeout), + "settimeout" => ctx.new_rustfunc(SocketRef::settimeout), }); let module = py_module!(vm, "socket", { + "error" => ctx.exceptions.os_error.clone(), + "timeout" => socket_timeout, + "gaierror" => socket_gaierror, "AF_INET" => ctx.new_int(AddressFamily::Inet as i32), "SOCK_STREAM" => ctx.new_int(SocketKind::Stream as i32), "SOCK_DGRAM" => ctx.new_int(SocketKind::Dgram as i32),