diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index b1d5f66c7..acba43101 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -19,9 +19,15 @@ use crate::obj::objbytes::PyBytesRef; use crate::obj::objstr::PyStringRef; use crate::obj::objtuple::PyTupleRef; use crate::obj::objtype::PyClassRef; -use crate::pyobject::{PyObjectRef, PyRef, PyResult, PyValue, TryFromObject}; +use crate::pyobject::{PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject}; use crate::vm::VirtualMachine; +#[cfg(unix)] +type RawSocket = std::os::unix::io::RawFd; +#[cfg(windows)] +type RawSocket = std::os::windows::raw::SOCKET; + +#[pyclass] #[derive(Debug)] pub struct PySocket { timeout: Cell>, @@ -34,15 +40,17 @@ impl PyValue for PySocket { } } -type PySocketRef = PyRef; +pub type PySocketRef = PyRef; -impl PySocketRef { +#[pyimpl] +impl PySocket { + #[pyslot(new)] fn new( cls: PyClassRef, domain: OptionalArg, socket_type: OptionalArg, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult> { let domain = domain.unwrap_or(libc::AF_INET); let socket_type = socket_type.unwrap_or(libc::SOCK_STREAM); let domain = match domain { @@ -66,7 +74,8 @@ impl PySocketRef { .into_ref_with_type(vm, cls) } - fn connect(self, address: Address, vm: &VirtualMachine) -> PyResult<()> { + #[pymethod] + fn connect(&self, address: Address, vm: &VirtualMachine) -> PyResult<()> { let sock_addr = get_addr(vm, address)?; let res = if let Some(duration) = self.timeout.get() { self.sock.connect_timeout(&sock_addr, duration) @@ -76,14 +85,16 @@ impl PySocketRef { res.map_err(|err| convert_io_error(vm, err)) } - fn bind(self, address: Address, vm: &VirtualMachine) -> PyResult<()> { + #[pymethod] + fn bind(&self, address: Address, vm: &VirtualMachine) -> PyResult<()> { let sock_addr = get_addr(vm, address)?; self.sock .bind(&sock_addr) .map_err(|err| convert_io_error(vm, err)) } - fn listen(self, backlog: OptionalArg, vm: &VirtualMachine) -> PyResult<()> { + #[pymethod] + fn listen(&self, backlog: OptionalArg, vm: &VirtualMachine) -> PyResult<()> { let backlog = backlog.unwrap_or(128); let backlog = if backlog < 0 { 0 } else { backlog }; self.sock @@ -91,7 +102,8 @@ impl PySocketRef { .map_err(|err| convert_io_error(vm, err)) } - fn accept(self, vm: &VirtualMachine) -> PyResult { + #[pymethod] + fn accept(&self, vm: &VirtualMachine) -> PyResult { let (sock, addr) = self .sock .accept() @@ -108,7 +120,8 @@ impl PySocketRef { Ok(vm.ctx.new_tuple(vec![socket.into_object(), addr_tuple])) } - fn recv(self, bufsize: usize, vm: &VirtualMachine) -> PyResult { + #[pymethod] + fn recv(&self, bufsize: usize, vm: &VirtualMachine) -> PyResult { let mut buffer = vec![0u8; bufsize]; match self.sock.recv(&mut buffer) { Ok(_) => Ok(vm.ctx.new_bytes(buffer)), @@ -120,7 +133,8 @@ impl PySocketRef { } } - fn recvfrom(self, bufsize: usize, vm: &VirtualMachine) -> PyResult { + #[pymethod] + fn recvfrom(&self, bufsize: usize, vm: &VirtualMachine) -> PyResult { let mut buffer = vec![0u8; bufsize]; let addr = match self.sock.recv_from(&mut buffer) { Ok((_, addr)) => addr, @@ -136,7 +150,8 @@ impl PySocketRef { Ok(vm.ctx.new_tuple(vec![vm.ctx.new_bytes(buffer), addr_tuple])) } - fn send(self, bytes: PyBytesLike, vm: &VirtualMachine) -> PyResult { + #[pymethod] + fn send(&self, bytes: PyBytesLike, vm: &VirtualMachine) -> PyResult { match self.sock.send(bytes.to_cow().as_ref()) { Ok(i) => Ok(i), Err(ref e) if e.kind() == io::ErrorKind::TimedOut => { @@ -147,7 +162,8 @@ impl PySocketRef { } } - fn sendto(self, bytes: PyBytesLike, address: Address, vm: &VirtualMachine) -> PyResult<()> { + #[pymethod] + fn sendto(&self, bytes: PyBytesLike, address: Address, vm: &VirtualMachine) -> PyResult<()> { let addr = get_addr(vm, address)?; self.sock .send_to(bytes.to_cow().as_ref(), &addr) @@ -155,7 +171,8 @@ impl PySocketRef { Ok(()) } - fn close(self, vm: &VirtualMachine) -> PyResult<()> { + #[pymethod] + fn close(&self, vm: &VirtualMachine) -> PyResult<()> { let fd = self.clone().fileno(vm); #[cfg(unix)] let ret = unsafe { libc::close(fd) }; @@ -173,18 +190,22 @@ impl PySocketRef { } } - #[cfg(unix)] - fn fileno(self, _vm: &VirtualMachine) -> std::os::unix::io::RawFd { - use std::os::unix::io::AsRawFd; - self.sock.as_raw_fd() - } - #[cfg(windows)] - fn fileno(self, _vm: &VirtualMachine) -> std::os::windows::raw::SOCKET { - use std::os::windows::io::AsRawSocket; - self.sock.as_raw_socket() + #[pymethod] + fn fileno(&self, _vm: &VirtualMachine) -> RawSocket { + #[cfg(unix)] + { + use std::os::unix::io::AsRawFd; + self.sock.as_raw_fd() + } + #[cfg(windows)] + { + use std::os::windows::io::AsRawSocket; + self.sock.as_raw_socket() + } } - fn getsockname(self, vm: &VirtualMachine) -> PyResult { + #[pymethod] + fn getsockname(&self, vm: &VirtualMachine) -> PyResult { let addr = self .sock .local_addr() @@ -193,11 +214,13 @@ impl PySocketRef { Ok(get_addr_tuple(vm, addr)) } - fn gettimeout(self, _vm: &VirtualMachine) -> Option { + #[pymethod] + fn gettimeout(&self, _vm: &VirtualMachine) -> Option { self.timeout.get().map(duration_to_f64) } - fn setblocking(self, block: bool, vm: &VirtualMachine) -> PyResult<()> { + #[pymethod] + fn setblocking(&self, block: bool, vm: &VirtualMachine) -> PyResult<()> { self.sock .set_nonblocking(!block) .map_err(|err| convert_io_error(vm, err))?; @@ -209,11 +232,13 @@ impl PySocketRef { Ok(()) } - fn getblocking(self, _vm: &VirtualMachine) -> bool { + #[pymethod] + fn getblocking(&self, _vm: &VirtualMachine) -> bool { self.timeout.get().map_or(false, |d| d.as_secs() != 0) } - fn settimeout(self, timeout: f64, vm: &VirtualMachine) -> PyResult<()> { + #[pymethod] + fn settimeout(&self, timeout: f64, vm: &VirtualMachine) -> PyResult<()> { let secs: u64 = timeout.trunc() as u64; let nanos: u32 = (timeout.fract() * 1e9) as u32; let duration = Duration::new(secs, nanos); @@ -227,7 +252,8 @@ impl PySocketRef { Ok(()) } - fn shutdown(self, how: i32, vm: &VirtualMachine) -> PyResult<()> { + #[pymethod] + fn shutdown(&self, how: i32, vm: &VirtualMachine) -> PyResult<()> { let how = match how { libc::SHUT_RD => Shutdown::Read, libc::SHUT_WR => Shutdown::Write, @@ -346,26 +372,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let socket_timeout = ctx.new_class("socket.timeout", vm.ctx.exceptions.os_error.clone()); let socket_gaierror = ctx.new_class("socket.gaierror", vm.ctx.exceptions.os_error.clone()); - let socket = py_class!(ctx, "socket", ctx.object(), { - (slot new) => PySocketRef::new, - "connect" => ctx.new_rustfunc(PySocketRef::connect), - "recv" => ctx.new_rustfunc(PySocketRef::recv), - "send" => ctx.new_rustfunc(PySocketRef::send), - "bind" => ctx.new_rustfunc(PySocketRef::bind), - "accept" => ctx.new_rustfunc(PySocketRef::accept), - "listen" => ctx.new_rustfunc(PySocketRef::listen), - "close" => ctx.new_rustfunc(PySocketRef::close), - "getsockname" => ctx.new_rustfunc(PySocketRef::getsockname), - "sendto" => ctx.new_rustfunc(PySocketRef::sendto), - "recvfrom" => ctx.new_rustfunc(PySocketRef::recvfrom), - "fileno" => ctx.new_rustfunc(PySocketRef::fileno), - "getblocking" => ctx.new_rustfunc(PySocketRef::getblocking), - "setblocking" => ctx.new_rustfunc(PySocketRef::setblocking), - "gettimeout" => ctx.new_rustfunc(PySocketRef::gettimeout), - "settimeout" => ctx.new_rustfunc(PySocketRef::settimeout), - "shutdown" => ctx.new_rustfunc(PySocketRef::shutdown), - }); - let module = py_module!(vm, "socket", { "error" => ctx.exceptions.os_error.clone(), "timeout" => socket_timeout, @@ -377,7 +383,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "SHUT_RD" => ctx.new_int(libc::SHUT_RD), "SHUT_WR" => ctx.new_int(libc::SHUT_WR), "SHUT_RDWR" => ctx.new_int(libc::SHUT_RDWR), - "socket" => socket, + "socket" => PySocket::make_class(ctx), "inet_aton" => ctx.new_rustfunc(socket_inet_aton), "inet_ntoa" => ctx.new_rustfunc(socket_inet_ntoa), "gethostname" => ctx.new_rustfunc(socket_gethostname),