diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index f1769ec09..5ef3e16c3 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -18,7 +18,7 @@ use crate::exceptions::{IntoPyException, PyBaseExceptionRef}; use crate::function::{FuncArgs, OptionalArg}; use crate::pyobject::{ BorrowValue, Either, IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, - StaticType, TryFromObject, + StaticType, TryFromObject, TypeProtocol, }; use crate::VirtualMachine; @@ -191,9 +191,68 @@ impl PySocket { } } + fn extract_address( + &self, + addr: PyObjectRef, + caller: &str, + vm: &VirtualMachine, + ) -> PyResult { + let family = self.family.load(); + match family { + c::AF_INET => { + let addr = Address::try_from_object(vm, addr)?; + let addr4 = get_addr(vm, addr, |sa| match sa { + SocketAddr::V4(v4) => Some(v4), + _ => None, + })?; + Ok(addr4.into()) + } + c::AF_INET6 => { + let tuple: PyTupleRef = addr.downcast().map_err(|obj| { + vm.new_type_error(format!( + "{}(): AF_INET6 address must be tuple, not {}", + caller, + obj.class().name + )) + })?; + let tuple = tuple.borrow_value(); + match tuple.len() { + 2 | 3 | 4 => {} + _ => { + return Err(vm.new_type_error( + "AF_INET6 address must be a tuple (host, port[, flowinfo[, scopeid]])" + .to_owned(), + )) + } + } + let addr = Address::from_tuple(tuple, vm)?; + let flowinfo = tuple + .get(2) + .map(|obj| u32::try_from_object(vm, obj.clone())) + .transpose()?; + let scopeid = tuple + .get(3) + .map(|obj| u32::try_from_object(vm, obj.clone())) + .transpose()?; + let mut addr6 = get_addr(vm, addr, |sa| match sa { + SocketAddr::V6(v6) => Some(v6), + _ => None, + })?; + if let Some(fi) = flowinfo { + addr6.set_flowinfo(fi) + } + if let Some(si) = scopeid { + addr6.set_scope_id(si) + } + Ok(addr6.into()) + } + _ => Err(vm.new_os_error(format!("{}(): bad family", caller))), + } + } + #[pymethod] - fn connect(&self, address: Address, vm: &VirtualMachine) -> PyResult<()> { - let sock_addr = get_addr(vm, address, Some(self.family.load()))?; + fn connect(&self, address: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let sock_addr = self.extract_address(address, "connect", vm)?; let err = match self.sock().connect(&sock_addr) { Ok(()) => return Ok(()), @@ -227,8 +286,8 @@ impl PySocket { } #[pymethod] - fn bind(&self, address: Address, vm: &VirtualMachine) -> PyResult<()> { - let sock_addr = get_addr(vm, address, Some(self.family.load()))?; + fn bind(&self, address: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let sock_addr = self.extract_address(address, "bind", vm)?; self.sock() .bind(&sock_addr) .map_err(|err| convert_sock_error(vm, err)) @@ -349,12 +408,12 @@ impl PySocket { fn sendto( &self, bytes: PyBytesLike, - address: Address, + address: PyObjectRef, flags: OptionalArg, vm: &VirtualMachine, ) -> PyResult { let flags = flags.unwrap_or(0); - let addr = get_addr(vm, address, Some(self.family.load()))?; + let addr = self.extract_address(address, "sendto", vm)?; self.sock_op(vm, SelectKind::Write, || { bytes.with_ref(|b| self.sock().send_to_with_flags(b, &addr, flags)) }) @@ -575,22 +634,27 @@ impl ToSocketAddrs for Address { impl TryFromObject for Address { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { let tuple = PyTupleRef::try_from_object(vm, obj)?; - // TODO: parse the tuple based on the family of the socket; extract all the info for inet6 - if tuple.borrow_value().len() < 2 { + if tuple.borrow_value().len() != 2 { Err(vm.new_type_error("Address tuple should have only 2 values".to_owned())) } else { - let host = PyStrRef::try_from_object(vm, tuple.borrow_value()[0].clone())?; - let host = if host.borrow_value().is_empty() { - PyStr::from("0.0.0.0").into_ref(vm) - } else { - host - }; - let port = u16::try_from_object(vm, tuple.borrow_value()[1].clone())?; - Ok(Address { host, port }) + Self::from_tuple(tuple.borrow_value(), vm) } } } +impl Address { + fn from_tuple(tuple: &[PyObjectRef], vm: &VirtualMachine) -> PyResult { + let host = PyStrRef::try_from_object(vm, tuple[0].clone())?; + let host = if host.borrow_value().is_empty() { + PyStr::from("0.0.0.0").into_ref(vm) + } else { + host + }; + let port = u16::try_from_object(vm, tuple[1].clone())?; + Ok(Address { host, port }) + } +} + fn get_addr_tuple>(addr: A, vm: &VirtualMachine) -> PyObjectRef { let addr = addr.into(); match addr.as_std() { @@ -890,11 +954,8 @@ fn _socket_getnameinfo( flags: i32, vm: &VirtualMachine, ) -> PyResult<(String, String)> { - let addr = get_addr(vm, address, None)?; - let nameinfo = addr - .as_std() - .and_then(|addr| dns_lookup::getnameinfo(&addr, flags).ok()); - nameinfo.ok_or_else(|| { + let addr = get_addr(vm, address, Some)?; + dns_lookup::getnameinfo(&addr, flags).map_err(|_| { let error_type = GAI_ERROR.get().unwrap().clone(); vm.new_exception_msg( error_type, @@ -903,39 +964,22 @@ fn _socket_getnameinfo( }) } -fn get_addr( +fn get_addr( vm: &VirtualMachine, addr: impl ToSocketAddrs, - domain: Option, -) -> PyResult { - let sock_addr = match addr.to_socket_addrs() { - Ok(mut sock_addrs) => match domain { - None => sock_addrs.next(), - Some(dom) => { - if dom == i32::from(Domain::ipv4()) { - sock_addrs.find(|a| a.is_ipv4()) - } else if dom == i32::from(Domain::ipv6()) { - sock_addrs.find(|a| a.is_ipv6()) - } else { - sock_addrs.next() - } - } - }, - Err(e) => { - let error_type = GAI_ERROR.get().unwrap().clone(); - return Err(vm.new_exception_msg(error_type, e.to_string())); - } - }; - match sock_addr { - Some(sock_addr) => Ok(sock_addr.into()), - None => { - let error_type = GAI_ERROR.get().unwrap().clone(); - Err(vm.new_exception_msg( - error_type, - "nodename nor servname provided, or not known".to_owned(), - )) - } - } + filter: impl FnMut(SocketAddr) -> Option, +) -> PyResult { + let mut sock_addrs = addr.to_socket_addrs().map_err(|e| { + let error_type = GAI_ERROR.get().unwrap().clone(); + vm.new_exception_msg(error_type, e.to_string()) + })?; + sock_addrs.find_map(filter).ok_or_else(|| { + let error_type = GAI_ERROR.get().unwrap().clone(); + vm.new_exception_msg( + error_type, + "nodename nor servname provided, or not known".to_owned(), + ) + }) } fn sock_fileno(sock: &Socket) -> RawSocket {