Replace special cases in boolval with __bool__ method on types.

This commit is contained in:
Adam Kelly
2019-03-23 09:31:42 +00:00
parent 26a238085f
commit 9ebbde8126
6 changed files with 25 additions and 24 deletions

View File

@@ -6,12 +6,7 @@ use crate::pyobject::{
};
use crate::vm::VirtualMachine;
use super::objdict::PyDict;
use super::objfloat::PyFloat;
use super::objint::PyInt;
use super::objlist::PyList;
use super::objstr::PyString;
use super::objtuple::PyTuple;
use super::objtype;
impl IntoPyObject for bool {
@@ -27,25 +22,6 @@ impl TryFromObject for bool {
}
pub fn boolval(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<bool> {
if let Some(s) = obj.payload::<PyString>() {
return Ok(!s.value.is_empty());
}
if let Some(value) = obj.payload::<PyFloat>() {
return Ok(*value != PyFloat::from(0.0));
}
if let Some(dict) = obj.payload::<PyDict>() {
return Ok(!dict.entries.borrow().is_empty());
}
if let Some(i) = obj.payload::<PyInt>() {
return Ok(!i.value.is_zero());
}
if let Some(list) = obj.payload::<PyList>() {
return Ok(!list.elements.borrow().is_empty());
}
if let Some(tuple) = obj.payload::<PyTuple>() {
return Ok(!tuple.elements.borrow().is_empty());
}
Ok(if let Ok(f) = vm.get_method(obj.clone(), "__bool__") {
let bool_res = vm.invoke(f, PyFuncArgs::default())?;
match bool_res.payload::<PyInt>() {

View File

@@ -175,6 +175,10 @@ fn dict_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
}
impl PyDictRef {
fn bool(self, _vm: &VirtualMachine) -> bool {
!self.entries.borrow().is_empty()
}
fn len(self, _vm: &VirtualMachine) -> usize {
self.entries.borrow().len()
}
@@ -300,6 +304,7 @@ impl PyDictRef {
pub fn init(context: &PyContext) {
extend_class!(context, &context.dict_type, {
"__bool__" => context.new_rustfunc(PyDictRef::bool),
"__len__" => context.new_rustfunc(PyDictRef::len),
"__contains__" => context.new_rustfunc(PyDictRef::contains),
"__delitem__" => context.new_rustfunc(PyDictRef::delitem),

View File

@@ -120,6 +120,10 @@ impl PyFloatRef {
}
}
fn bool(self, _vm: &VirtualMachine) -> bool {
self.value != 0.0
}
fn divmod(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
if objtype::isinstance(&other, &vm.ctx.float_type())
|| objtype::isinstance(&other, &vm.ctx.int_type())
@@ -366,6 +370,7 @@ pub fn init(context: &PyContext) {
"__abs__" => context.new_rustfunc(PyFloatRef::abs),
"__add__" => context.new_rustfunc(PyFloatRef::add),
"__radd__" => context.new_rustfunc(PyFloatRef::add),
"__bool__" => context.new_rustfunc(PyFloatRef::bool),
"__divmod__" => context.new_rustfunc(PyFloatRef::divmod),
"__floordiv__" => context.new_rustfunc(PyFloatRef::floordiv),
"__new__" => context.new_rustfunc(PyFloatRef::new_float),

View File

@@ -94,6 +94,10 @@ impl PyListRef {
}
}
fn bool(self, _vm: &VirtualMachine) -> bool {
!self.elements.borrow().is_empty()
}
fn clear(self, _vm: &VirtualMachine) {
self.elements.borrow_mut().clear();
}
@@ -419,6 +423,7 @@ pub fn init(context: &PyContext) {
extend_class!(context, list_type, {
"__add__" => context.new_rustfunc(PyListRef::add),
"__iadd__" => context.new_rustfunc(PyListRef::iadd),
"__bool__" => context.new_rustfunc(PyListRef::bool),
"__contains__" => context.new_rustfunc(PyListRef::contains),
"__eq__" => context.new_rustfunc(PyListRef::eq),
"__lt__" => context.new_rustfunc(PyListRef::lt),

View File

@@ -57,6 +57,10 @@ impl PyStringRef {
}
}
fn bool(self, _vm: &VirtualMachine) -> bool {
!self.value.is_empty()
}
fn eq(self, rhs: PyObjectRef, vm: &VirtualMachine) -> bool {
if objtype::isinstance(&rhs, &vm.ctx.str_type()) {
self.value == get_value(&rhs)
@@ -632,6 +636,7 @@ pub fn init(context: &PyContext) {
extend_class!(context, str_type, {
"__add__" => context.new_rustfunc(PyStringRef::add),
"__bool__" => context.new_rustfunc(PyStringRef::bool),
"__contains__" => context.new_rustfunc(PyStringRef::contains),
"__doc__" => context.new_str(str_doc.to_string()),
"__eq__" => context.new_rustfunc(PyStringRef::eq),

View File

@@ -101,6 +101,10 @@ impl PyTupleRef {
}
}
fn bool(self, _vm: &VirtualMachine) -> bool {
!self.elements.borrow().is_empty()
}
fn count(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
let mut count: usize = 0;
for element in self.elements.borrow().iter() {
@@ -230,6 +234,7 @@ tuple(iterable) -> tuple initialized from iterable's items
If the argument is a tuple, the return value is the same object.";
context.set_attr(tuple_type, "__add__", context.new_rustfunc(PyTupleRef::add));
context.set_attr(tuple_type, "__bool__", context.new_rustfunc(PyTupleRef::bool));
context.set_attr(tuple_type, "__eq__", context.new_rustfunc(PyTupleRef::eq));
context.set_attr(tuple_type,"__contains__",context.new_rustfunc(PyTupleRef::contains));
context.set_attr(tuple_type,"__getitem__",context.new_rustfunc(PyTupleRef::getitem));