Parse address tuple based on socket family

This commit is contained in:
Noah
2021-02-19 11:21:36 -06:00
parent a46fb496aa
commit 7d99e49fd9

View File

@@ -18,7 +18,7 @@ use crate::exceptions::{IntoPyException, PyBaseExceptionRef};
use crate::function::{FuncArgs, OptionalArg};
use crate::pyobject::{
BorrowValue, Either, IntoPyObject, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue,
StaticType, TryFromObject,
StaticType, TryFromObject, TypeProtocol,
};
use crate::VirtualMachine;
@@ -191,9 +191,68 @@ impl PySocket {
}
}
fn extract_address(
&self,
addr: PyObjectRef,
caller: &str,
vm: &VirtualMachine,
) -> PyResult<socket2::SockAddr> {
let family = self.family.load();
match family {
c::AF_INET => {
let addr = Address::try_from_object(vm, addr)?;
let addr4 = get_addr(vm, addr, |sa| match sa {
SocketAddr::V4(v4) => Some(v4),
_ => None,
})?;
Ok(addr4.into())
}
c::AF_INET6 => {
let tuple: PyTupleRef = addr.downcast().map_err(|obj| {
vm.new_type_error(format!(
"{}(): AF_INET6 address must be tuple, not {}",
caller,
obj.class().name
))
})?;
let tuple = tuple.borrow_value();
match tuple.len() {
2 | 3 | 4 => {}
_ => {
return Err(vm.new_type_error(
"AF_INET6 address must be a tuple (host, port[, flowinfo[, scopeid]])"
.to_owned(),
))
}
}
let addr = Address::from_tuple(tuple, vm)?;
let flowinfo = tuple
.get(2)
.map(|obj| u32::try_from_object(vm, obj.clone()))
.transpose()?;
let scopeid = tuple
.get(3)
.map(|obj| u32::try_from_object(vm, obj.clone()))
.transpose()?;
let mut addr6 = get_addr(vm, addr, |sa| match sa {
SocketAddr::V6(v6) => Some(v6),
_ => None,
})?;
if let Some(fi) = flowinfo {
addr6.set_flowinfo(fi)
}
if let Some(si) = scopeid {
addr6.set_scope_id(si)
}
Ok(addr6.into())
}
_ => Err(vm.new_os_error(format!("{}(): bad family", caller))),
}
}
#[pymethod]
fn connect(&self, address: Address, vm: &VirtualMachine) -> PyResult<()> {
let sock_addr = get_addr(vm, address, Some(self.family.load()))?;
fn connect(&self, address: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
let sock_addr = self.extract_address(address, "connect", vm)?;
let err = match self.sock().connect(&sock_addr) {
Ok(()) => return Ok(()),
@@ -227,8 +286,8 @@ impl PySocket {
}
#[pymethod]
fn bind(&self, address: Address, vm: &VirtualMachine) -> PyResult<()> {
let sock_addr = get_addr(vm, address, Some(self.family.load()))?;
fn bind(&self, address: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
let sock_addr = self.extract_address(address, "bind", vm)?;
self.sock()
.bind(&sock_addr)
.map_err(|err| convert_sock_error(vm, err))
@@ -349,12 +408,12 @@ impl PySocket {
fn sendto(
&self,
bytes: PyBytesLike,
address: Address,
address: PyObjectRef,
flags: OptionalArg<i32>,
vm: &VirtualMachine,
) -> PyResult<usize> {
let flags = flags.unwrap_or(0);
let addr = get_addr(vm, address, Some(self.family.load()))?;
let addr = self.extract_address(address, "sendto", vm)?;
self.sock_op(vm, SelectKind::Write, || {
bytes.with_ref(|b| self.sock().send_to_with_flags(b, &addr, flags))
})
@@ -575,22 +634,27 @@ impl ToSocketAddrs for Address {
impl TryFromObject for Address {
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
let tuple = PyTupleRef::try_from_object(vm, obj)?;
// TODO: parse the tuple based on the family of the socket; extract all the info for inet6
if tuple.borrow_value().len() < 2 {
if tuple.borrow_value().len() != 2 {
Err(vm.new_type_error("Address tuple should have only 2 values".to_owned()))
} else {
let host = PyStrRef::try_from_object(vm, tuple.borrow_value()[0].clone())?;
let host = if host.borrow_value().is_empty() {
PyStr::from("0.0.0.0").into_ref(vm)
} else {
host
};
let port = u16::try_from_object(vm, tuple.borrow_value()[1].clone())?;
Ok(Address { host, port })
Self::from_tuple(tuple.borrow_value(), vm)
}
}
}
impl Address {
fn from_tuple(tuple: &[PyObjectRef], vm: &VirtualMachine) -> PyResult<Self> {
let host = PyStrRef::try_from_object(vm, tuple[0].clone())?;
let host = if host.borrow_value().is_empty() {
PyStr::from("0.0.0.0").into_ref(vm)
} else {
host
};
let port = u16::try_from_object(vm, tuple[1].clone())?;
Ok(Address { host, port })
}
}
fn get_addr_tuple<A: Into<socket2::SockAddr>>(addr: A, vm: &VirtualMachine) -> PyObjectRef {
let addr = addr.into();
match addr.as_std() {
@@ -890,11 +954,8 @@ fn _socket_getnameinfo(
flags: i32,
vm: &VirtualMachine,
) -> PyResult<(String, String)> {
let addr = get_addr(vm, address, None)?;
let nameinfo = addr
.as_std()
.and_then(|addr| dns_lookup::getnameinfo(&addr, flags).ok());
nameinfo.ok_or_else(|| {
let addr = get_addr(vm, address, Some)?;
dns_lookup::getnameinfo(&addr, flags).map_err(|_| {
let error_type = GAI_ERROR.get().unwrap().clone();
vm.new_exception_msg(
error_type,
@@ -903,39 +964,22 @@ fn _socket_getnameinfo(
})
}
fn get_addr(
fn get_addr<R>(
vm: &VirtualMachine,
addr: impl ToSocketAddrs,
domain: Option<i32>,
) -> PyResult<socket2::SockAddr> {
let sock_addr = match addr.to_socket_addrs() {
Ok(mut sock_addrs) => match domain {
None => sock_addrs.next(),
Some(dom) => {
if dom == i32::from(Domain::ipv4()) {
sock_addrs.find(|a| a.is_ipv4())
} else if dom == i32::from(Domain::ipv6()) {
sock_addrs.find(|a| a.is_ipv6())
} else {
sock_addrs.next()
}
}
},
Err(e) => {
let error_type = GAI_ERROR.get().unwrap().clone();
return Err(vm.new_exception_msg(error_type, e.to_string()));
}
};
match sock_addr {
Some(sock_addr) => Ok(sock_addr.into()),
None => {
let error_type = GAI_ERROR.get().unwrap().clone();
Err(vm.new_exception_msg(
error_type,
"nodename nor servname provided, or not known".to_owned(),
))
}
}
filter: impl FnMut(SocketAddr) -> Option<R>,
) -> PyResult<R> {
let mut sock_addrs = addr.to_socket_addrs().map_err(|e| {
let error_type = GAI_ERROR.get().unwrap().clone();
vm.new_exception_msg(error_type, e.to_string())
})?;
sock_addrs.find_map(filter).ok_or_else(|| {
let error_type = GAI_ERROR.get().unwrap().clone();
vm.new_exception_msg(
error_type,
"nodename nor servname provided, or not known".to_owned(),
)
})
}
fn sock_fileno(sock: &Socket) -> RawSocket {