From 8b8f542e23e39c2f0135cee099ffc9327d1c5f06 Mon Sep 17 00:00:00 2001 From: Hyunji Kim Date: Thu, 15 Aug 2019 11:29:48 +0900 Subject: [PATCH] add __enter__ and __exit__ for socket --- tests/snippets/stdlib_socket.py | 2 ++ vm/src/stdlib/socket.rs | 49 ++++++++++++++++++++------------- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/tests/snippets/stdlib_socket.py b/tests/snippets/stdlib_socket.py index f34c83e47..d5727bb25 100644 --- a/tests/snippets/stdlib_socket.py +++ b/tests/snippets/stdlib_socket.py @@ -135,3 +135,5 @@ assert socket.inet_ntoa(b"\xff\xff\xff\xff")=="255.255.255.255" with assertRaises(OSError): socket.inet_ntoa(b"\xff\xff\xff\xff\xff") +with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + pass diff --git a/vm/src/stdlib/socket.rs b/vm/src/stdlib/socket.rs index 1a111d7ea..c83b75656 100644 --- a/vm/src/stdlib/socket.rs +++ b/vm/src/stdlib/socket.rs @@ -11,6 +11,7 @@ use gethostname::gethostname; use byteorder::{BigEndian, ByteOrder}; +use crate::function::PyFuncArgs; use crate::obj::objbytes::PyBytesRef; use crate::obj::objint::PyIntRef; use crate::obj::objstr::PyStringRef; @@ -181,6 +182,14 @@ impl SocketRef { Socket::new(family, kind).into_ref_with_type(vm, cls) } + fn enter(self, _vm: &VirtualMachine) -> SocketRef { + self + } + + fn exit(self, _args: PyFuncArgs, _vm: &VirtualMachine) { + self.close(_vm) + } + fn connect(self, address: Address, vm: &VirtualMachine) -> PyResult<()> { let address_string = address.get_address_string(); @@ -425,29 +434,31 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; let socket = py_class!(ctx, "socket", ctx.object(), { - "__new__" => ctx.new_rustfunc(SocketRef::new), - "connect" => ctx.new_rustfunc(SocketRef::connect), - "recv" => ctx.new_rustfunc(SocketRef::recv), - "send" => ctx.new_rustfunc(SocketRef::send), - "bind" => ctx.new_rustfunc(SocketRef::bind), - "accept" => ctx.new_rustfunc(SocketRef::accept), - "listen" => ctx.new_rustfunc(SocketRef::listen), - "close" => ctx.new_rustfunc(SocketRef::close), - "getsockname" => ctx.new_rustfunc(SocketRef::getsockname), - "sendto" => ctx.new_rustfunc(SocketRef::sendto), - "recvfrom" => ctx.new_rustfunc(SocketRef::recvfrom), - "fileno" => ctx.new_rustfunc(SocketRef::fileno), + "__new__" => ctx.new_rustfunc(SocketRef::new), + "__enter__" => ctx.new_rustfunc(SocketRef::enter), + "__exit__" => ctx.new_rustfunc(SocketRef::exit), + "connect" => ctx.new_rustfunc(SocketRef::connect), + "recv" => ctx.new_rustfunc(SocketRef::recv), + "send" => ctx.new_rustfunc(SocketRef::send), + "bind" => ctx.new_rustfunc(SocketRef::bind), + "accept" => ctx.new_rustfunc(SocketRef::accept), + "listen" => ctx.new_rustfunc(SocketRef::listen), + "close" => ctx.new_rustfunc(SocketRef::close), + "getsockname" => ctx.new_rustfunc(SocketRef::getsockname), + "sendto" => ctx.new_rustfunc(SocketRef::sendto), + "recvfrom" => ctx.new_rustfunc(SocketRef::recvfrom), + "fileno" => ctx.new_rustfunc(SocketRef::fileno), }); let module = py_module!(vm, "socket", { "AF_INET" => ctx.new_int(AddressFamily::Inet as i32), "SOCK_STREAM" => ctx.new_int(SocketKind::Stream as i32), - "SOCK_DGRAM" => ctx.new_int(SocketKind::Dgram as i32), - "socket" => socket, - "inet_aton" => ctx.new_rustfunc(socket_inet_aton), - "inet_ntoa" => ctx.new_rustfunc(socket_inet_ntoa), - "gethostname" => ctx.new_rustfunc(socket_gethostname), - "htonl" => ctx.new_rustfunc(socket_htonl), + "SOCK_DGRAM" => ctx.new_int(SocketKind::Dgram as i32), + "socket" => socket, + "inet_aton" => ctx.new_rustfunc(socket_inet_aton), + "inet_ntoa" => ctx.new_rustfunc(socket_inet_ntoa), + "gethostname" => ctx.new_rustfunc(socket_gethostname), + "htonl" => ctx.new_rustfunc(socket_htonl), }); extend_module_platform_specific(vm, module) @@ -464,7 +475,7 @@ fn extend_module_platform_specific(vm: &VirtualMachine, module: PyObjectRef) -> #[cfg(not(target_os = "redox"))] extend_module!(vm, module, { - "sethostname" => ctx.new_rustfunc(socket_sethostname), + "sethostname" => ctx.new_rustfunc(socket_sethostname), }); module