From f09f75ac8d5c1190aa6d7d89ecf652f786e62bc0 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 2 Mar 2019 17:45:26 +0200 Subject: [PATCH] Do implicit bind in socket.sendto --- tests/snippets/stdlib_socket.py | 28 +++++++++++++++++++++------- vm/src/stdlib/socket.rs | 33 ++++++++++++++++++++++++++------- 2 files changed, 47 insertions(+), 14 deletions(-) diff --git a/tests/snippets/stdlib_socket.py b/tests/snippets/stdlib_socket.py index 54ae180c3..f984515ea 100644 --- a/tests/snippets/stdlib_socket.py +++ b/tests/snippets/stdlib_socket.py @@ -44,18 +44,32 @@ sock1.bind(("127.0.0.1", 0)) sock2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock2.sendto(MESSAGE_A, sock1.getsockname()) -(recv_a, addr) = sock1.recvfrom(len(MESSAGE_A)) +(recv_a, addr1) = sock1.recvfrom(len(MESSAGE_A)) assert recv_a == MESSAGE_A -sock2.bind(("127.0.0.1", 0)) -sock1.connect(("127.0.0.1", sock2.getsockname()[1])) -sock2.connect(("127.0.0.1", sock1.getsockname()[1])) +sock2.sendto(MESSAGE_B, sock1.getsockname()) +(recv_b, addr2) = sock1.recvfrom(len(MESSAGE_B)) +assert recv_b == MESSAGE_B +assert addr1[0] == addr2[0] +assert addr1[1] == addr2[1] + +sock2.close() + +sock3 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) +sock3.bind(("127.0.0.1", 0)) +sock3.sendto(MESSAGE_A, sock1.getsockname()) +(recv_a, addr) = sock1.recvfrom(len(MESSAGE_A)) +assert recv_a == MESSAGE_A +assert addr == sock3.getsockname() + +sock1.connect(("127.0.0.1", sock3.getsockname()[1])) +sock3.connect(("127.0.0.1", sock1.getsockname()[1])) sock1.send(MESSAGE_A) -sock2.send(MESSAGE_B) -recv_a = sock2.recv(len(MESSAGE_A)) +sock3.send(MESSAGE_B) +recv_a = sock3.recv(len(MESSAGE_A)) recv_b = sock1.recv(len(MESSAGE_B)) assert recv_a == MESSAGE_A assert recv_b == MESSAGE_B sock1.close() -sock2.close() +sock3.close() diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index f4b88b85f..d0ff61b37 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -2,7 +2,7 @@ use std::cell::RefCell; use std::io; use std::io::Read; use std::io::Write; -use std::net::{SocketAddr, TcpListener, TcpStream, UdpSocket}; +use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket}; use std::ops::DerefMut; use crate::obj::objbytes; @@ -78,6 +78,13 @@ impl Connection { _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), } } + + fn send_to(&self, buf: &[u8], addr: A) -> io::Result { + match self { + Connection::UdpSocket(con) => con.send_to(buf, addr), + _ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")), + } + } } impl Read for Connection { @@ -360,17 +367,29 @@ fn socket_sendto(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { ); let address_string = get_address_string(vm, address)?; - let socket = get_socket(zelf); + let mut socket = get_socket(zelf); match socket.socket_kind { SocketKind::Dgram => { - // We can't do sendto without bind in std::net::UdpSocket - if let Ok(dgram) = UdpSocket::bind("0.0.0.0:0") { - if let Ok(_) = dgram.send_to(&objbytes::get_value(&bytes), address_string) { - return Ok(vm.get_none()); + match socket.con { + Some(ref mut v) => { + if let Ok(_) = v.send_to(&objbytes::get_value(&bytes), address_string) { + Ok(vm.get_none()) + } else { + Err(vm.new_type_error("socket failed".to_string())) + } + } + None => { + // Doing implicit bind + if let Ok(dgram) = UdpSocket::bind("0.0.0.0:0") { + if let Ok(_) = dgram.send_to(&objbytes::get_value(&bytes), address_string) { + socket.con = Some(Connection::UdpSocket(dgram)); + return Ok(vm.get_none()); + } + } + Err(vm.new_type_error("socket failed".to_string())) } } - Err(vm.new_type_error("socket failed".to_string())) } _ => Err(vm.new_not_implemented_error("".to_string())), }