use std::cell::RefCell; use std::io; use std::io::Read; use std::io::Write; use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket}; use std::ops::Deref; use crate::function::PyFuncArgs; use crate::obj::objbytes; use crate::obj::objbytes::PyBytesRef; use crate::obj::objint; use crate::obj::objint::PyIntRef; use crate::obj::objstr; use crate::obj::objtuple::PyTupleRef; use crate::pyobject::{PyObjectRef, PyRef, PyResult, PyValue, TryFromObject}; use crate::vm::VirtualMachine; use crate::obj::objtype::PyClassRef; use num_traits::ToPrimitive; #[derive(Debug, Copy, Clone)] enum AddressFamily { Unix = 1, Inet = 2, Inet6 = 3, } impl TryFromObject for AddressFamily { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { match i32::try_from_object(vm, obj)? { 1 => Ok(AddressFamily::Unix), 2 => Ok(AddressFamily::Inet), 3 => Ok(AddressFamily::Inet6), value => Err(vm.new_os_error(format!("Unknown address family value: {}", value))), } } } #[derive(Debug, Copy, Clone)] enum SocketKind { Stream = 1, Dgram = 2, } impl TryFromObject for SocketKind { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { match i32::try_from_object(vm, obj)? { 1 => Ok(SocketKind::Stream), 2 => Ok(SocketKind::Dgram), value => Err(vm.new_os_error(format!("Unknown socket kind value: {}", value))), } } } #[derive(Debug)] enum Connection { TcpListener(TcpListener), TcpStream(TcpStream), UdpSocket(UdpSocket), } impl Connection { fn accept(&mut self) -> io::Result<(TcpStream, SocketAddr)> { match self { Connection::TcpListener(con) => con.accept(), _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), } } fn local_addr(&self) -> io::Result { match self { Connection::TcpListener(con) => con.local_addr(), Connection::UdpSocket(con) => con.local_addr(), Connection::TcpStream(con) => con.local_addr(), } } fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { match self { Connection::UdpSocket(con) => con.recv_from(buf), _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), } } fn send_to(&self, buf: &[u8], addr: A) -> io::Result { match self { Connection::UdpSocket(con) => con.send_to(buf, addr), _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), } } #[cfg(unix)] fn fileno(&self) -> i64 { use std::os::unix::io::AsRawFd; let raw_fd = match self { Connection::TcpListener(con) => con.as_raw_fd(), Connection::UdpSocket(con) => con.as_raw_fd(), Connection::TcpStream(con) => con.as_raw_fd(), }; raw_fd as i64 } #[cfg(windows)] fn fileno(&self) -> i64 { use std::os::windows::io::AsRawSocket; let raw_fd = match self { Connection::TcpListener(con) => con.as_raw_socket(), Connection::UdpSocket(con) => con.as_raw_socket(), Connection::TcpStream(con) => con.as_raw_socket(), }; raw_fd as i64 } #[cfg(all(not(unix), not(windows)))] fn fileno(&self) -> i64 { unimplemented!(); } } impl Read for Connection { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self { Connection::TcpStream(con) => con.read(buf), Connection::UdpSocket(con) => con.recv(buf), _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), } } } impl Write for Connection { fn write(&mut self, buf: &[u8]) -> io::Result { match self { Connection::TcpStream(con) => con.write(buf), Connection::UdpSocket(con) => con.send(buf), _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), } } fn flush(&mut self) -> io::Result<()> { Ok(()) } } #[derive(Debug)] pub struct Socket { address_family: AddressFamily, socket_kind: SocketKind, con: RefCell>, } impl PyValue for Socket { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("socket", "socket") } } impl Socket { fn new(address_family: AddressFamily, socket_kind: SocketKind) -> Socket { Socket { address_family, socket_kind, con: RefCell::new(None), } } } fn get_socket<'a>(obj: &'a PyObjectRef) -> impl Deref + 'a { obj.payload::().unwrap() } type SocketRef = PyRef; impl SocketRef { fn new( cls: PyClassRef, family: AddressFamily, kind: SocketKind, vm: &VirtualMachine, ) -> PyResult { Socket::new(family, kind).into_ref_with_type(vm, cls) } fn connect(self, address: PyTupleRef, vm: &VirtualMachine) -> PyResult { let address_string = get_address_string(vm, address)?; match self.socket_kind { SocketKind::Stream => match TcpStream::connect(address_string) { Ok(stream) => { self.con.borrow_mut().replace(Connection::TcpStream(stream)); Ok(vm.get_none()) } Err(s) => Err(vm.new_os_error(s.to_string())), }, SocketKind::Dgram => { if let Some(Connection::UdpSocket(con)) = self.con.borrow().as_ref() { match con.connect(address_string) { Ok(_) => Ok(vm.get_none()), Err(s) => Err(vm.new_os_error(s.to_string())), } } else { Err(vm.new_type_error("".to_string())) } } } } fn bind(self, address: PyTupleRef, vm: &VirtualMachine) -> PyResult { let address_string = get_address_string(vm, address)?; match self.socket_kind { SocketKind::Stream => match TcpListener::bind(address_string) { Ok(stream) => { self.con .borrow_mut() .replace(Connection::TcpListener(stream)); Ok(vm.get_none()) } Err(s) => Err(vm.new_os_error(s.to_string())), }, SocketKind::Dgram => match UdpSocket::bind(address_string) { Ok(dgram) => { self.con.borrow_mut().replace(Connection::UdpSocket(dgram)); Ok(vm.get_none()) } Err(s) => Err(vm.new_os_error(s.to_string())), }, } } fn sendto(self, bytes: PyBytesRef, address: PyTupleRef, vm: &VirtualMachine) -> PyResult { let address_string = get_address_string(vm, address)?; match self.socket_kind { SocketKind::Dgram => { if let Some(v) = self.con.borrow().as_ref() { return match v.send_to(&bytes, address_string) { Ok(_) => Ok(vm.get_none()), Err(s) => Err(vm.new_os_error(s.to_string())), }; } // Doing implicit bind match UdpSocket::bind("0.0.0.0:0") { Ok(dgram) => match dgram.send_to(&bytes, address_string) { Ok(_) => { self.con.borrow_mut().replace(Connection::UdpSocket(dgram)); Ok(vm.get_none()) } Err(s) => Err(vm.new_os_error(s.to_string())), }, Err(s) => Err(vm.new_os_error(s.to_string())), } } _ => Err(vm.new_not_implemented_error("".to_string())), } } fn listen(self, _num: PyIntRef, _vm: &VirtualMachine) -> () {} fn accept(self, vm: &VirtualMachine) -> PyResult { let ret = match self.con.borrow_mut().as_mut() { Some(v) => v.accept(), None => return Err(vm.new_type_error("".to_string())), }; let (tcp_stream, addr) = match ret { Ok((socket, addr)) => (socket, addr), Err(s) => return Err(vm.new_os_error(s.to_string())), }; let socket = Socket { address_family: self.address_family, socket_kind: self.socket_kind, con: RefCell::new(Some(Connection::TcpStream(tcp_stream))), } .into_ref(vm); let addr_tuple = get_addr_tuple(vm, addr)?; Ok(vm.ctx.new_tuple(vec![socket.into_object(), addr_tuple])) } fn recv(self, bufsize: PyIntRef, vm: &VirtualMachine) -> PyResult { let mut buffer = vec![0u8; bufsize.as_bigint().to_usize().unwrap()]; match self.con.borrow_mut().as_mut() { Some(v) => match v.read_exact(&mut buffer) { Ok(_) => (), Err(s) => return Err(vm.new_os_error(s.to_string())), }, None => return Err(vm.new_type_error("".to_string())), }; Ok(vm.ctx.new_bytes(buffer)) } } fn get_address_string(vm: &VirtualMachine, address: PyTupleRef) -> Result { let args = PyFuncArgs { args: address.elements.borrow().to_vec(), kwargs: vec![], }; arg_check!( vm, args, required = [ (host, Some(vm.ctx.str_type())), (port, Some(vm.ctx.int_type())) ] ); Ok(format!( "{}:{}", objstr::get_value(host), objint::get_value(port).to_string() )) } fn socket_recvfrom(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, args, required = [(zelf, None), (bufsize, Some(vm.ctx.int_type()))] ); let socket = get_socket(zelf); let mut buffer = vec![0u8; objint::get_value(bufsize).to_usize().unwrap()]; let ret = match socket.con.borrow().as_ref() { Some(v) => v.recv_from(&mut buffer), None => return Err(vm.new_type_error("".to_string())), }; let addr = match ret { Ok((_size, addr)) => addr, Err(s) => return Err(vm.new_os_error(s.to_string())), }; let addr_tuple = get_addr_tuple(vm, addr)?; Ok(vm.ctx.new_tuple(vec![vm.ctx.new_bytes(buffer), addr_tuple])) } fn socket_send(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, args, required = [(zelf, None), (bytes, Some(vm.ctx.bytes_type()))] ); let socket = get_socket(zelf); match socket.con.borrow_mut().as_mut() { Some(v) => match v.write(&objbytes::get_value(&bytes)) { Ok(_) => (), Err(s) => return Err(vm.new_os_error(s.to_string())), }, None => return Err(vm.new_type_error("".to_string())), }; Ok(vm.get_none()) } fn socket_close(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(zelf, None)]); let socket = get_socket(zelf); socket.con.borrow_mut().take(); Ok(vm.get_none()) } fn socket_fileno(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(zelf, None)]); let socket = get_socket(zelf); let fileno = match socket.con.borrow_mut().as_mut() { Some(v) => v.fileno(), None => return Err(vm.new_type_error("".to_string())), }; Ok(vm.ctx.new_int(fileno)) } fn socket_getsockname(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(zelf, None)]); let socket = get_socket(zelf); let addr = match socket.con.borrow().as_ref() { Some(v) => v.local_addr(), None => return Err(vm.new_type_error("".to_string())), }; match addr { Ok(addr) => get_addr_tuple(vm, addr), Err(s) => Err(vm.new_os_error(s.to_string())), } } fn get_addr_tuple(vm: &VirtualMachine, addr: SocketAddr) -> PyResult { let port = vm.ctx.new_int(addr.port()); let ip = vm.ctx.new_str(addr.ip().to_string()); Ok(vm.ctx.new_tuple(vec![ip, port])) } pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; let socket = py_class!(ctx, "socket", ctx.object(), { "__new__" => ctx.new_rustfunc(SocketRef::new), "connect" => ctx.new_rustfunc(SocketRef::connect), "recv" => ctx.new_rustfunc(SocketRef::recv), "send" => ctx.new_rustfunc(socket_send), "bind" => ctx.new_rustfunc(SocketRef::bind), "accept" => ctx.new_rustfunc(SocketRef::accept), "listen" => ctx.new_rustfunc(SocketRef::listen), "close" => ctx.new_rustfunc(socket_close), "getsockname" => ctx.new_rustfunc(socket_getsockname), "sendto" => ctx.new_rustfunc(SocketRef::sendto), "recvfrom" => ctx.new_rustfunc(socket_recvfrom), "fileno" => ctx.new_rustfunc(socket_fileno), }); py_module!(vm, "socket", { "AF_INET" => ctx.new_int(AddressFamily::Inet as i32), "SOCK_STREAM" => ctx.new_int(SocketKind::Stream as i32), "SOCK_DGRAM" => ctx.new_int(SocketKind::Dgram as i32), "socket" => socket, }) }