Move payload boxing into PyObject::new

This commit is contained in:
Joey Hain
2019-03-10 20:14:02 -07:00
parent 4510489bba
commit 053ceb1a30
25 changed files with 210 additions and 181 deletions

View File

@@ -49,10 +49,7 @@ pub fn compile(
let code = compiler.pop_code_object();
trace!("Compilation completed: {:?}", code);
Ok(PyObject::new(
Box::new(objcode::PyCode::new(code)),
code_type,
))
Ok(PyObject::new(objcode::PyCode::new(code), code_type))
}
pub enum Mode {

View File

@@ -407,8 +407,7 @@ impl Frame {
let stop = out[1].take();
let step = if out.len() == 3 { out[2].take() } else { None };
let obj =
PyObject::new(Box::new(PySlice { start, stop, step }), vm.ctx.slice_type());
let obj = PyObject::new(PySlice { start, stop, step }, vm.ctx.slice_type());
self.push_value(obj);
Ok(None)
}
@@ -702,9 +701,7 @@ impl Frame {
}
bytecode::Instruction::LoadBuildClass => {
let rustfunc = PyObject::new(
Box::new(PyBuiltinFunction::new(Box::new(
builtins::builtin_build_class_,
))),
PyBuiltinFunction::new(Box::new(builtins::builtin_build_class_)),
vm.ctx.type_type(),
);
self.push_value(rustfunc);

View File

@@ -172,10 +172,7 @@ fn bytearray_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
} else {
vec![]
};
Ok(PyObject::new(
Box::new(PyByteArray::new(value)),
cls.clone(),
))
Ok(PyObject::new(PyByteArray::new(value), cls.clone()))
}
fn bytesarray_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {

View File

@@ -94,7 +94,7 @@ fn bytes_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
vec![]
};
Ok(PyObject::new(Box::new(PyBytes::new(value)), cls.clone()))
Ok(PyObject::new(PyBytes::new(value), cls.clone()))
}
fn bytes_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
@@ -203,10 +203,10 @@ fn bytes_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(obj, Some(vm.ctx.bytes_type()))]);
let iter_obj = PyObject::new(
Box::new(PyIteratorValue {
PyIteratorValue {
position: Cell::new(0),
iterated_obj: obj.clone(),
}),
},
vm.ctx.iter_type(),
);

View File

@@ -89,7 +89,7 @@ fn complex_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
let value = Complex64::new(real, imag);
Ok(PyObject::new(Box::new(PyComplex { value }), cls.clone()))
Ok(PyObject::new(PyComplex { value }, cls.clone()))
}
fn complex_real(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {

View File

@@ -251,10 +251,10 @@ fn dict_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
let key_list = vm.ctx.new_list(keys);
let iter_obj = PyObject::new(
Box::new(PyIteratorValue {
PyIteratorValue {
position: Cell::new(0),
iterated_obj: key_list,
}),
},
vm.ctx.iter_type(),
);
@@ -271,10 +271,10 @@ fn dict_values(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
let values_list = vm.ctx.new_list(values);
let iter_obj = PyObject::new(
Box::new(PyIteratorValue {
PyIteratorValue {
position: Cell::new(0),
iterated_obj: values_list,
}),
},
vm.ctx.iter_type(),
);
@@ -291,10 +291,10 @@ fn dict_items(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
let items_list = vm.ctx.new_list(items);
let iter_obj = PyObject::new(
Box::new(PyIteratorValue {
PyIteratorValue {
position: Cell::new(0),
iterated_obj: items_list,
}),
},
vm.ctx.iter_type(),
);

View File

@@ -36,10 +36,10 @@ fn enumerate_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
};
let iterator = objiter::get_iter(vm, iterable)?;
Ok(PyObject::new(
Box::new(PyEnumerate {
PyEnumerate {
counter: RefCell::new(counter),
iterator,
}),
},
cls.clone(),
))
}

View File

@@ -26,10 +26,10 @@ fn filter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
);
let iterator = objiter::get_iter(vm, iterable)?;
Ok(PyObject::new(
Box::new(PyFilter {
PyFilter {
predicate: function.clone(),
iterator,
}),
},
cls.clone(),
))
}

View File

@@ -188,7 +188,7 @@ impl PyFloatRef {
let type_name = objtype::get_type_name(&arg.typ());
return Err(vm.new_type_error(format!("can't convert {} to float", type_name)));
};
Ok(PyObject::new(Box::new(PyFloat { value }), cls.clone()))
Ok(PyObject::new(PyFloat { value }, cls.clone()))
}
fn mod_(self, other: PyObjectRef, vm: &mut VirtualMachine) -> PyResult {

View File

@@ -40,7 +40,7 @@ pub fn init(context: &PyContext) {
pub fn new_generator(vm: &mut VirtualMachine, frame: PyObjectRef) -> PyResult {
Ok(PyObject::new(
Box::new(PyGenerator { frame }),
PyGenerator { frame },
vm.ctx.generator_type.clone(),
))
}

View File

@@ -105,7 +105,7 @@ fn int_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
Some(val) => to_int(vm, val, base)?,
None => Zero::zero(),
};
Ok(PyObject::new(Box::new(PyInt::new(val)), cls.clone()))
Ok(PyObject::new(PyInt::new(val), cls.clone()))
}
// Casting function:

View File

@@ -111,10 +111,10 @@ impl PyListRef {
fn iter(self, vm: &mut VirtualMachine) -> PyObjectRef {
PyObject::new(
Box::new(PyIteratorValue {
PyIteratorValue {
position: Cell::new(0),
iterated_obj: self.into_object(),
}),
},
vm.ctx.iter_type(),
)
}
@@ -302,10 +302,7 @@ fn list_new(
vec![]
};
Ok(PyObject::new(
Box::new(PyList::from(elements)),
cls.into_object(),
))
Ok(PyObject::new(PyList::from(elements), cls.into_object()))
}
fn quicksort(

View File

@@ -30,10 +30,10 @@ fn map_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
.map(|iterable| objiter::get_iter(vm, iterable))
.collect::<Result<Vec<_>, _>>()?;
Ok(PyObject::new(
Box::new(PyMap {
PyMap {
mapper: function.clone(),
iterators,
}),
},
cls.clone(),
))
}

View File

@@ -18,9 +18,9 @@ pub fn new_memory_view(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(cls, None), (bytes_object, None)]);
vm.ctx.set_attr(&cls, "obj", bytes_object.clone());
Ok(PyObject::new(
Box::new(PyMemoryView {
PyMemoryView {
obj: bytes_object.clone(),
}),
},
cls.clone(),
))
}

View File

@@ -137,7 +137,7 @@ impl<'a, T> PropertyBuilder<'a, T> {
deleter: None,
};
PyObject::new(Box::new(payload), self.ctx.property_type())
PyObject::new(payload, self.ctx.property_type())
} else {
let payload = PyReadOnlyProperty {
getter: self.getter.expect(
@@ -145,7 +145,7 @@ impl<'a, T> PropertyBuilder<'a, T> {
),
};
PyObject::new(Box::new(payload), self.ctx.readonly_property_type())
PyObject::new(payload, self.ctx.readonly_property_type())
}
}
}

View File

@@ -227,10 +227,7 @@ fn range_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
if step.is_zero() {
Err(vm.new_value_error("range with 0 step size".to_string()))
} else {
Ok(PyObject::new(
Box::new(PyRange { start, end, step }),
cls.clone(),
))
Ok(PyObject::new(PyRange { start, end, step }, cls.clone()))
}
}
@@ -238,10 +235,10 @@ fn range_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(range, Some(vm.ctx.range_type()))]);
Ok(PyObject::new(
Box::new(PyIteratorValue {
PyIteratorValue {
position: Cell::new(0),
iterated_obj: range.clone(),
}),
},
vm.ctx.iter_type(),
))
}
@@ -252,10 +249,10 @@ fn range_reversed(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
let range = get_value(zelf).reversed();
Ok(PyObject::new(
Box::new(PyIteratorValue {
PyIteratorValue {
position: Cell::new(0),
iterated_obj: PyObject::new(Box::new(range), vm.ctx.range_type()),
}),
iterated_obj: PyObject::new(range, vm.ctx.range_type()),
},
vm.ctx.iter_type(),
))
}
@@ -318,11 +315,11 @@ fn range_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
};
Ok(PyObject::new(
Box::new(PyRange {
PyRange {
start: new_start,
end: new_end,
step: new_step,
}),
},
vm.ctx.range_type(),
))
} else {

View File

@@ -165,16 +165,12 @@ pub fn get_item(
if subscript.payload::<PySlice>().is_some() {
if sequence.payload::<PyList>().is_some() {
Ok(PyObject::new(
Box::new(PyList::from(
elements.to_vec().get_slice_items(vm, &subscript)?,
)),
PyList::from(elements.to_vec().get_slice_items(vm, &subscript)?),
sequence.typ(),
))
} else if sequence.payload::<PyTuple>().is_some() {
Ok(PyObject::new(
Box::new(PyTuple::from(
elements.to_vec().get_slice_items(vm, &subscript)?,
)),
PyTuple::from(elements.to_vec().get_slice_items(vm, &subscript)?),
sequence.typ(),
))
} else {

View File

@@ -168,9 +168,9 @@ fn set_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
};
Ok(PyObject::new(
Box::new(PySet {
PySet {
elements: RefCell::new(elements),
}),
},
cls.clone(),
))
}
@@ -187,9 +187,9 @@ fn set_copy(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(s, Some(vm.ctx.set_type()))]);
let elements = get_elements(s);
Ok(PyObject::new(
Box::new(PySet {
PySet {
elements: RefCell::new(elements),
}),
},
vm.ctx.set_type(),
))
}
@@ -341,9 +341,9 @@ fn set_union(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
elements.extend(get_elements(other).clone());
Ok(PyObject::new(
Box::new(PySet {
PySet {
elements: RefCell::new(elements),
}),
},
vm.ctx.set_type(),
))
}
@@ -383,9 +383,9 @@ fn set_symmetric_difference(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResu
}
Ok(PyObject::new(
Box::new(PySet {
PySet {
elements: RefCell::new(elements),
}),
},
vm.ctx.set_type(),
))
}
@@ -423,9 +423,9 @@ fn set_combine_inner(
}
Ok(PyObject::new(
Box::new(PySet {
PySet {
elements: RefCell::new(elements),
}),
},
vm.ctx.set_type(),
))
}
@@ -555,10 +555,10 @@ fn set_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
let items = get_elements(zelf).values().cloned().collect();
let set_list = vm.ctx.new_list(items);
let iter_obj = PyObject::new(
Box::new(PyIteratorValue {
PyIteratorValue {
position: Cell::new(0),
iterated_obj: set_list,
}),
},
vm.ctx.iter_type(),
);

View File

@@ -54,11 +54,11 @@ fn slice_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
}
}?;
Ok(PyObject::new(
Box::new(PySlice {
PySlice {
start: start.map(|x| objint::get_value(x)),
stop: stop.map(|x| objint::get_value(x)),
step: step.map(|x| objint::get_value(x)),
}),
},
cls.clone(),
))
}

View File

@@ -126,10 +126,10 @@ impl PyTupleRef {
fn iter(self, vm: &mut VirtualMachine) -> PyObjectRef {
PyObject::new(
Box::new(PyIteratorValue {
PyIteratorValue {
position: Cell::new(0),
iterated_obj: self.into_object(),
}),
},
vm.ctx.iter_type(),
)
}
@@ -213,10 +213,7 @@ fn tuple_new(
vec![]
};
Ok(PyObject::new(
Box::new(PyTuple::from(elements)),
cls.into_object(),
))
Ok(PyObject::new(PyTuple::from(elements), cls.into_object()))
}
#[rustfmt::skip] // to avoid line splitting

View File

@@ -24,7 +24,7 @@ fn zip_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
.iter()
.map(|iterable| objiter::get_iter(vm, iterable))
.collect::<Result<Vec<_>, _>>()?;
Ok(PyObject::new(Box::new(PyZip { iterators }), cls.clone()))
Ok(PyObject::new(PyZip { iterators }, cls.clone()))
}
fn zip_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {

View File

@@ -161,6 +161,24 @@ pub fn create_type(
objtype::new(type_type.clone(), name, vec![base.clone()], dict).unwrap()
}
#[derive(Debug)]
pub struct PyNotImplemented;
impl PyValue for PyNotImplemented {
fn required_type(ctx: &PyContext) -> PyObjectRef {
ctx.not_implemented().typ()
}
}
#[derive(Debug)]
pub struct PyEllipsis;
impl PyValue for PyEllipsis {
fn required_type(ctx: &PyContext) -> PyObjectRef {
ctx.ellipsis_type.clone()
}
}
// Basic objects:
impl PyContext {
pub fn new() -> Self {
@@ -214,19 +232,19 @@ impl PyContext {
let exceptions = exceptions::ExceptionZoo::new(&type_type, &object_type, &dict_type);
let none = PyObject::new(
Box::new(objnone::PyNone),
objnone::PyNone,
create_type("NoneType", &type_type, &object_type, &dict_type),
);
let ellipsis = PyObject::new(Box::new(()), ellipsis_type.clone());
let ellipsis = PyObject::new(PyEllipsis, ellipsis_type.clone());
let not_implemented = PyObject::new(
Box::new(()),
PyNotImplemented,
create_type("NotImplementedType", &type_type, &object_type, &dict_type),
);
let true_value = PyObject::new(Box::new(PyInt::new(BigInt::one())), bool_type.clone());
let false_value = PyObject::new(Box::new(PyInt::new(BigInt::zero())), bool_type.clone());
let true_value = PyObject::new(PyInt::new(BigInt::one()), bool_type.clone());
let false_value = PyObject::new(PyInt::new(BigInt::zero()), bool_type.clone());
let context = PyContext {
bool_type,
memoryview_type,
@@ -464,30 +482,27 @@ impl PyContext {
}
pub fn new_int<T: ToBigInt>(&self, i: T) -> PyObjectRef {
PyObject::new(Box::new(PyInt::new(i)), self.int_type())
PyObject::new(PyInt::new(i), self.int_type())
}
pub fn new_float(&self, value: f64) -> PyObjectRef {
PyObject::new(Box::new(PyFloat::from(value)), self.float_type())
PyObject::new(PyFloat::from(value), self.float_type())
}
pub fn new_complex(&self, value: Complex64) -> PyObjectRef {
PyObject::new(Box::new(PyComplex::from(value)), self.complex_type())
PyObject::new(PyComplex::from(value), self.complex_type())
}
pub fn new_str(&self, s: String) -> PyObjectRef {
PyObject::new(Box::new(objstr::PyString { value: s }), self.str_type())
PyObject::new(objstr::PyString { value: s }, self.str_type())
}
pub fn new_bytes(&self, data: Vec<u8>) -> PyObjectRef {
PyObject::new(Box::new(objbytes::PyBytes::new(data)), self.bytes_type())
PyObject::new(objbytes::PyBytes::new(data), self.bytes_type())
}
pub fn new_bytearray(&self, data: Vec<u8>) -> PyObjectRef {
PyObject::new(
Box::new(objbytearray::PyByteArray::new(data)),
self.bytearray_type(),
)
PyObject::new(objbytearray::PyByteArray::new(data), self.bytearray_type())
}
pub fn new_bool(&self, b: bool) -> PyObjectRef {
@@ -499,21 +514,21 @@ impl PyContext {
}
pub fn new_tuple(&self, elements: Vec<PyObjectRef>) -> PyObjectRef {
PyObject::new(Box::new(PyTuple::from(elements)), self.tuple_type())
PyObject::new(PyTuple::from(elements), self.tuple_type())
}
pub fn new_list(&self, elements: Vec<PyObjectRef>) -> PyObjectRef {
PyObject::new(Box::new(PyList::from(elements)), self.list_type())
PyObject::new(PyList::from(elements), self.list_type())
}
pub fn new_set(&self) -> PyObjectRef {
// Initialized empty, as calling __hash__ is required for adding each object to the set
// which requires a VM context - this is done in the objset code itself.
PyObject::new(Box::new(PySet::default()), self.set_type())
PyObject::new(PySet::default(), self.set_type())
}
pub fn new_dict(&self) -> PyObjectRef {
PyObject::new(Box::new(PyDict::default()), self.dict_type())
PyObject::new(PyDict::default(), self.dict_type())
}
pub fn new_class(&self, name: &str, base: PyObjectRef) -> PyObjectRef {
@@ -526,10 +541,10 @@ impl PyContext {
pub fn new_module(&self, name: &str, dict: PyObjectRef) -> PyObjectRef {
PyObject::new(
Box::new(PyModule {
PyModule {
name: name.to_string(),
dict,
}),
},
self.module_type.clone(),
)
}
@@ -539,13 +554,13 @@ impl PyContext {
F: IntoPyNativeFunc<T, R>,
{
PyObject::new(
Box::new(PyBuiltinFunction::new(f.into_func())),
PyBuiltinFunction::new(f.into_func()),
self.builtin_function_or_method_type(),
)
}
pub fn new_frame(&self, code: PyObjectRef, scope: Scope) -> PyObjectRef {
PyObject::new(Box::new(Frame::new(code, scope)), self.frame_type())
PyObject::new(Frame::new(code, scope), self.frame_type())
}
pub fn new_property<F, T, R>(&self, f: F) -> PyObjectRef
@@ -556,7 +571,7 @@ impl PyContext {
}
pub fn new_code_object(&self, code: bytecode::CodeObject) -> PyObjectRef {
PyObject::new(Box::new(objcode::PyCode::new(code)), self.code_type())
PyObject::new(objcode::PyCode::new(code), self.code_type())
}
pub fn new_function(
@@ -566,16 +581,13 @@ impl PyContext {
defaults: PyObjectRef,
) -> PyObjectRef {
PyObject::new(
Box::new(PyFunction::new(code_obj, scope, defaults)),
PyFunction::new(code_obj, scope, defaults),
self.function_type(),
)
}
pub fn new_bound_method(&self, function: PyObjectRef, object: PyObjectRef) -> PyObjectRef {
PyObject::new(
Box::new(PyMethod::new(object, function)),
self.bound_method_type(),
)
PyObject::new(PyMethod::new(object, function), self.bound_method_type())
}
pub fn new_instance(&self, class: PyObjectRef, dict: Option<PyAttributes>) -> PyObjectRef {
@@ -682,13 +694,10 @@ pub struct PyRef<T> {
_payload: PhantomData<T>,
}
impl<T> PyRef<T>
where
T: PyValue,
{
impl<T: PyValue> PyRef<T> {
pub fn new(ctx: &PyContext, payload: T) -> Self {
PyRef {
obj: PyObject::new(Box::new(payload), T::required_type(ctx)),
obj: PyObject::new(payload, T::required_type(ctx)),
_payload: PhantomData,
}
}
@@ -697,7 +706,7 @@ where
let required_type = T::required_type(&vm.ctx);
if objtype::issubclass(&cls.obj, &required_type) {
Ok(PyRef {
obj: PyObject::new(Box::new(payload), cls.obj),
obj: PyObject::new(payload, cls.obj),
_payload: PhantomData,
})
} else {
@@ -1366,7 +1375,7 @@ where
T: PyValue + Sized,
{
fn into_pyobject(self, ctx: &PyContext) -> PyResult {
Ok(PyObject::new(Box::new(self), T::required_type(ctx)))
Ok(PyObject::new(self, T::required_type(ctx)))
}
}
@@ -1511,11 +1520,11 @@ impl PyValue for PyIteratorValue {
}
impl PyObject {
pub fn new(payload: Box<dyn Any>, typ: PyObjectRef) -> PyObjectRef {
pub fn new<T: PyValue>(payload: T, typ: PyObjectRef) -> PyObjectRef {
PyObject {
payload,
typ: Some(typ),
dict: Some(RefCell::new(PyAttributes::new())),
payload: Box::new(payload),
}
.into_ref()
}

View File

@@ -11,9 +11,18 @@ use regex::{Match, Regex};
use std::path::PathBuf;
use crate::obj::objstr;
use crate::pyobject::{PyContext, PyFuncArgs, PyObject, PyObjectRef, PyResult, TypeProtocol};
use crate::pyobject::{
PyContext, PyFuncArgs, PyObject, PyObjectRef, PyResult, PyValue, TypeProtocol,
};
use crate::VirtualMachine;
impl PyValue for Regex {
fn required_type(_ctx: &PyContext) -> PyObjectRef {
// TODO
unimplemented!()
}
}
/// Create the python `re` module with all its members.
pub fn mk_module(ctx: &PyContext) -> PyObjectRef {
let match_type = py_class!(ctx, "Match", ctx.object(), {
@@ -95,11 +104,19 @@ fn make_regex(vm: &mut VirtualMachine, pattern: &PyObjectRef) -> PyResult<Regex>
}
/// Inner data for a match object.
#[derive(Debug)]
struct PyMatch {
start: usize,
end: usize,
}
impl PyValue for PyMatch {
fn required_type(_ctx: &PyContext) -> PyObjectRef {
// TODO
unimplemented!()
}
}
/// Take a found regular expression and convert it to proper match object.
fn create_match(vm: &mut VirtualMachine, match_value: &Match) -> PyResult {
// Return match object:
@@ -116,7 +133,7 @@ fn create_match(vm: &mut VirtualMachine, match_value: &Match) -> PyResult {
end: match_value.end(),
};
Ok(PyObject::new(Box::new(match_value), match_class.clone()))
Ok(PyObject::new(match_value, match_class.clone()))
}
/// Compile a regular expression into a Pattern object.
@@ -134,7 +151,7 @@ fn re_compile(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
let module = import::import_module(vm, PathBuf::default(), "re").unwrap();
let pattern_class = vm.ctx.get_attr(&module, "Pattern").unwrap();
Ok(PyObject::new(Box::new(regex), pattern_class.clone()))
Ok(PyObject::new(regex, pattern_class.clone()))
}
fn pattern_match(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {

View File

@@ -3,18 +3,20 @@ use std::io;
use std::io::Read;
use std::io::Write;
use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket};
use std::ops::DerefMut;
use std::ops::Deref;
use crate::obj::objbytes;
use crate::obj::objint;
use crate::obj::objsequence::get_elements;
use crate::obj::objstr;
use crate::pyobject::{PyContext, PyFuncArgs, PyObject, PyObjectRef, PyResult, TypeProtocol};
use crate::pyobject::{
PyContext, PyFuncArgs, PyObject, PyObjectRef, PyResult, PyValue, TypeProtocol,
};
use crate::vm::VirtualMachine;
use num_traits::ToPrimitive;
#[derive(Copy, Clone)]
#[derive(Debug, Copy, Clone)]
enum AddressFamily {
Unix = 1,
Inet = 2,
@@ -32,7 +34,7 @@ impl AddressFamily {
}
}
#[derive(Copy, Clone)]
#[derive(Debug, Copy, Clone)]
enum SocketKind {
Stream = 1,
Dgram = 2,
@@ -48,6 +50,7 @@ impl SocketKind {
}
}
#[derive(Debug)]
enum Connection {
TcpListener(TcpListener),
TcpStream(TcpStream),
@@ -108,10 +111,18 @@ impl Write for Connection {
}
}
#[derive(Debug)]
pub struct Socket {
address_family: AddressFamily,
socket_kind: SocketKind,
con: Option<Connection>,
con: RefCell<Option<Connection>>,
}
impl PyValue for Socket {
fn required_type(_ctx: &PyContext) -> PyObjectRef {
// TODO
unimplemented!()
}
}
impl Socket {
@@ -119,16 +130,13 @@ impl Socket {
Socket {
address_family,
socket_kind,
con: None,
con: RefCell::new(None),
}
}
}
fn get_socket<'a>(obj: &'a PyObjectRef) -> impl DerefMut<Target = Socket> + 'a {
if let Some(socket) = obj.payload.downcast_ref::<RefCell<Socket>>() {
return socket.borrow_mut();
}
panic!("Inner error getting socket {:?}", obj);
fn get_socket<'a>(obj: &'a PyObjectRef) -> impl Deref<Target = Socket> + 'a {
obj.payload::<Socket>().unwrap()
}
fn socket_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
@@ -146,9 +154,10 @@ fn socket_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
AddressFamily::from_i32(vm, objint::get_value(family_int).to_i32().unwrap())?;
let kind = SocketKind::from_i32(vm, objint::get_value(kind_int).to_i32().unwrap())?;
let socket = RefCell::new(Socket::new(address_family, kind));
Ok(PyObject::new(Box::new(socket), cls.clone()))
Ok(PyObject::new(
Socket::new(address_family, kind),
cls.clone(),
))
}
fn socket_connect(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
@@ -160,18 +169,21 @@ fn socket_connect(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
let address_string = get_address_string(vm, address)?;
let mut socket = get_socket(zelf);
let socket = get_socket(zelf);
match socket.socket_kind {
SocketKind::Stream => match TcpStream::connect(address_string) {
Ok(stream) => {
socket.con = Some(Connection::TcpStream(stream));
socket
.con
.borrow_mut()
.replace(Connection::TcpStream(stream));
Ok(vm.get_none())
}
Err(s) => Err(vm.new_os_error(s.to_string())),
},
SocketKind::Dgram => {
if let Some(Connection::UdpSocket(con)) = &socket.con {
if let Some(Connection::UdpSocket(con)) = socket.con.borrow().as_ref() {
match con.connect(address_string) {
Ok(_) => Ok(vm.get_none()),
Err(s) => Err(vm.new_os_error(s.to_string())),
@@ -192,19 +204,25 @@ fn socket_bind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
let address_string = get_address_string(vm, address)?;
let mut socket = get_socket(zelf);
let socket = get_socket(zelf);
match socket.socket_kind {
SocketKind::Stream => match TcpListener::bind(address_string) {
Ok(stream) => {
socket.con = Some(Connection::TcpListener(stream));
socket
.con
.borrow_mut()
.replace(Connection::TcpListener(stream));
Ok(vm.get_none())
}
Err(s) => Err(vm.new_os_error(s.to_string())),
},
SocketKind::Dgram => match UdpSocket::bind(address_string) {
Ok(dgram) => {
socket.con = Some(Connection::UdpSocket(dgram));
socket
.con
.borrow_mut()
.replace(Connection::UdpSocket(dgram));
Ok(vm.get_none())
}
Err(s) => Err(vm.new_os_error(s.to_string())),
@@ -248,10 +266,10 @@ fn socket_listen(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
fn socket_accept(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(zelf, None)]);
let mut socket = get_socket(zelf);
let socket = get_socket(zelf);
let ret = match socket.con {
Some(ref mut v) => v.accept(),
let ret = match socket.con.borrow_mut().as_mut() {
Some(v) => v.accept(),
None => return Err(vm.new_type_error("".to_string())),
};
@@ -260,13 +278,13 @@ fn socket_accept(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
Err(s) => return Err(vm.new_os_error(s.to_string())),
};
let socket = RefCell::new(Socket {
let socket = Socket {
address_family: socket.address_family,
socket_kind: socket.socket_kind,
con: Some(Connection::TcpStream(tcp_stream)),
});
con: RefCell::new(Some(Connection::TcpStream(tcp_stream))),
};
let sock_obj = PyObject::new(Box::new(socket), zelf.typ());
let sock_obj = PyObject::new(socket, zelf.typ());
let addr_tuple = get_addr_tuple(vm, addr)?;
@@ -279,11 +297,11 @@ fn socket_recv(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
args,
required = [(zelf, None), (bufsize, Some(vm.ctx.int_type()))]
);
let mut socket = get_socket(zelf);
let socket = get_socket(zelf);
let mut buffer = vec![0u8; objint::get_value(bufsize).to_usize().unwrap()];
match socket.con {
Some(ref mut v) => match v.read_exact(&mut buffer) {
match socket.con.borrow_mut().as_mut() {
Some(v) => match v.read_exact(&mut buffer) {
Ok(_) => (),
Err(s) => return Err(vm.new_os_error(s.to_string())),
},
@@ -299,11 +317,11 @@ fn socket_recvfrom(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
required = [(zelf, None), (bufsize, Some(vm.ctx.int_type()))]
);
let mut socket = get_socket(zelf);
let socket = get_socket(zelf);
let mut buffer = vec![0u8; objint::get_value(bufsize).to_usize().unwrap()];
let ret = match socket.con {
Some(ref mut v) => v.recv_from(&mut buffer),
let ret = match socket.con.borrow().as_ref() {
Some(v) => v.recv_from(&mut buffer),
None => return Err(vm.new_type_error("".to_string())),
};
@@ -323,10 +341,10 @@ fn socket_send(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
args,
required = [(zelf, None), (bytes, Some(vm.ctx.bytes_type()))]
);
let mut socket = get_socket(zelf);
let socket = get_socket(zelf);
match socket.con {
Some(ref mut v) => match v.write(&objbytes::get_value(&bytes)) {
match socket.con.borrow_mut().as_mut() {
Some(v) => match v.write(&objbytes::get_value(&bytes)) {
Ok(_) => (),
Err(s) => return Err(vm.new_os_error(s.to_string())),
},
@@ -347,30 +365,29 @@ fn socket_sendto(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
);
let address_string = get_address_string(vm, address)?;
let mut socket = get_socket(zelf);
let socket = get_socket(zelf);
match socket.socket_kind {
SocketKind::Dgram => {
match socket.con {
Some(ref mut v) => match v.send_to(&objbytes::get_value(&bytes), address_string) {
if let Some(v) = socket.con.borrow().as_ref() {
return match v.send_to(&objbytes::get_value(&bytes), address_string) {
Ok(_) => Ok(vm.get_none()),
Err(s) => Err(vm.new_os_error(s.to_string())),
},
None => {
// Doing implicit bind
match UdpSocket::bind("0.0.0.0:0") {
Ok(dgram) => {
match dgram.send_to(&objbytes::get_value(&bytes), address_string) {
Ok(_) => {
socket.con = Some(Connection::UdpSocket(dgram));
Ok(vm.get_none())
}
Err(s) => Err(vm.new_os_error(s.to_string())),
}
}
Err(s) => Err(vm.new_os_error(s.to_string())),
};
}
// Doing implicit bind
match UdpSocket::bind("0.0.0.0:0") {
Ok(dgram) => match dgram.send_to(&objbytes::get_value(&bytes), address_string) {
Ok(_) => {
socket
.con
.borrow_mut()
.replace(Connection::UdpSocket(dgram));
Ok(vm.get_none())
}
}
Err(s) => Err(vm.new_os_error(s.to_string())),
},
Err(s) => Err(vm.new_os_error(s.to_string())),
}
}
_ => Err(vm.new_not_implemented_error("".to_string())),
@@ -380,17 +397,17 @@ fn socket_sendto(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
fn socket_close(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(zelf, None)]);
let mut socket = get_socket(zelf);
socket.con = None;
let socket = get_socket(zelf);
socket.con.borrow_mut().take();
Ok(vm.get_none())
}
fn socket_getsockname(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(zelf, None)]);
let mut socket = get_socket(zelf);
let socket = get_socket(zelf);
let addr = match socket.con {
Some(ref mut v) => v.local_addr(),
let addr = match socket.con.borrow().as_ref() {
Some(v) => v.local_addr(),
None => return Err(vm.new_type_error("".to_string())),
};