From 86b60faa61d80458e6304db5b6ad596da52f4183 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 2 Mar 2019 17:12:13 +0200 Subject: [PATCH] Add socket.{sendto, recvfrom} --- tests/snippets/stdlib_socket.py | 6 ++- vm/src/stdlib/socket.rs | 87 +++++++++++++++++++++++++++++---- 2 files changed, 82 insertions(+), 11 deletions(-) diff --git a/tests/snippets/stdlib_socket.py b/tests/snippets/stdlib_socket.py index 2a3c3a94e2..af29b451c7 100644 --- a/tests/snippets/stdlib_socket.py +++ b/tests/snippets/stdlib_socket.py @@ -41,8 +41,12 @@ sock1 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock1.bind(("127.0.0.1", 0)) sock2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) -sock2.bind(("127.0.0.1", 0)) +sock2.sendto(MESSAGE_A, sock1.getsockname()) +(recv_a, addr) = sock1.recvfrom(len(MESSAGE_A)) +assert recv_a == MESSAGE_A + +sock2.bind(("127.0.0.1", 0)) sock1.connect(("127.0.0.1", sock2.getsockname()[1])) sock2.connect(("127.0.0.1", sock1.getsockname()[1])) diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index e85b22d815..37dc8a6fa2 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -71,6 +71,13 @@ impl Connection { _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), } } + + fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + match self { + Connection::UdpSocket(con) => con.recv_from(buf), + _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), + } + } } impl Read for Connection { @@ -298,6 +305,34 @@ fn socket_recv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_bytes(buffer)) } +fn socket_recvfrom(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(zelf, None), (bufsize, Some(vm.ctx.int_type()))] + ); + + let mut socket = get_socket(zelf); + + let mut buffer = vec![0u8; objint::get_value(bufsize).to_usize().unwrap()]; + let ret = match socket.con { + Some(ref mut v) => v.recv_from(&mut buffer), + None => return Err(vm.new_type_error("".to_string())), + }; + + let addr = match ret { + Ok((_size, addr)) => addr, + _ => return Err(vm.new_type_error("".to_string())), + }; + + let elements = RefCell::new(vec![vm.ctx.new_bytes(buffer), get_addr_tuple(vm, addr)?]); + + Ok(PyObject::new( + PyObjectPayload::Sequence { elements }, + vm.ctx.tuple_type(), + )) +} + fn socket_send(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, @@ -313,6 +348,34 @@ fn socket_send(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.get_none()) } +fn socket_sendto(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [ + (zelf, None), + (bytes, Some(vm.ctx.bytes_type())), + (address, Some(vm.ctx.tuple_type())) + ] + ); + let address_string = get_address_string(vm, address)?; + + let socket = get_socket(zelf); + + match socket.socket_kind { + SocketKind::Dgram => { + // We can't do sendto without bind in std::net::UdpSocket + if let Ok(dgram) = UdpSocket::bind("0.0.0.0:0") { + if let Ok(_) = dgram.send_to(&objbytes::get_value(&bytes), address_string) { + return Ok(vm.get_none()); + } + } + Err(vm.new_type_error("socket failed".to_string())) + } + _ => Err(vm.new_not_implemented_error("".to_string())), + } +} + fn socket_close(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(zelf, None)]); @@ -331,20 +394,22 @@ fn socket_getsockname(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { }; match addr { - Ok(addr) => { - let port = vm.ctx.new_int(addr.port()); - let ip = vm.ctx.new_str(addr.ip().to_string()); - let elements = RefCell::new(vec![ip, port]); - - Ok(PyObject::new( - PyObjectPayload::Sequence { elements }, - vm.ctx.tuple_type(), - )) - } + Ok(addr) => get_addr_tuple(vm, addr), _ => Err(vm.new_type_error("".to_string())), } } +fn get_addr_tuple(vm: &mut VirtualMachine, addr: SocketAddr) -> PyResult { + let port = vm.ctx.new_int(addr.port()); + let ip = vm.ctx.new_str(addr.ip().to_string()); + let elements = RefCell::new(vec![ip, port]); + + Ok(PyObject::new( + PyObjectPayload::Sequence { elements }, + vm.ctx.tuple_type(), + )) +} + pub fn mk_module(ctx: &PyContext) -> PyObjectRef { let py_mod = ctx.new_module(&"socket".to_string(), ctx.new_scope(None)); @@ -369,6 +434,8 @@ pub fn mk_module(ctx: &PyContext) -> PyObjectRef { ctx.set_attr(&socket, "listen", ctx.new_rustfunc(socket_listen)); ctx.set_attr(&socket, "close", ctx.new_rustfunc(socket_close)); ctx.set_attr(&socket, "getsockname", ctx.new_rustfunc(socket_getsockname)); + ctx.set_attr(&socket, "sendto", ctx.new_rustfunc(socket_sendto)); + ctx.set_attr(&socket, "recvfrom", ctx.new_rustfunc(socket_recvfrom)); socket }; ctx.set_attr(&py_mod, "socket", socket.clone());