Basic membership implementation

This commit is contained in:
Éloi Rivard
2018-10-17 13:20:32 +02:00
parent 7f263d14f3
commit e415ee1fbf
7 changed files with 184 additions and 4 deletions

View File

@@ -0,0 +1,62 @@
# test lists
assert 3 in [1, 2, 3]
assert 3 not in [1, 2]
assert not (3 in [1, 2])
assert not (3 not in [1, 2, 3])
# test strings
assert "foo" in "foobar"
assert "whatever" not in "foobar"
# test bytes
# TODO: uncomment this when bytes are implemented
# assert b"foo" in b"foobar"
# assert b"whatever" not in b"foobar"
# test tuple
assert 1 in (1, 2)
assert 3 not in (1, 2)
# test set
# TODO: uncomment this when sets are implemented
# assert 1 in set(1, 2)
# assert 3 not in set(1, 2)
# test dicts
# TODO: test dicts when keys other than strings will be allowed
assert "a" in {"a": 0, "b": 0}
assert "c" not in {"a": 0, "b": 0}
# test iter
assert 3 in iter([1, 2, 3])
assert 3 not in iter([1, 2])
# test sequence
# TODO: uncomment this when ranges are usable
# assert 1 in range(0, 2)
# assert 3 not in range(0, 2)
# test __contains__ in user objects
class MyNotContainingClass():
pass
try:
1 in MyNotContainingClass()
except TypeError:
pass
else:
assert False, "TypeError not raised"
class MyContainingClass():
def __init__(self, value):
self.value = value
def __contains__(self, something):
return something == self.value
assert 2 in MyContainingClass(2)
assert 1 not in MyContainingClass(2)

View File

@@ -63,6 +63,26 @@ fn dict_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
Ok(vm.new_str(s))
}
pub fn dict_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [
(dict, Some(vm.ctx.dict_type())),
(needle, Some(vm.ctx.str_type()))
]
);
let needle = objstr::get_value(&needle);
for element in get_elements(dict).iter() {
if &needle == element.0 {
return Ok(vm.new_bool(true));
}
}
Ok(vm.new_bool(false))
}
pub fn create_type(type_type: PyObjectRef, object_type: PyObjectRef, dict_type: PyObjectRef) {
(*dict_type.borrow_mut()).kind = PyObjectKind::Class {
name: String::from("dict"),
@@ -75,6 +95,7 @@ pub fn create_type(type_type: PyObjectRef, object_type: PyObjectRef, dict_type:
pub fn init(context: &PyContext) {
let ref dict_type = context.dict_type;
dict_type.set_attr("__len__", context.new_rustfunc(dict_len));
dict_type.set_attr("__contains__", context.new_rustfunc(dict_contains));
dict_type.set_attr("__new__", context.new_rustfunc(dict_new));
dict_type.set_attr("__repr__", context.new_rustfunc(dict_repr));
}

View File

@@ -53,6 +53,26 @@ fn iter_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
Ok(iter.clone())
}
fn iter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(iter, Some(vm.ctx.iter_type())), (needle, None)]
);
loop {
match vm.call_method(&iter, "__next__", vec![]) {
Ok(element) => {
if &element == needle {
return Ok(vm.new_bool(true));
} else {
continue;
}
}
Err(_) => return Ok(vm.new_bool(false)),
}
}
}
fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(iter, Some(vm.ctx.iter_type()))]);
@@ -86,7 +106,8 @@ fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
pub fn init(context: &PyContext) {
let ref iter_type = context.iter_type;
iter_type.set_attr("__new__", context.new_rustfunc(iter_new));
iter_type.set_attr("__contains__", context.new_rustfunc(iter_contains));
iter_type.set_attr("__iter__", context.new_rustfunc(iter_iter));
iter_type.set_attr("__new__", context.new_rustfunc(iter_new));
iter_type.set_attr("__next__", context.new_rustfunc(iter_next));
}

View File

@@ -162,10 +162,27 @@ fn reverse(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
}
}
pub fn contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
trace!("list.len called with: {:?}", args);
arg_check!(
vm,
args,
required = [(list, Some(vm.ctx.list_type())), (x, None)]
);
for element in get_elements(list).iter() {
if x == element {
return Ok(vm.new_bool(true));
}
}
Ok(vm.new_bool(false))
}
pub fn init(context: &PyContext) {
let ref list_type = context.list_type;
list_type.set_attr("__eq__", context.new_rustfunc(list_eq));
list_type.set_attr("__add__", context.new_rustfunc(list_add));
list_type.set_attr("__contains__", context.new_rustfunc(contains));
list_type.set_attr("__eq__", context.new_rustfunc(list_eq));
list_type.set_attr("__len__", context.new_rustfunc(list_len));
list_type.set_attr("__new__", context.new_rustfunc(list_new));
list_type.set_attr("__repr__", context.new_rustfunc(list_repr));

View File

@@ -8,8 +8,9 @@ use super::objtype;
pub fn init(context: &PyContext) {
let ref str_type = context.str_type;
str_type.set_attr("__eq__", context.new_rustfunc(str_eq));
str_type.set_attr("__add__", context.new_rustfunc(str_add));
str_type.set_attr("__eq__", context.new_rustfunc(str_eq));
str_type.set_attr("__contains__", context.new_rustfunc(str_contains));
str_type.set_attr("__len__", context.new_rustfunc(str_len));
str_type.set_attr("__mul__", context.new_rustfunc(str_mul));
str_type.set_attr("__new__", context.new_rustfunc(str_new));
@@ -197,6 +198,20 @@ fn str_startswith(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
Ok(vm.ctx.new_bool(value.starts_with(pat.as_str())))
}
fn str_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [
(s, Some(vm.ctx.str_type())),
(needle, Some(vm.ctx.str_type()))
]
);
let value = get_value(&s);
let needle = get_value(&needle);
Ok(vm.ctx.new_bool(value.contains(needle.as_str())))
}
// TODO: should with following format
// class str(object='')
// class str(object=b'', encoding='utf-8', errors='strict')

View File

@@ -50,6 +50,21 @@ fn tuple_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
Ok(vm.new_str(s))
}
pub fn tuple_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(tuple, Some(vm.ctx.tuple_type())), (x, None)]
);
for element in get_elements(tuple).iter() {
if x == element {
return Ok(vm.new_bool(true));
}
}
Ok(vm.new_bool(false))
}
pub fn get_elements(obj: &PyObjectRef) -> Vec<PyObjectRef> {
if let PyObjectKind::Tuple { elements } = &obj.borrow().kind {
elements.to_vec()
@@ -61,6 +76,7 @@ pub fn get_elements(obj: &PyObjectRef) -> Vec<PyObjectRef> {
pub fn init(context: &PyContext) {
let ref tuple_type = context.tuple_type;
tuple_type.set_attr("__eq__", context.new_rustfunc(tuple_eq));
tuple_type.set_attr("__contains__", context.new_rustfunc(tuple_contains));
tuple_type.set_attr("__len__", context.new_rustfunc(tuple_len));
tuple_type.set_attr("__repr__", context.new_rustfunc(tuple_repr));
}

View File

@@ -580,6 +580,33 @@ impl VirtualMachine {
a.get_id()
}
// https://docs.python.org/3/reference/expressions.html#membership-test-operations
fn _membership(&mut self, needle: PyObjectRef, haystack: &PyObjectRef) -> PyResult {
self.call_method(&haystack, "__contains__", vec![needle])
// TODO: implement __iter__ and __getitem__ cases when __contains__ is
// not implemented.
}
fn _in(&mut self, needle: PyObjectRef, haystack: PyObjectRef) -> PyResult {
match self._membership(needle, &haystack) {
Ok(found) => Ok(found),
Err(_) => Err(self.new_type_error(format!(
"{} has no __contains__ method",
objtype::get_type_name(&haystack.typ())
))),
}
}
fn _not_in(&mut self, needle: PyObjectRef, haystack: PyObjectRef) -> PyResult {
match self._membership(needle, &haystack) {
Ok(found) => Ok(self.ctx.new_bool(!objbool::get_value(&found))),
Err(_) => Err(self.new_type_error(format!(
"{} has no __contains__ method",
objtype::get_type_name(&haystack.typ())
))),
}
}
fn _is(&self, a: PyObjectRef, b: PyObjectRef) -> bool {
// Pointer equal:
a.is(&b)
@@ -603,7 +630,8 @@ impl VirtualMachine {
&bytecode::ComparisonOperator::GreaterOrEqual => self._ge(a, b),
&bytecode::ComparisonOperator::Is => Ok(self.ctx.new_bool(self._is(a, b))),
&bytecode::ComparisonOperator::IsNot => self._is_not(a, b),
_ => panic!("NOT IMPL {:?}", op),
&bytecode::ComparisonOperator::In => self._in(a, b),
&bytecode::ComparisonOperator::NotIn => self._not_in(a, b),
};
match result {
Ok(value) => {