forked from Rust-related/RustPython
Merge pull request #856 from palaviv/socket-new-args
Convert socket to new args style
This commit is contained in:
@@ -3,13 +3,11 @@ use std::io;
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket};
|
||||
use std::ops::Deref;
|
||||
|
||||
use crate::function::PyFuncArgs;
|
||||
use crate::obj::objbytes;
|
||||
use crate::obj::objint;
|
||||
use crate::obj::objsequence::get_elements;
|
||||
use crate::obj::objstr;
|
||||
use crate::obj::objbytes::PyBytesRef;
|
||||
use crate::obj::objint::PyIntRef;
|
||||
use crate::obj::objstr::PyStringRef;
|
||||
use crate::obj::objtuple::PyTupleRef;
|
||||
use crate::pyobject::{PyObjectRef, PyRef, PyResult, PyValue, TryFromObject};
|
||||
use crate::vm::VirtualMachine;
|
||||
|
||||
@@ -161,283 +159,209 @@ impl Socket {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_socket<'a>(obj: &'a PyObjectRef) -> impl Deref<Target = Socket> + 'a {
|
||||
obj.payload::<Socket>().unwrap()
|
||||
}
|
||||
|
||||
type SocketRef = PyRef<Socket>;
|
||||
|
||||
fn socket_new(
|
||||
cls: PyClassRef,
|
||||
family: AddressFamily,
|
||||
kind: SocketKind,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<SocketRef> {
|
||||
Socket::new(family, kind).into_ref_with_type(vm, cls)
|
||||
}
|
||||
impl SocketRef {
|
||||
fn new(
|
||||
cls: PyClassRef,
|
||||
family: AddressFamily,
|
||||
kind: SocketKind,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<SocketRef> {
|
||||
Socket::new(family, kind).into_ref_with_type(vm, cls)
|
||||
}
|
||||
|
||||
fn socket_connect(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [(zelf, None), (address, Some(vm.ctx.tuple_type()))]
|
||||
);
|
||||
fn connect(self, address: Address, vm: &VirtualMachine) -> PyResult<()> {
|
||||
let address_string = address.get_address_string();
|
||||
|
||||
let address_string = get_address_string(vm, address)?;
|
||||
|
||||
let socket = get_socket(zelf);
|
||||
|
||||
match socket.socket_kind {
|
||||
SocketKind::Stream => match TcpStream::connect(address_string) {
|
||||
Ok(stream) => {
|
||||
socket
|
||||
.con
|
||||
.borrow_mut()
|
||||
.replace(Connection::TcpStream(stream));
|
||||
Ok(vm.get_none())
|
||||
match self.socket_kind {
|
||||
SocketKind::Stream => match TcpStream::connect(address_string) {
|
||||
Ok(stream) => {
|
||||
self.con.borrow_mut().replace(Connection::TcpStream(stream));
|
||||
Ok(())
|
||||
}
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
SocketKind::Dgram => {
|
||||
if let Some(Connection::UdpSocket(con)) = self.con.borrow().as_ref() {
|
||||
match con.connect(address_string) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
}
|
||||
} else {
|
||||
Err(vm.new_type_error("".to_string()))
|
||||
}
|
||||
}
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
SocketKind::Dgram => {
|
||||
if let Some(Connection::UdpSocket(con)) = socket.con.borrow().as_ref() {
|
||||
match con.connect(address_string) {
|
||||
Ok(_) => Ok(vm.get_none()),
|
||||
}
|
||||
}
|
||||
|
||||
fn bind(self, address: Address, vm: &VirtualMachine) -> PyResult<()> {
|
||||
let address_string = address.get_address_string();
|
||||
|
||||
match self.socket_kind {
|
||||
SocketKind::Stream => match TcpListener::bind(address_string) {
|
||||
Ok(stream) => {
|
||||
self.con
|
||||
.borrow_mut()
|
||||
.replace(Connection::TcpListener(stream));
|
||||
Ok(())
|
||||
}
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
SocketKind::Dgram => match UdpSocket::bind(address_string) {
|
||||
Ok(dgram) => {
|
||||
self.con.borrow_mut().replace(Connection::UdpSocket(dgram));
|
||||
Ok(())
|
||||
}
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn listen(self, _num: PyIntRef, _vm: &VirtualMachine) -> () {}
|
||||
|
||||
fn accept(self, vm: &VirtualMachine) -> PyResult {
|
||||
let ret = match self.con.borrow_mut().as_mut() {
|
||||
Some(v) => v.accept(),
|
||||
None => return Err(vm.new_type_error("".to_string())),
|
||||
};
|
||||
|
||||
let (tcp_stream, addr) = match ret {
|
||||
Ok((socket, addr)) => (socket, addr),
|
||||
Err(s) => return Err(vm.new_os_error(s.to_string())),
|
||||
};
|
||||
|
||||
let socket = Socket {
|
||||
address_family: self.address_family,
|
||||
socket_kind: self.socket_kind,
|
||||
con: RefCell::new(Some(Connection::TcpStream(tcp_stream))),
|
||||
}
|
||||
.into_ref(vm);
|
||||
|
||||
let addr_tuple = get_addr_tuple(vm, addr)?;
|
||||
|
||||
Ok(vm.ctx.new_tuple(vec![socket.into_object(), addr_tuple]))
|
||||
}
|
||||
|
||||
fn recv(self, bufsize: PyIntRef, vm: &VirtualMachine) -> PyResult {
|
||||
let mut buffer = vec![0u8; bufsize.as_bigint().to_usize().unwrap()];
|
||||
match self.con.borrow_mut().as_mut() {
|
||||
Some(v) => match v.read_exact(&mut buffer) {
|
||||
Ok(_) => (),
|
||||
Err(s) => return Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
None => return Err(vm.new_type_error("".to_string())),
|
||||
};
|
||||
Ok(vm.ctx.new_bytes(buffer))
|
||||
}
|
||||
|
||||
fn recvfrom(self, bufsize: PyIntRef, vm: &VirtualMachine) -> PyResult {
|
||||
let mut buffer = vec![0u8; bufsize.as_bigint().to_usize().unwrap()];
|
||||
let ret = match self.con.borrow().as_ref() {
|
||||
Some(v) => v.recv_from(&mut buffer),
|
||||
None => return Err(vm.new_type_error("".to_string())),
|
||||
};
|
||||
|
||||
let addr = match ret {
|
||||
Ok((_size, addr)) => addr,
|
||||
Err(s) => return Err(vm.new_os_error(s.to_string())),
|
||||
};
|
||||
|
||||
let addr_tuple = get_addr_tuple(vm, addr)?;
|
||||
|
||||
Ok(vm.ctx.new_tuple(vec![vm.ctx.new_bytes(buffer), addr_tuple]))
|
||||
}
|
||||
|
||||
fn send(self, bytes: PyBytesRef, vm: &VirtualMachine) -> PyResult<()> {
|
||||
match self.con.borrow_mut().as_mut() {
|
||||
Some(v) => match v.write(&bytes) {
|
||||
Ok(_) => (),
|
||||
Err(s) => return Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
None => return Err(vm.new_type_error("".to_string())),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sendto(self, bytes: PyBytesRef, address: Address, vm: &VirtualMachine) -> PyResult<()> {
|
||||
let address_string = address.get_address_string();
|
||||
|
||||
match self.socket_kind {
|
||||
SocketKind::Dgram => {
|
||||
if let Some(v) = self.con.borrow().as_ref() {
|
||||
return match v.send_to(&bytes, address_string) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
};
|
||||
}
|
||||
// Doing implicit bind
|
||||
match UdpSocket::bind("0.0.0.0:0") {
|
||||
Ok(dgram) => match dgram.send_to(&bytes, address_string) {
|
||||
Ok(_) => {
|
||||
self.con.borrow_mut().replace(Connection::UdpSocket(dgram));
|
||||
Ok(())
|
||||
}
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
}
|
||||
} else {
|
||||
Err(vm.new_type_error("".to_string()))
|
||||
}
|
||||
_ => Err(vm.new_not_implemented_error("".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn close(self, _vm: &VirtualMachine) -> () {
|
||||
self.con.borrow_mut().take();
|
||||
}
|
||||
|
||||
fn fileno(self, vm: &VirtualMachine) -> PyResult {
|
||||
let fileno = match self.con.borrow_mut().as_mut() {
|
||||
Some(v) => v.fileno(),
|
||||
None => return Err(vm.new_type_error("".to_string())),
|
||||
};
|
||||
Ok(vm.ctx.new_int(fileno))
|
||||
}
|
||||
|
||||
fn getsockname(self, vm: &VirtualMachine) -> PyResult {
|
||||
let addr = match self.con.borrow().as_ref() {
|
||||
Some(v) => v.local_addr(),
|
||||
None => return Err(vm.new_type_error("".to_string())),
|
||||
};
|
||||
|
||||
match addr {
|
||||
Ok(addr) => get_addr_tuple(vm, addr),
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn socket_bind(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [(zelf, None), (address, Some(vm.ctx.tuple_type()))]
|
||||
);
|
||||
struct Address {
|
||||
host: String,
|
||||
port: usize,
|
||||
}
|
||||
|
||||
let address_string = get_address_string(vm, address)?;
|
||||
|
||||
let socket = get_socket(zelf);
|
||||
|
||||
match socket.socket_kind {
|
||||
SocketKind::Stream => match TcpListener::bind(address_string) {
|
||||
Ok(stream) => {
|
||||
socket
|
||||
.con
|
||||
.borrow_mut()
|
||||
.replace(Connection::TcpListener(stream));
|
||||
Ok(vm.get_none())
|
||||
}
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
SocketKind::Dgram => match UdpSocket::bind(address_string) {
|
||||
Ok(dgram) => {
|
||||
socket
|
||||
.con
|
||||
.borrow_mut()
|
||||
.replace(Connection::UdpSocket(dgram));
|
||||
Ok(vm.get_none())
|
||||
}
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
impl Address {
|
||||
fn get_address_string(self) -> String {
|
||||
format!("{}:{}", self.host, self.port.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_address_string(vm: &VirtualMachine, address: &PyObjectRef) -> Result<String, PyObjectRef> {
|
||||
let args = PyFuncArgs {
|
||||
args: get_elements(address).to_vec(),
|
||||
kwargs: vec![],
|
||||
};
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [
|
||||
(host, Some(vm.ctx.str_type())),
|
||||
(port, Some(vm.ctx.int_type()))
|
||||
]
|
||||
);
|
||||
|
||||
Ok(format!(
|
||||
"{}:{}",
|
||||
objstr::get_value(host),
|
||||
objint::get_value(port).to_string()
|
||||
))
|
||||
}
|
||||
|
||||
fn socket_listen(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [(_zelf, None), (_num, Some(vm.ctx.int_type()))]
|
||||
);
|
||||
Ok(vm.get_none())
|
||||
}
|
||||
|
||||
fn socket_accept(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(vm, args, required = [(zelf, None)]);
|
||||
|
||||
let socket = get_socket(zelf);
|
||||
|
||||
let ret = match socket.con.borrow_mut().as_mut() {
|
||||
Some(v) => v.accept(),
|
||||
None => return Err(vm.new_type_error("".to_string())),
|
||||
};
|
||||
|
||||
let (tcp_stream, addr) = match ret {
|
||||
Ok((socket, addr)) => (socket, addr),
|
||||
Err(s) => return Err(vm.new_os_error(s.to_string())),
|
||||
};
|
||||
|
||||
let socket = Socket {
|
||||
address_family: socket.address_family,
|
||||
socket_kind: socket.socket_kind,
|
||||
con: RefCell::new(Some(Connection::TcpStream(tcp_stream))),
|
||||
}
|
||||
.into_ref(vm);
|
||||
|
||||
let addr_tuple = get_addr_tuple(vm, addr)?;
|
||||
|
||||
Ok(vm.ctx.new_tuple(vec![socket.into_object(), addr_tuple]))
|
||||
}
|
||||
|
||||
fn socket_recv(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [(zelf, None), (bufsize, Some(vm.ctx.int_type()))]
|
||||
);
|
||||
let socket = get_socket(zelf);
|
||||
|
||||
let mut buffer = vec![0u8; objint::get_value(bufsize).to_usize().unwrap()];
|
||||
match socket.con.borrow_mut().as_mut() {
|
||||
Some(v) => match v.read_exact(&mut buffer) {
|
||||
Ok(_) => (),
|
||||
Err(s) => return Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
None => return Err(vm.new_type_error("".to_string())),
|
||||
};
|
||||
Ok(vm.ctx.new_bytes(buffer))
|
||||
}
|
||||
|
||||
fn socket_recvfrom(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [(zelf, None), (bufsize, Some(vm.ctx.int_type()))]
|
||||
);
|
||||
|
||||
let socket = get_socket(zelf);
|
||||
|
||||
let mut buffer = vec![0u8; objint::get_value(bufsize).to_usize().unwrap()];
|
||||
let ret = match socket.con.borrow().as_ref() {
|
||||
Some(v) => v.recv_from(&mut buffer),
|
||||
None => return Err(vm.new_type_error("".to_string())),
|
||||
};
|
||||
|
||||
let addr = match ret {
|
||||
Ok((_size, addr)) => addr,
|
||||
Err(s) => return Err(vm.new_os_error(s.to_string())),
|
||||
};
|
||||
|
||||
let addr_tuple = get_addr_tuple(vm, addr)?;
|
||||
|
||||
Ok(vm.ctx.new_tuple(vec![vm.ctx.new_bytes(buffer), addr_tuple]))
|
||||
}
|
||||
|
||||
fn socket_send(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [(zelf, None), (bytes, Some(vm.ctx.bytes_type()))]
|
||||
);
|
||||
let socket = get_socket(zelf);
|
||||
|
||||
match socket.con.borrow_mut().as_mut() {
|
||||
Some(v) => match v.write(&objbytes::get_value(&bytes)) {
|
||||
Ok(_) => (),
|
||||
Err(s) => return Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
None => return Err(vm.new_type_error("".to_string())),
|
||||
};
|
||||
Ok(vm.get_none())
|
||||
}
|
||||
|
||||
fn socket_sendto(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [
|
||||
(zelf, None),
|
||||
(bytes, Some(vm.ctx.bytes_type())),
|
||||
(address, Some(vm.ctx.tuple_type()))
|
||||
]
|
||||
);
|
||||
let address_string = get_address_string(vm, address)?;
|
||||
|
||||
let socket = get_socket(zelf);
|
||||
|
||||
match socket.socket_kind {
|
||||
SocketKind::Dgram => {
|
||||
if let Some(v) = socket.con.borrow().as_ref() {
|
||||
return match v.send_to(&objbytes::get_value(&bytes), address_string) {
|
||||
Ok(_) => Ok(vm.get_none()),
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
};
|
||||
}
|
||||
// Doing implicit bind
|
||||
match UdpSocket::bind("0.0.0.0:0") {
|
||||
Ok(dgram) => match dgram.send_to(&objbytes::get_value(&bytes), address_string) {
|
||||
Ok(_) => {
|
||||
socket
|
||||
.con
|
||||
.borrow_mut()
|
||||
.replace(Connection::UdpSocket(dgram));
|
||||
Ok(vm.get_none())
|
||||
}
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
},
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
}
|
||||
impl TryFromObject for Address {
|
||||
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
|
||||
let tuple = PyTupleRef::try_from_object(vm, obj)?;
|
||||
if tuple.elements.borrow().len() != 2 {
|
||||
Err(vm.new_type_error("Address tuple should have only 2 values".to_string()))
|
||||
} else {
|
||||
Ok(Address {
|
||||
host: PyStringRef::try_from_object(vm, tuple.elements.borrow()[0].clone())?
|
||||
.value
|
||||
.to_string(),
|
||||
port: PyIntRef::try_from_object(vm, tuple.elements.borrow()[1].clone())?
|
||||
.as_bigint()
|
||||
.to_usize()
|
||||
.unwrap(),
|
||||
})
|
||||
}
|
||||
_ => Err(vm.new_not_implemented_error("".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn socket_close(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(vm, args, required = [(zelf, None)]);
|
||||
|
||||
let socket = get_socket(zelf);
|
||||
socket.con.borrow_mut().take();
|
||||
Ok(vm.get_none())
|
||||
}
|
||||
|
||||
fn socket_fileno(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(vm, args, required = [(zelf, None)]);
|
||||
|
||||
let socket = get_socket(zelf);
|
||||
|
||||
let fileno = match socket.con.borrow_mut().as_mut() {
|
||||
Some(v) => v.fileno(),
|
||||
None => return Err(vm.new_type_error("".to_string())),
|
||||
};
|
||||
Ok(vm.ctx.new_int(fileno))
|
||||
}
|
||||
|
||||
fn socket_getsockname(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(vm, args, required = [(zelf, None)]);
|
||||
let socket = get_socket(zelf);
|
||||
|
||||
let addr = match socket.con.borrow().as_ref() {
|
||||
Some(v) => v.local_addr(),
|
||||
None => return Err(vm.new_type_error("".to_string())),
|
||||
};
|
||||
|
||||
match addr {
|
||||
Ok(addr) => get_addr_tuple(vm, addr),
|
||||
Err(s) => Err(vm.new_os_error(s.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -452,18 +376,18 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
|
||||
let ctx = &vm.ctx;
|
||||
|
||||
let socket = py_class!(ctx, "socket", ctx.object(), {
|
||||
"__new__" => ctx.new_rustfunc(socket_new),
|
||||
"connect" => ctx.new_rustfunc(socket_connect),
|
||||
"recv" => ctx.new_rustfunc(socket_recv),
|
||||
"send" => ctx.new_rustfunc(socket_send),
|
||||
"bind" => ctx.new_rustfunc(socket_bind),
|
||||
"accept" => ctx.new_rustfunc(socket_accept),
|
||||
"listen" => ctx.new_rustfunc(socket_listen),
|
||||
"close" => ctx.new_rustfunc(socket_close),
|
||||
"getsockname" => ctx.new_rustfunc(socket_getsockname),
|
||||
"sendto" => ctx.new_rustfunc(socket_sendto),
|
||||
"recvfrom" => ctx.new_rustfunc(socket_recvfrom),
|
||||
"fileno" => ctx.new_rustfunc(socket_fileno),
|
||||
"__new__" => ctx.new_rustfunc(SocketRef::new),
|
||||
"connect" => ctx.new_rustfunc(SocketRef::connect),
|
||||
"recv" => ctx.new_rustfunc(SocketRef::recv),
|
||||
"send" => ctx.new_rustfunc(SocketRef::send),
|
||||
"bind" => ctx.new_rustfunc(SocketRef::bind),
|
||||
"accept" => ctx.new_rustfunc(SocketRef::accept),
|
||||
"listen" => ctx.new_rustfunc(SocketRef::listen),
|
||||
"close" => ctx.new_rustfunc(SocketRef::close),
|
||||
"getsockname" => ctx.new_rustfunc(SocketRef::getsockname),
|
||||
"sendto" => ctx.new_rustfunc(SocketRef::sendto),
|
||||
"recvfrom" => ctx.new_rustfunc(SocketRef::recvfrom),
|
||||
"fileno" => ctx.new_rustfunc(SocketRef::fileno),
|
||||
});
|
||||
|
||||
py_module!(vm, "socket", {
|
||||
|
||||
Reference in New Issue
Block a user