mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Basic membership implementation
This commit is contained in:
62
tests/snippets/membership.py
Normal file
62
tests/snippets/membership.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# test lists
|
||||
assert 3 in [1, 2, 3]
|
||||
assert 3 not in [1, 2]
|
||||
|
||||
assert not (3 in [1, 2])
|
||||
assert not (3 not in [1, 2, 3])
|
||||
|
||||
# test strings
|
||||
assert "foo" in "foobar"
|
||||
assert "whatever" not in "foobar"
|
||||
|
||||
# test bytes
|
||||
# TODO: uncomment this when bytes are implemented
|
||||
# assert b"foo" in b"foobar"
|
||||
# assert b"whatever" not in b"foobar"
|
||||
|
||||
# test tuple
|
||||
assert 1 in (1, 2)
|
||||
assert 3 not in (1, 2)
|
||||
|
||||
# test set
|
||||
# TODO: uncomment this when sets are implemented
|
||||
# assert 1 in set(1, 2)
|
||||
# assert 3 not in set(1, 2)
|
||||
|
||||
# test dicts
|
||||
# TODO: test dicts when keys other than strings will be allowed
|
||||
assert "a" in {"a": 0, "b": 0}
|
||||
assert "c" not in {"a": 0, "b": 0}
|
||||
|
||||
# test iter
|
||||
assert 3 in iter([1, 2, 3])
|
||||
assert 3 not in iter([1, 2])
|
||||
|
||||
# test sequence
|
||||
# TODO: uncomment this when ranges are usable
|
||||
# assert 1 in range(0, 2)
|
||||
# assert 3 not in range(0, 2)
|
||||
|
||||
# test __contains__ in user objects
|
||||
class MyNotContainingClass():
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
1 in MyNotContainingClass()
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
assert False, "TypeError not raised"
|
||||
|
||||
|
||||
class MyContainingClass():
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __contains__(self, something):
|
||||
return something == self.value
|
||||
|
||||
|
||||
assert 2 in MyContainingClass(2)
|
||||
assert 1 not in MyContainingClass(2)
|
||||
@@ -63,6 +63,26 @@ fn dict_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
Ok(vm.new_str(s))
|
||||
}
|
||||
|
||||
pub fn dict_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [
|
||||
(dict, Some(vm.ctx.dict_type())),
|
||||
(needle, Some(vm.ctx.str_type()))
|
||||
]
|
||||
);
|
||||
|
||||
let needle = objstr::get_value(&needle);
|
||||
for element in get_elements(dict).iter() {
|
||||
if &needle == element.0 {
|
||||
return Ok(vm.new_bool(true));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vm.new_bool(false))
|
||||
}
|
||||
|
||||
pub fn create_type(type_type: PyObjectRef, object_type: PyObjectRef, dict_type: PyObjectRef) {
|
||||
(*dict_type.borrow_mut()).kind = PyObjectKind::Class {
|
||||
name: String::from("dict"),
|
||||
@@ -75,6 +95,7 @@ pub fn create_type(type_type: PyObjectRef, object_type: PyObjectRef, dict_type:
|
||||
pub fn init(context: &PyContext) {
|
||||
let ref dict_type = context.dict_type;
|
||||
dict_type.set_attr("__len__", context.new_rustfunc(dict_len));
|
||||
dict_type.set_attr("__contains__", context.new_rustfunc(dict_contains));
|
||||
dict_type.set_attr("__new__", context.new_rustfunc(dict_new));
|
||||
dict_type.set_attr("__repr__", context.new_rustfunc(dict_repr));
|
||||
}
|
||||
|
||||
@@ -53,6 +53,26 @@ fn iter_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
Ok(iter.clone())
|
||||
}
|
||||
|
||||
fn iter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [(iter, Some(vm.ctx.iter_type())), (needle, None)]
|
||||
);
|
||||
loop {
|
||||
match vm.call_method(&iter, "__next__", vec![]) {
|
||||
Ok(element) => {
|
||||
if &element == needle {
|
||||
return Ok(vm.new_bool(true));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Err(_) => return Ok(vm.new_bool(false)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(vm, args, required = [(iter, Some(vm.ctx.iter_type()))]);
|
||||
|
||||
@@ -86,7 +106,8 @@ fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
|
||||
pub fn init(context: &PyContext) {
|
||||
let ref iter_type = context.iter_type;
|
||||
iter_type.set_attr("__new__", context.new_rustfunc(iter_new));
|
||||
iter_type.set_attr("__contains__", context.new_rustfunc(iter_contains));
|
||||
iter_type.set_attr("__iter__", context.new_rustfunc(iter_iter));
|
||||
iter_type.set_attr("__new__", context.new_rustfunc(iter_new));
|
||||
iter_type.set_attr("__next__", context.new_rustfunc(iter_next));
|
||||
}
|
||||
|
||||
@@ -162,10 +162,27 @@ fn reverse(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
trace!("list.len called with: {:?}", args);
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [(list, Some(vm.ctx.list_type())), (x, None)]
|
||||
);
|
||||
for element in get_elements(list).iter() {
|
||||
if x == element {
|
||||
return Ok(vm.new_bool(true));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vm.new_bool(false))
|
||||
}
|
||||
|
||||
pub fn init(context: &PyContext) {
|
||||
let ref list_type = context.list_type;
|
||||
list_type.set_attr("__eq__", context.new_rustfunc(list_eq));
|
||||
list_type.set_attr("__add__", context.new_rustfunc(list_add));
|
||||
list_type.set_attr("__contains__", context.new_rustfunc(contains));
|
||||
list_type.set_attr("__eq__", context.new_rustfunc(list_eq));
|
||||
list_type.set_attr("__len__", context.new_rustfunc(list_len));
|
||||
list_type.set_attr("__new__", context.new_rustfunc(list_new));
|
||||
list_type.set_attr("__repr__", context.new_rustfunc(list_repr));
|
||||
|
||||
@@ -8,8 +8,9 @@ use super::objtype;
|
||||
|
||||
pub fn init(context: &PyContext) {
|
||||
let ref str_type = context.str_type;
|
||||
str_type.set_attr("__eq__", context.new_rustfunc(str_eq));
|
||||
str_type.set_attr("__add__", context.new_rustfunc(str_add));
|
||||
str_type.set_attr("__eq__", context.new_rustfunc(str_eq));
|
||||
str_type.set_attr("__contains__", context.new_rustfunc(str_contains));
|
||||
str_type.set_attr("__len__", context.new_rustfunc(str_len));
|
||||
str_type.set_attr("__mul__", context.new_rustfunc(str_mul));
|
||||
str_type.set_attr("__new__", context.new_rustfunc(str_new));
|
||||
@@ -197,6 +198,20 @@ fn str_startswith(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
Ok(vm.ctx.new_bool(value.starts_with(pat.as_str())))
|
||||
}
|
||||
|
||||
fn str_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [
|
||||
(s, Some(vm.ctx.str_type())),
|
||||
(needle, Some(vm.ctx.str_type()))
|
||||
]
|
||||
);
|
||||
let value = get_value(&s);
|
||||
let needle = get_value(&needle);
|
||||
Ok(vm.ctx.new_bool(value.contains(needle.as_str())))
|
||||
}
|
||||
|
||||
// TODO: should with following format
|
||||
// class str(object='')
|
||||
// class str(object=b'', encoding='utf-8', errors='strict')
|
||||
|
||||
@@ -50,6 +50,21 @@ fn tuple_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
Ok(vm.new_str(s))
|
||||
}
|
||||
|
||||
pub fn tuple_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
|
||||
arg_check!(
|
||||
vm,
|
||||
args,
|
||||
required = [(tuple, Some(vm.ctx.tuple_type())), (x, None)]
|
||||
);
|
||||
for element in get_elements(tuple).iter() {
|
||||
if x == element {
|
||||
return Ok(vm.new_bool(true));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vm.new_bool(false))
|
||||
}
|
||||
|
||||
pub fn get_elements(obj: &PyObjectRef) -> Vec<PyObjectRef> {
|
||||
if let PyObjectKind::Tuple { elements } = &obj.borrow().kind {
|
||||
elements.to_vec()
|
||||
@@ -61,6 +76,7 @@ pub fn get_elements(obj: &PyObjectRef) -> Vec<PyObjectRef> {
|
||||
pub fn init(context: &PyContext) {
|
||||
let ref tuple_type = context.tuple_type;
|
||||
tuple_type.set_attr("__eq__", context.new_rustfunc(tuple_eq));
|
||||
tuple_type.set_attr("__contains__", context.new_rustfunc(tuple_contains));
|
||||
tuple_type.set_attr("__len__", context.new_rustfunc(tuple_len));
|
||||
tuple_type.set_attr("__repr__", context.new_rustfunc(tuple_repr));
|
||||
}
|
||||
|
||||
30
vm/src/vm.rs
30
vm/src/vm.rs
@@ -580,6 +580,33 @@ impl VirtualMachine {
|
||||
a.get_id()
|
||||
}
|
||||
|
||||
// https://docs.python.org/3/reference/expressions.html#membership-test-operations
|
||||
fn _membership(&mut self, needle: PyObjectRef, haystack: &PyObjectRef) -> PyResult {
|
||||
self.call_method(&haystack, "__contains__", vec![needle])
|
||||
// TODO: implement __iter__ and __getitem__ cases when __contains__ is
|
||||
// not implemented.
|
||||
}
|
||||
|
||||
fn _in(&mut self, needle: PyObjectRef, haystack: PyObjectRef) -> PyResult {
|
||||
match self._membership(needle, &haystack) {
|
||||
Ok(found) => Ok(found),
|
||||
Err(_) => Err(self.new_type_error(format!(
|
||||
"{} has no __contains__ method",
|
||||
objtype::get_type_name(&haystack.typ())
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn _not_in(&mut self, needle: PyObjectRef, haystack: PyObjectRef) -> PyResult {
|
||||
match self._membership(needle, &haystack) {
|
||||
Ok(found) => Ok(self.ctx.new_bool(!objbool::get_value(&found))),
|
||||
Err(_) => Err(self.new_type_error(format!(
|
||||
"{} has no __contains__ method",
|
||||
objtype::get_type_name(&haystack.typ())
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn _is(&self, a: PyObjectRef, b: PyObjectRef) -> bool {
|
||||
// Pointer equal:
|
||||
a.is(&b)
|
||||
@@ -603,7 +630,8 @@ impl VirtualMachine {
|
||||
&bytecode::ComparisonOperator::GreaterOrEqual => self._ge(a, b),
|
||||
&bytecode::ComparisonOperator::Is => Ok(self.ctx.new_bool(self._is(a, b))),
|
||||
&bytecode::ComparisonOperator::IsNot => self._is_not(a, b),
|
||||
_ => panic!("NOT IMPL {:?}", op),
|
||||
&bytecode::ComparisonOperator::In => self._in(a, b),
|
||||
&bytecode::ComparisonOperator::NotIn => self._not_in(a, b),
|
||||
};
|
||||
match result {
|
||||
Ok(value) => {
|
||||
|
||||
Reference in New Issue
Block a user