mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-09 22:49:57 +09:00
Merge pull request #1309 from Lynskylate/extend-socket
Add settimeout and setblocking for socket module
This commit is contained in:
@@ -137,3 +137,13 @@ with assertRaises(OSError):
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
pass
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as listener:
|
||||
listener.bind(("127.0.0.1", 0))
|
||||
listener.listen(1)
|
||||
connector = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
connector.connect(("127.0.0.1", listener.getsockname()[1]))
|
||||
(connection, addr) = listener.accept()
|
||||
connection.settimeout(1.0)
|
||||
with assertRaises(OSError):
|
||||
connection.recv(len(MESSAGE_A))
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::io;
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
use std::net::{Ipv4Addr, SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket};
|
||||
use std::time::Duration;
|
||||
|
||||
#[cfg(all(unix, not(target_os = "redox")))]
|
||||
use nix::unistd::sethostname;
|
||||
@@ -122,6 +123,27 @@ impl Connection {
|
||||
fn fileno(&self) -> i64 {
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
fn setblocking(&mut self, value: bool) -> io::Result<()> {
|
||||
match self {
|
||||
Connection::TcpListener(con) => con.set_nonblocking(!value),
|
||||
Connection::UdpSocket(con) => con.set_nonblocking(!value),
|
||||
Connection::TcpStream(con) => con.set_nonblocking(!value),
|
||||
}
|
||||
}
|
||||
|
||||
fn settimeout(&mut self, duration: Duration) -> io::Result<()> {
|
||||
match self {
|
||||
// net
|
||||
Connection::TcpListener(_con) => Ok(()),
|
||||
Connection::UdpSocket(con) => con
|
||||
.set_read_timeout(Some(duration))
|
||||
.and_then(|_| con.set_write_timeout(Some(duration))),
|
||||
Connection::TcpStream(con) => con
|
||||
.set_read_timeout(Some(duration))
|
||||
.and_then(|_| con.set_write_timeout(Some(duration))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for Connection {
|
||||
@@ -152,6 +174,7 @@ pub struct Socket {
|
||||
address_family: AddressFamily,
|
||||
socket_kind: SocketKind,
|
||||
con: RefCell<Option<Connection>>,
|
||||
timeout: RefCell<Option<Duration>>,
|
||||
}
|
||||
|
||||
impl PyValue for Socket {
|
||||
@@ -166,6 +189,7 @@ impl Socket {
|
||||
address_family,
|
||||
socket_kind,
|
||||
con: RefCell::new(None),
|
||||
timeout: RefCell::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -194,13 +218,41 @@ impl SocketRef {
|
||||
let address_string = address.get_address_string();
|
||||
|
||||
match self.socket_kind {
|
||||
SocketKind::Stream => match TcpStream::connect(address_string) {
|
||||
Ok(stream) => {
|
||||
self.con.borrow_mut().replace(Connection::TcpStream(stream));
|
||||
Ok(())
|
||||
SocketKind::Stream => {
|
||||
let con = if let Some(duration) = self.timeout.borrow().as_ref() {
|
||||
let sock_addr = match address_string.to_socket_addrs() {
|
||||
Ok(mut sock_addrs) => {
|
||||
if sock_addrs.len() == 0 {
|
||||
let error_type = vm.class("socket", "gaierror");
|
||||
return Err(vm.new_exception(
|
||||
error_type,
|
||||
"nodename nor servname provided, or not known".to_string(),
|
||||
));
|
||||
} else {
|
||||
sock_addrs.next().unwrap()
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let error_type = vm.class("socket", "gaierror");
|
||||
return Err(vm.new_exception(error_type, e.to_string()));
|
||||
}
|
||||
};
|
||||
TcpStream::connect_timeout(&sock_addr, *duration)
|
||||
} else {
|
||||
TcpStream::connect(address_string)
|
||||
};
|
||||
match con {
|
||||
Ok(stream) => {
|
||||
self.con.borrow_mut().replace(Connection::TcpStream(stream));
|
||||
Ok(())
|
||||
}
|
||||
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
|
||||
let socket_timeout = vm.class("socket", "timeout");
|
||||
Err(vm.new_exception(socket_timeout, "Timed out".to_string()))
|
||||
}
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
}
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
}
|
||||
SocketKind::Dgram => {
|
||||
if let Some(Connection::UdpSocket(con)) = self.con.borrow().as_ref() {
|
||||
match con.connect(address_string) {
|
||||
@@ -254,6 +306,7 @@ impl SocketRef {
|
||||
address_family: self.address_family,
|
||||
socket_kind: self.socket_kind,
|
||||
con: RefCell::new(Some(Connection::TcpStream(tcp_stream))),
|
||||
timeout: RefCell::new(None),
|
||||
}
|
||||
.into_ref(vm);
|
||||
|
||||
@@ -267,6 +320,10 @@ impl SocketRef {
|
||||
match self.con.borrow_mut().as_mut() {
|
||||
Some(v) => match v.read_exact(&mut buffer) {
|
||||
Ok(_) => (),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
|
||||
let socket_timeout = vm.class("socket", "timeout");
|
||||
return Err(vm.new_exception(socket_timeout, "Timed out".to_string()));
|
||||
}
|
||||
Err(s) => return Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
None => return Err(vm.new_type_error("".to_string())),
|
||||
@@ -295,9 +352,13 @@ impl SocketRef {
|
||||
match self.con.borrow_mut().as_mut() {
|
||||
Some(v) => match v.write(&bytes) {
|
||||
Ok(_) => (),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::TimedOut => {
|
||||
let socket_timeout = vm.class("socket", "timeout");
|
||||
return Err(vm.new_exception(socket_timeout, "Timed out".to_string()));
|
||||
}
|
||||
Err(s) => return Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
None => return Err(vm.new_type_error("".to_string())),
|
||||
None => return Err(vm.new_type_error("Socket is not connected".to_string())),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
@@ -352,6 +413,75 @@ impl SocketRef {
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn gettimeout(self, _vm: &VirtualMachine) -> PyResult<Option<f64>> {
|
||||
match self.timeout.borrow().as_ref() {
|
||||
Some(duration) => Ok(Some(duration.as_secs() as f64)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn setblocking(self, block: Option<bool>, vm: &VirtualMachine) -> PyResult<()> {
|
||||
match block {
|
||||
Some(value) => {
|
||||
if value {
|
||||
self.timeout.replace(None);
|
||||
} else {
|
||||
self.timeout.borrow_mut().replace(Duration::from_secs(0));
|
||||
}
|
||||
if let Some(conn) = self.con.borrow_mut().as_mut() {
|
||||
return match conn.setblocking(value) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(vm.new_os_error(err.to_string())),
|
||||
};
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// Avoid converting None to bool
|
||||
Err(vm.new_type_error("an bool is required".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn getblocking(self, _vm: &VirtualMachine) -> PyResult<Option<bool>> {
|
||||
match self.timeout.borrow().as_ref() {
|
||||
Some(duration) => {
|
||||
if duration.as_secs() != 0 {
|
||||
Ok(Some(true))
|
||||
} else {
|
||||
Ok(Some(false))
|
||||
}
|
||||
}
|
||||
None => Ok(Some(true)),
|
||||
}
|
||||
}
|
||||
|
||||
fn settimeout(self, timeout: Option<f64>, vm: &VirtualMachine) -> PyResult<()> {
|
||||
match timeout {
|
||||
Some(timeout) => {
|
||||
self.timeout
|
||||
.borrow_mut()
|
||||
.replace(Duration::from_secs(timeout as u64));
|
||||
|
||||
let block = timeout > 0.0;
|
||||
|
||||
if let Some(conn) = self.con.borrow_mut().as_mut() {
|
||||
conn.setblocking(block)
|
||||
.and_then(|_| conn.settimeout(Duration::from_secs(timeout as u64)))
|
||||
.map_err(|err| vm.new_os_error(err.to_string()))
|
||||
.map(|_| ())
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
None => {
|
||||
self.timeout.replace(None);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Address {
|
||||
@@ -432,6 +562,8 @@ fn socket_htonl(host: PyIntRef, vm: &VirtualMachine) -> PyResult {
|
||||
|
||||
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
let ctx = &vm.ctx;
|
||||
let socket_timeout = ctx.new_class("socket.timeout", vm.ctx.exceptions.os_error.clone());
|
||||
let socket_gaierror = ctx.new_class("socket.gaierror", vm.ctx.exceptions.os_error.clone());
|
||||
|
||||
let socket = py_class!(ctx, "socket", ctx.object(), {
|
||||
"__new__" => ctx.new_rustfunc(SocketRef::new),
|
||||
@@ -448,9 +580,16 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
"sendto" => ctx.new_rustfunc(SocketRef::sendto),
|
||||
"recvfrom" => ctx.new_rustfunc(SocketRef::recvfrom),
|
||||
"fileno" => ctx.new_rustfunc(SocketRef::fileno),
|
||||
"getblocking" => ctx.new_rustfunc(SocketRef::getblocking),
|
||||
"setblocking" => ctx.new_rustfunc(SocketRef::setblocking),
|
||||
"gettimeout" => ctx.new_rustfunc(SocketRef::gettimeout),
|
||||
"settimeout" => ctx.new_rustfunc(SocketRef::settimeout),
|
||||
});
|
||||
|
||||
let module = py_module!(vm, "socket", {
|
||||
"error" => ctx.exceptions.os_error.clone(),
|
||||
"timeout" => socket_timeout,
|
||||
"gaierror" => socket_gaierror,
|
||||
"AF_INET" => ctx.new_int(AddressFamily::Inet as i32),
|
||||
"SOCK_STREAM" => ctx.new_int(SocketKind::Stream as i32),
|
||||
"SOCK_DGRAM" => ctx.new_int(SocketKind::Dgram as i32),
|
||||
|
||||
Reference in New Issue
Block a user