Add socket.{sendto, recvfrom}

This commit is contained in:
Aviv Palivoda
2019-03-02 17:12:13 +02:00
parent e5d1d11c3e
commit 86b60faa61
2 changed files with 82 additions and 11 deletions

View File

@@ -41,8 +41,12 @@ sock1 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock1.bind(("127.0.0.1", 0))
sock2 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock2.bind(("127.0.0.1", 0))
sock2.sendto(MESSAGE_A, sock1.getsockname())
(recv_a, addr) = sock1.recvfrom(len(MESSAGE_A))
assert recv_a == MESSAGE_A
sock2.bind(("127.0.0.1", 0))
sock1.connect(("127.0.0.1", sock2.getsockname()[1]))
sock2.connect(("127.0.0.1", sock1.getsockname()[1]))

View File

@@ -71,6 +71,13 @@ impl Connection {
_ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")),
}
}
fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
match self {
Connection::UdpSocket(con) => con.recv_from(buf),
_ => Err(io::Error::new(io::ErrorKind::Other, "oh no!")),
}
}
}
impl Read for Connection {
@@ -298,6 +305,34 @@ fn socket_recv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
Ok(vm.ctx.new_bytes(buffer))
}
fn socket_recvfrom(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(zelf, None), (bufsize, Some(vm.ctx.int_type()))]
);
let mut socket = get_socket(zelf);
let mut buffer = vec![0u8; objint::get_value(bufsize).to_usize().unwrap()];
let ret = match socket.con {
Some(ref mut v) => v.recv_from(&mut buffer),
None => return Err(vm.new_type_error("".to_string())),
};
let addr = match ret {
Ok((_size, addr)) => addr,
_ => return Err(vm.new_type_error("".to_string())),
};
let elements = RefCell::new(vec![vm.ctx.new_bytes(buffer), get_addr_tuple(vm, addr)?]);
Ok(PyObject::new(
PyObjectPayload::Sequence { elements },
vm.ctx.tuple_type(),
))
}
fn socket_send(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
@@ -313,6 +348,34 @@ fn socket_send(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
Ok(vm.get_none())
}
fn socket_sendto(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [
(zelf, None),
(bytes, Some(vm.ctx.bytes_type())),
(address, Some(vm.ctx.tuple_type()))
]
);
let address_string = get_address_string(vm, address)?;
let socket = get_socket(zelf);
match socket.socket_kind {
SocketKind::Dgram => {
// We can't do sendto without bind in std::net::UdpSocket
if let Ok(dgram) = UdpSocket::bind("0.0.0.0:0") {
if let Ok(_) = dgram.send_to(&objbytes::get_value(&bytes), address_string) {
return Ok(vm.get_none());
}
}
Err(vm.new_type_error("socket failed".to_string()))
}
_ => Err(vm.new_not_implemented_error("".to_string())),
}
}
fn socket_close(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(zelf, None)]);
@@ -331,20 +394,22 @@ fn socket_getsockname(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
};
match addr {
Ok(addr) => {
let port = vm.ctx.new_int(addr.port());
let ip = vm.ctx.new_str(addr.ip().to_string());
let elements = RefCell::new(vec![ip, port]);
Ok(PyObject::new(
PyObjectPayload::Sequence { elements },
vm.ctx.tuple_type(),
))
}
Ok(addr) => get_addr_tuple(vm, addr),
_ => Err(vm.new_type_error("".to_string())),
}
}
fn get_addr_tuple(vm: &mut VirtualMachine, addr: SocketAddr) -> PyResult {
let port = vm.ctx.new_int(addr.port());
let ip = vm.ctx.new_str(addr.ip().to_string());
let elements = RefCell::new(vec![ip, port]);
Ok(PyObject::new(
PyObjectPayload::Sequence { elements },
vm.ctx.tuple_type(),
))
}
pub fn mk_module(ctx: &PyContext) -> PyObjectRef {
let py_mod = ctx.new_module(&"socket".to_string(), ctx.new_scope(None));
@@ -369,6 +434,8 @@ pub fn mk_module(ctx: &PyContext) -> PyObjectRef {
ctx.set_attr(&socket, "listen", ctx.new_rustfunc(socket_listen));
ctx.set_attr(&socket, "close", ctx.new_rustfunc(socket_close));
ctx.set_attr(&socket, "getsockname", ctx.new_rustfunc(socket_getsockname));
ctx.set_attr(&socket, "sendto", ctx.new_rustfunc(socket_sendto));
ctx.set_attr(&socket, "recvfrom", ctx.new_rustfunc(socket_recvfrom));
socket
};
ctx.set_attr(&py_mod, "socket", socket.clone());