Merge branch 'master' into dict_into_iter

This commit is contained in:
Adam Kelly
2019-04-11 08:28:12 +01:00
9 changed files with 903 additions and 265 deletions

View File

@@ -3,6 +3,7 @@
pub mod objbool;
pub mod objbuiltinfunc;
pub mod objbytearray;
pub mod objbyteinner;
pub mod objbytes;
pub mod objclassmethod;
pub mod objcode;

349
vm/src/obj/objbyteinner.rs Normal file
View File

@@ -0,0 +1,349 @@
use crate::pyobject::PyObjectRef;
use crate::function::OptionalArg;
use crate::vm::VirtualMachine;
use crate::pyobject::{PyResult, TypeProtocol};
use crate::obj::objstr::PyString;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use super::objint;
use super::objsequence::PySliceableSequence;
use crate::obj::objint::PyInt;
use num_traits::ToPrimitive;
#[derive(Debug, Default, Clone)]
pub struct PyByteInner {
pub elements: Vec<u8>,
}
impl PyByteInner {
pub fn new(
val_option: OptionalArg<PyObjectRef>,
enc_option: OptionalArg<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<PyByteInner> {
// First handle bytes(string, encoding[, errors])
if let OptionalArg::Present(enc) = enc_option {
if let OptionalArg::Present(eval) = val_option {
if let Ok(input) = eval.downcast::<PyString>() {
if let Ok(encoding) = enc.clone().downcast::<PyString>() {
if &encoding.value.to_lowercase() == "utf8"
|| &encoding.value.to_lowercase() == "utf-8"
// TODO: different encoding
{
return Ok(PyByteInner {
elements: input.value.as_bytes().to_vec(),
});
} else {
return Err(
vm.new_value_error(format!("unknown encoding: {}", encoding.value)), //should be lookup error
);
}
} else {
return Err(vm.new_type_error(format!(
"bytes() argument 2 must be str, not {}",
enc.class().name
)));
}
} else {
return Err(vm.new_type_error("encoding without a string argument".to_string()));
}
} else {
return Err(vm.new_type_error("encoding without a string argument".to_string()));
}
// Only one argument
} else {
let value = if let OptionalArg::Present(ival) = val_option {
match_class!(ival.clone(),
i @ PyInt => {
let size = objint::get_value(&i.into_object()).to_usize().unwrap();
Ok(vec![0; size])},
_l @ PyString=> {return Err(vm.new_type_error("string argument without an encoding".to_string()));},
obj => {
let elements = vm.extract_elements(&obj).or_else(|_| {Err(vm.new_type_error(format!(
"cannot convert {} object to bytes", obj.class().name)))});
let mut data_bytes = vec![];
for elem in elements.unwrap(){
let v = objint::to_int(vm, &elem, 10)?;
if let Some(i) = v.to_u8() {
data_bytes.push(i);
} else {
return Err(vm.new_value_error("bytes must be in range(0, 256)".to_string()));
}
}
Ok(data_bytes)
}
)
} else {
Ok(vec![])
};
match value {
Ok(val) => Ok(PyByteInner { elements: val }),
Err(err) => Err(err),
}
}
}
pub fn repr(&self) -> PyResult<String> {
let mut res = String::with_capacity(self.elements.len());
for i in self.elements.iter() {
match i {
0..=8 => res.push_str(&format!("\\x0{}", i)),
9 => res.push_str("\\t"),
10 => res.push_str("\\n"),
13 => res.push_str("\\r"),
32..=126 => res.push(*(i) as char),
_ => res.push_str(&format!("\\x{:x}", i)),
}
}
Ok(res)
}
pub fn len(&self) -> usize {
self.elements.len()
}
pub fn is_empty(&self) -> bool {
self.elements.len() == 0
}
pub fn eq(&self, other: &PyByteInner, vm: &VirtualMachine) -> PyResult {
if self.elements == other.elements {
Ok(vm.new_bool(true))
} else {
Ok(vm.new_bool(false))
}
}
pub fn ge(&self, other: &PyByteInner, vm: &VirtualMachine) -> PyResult {
if self.elements >= other.elements {
Ok(vm.new_bool(true))
} else {
Ok(vm.new_bool(false))
}
}
pub fn le(&self, other: &PyByteInner, vm: &VirtualMachine) -> PyResult {
if self.elements <= other.elements {
Ok(vm.new_bool(true))
} else {
Ok(vm.new_bool(false))
}
}
pub fn gt(&self, other: &PyByteInner, vm: &VirtualMachine) -> PyResult {
if self.elements > other.elements {
Ok(vm.new_bool(true))
} else {
Ok(vm.new_bool(false))
}
}
pub fn lt(&self, other: &PyByteInner, vm: &VirtualMachine) -> PyResult {
if self.elements < other.elements {
Ok(vm.new_bool(true))
} else {
Ok(vm.new_bool(false))
}
}
pub fn hash(&self) -> usize {
let mut hasher = DefaultHasher::new();
self.elements.hash(&mut hasher);
hasher.finish() as usize
}
pub fn add(&self, other: &PyByteInner, _vm: &VirtualMachine) -> Vec<u8> {
let elements: Vec<u8> = self
.elements
.iter()
.chain(other.elements.iter())
.cloned()
.collect();
elements
}
pub fn contains_bytes(&self, other: &PyByteInner, vm: &VirtualMachine) -> PyResult {
for (n, i) in self.elements.iter().enumerate() {
if n + other.len() <= self.len()
&& *i == other.elements[0]
&& &self.elements[n..n + other.len()] == other.elements.as_slice()
{
return Ok(vm.new_bool(true));
}
}
Ok(vm.new_bool(false))
}
pub fn contains_int(&self, int: &PyInt, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
if let Some(int) = int.as_bigint().to_u8() {
if self.elements.contains(&int) {
Ok(vm.new_bool(true))
} else {
Ok(vm.new_bool(false))
}
} else {
Err(vm.new_value_error("byte must be in range(0, 256)".to_string()))
}
}
pub fn getitem_int(&self, int: &PyInt, vm: &VirtualMachine) -> PyResult {
if let Some(idx) = self.elements.get_pos(int.as_bigint().to_i32().unwrap()) {
Ok(vm.new_int(self.elements[idx]))
} else {
Err(vm.new_index_error("index out of range".to_string()))
}
}
pub fn getitem_slice(&self, slice: &PyObjectRef, vm: &VirtualMachine) -> PyResult {
Ok(vm
.ctx
.new_bytes(self.elements.get_slice_items(vm, slice).unwrap()))
}
pub fn isalnum(&self, vm: &VirtualMachine) -> PyResult {
Ok(vm.new_bool(
!self.elements.is_empty()
&& self
.elements
.iter()
.all(|x| char::from(*x).is_alphanumeric()),
))
}
pub fn isalpha(&self, vm: &VirtualMachine) -> PyResult {
Ok(vm.new_bool(
!self.elements.is_empty()
&& self.elements.iter().all(|x| char::from(*x).is_alphabetic()),
))
}
pub fn isascii(&self, vm: &VirtualMachine) -> PyResult {
Ok(vm.new_bool(
!self.elements.is_empty() && self.elements.iter().all(|x| char::from(*x).is_ascii()),
))
}
pub fn isdigit(&self, vm: &VirtualMachine) -> PyResult {
Ok(vm.new_bool(
!self.elements.is_empty() && self.elements.iter().all(|x| char::from(*x).is_digit(10)),
))
}
pub fn islower(&self, vm: &VirtualMachine) -> PyResult {
Ok(vm.new_bool(
!self.elements.is_empty()
&& self
.elements
.iter()
.filter(|x| !char::from(**x).is_whitespace())
.all(|x| char::from(*x).is_lowercase()),
))
}
pub fn isspace(&self, vm: &VirtualMachine) -> PyResult {
Ok(vm.new_bool(
!self.elements.is_empty()
&& self.elements.iter().all(|x| char::from(*x).is_whitespace()),
))
}
pub fn isupper(&self, vm: &VirtualMachine) -> PyResult {
Ok(vm.new_bool(
!self.elements.is_empty()
&& self
.elements
.iter()
.filter(|x| !char::from(**x).is_whitespace())
.all(|x| char::from(*x).is_uppercase()),
))
}
pub fn istitle(&self, vm: &VirtualMachine) -> PyResult {
if self.elements.is_empty() {
return Ok(vm.new_bool(false));
}
let mut iter = self.elements.iter().peekable();
let mut prev_cased = false;
while let Some(c) = iter.next() {
let current = char::from(*c);
let next = if let Some(k) = iter.peek() {
char::from(**k)
} else if current.is_uppercase() {
return Ok(vm.new_bool(!prev_cased));
} else {
return Ok(vm.new_bool(prev_cased));
};
let is_cased = current.to_uppercase().next().unwrap() != current
|| current.to_lowercase().next().unwrap() != current;
if (is_cased && next.is_uppercase() && !prev_cased)
|| (!is_cased && next.is_lowercase())
{
return Ok(vm.new_bool(false));
}
prev_cased = is_cased;
}
Ok(vm.new_bool(true))
}
pub fn lower(&self, _vm: &VirtualMachine) -> Vec<u8> {
self.elements.to_ascii_lowercase()
}
pub fn upper(&self, _vm: &VirtualMachine) -> Vec<u8> {
self.elements.to_ascii_uppercase()
}
pub fn hex(&self, vm: &VirtualMachine) -> PyResult {
let bla = self
.elements
.iter()
.map(|x| format!("{:02x}", x))
.collect::<String>();
Ok(vm.ctx.new_str(bla))
}
pub fn fromhex(string: String, vm: &VirtualMachine) -> Result<Vec<u8>, PyObjectRef> {
// first check for invalid character
for (i, c) in string.char_indices() {
if !c.is_digit(16) && !c.is_whitespace() {
return Err(vm.new_value_error(format!(
"non-hexadecimal number found in fromhex() arg at position {}",
i
)));
}
}
// strip white spaces
let stripped = string.split_whitespace().collect::<String>();
// Hex is evaluated on 2 digits
if stripped.len() % 2 != 0 {
return Err(vm.new_value_error(format!(
"non-hexadecimal number found in fromhex() arg at position {}",
stripped.len() - 1
)));
}
// parse even string
Ok(stripped
.chars()
.collect::<Vec<char>>()
.chunks(2)
.map(|x| x.to_vec().iter().collect::<String>())
.map(|x| u8::from_str_radix(&x, 16))
.map(|x| x.unwrap())
.collect::<Vec<u8>>())
}
}

View File

@@ -1,27 +1,38 @@
use std::cell::Cell;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use crate::obj::objint::PyInt;
use crate::obj::objstr::PyString;
use crate::vm::VirtualMachine;
use core::cell::Cell;
use std::ops::Deref;
use num_traits::ToPrimitive;
use crate::function::OptionalArg;
use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue};
use crate::vm::VirtualMachine;
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue};
use super::objint;
use super::objbyteinner::PyByteInner;
use super::objiter;
use super::objslice::PySlice;
use super::objtype::PyClassRef;
#[derive(Debug)]
/// "bytes(iterable_of_ints) -> bytes\n\
/// bytes(string, encoding[, errors]) -> bytes\n\
/// bytes(bytes_or_buffer) -> immutable copy of bytes_or_buffer\n\
/// bytes(int) -> bytes object of size given by the parameter initialized with null bytes\n\
/// bytes() -> empty bytes object\n\nConstruct an immutable array of bytes from:\n \
/// - an iterable yielding integers in range(256)\n \
/// - a text string encoded using the specified encoding\n \
/// - any object implementing the buffer API.\n \
/// - an integer";
#[pyclass(name = "bytes", __inside_vm)]
#[derive(Clone, Debug)]
pub struct PyBytes {
value: Vec<u8>,
inner: PyByteInner,
}
pub type PyBytesRef = PyRef<PyBytes>;
impl PyBytes {
pub fn new(data: Vec<u8>) -> Self {
PyBytes { value: data }
pub fn new(elements: Vec<u8>) -> Self {
PyBytes {
inner: PyByteInner { elements },
}
}
}
@@ -29,7 +40,7 @@ impl Deref for PyBytes {
type Target = [u8];
fn deref(&self) -> &[u8] {
&self.value
&self.inner.elements
}
}
@@ -39,131 +50,180 @@ impl PyValue for PyBytes {
}
}
// Binary data support
pub fn get_value<'a>(obj: &'a PyObjectRef) -> impl Deref<Target = Vec<u8>> + 'a {
&obj.payload::<PyBytes>().unwrap().inner.elements
}
// Fill bytes class methods:
pub fn init(context: &PyContext) {
let bytes_doc =
"bytes(iterable_of_ints) -> bytes\n\
bytes(string, encoding[, errors]) -> bytes\n\
bytes(bytes_or_buffer) -> immutable copy of bytes_or_buffer\n\
bytes(int) -> bytes object of size given by the parameter initialized with null bytes\n\
bytes() -> empty bytes object\n\nConstruct an immutable array of bytes from:\n \
- an iterable yielding integers in range(256)\n \
- a text string encoded using the specified encoding\n \
- any object implementing the buffer API.\n \
- an integer";
extend_class!(context, &context.bytes_type, {
"__new__" => context.new_rustfunc(bytes_new),
"__eq__" => context.new_rustfunc(PyBytesRef::eq),
"__lt__" => context.new_rustfunc(PyBytesRef::lt),
"__le__" => context.new_rustfunc(PyBytesRef::le),
"__gt__" => context.new_rustfunc(PyBytesRef::gt),
"__ge__" => context.new_rustfunc(PyBytesRef::ge),
"__hash__" => context.new_rustfunc(PyBytesRef::hash),
"__repr__" => context.new_rustfunc(PyBytesRef::repr),
"__len__" => context.new_rustfunc(PyBytesRef::len),
"__iter__" => context.new_rustfunc(PyBytesRef::iter),
"__doc__" => context.new_str(bytes_doc.to_string())
});
PyBytesRef::extend_class(context, &context.bytes_type);
let bytes_type = &context.bytes_type;
extend_class!(context, bytes_type, {
"fromhex" => context.new_rustfunc(PyBytesRef::fromhex),
});
let bytesiterator_type = &context.bytesiterator_type;
extend_class!(context, bytesiterator_type, {
"__next__" => context.new_rustfunc(PyBytesIteratorRef::next),
"__iter__" => context.new_rustfunc(PyBytesIteratorRef::iter),
});
}
fn bytes_new(
cls: PyClassRef,
val_option: OptionalArg<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<PyBytesRef> {
// Create bytes data:
let value = if let OptionalArg::Present(ival) = val_option {
let elements = vm.extract_elements(&ival)?;
let mut data_bytes = vec![];
for elem in elements.iter() {
let v = objint::to_int(vm, elem, 10)?;
data_bytes.push(v.to_u8().unwrap());
}
data_bytes
// return Err(vm.new_type_error("Cannot construct bytes".to_string()));
} else {
vec![]
};
PyBytes::new(value).into_ref_with_type(vm, cls)
"__next__" => context.new_rustfunc(PyBytesIteratorRef::next),
"__iter__" => context.new_rustfunc(PyBytesIteratorRef::iter),
});
}
#[pyimpl(__inside_vm)]
impl PyBytesRef {
fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if let Ok(other) = other.downcast::<PyBytes>() {
vm.ctx.new_bool(self.value == other.value)
} else {
vm.ctx.not_implemented()
#[pymethod(name = "__new__")]
fn bytes_new(
cls: PyClassRef,
val_option: OptionalArg<PyObjectRef>,
enc_option: OptionalArg<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<PyBytesRef> {
PyBytes {
inner: PyByteInner::new(val_option, enc_option, vm)?,
}
.into_ref_with_type(vm, cls)
}
fn ge(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if let Ok(other) = other.downcast::<PyBytes>() {
vm.ctx.new_bool(self.value >= other.value)
} else {
vm.ctx.not_implemented()
}
}
fn gt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if let Ok(other) = other.downcast::<PyBytes>() {
vm.ctx.new_bool(self.value > other.value)
} else {
vm.ctx.not_implemented()
}
}
fn le(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if let Ok(other) = other.downcast::<PyBytes>() {
vm.ctx.new_bool(self.value <= other.value)
} else {
vm.ctx.not_implemented()
}
}
fn lt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if let Ok(other) = other.downcast::<PyBytes>() {
vm.ctx.new_bool(self.value < other.value)
} else {
vm.ctx.not_implemented()
}
#[pymethod(name = "__repr__")]
fn repr(self, vm: &VirtualMachine) -> PyResult {
Ok(vm.new_str(format!("b'{}'", self.inner.repr()?)))
}
#[pymethod(name = "__len__")]
fn len(self, _vm: &VirtualMachine) -> usize {
self.value.len()
self.inner.len()
}
fn hash(self, _vm: &VirtualMachine) -> u64 {
let mut hasher = DefaultHasher::new();
self.value.hash(&mut hasher);
hasher.finish()
#[pymethod(name = "__eq__")]
fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
match_class!(other,
bytes @ PyBytes => self.inner.eq(&bytes.inner, vm),
_ => Ok(vm.ctx.not_implemented()))
}
fn repr(self, _vm: &VirtualMachine) -> String {
// TODO: don't just unwrap
let data = String::from_utf8(self.value.clone()).unwrap();
format!("b'{}'", data)
#[pymethod(name = "__ge__")]
fn ge(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
match_class!(other,
bytes @ PyBytes => self.inner.ge(&bytes.inner, vm),
_ => Ok(vm.ctx.not_implemented()))
}
#[pymethod(name = "__le__")]
fn le(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
match_class!(other,
bytes @ PyBytes => self.inner.le(&bytes.inner, vm),
_ => Ok(vm.ctx.not_implemented()))
}
#[pymethod(name = "__gt__")]
fn gt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
match_class!(other,
bytes @ PyBytes => self.inner.gt(&bytes.inner, vm),
_ => Ok(vm.ctx.not_implemented()))
}
#[pymethod(name = "__lt__")]
fn lt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
match_class!(other,
bytes @ PyBytes => self.inner.lt(&bytes.inner, vm),
_ => Ok(vm.ctx.not_implemented()))
}
#[pymethod(name = "__hash__")]
fn hash(self, _vm: &VirtualMachine) -> usize {
self.inner.hash()
}
#[pymethod(name = "__iter__")]
fn iter(self, _vm: &VirtualMachine) -> PyBytesIterator {
PyBytesIterator {
position: Cell::new(0),
bytes: self,
}
}
}
pub fn get_value<'a>(obj: &'a PyObjectRef) -> impl Deref<Target = Vec<u8>> + 'a {
&obj.payload::<PyBytes>().unwrap().value
#[pymethod(name = "__add__")]
fn add(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
match_class!(other,
bytes @ PyBytes => Ok(vm.ctx.new_bytes(self.inner.add(&bytes.inner, vm))),
_ => Ok(vm.ctx.not_implemented()))
}
#[pymethod(name = "__contains__")]
fn contains(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
match_class!(needle,
bytes @ PyBytes => self.inner.contains_bytes(&bytes.inner, vm),
int @ PyInt => self.inner.contains_int(&int, vm),
obj => Err(vm.new_type_error(format!("a bytes-like object is required, not {}", obj))))
}
#[pymethod(name = "__getitem__")]
fn getitem(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
match_class!(needle,
int @ PyInt => self.inner.getitem_int(&int, vm),
slice @ PySlice => self.inner.getitem_slice(slice.as_object(), vm),
obj => Err(vm.new_type_error(format!("byte indices must be integers or slices, not {}", obj))))
}
#[pymethod(name = "isalnum")]
fn isalnum(self, vm: &VirtualMachine) -> PyResult {
self.inner.isalnum(vm)
}
#[pymethod(name = "isalpha")]
fn isalpha(self, vm: &VirtualMachine) -> PyResult {
self.inner.isalpha(vm)
}
#[pymethod(name = "isascii")]
fn isascii(self, vm: &VirtualMachine) -> PyResult {
self.inner.isascii(vm)
}
#[pymethod(name = "isdigit")]
fn isdigit(self, vm: &VirtualMachine) -> PyResult {
self.inner.isdigit(vm)
}
#[pymethod(name = "islower")]
fn islower(self, vm: &VirtualMachine) -> PyResult {
self.inner.islower(vm)
}
#[pymethod(name = "isspace")]
fn isspace(self, vm: &VirtualMachine) -> PyResult {
self.inner.isspace(vm)
}
#[pymethod(name = "isupper")]
fn isupper(self, vm: &VirtualMachine) -> PyResult {
self.inner.isupper(vm)
}
#[pymethod(name = "istitle")]
fn istitle(self, vm: &VirtualMachine) -> PyResult {
self.inner.istitle(vm)
}
#[pymethod(name = "lower")]
fn lower(self, vm: &VirtualMachine) -> PyResult {
Ok(vm.ctx.new_bytes(self.inner.lower(vm)))
}
#[pymethod(name = "upper")]
fn upper(self, vm: &VirtualMachine) -> PyResult {
Ok(vm.ctx.new_bytes(self.inner.upper(vm)))
}
#[pymethod(name = "hex")]
fn hex(self, vm: &VirtualMachine) -> PyResult {
self.inner.hex(vm)
}
// #[pymethod(name = "fromhex")]
fn fromhex(string: PyObjectRef, vm: &VirtualMachine) -> PyResult {
match_class!(string,
s @ PyString => {
match PyByteInner::fromhex(s.to_string(), vm) {
Ok(x) => Ok(vm.ctx.new_bytes(x)),
Err(y) => Err(y)}},
obj => Err(vm.new_type_error(format!("fromhex() argument must be str, not {}", obj )))
)
}
}
#[derive(Debug)]
@@ -182,7 +242,7 @@ type PyBytesIteratorRef = PyRef<PyBytesIterator>;
impl PyBytesIteratorRef {
fn next(self, vm: &VirtualMachine) -> PyResult<u8> {
if self.position.get() < self.bytes.value.len() {
if self.position.get() < self.bytes.inner.len() {
let ret = self.bytes[self.position.get()];
self.position.set(self.position.get() + 1);
Ok(ret)

View File

@@ -296,6 +296,11 @@ macro_rules! dict_iterator {
fn iter(&self, _vm: &VirtualMachine) -> $iter_name {
$iter_name::new(self.dict.clone())
}
#[pymethod(name = "__len__")]
fn len(&self, vm: &VirtualMachine) -> usize {
self.dict.clone().len(vm)
}
}
impl PyValue for $name {

View File

@@ -8,7 +8,8 @@ use num_traits::{Pow, Signed, ToPrimitive, Zero};
use crate::format::FormatSpec;
use crate::function::OptionalArg;
use crate::pyobject::{
IntoPyObject, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
TypeProtocol,
};
use crate::vm::VirtualMachine;
@@ -17,6 +18,21 @@ use super::objstr::{PyString, PyStringRef};
use super::objtype;
use crate::obj::objtype::PyClassRef;
/// int(x=0) -> integer
/// int(x, base=10) -> integer
///
/// Convert a number or string to an integer, or return 0 if no arguments
/// are given. If x is a number, return x.__int__(). For floating point
/// numbers, this truncates towards zero.
///
/// If x is not a number or if base is given, then x must be a string,
/// bytes, or bytearray instance representing an integer literal in the
/// given base. The literal can be preceded by '+' or '-' and be surrounded
/// by whitespace. The base defaults to 10. Valid bases are 0 and 2-36.
/// Base 0 means to interpret the base from the string as an integer literal.
/// >>> int('0b100', base=0)
/// 4
#[pyclass(__inside_vm)]
#[derive(Debug)]
pub struct PyInt {
value: BigInt,
@@ -95,7 +111,9 @@ impl_try_from_object_int!(
(u64, to_u64),
);
#[pyimpl(__inside_vm)]
impl PyInt {
#[pymethod(name = "__eq__")]
fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
vm.ctx.new_bool(self.value == *get_value(&other))
@@ -104,6 +122,7 @@ impl PyInt {
}
}
#[pymethod(name = "__ne__")]
fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
vm.ctx.new_bool(self.value != *get_value(&other))
@@ -112,6 +131,7 @@ impl PyInt {
}
}
#[pymethod(name = "__lt__")]
fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
vm.ctx.new_bool(self.value < *get_value(&other))
@@ -120,6 +140,7 @@ impl PyInt {
}
}
#[pymethod(name = "__le__")]
fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
vm.ctx.new_bool(self.value <= *get_value(&other))
@@ -128,6 +149,7 @@ impl PyInt {
}
}
#[pymethod(name = "__gt__")]
fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
vm.ctx.new_bool(self.value > *get_value(&other))
@@ -136,6 +158,7 @@ impl PyInt {
}
}
#[pymethod(name = "__ge__")]
fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
vm.ctx.new_bool(self.value >= *get_value(&other))
@@ -144,6 +167,7 @@ impl PyInt {
}
}
#[pymethod(name = "__add__")]
fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
vm.ctx.new_int((&self.value) + get_value(&other))
@@ -152,6 +176,12 @@ impl PyInt {
}
}
#[pymethod(name = "__radd__")]
fn radd(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
self.add(other, vm)
}
#[pymethod(name = "__sub__")]
fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
vm.ctx.new_int((&self.value) - get_value(&other))
@@ -160,6 +190,7 @@ impl PyInt {
}
}
#[pymethod(name = "__rsub__")]
fn rsub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
vm.ctx.new_int(get_value(&other) - (&self.value))
@@ -168,6 +199,7 @@ impl PyInt {
}
}
#[pymethod(name = "__mul__")]
fn mul(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
vm.ctx.new_int((&self.value) * get_value(&other))
@@ -176,6 +208,12 @@ impl PyInt {
}
}
#[pymethod(name = "__rmul__")]
fn rmul(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
self.mul(other, vm)
}
#[pymethod(name = "__truediv__")]
fn truediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
div_ints(vm, &self.value, &get_value(&other))
@@ -184,6 +222,7 @@ impl PyInt {
}
}
#[pymethod(name = "__rtruediv__")]
fn rtruediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
div_ints(vm, &get_value(&other), &self.value)
@@ -192,6 +231,7 @@ impl PyInt {
}
}
#[pymethod(name = "__floordiv__")]
fn floordiv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
let v2 = get_value(&other);
@@ -205,6 +245,7 @@ impl PyInt {
}
}
#[pymethod(name = "__lshift__")]
fn lshift(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
if !objtype::isinstance(&other, &vm.ctx.int_type()) {
return Ok(vm.ctx.not_implemented());
@@ -224,6 +265,7 @@ impl PyInt {
}
}
#[pymethod(name = "__rshift__")]
fn rshift(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
if !objtype::isinstance(&other, &vm.ctx.int_type()) {
return Ok(vm.ctx.not_implemented());
@@ -243,6 +285,7 @@ impl PyInt {
}
}
#[pymethod(name = "__xor__")]
fn xor(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
vm.ctx.new_int((&self.value) ^ get_value(&other))
@@ -251,6 +294,7 @@ impl PyInt {
}
}
#[pymethod(name = "__rxor__")]
fn rxor(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
vm.ctx.new_int(get_value(&other) ^ (&self.value))
@@ -259,6 +303,7 @@ impl PyInt {
}
}
#[pymethod(name = "__or__")]
fn or(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
vm.ctx.new_int((&self.value) | get_value(&other))
@@ -267,6 +312,7 @@ impl PyInt {
}
}
#[pymethod(name = "__and__")]
fn and(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
let v2 = get_value(&other);
@@ -276,6 +322,7 @@ impl PyInt {
}
}
#[pymethod(name = "__pow__")]
fn pow(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
let v2 = get_value(&other).to_u32().unwrap();
@@ -288,6 +335,7 @@ impl PyInt {
}
}
#[pymethod(name = "__mod__")]
fn mod_(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
let v2 = get_value(&other);
@@ -301,6 +349,7 @@ impl PyInt {
}
}
#[pymethod(name = "__divmod__")]
fn divmod(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
if objtype::isinstance(&other, &vm.ctx.int_type()) {
let v2 = get_value(&other);
@@ -317,20 +366,24 @@ impl PyInt {
}
}
#[pymethod(name = "__neg__")]
fn neg(&self, _vm: &VirtualMachine) -> BigInt {
-(&self.value)
}
#[pymethod(name = "__hash__")]
fn hash(&self, _vm: &VirtualMachine) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
self.value.hash(&mut hasher);
hasher.finish()
}
#[pymethod(name = "__abs__")]
fn abs(&self, _vm: &VirtualMachine) -> BigInt {
self.value.abs()
}
#[pymethod(name = "__round__")]
fn round(
zelf: PyRef<Self>,
_precision: OptionalArg<PyObjectRef>,
@@ -339,14 +392,17 @@ impl PyInt {
zelf
}
#[pymethod(name = "__int__")]
fn int(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyIntRef {
zelf
}
#[pymethod(name = "__pos__")]
fn pos(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyIntRef {
zelf
}
#[pymethod(name = "__float__")]
fn float(&self, vm: &VirtualMachine) -> PyResult<PyFloat> {
self.value
.to_f64()
@@ -354,30 +410,37 @@ impl PyInt {
.ok_or_else(|| vm.new_overflow_error("int too large to convert to float".to_string()))
}
#[pymethod(name = "__trunc__")]
fn trunc(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyIntRef {
zelf
}
#[pymethod(name = "__floor__")]
fn floor(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyIntRef {
zelf
}
#[pymethod(name = "__ceil__")]
fn ceil(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyIntRef {
zelf
}
#[pymethod(name = "__index__")]
fn index(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyIntRef {
zelf
}
#[pymethod(name = "__invert__")]
fn invert(&self, _vm: &VirtualMachine) -> BigInt {
!(&self.value)
}
#[pymethod(name = "__repr__")]
fn repr(&self, _vm: &VirtualMachine) -> String {
self.value.to_string()
}
#[pymethod(name = "__format__")]
fn format(&self, spec: PyStringRef, vm: &VirtualMachine) -> PyResult<String> {
let format_spec = FormatSpec::parse(&spec.value);
match format_spec.format_int(&self.value) {
@@ -386,22 +449,27 @@ impl PyInt {
}
}
#[pymethod(name = "__bool__")]
fn bool(&self, _vm: &VirtualMachine) -> bool {
!self.value.is_zero()
}
#[pymethod]
fn bit_length(&self, _vm: &VirtualMachine) -> usize {
self.value.bits()
}
#[pymethod]
fn conjugate(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyIntRef {
zelf
}
#[pyproperty]
fn real(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyIntRef {
zelf
}
#[pyproperty]
fn imag(&self, _vm: &VirtualMachine) -> usize {
0
}
@@ -500,68 +568,9 @@ fn div_ints(vm: &VirtualMachine, i1: &BigInt, i2: &BigInt) -> PyResult {
}
}
#[rustfmt::skip] // to avoid line splitting
pub fn init(context: &PyContext) {
let int_doc = "int(x=0) -> integer
int(x, base=10) -> integer
Convert a number or string to an integer, or return 0 if no arguments
are given. If x is a number, return x.__int__(). For floating point
numbers, this truncates towards zero.
If x is not a number or if base is given, then x must be a string,
bytes, or bytearray instance representing an integer literal in the
given base. The literal can be preceded by '+' or '-' and be surrounded
by whitespace. The base defaults to 10. Valid bases are 0 and 2-36.
Base 0 means to interpret the base from the string as an integer literal.
>>> int('0b100', base=0)
4";
let int_type = &context.int_type;
extend_class!(context, int_type, {
"__doc__" => context.new_str(int_doc.to_string()),
"__eq__" => context.new_rustfunc(PyInt::eq),
"__ne__" => context.new_rustfunc(PyInt::ne),
"__lt__" => context.new_rustfunc(PyInt::lt),
"__le__" => context.new_rustfunc(PyInt::le),
"__gt__" => context.new_rustfunc(PyInt::gt),
"__ge__" => context.new_rustfunc(PyInt::ge),
"__abs__" => context.new_rustfunc(PyInt::abs),
"__add__" => context.new_rustfunc(PyInt::add),
"__radd__" => context.new_rustfunc(PyInt::add),
"__and__" => context.new_rustfunc(PyInt::and),
"__divmod__" => context.new_rustfunc(PyInt::divmod),
"__float__" => context.new_rustfunc(PyInt::float),
"__round__" => context.new_rustfunc(PyInt::round),
"__ceil__" => context.new_rustfunc(PyInt::ceil),
"__floor__" => context.new_rustfunc(PyInt::floor),
"__index__" => context.new_rustfunc(PyInt::index),
"__trunc__" => context.new_rustfunc(PyInt::trunc),
"__int__" => context.new_rustfunc(PyInt::int),
"__floordiv__" => context.new_rustfunc(PyInt::floordiv),
"__hash__" => context.new_rustfunc(PyInt::hash),
"__lshift__" => context.new_rustfunc(PyInt::lshift),
"__rshift__" => context.new_rustfunc(PyInt::rshift),
PyInt::extend_class(context, &context.int_type);
extend_class!(context, &context.int_type, {
"__new__" => context.new_rustfunc(int_new),
"__mod__" => context.new_rustfunc(PyInt::mod_),
"__mul__" => context.new_rustfunc(PyInt::mul),
"__rmul__" => context.new_rustfunc(PyInt::mul),
"__or__" => context.new_rustfunc(PyInt::or),
"__neg__" => context.new_rustfunc(PyInt::neg),
"__pos__" => context.new_rustfunc(PyInt::pos),
"__pow__" => context.new_rustfunc(PyInt::pow),
"__repr__" => context.new_rustfunc(PyInt::repr),
"__sub__" => context.new_rustfunc(PyInt::sub),
"__rsub__" => context.new_rustfunc(PyInt::rsub),
"__format__" => context.new_rustfunc(PyInt::format),
"__truediv__" => context.new_rustfunc(PyInt::truediv),
"__rtruediv__" => context.new_rustfunc(PyInt::rtruediv),
"__xor__" => context.new_rustfunc(PyInt::xor),
"__rxor__" => context.new_rustfunc(PyInt::rxor),
"__bool__" => context.new_rustfunc(PyInt::bool),
"__invert__" => context.new_rustfunc(PyInt::invert),
"bit_length" => context.new_rustfunc(PyInt::bit_length),
"conjugate" => context.new_rustfunc(PyInt::conjugate),
"real" => context.new_property(PyInt::real),
"imag" => context.new_property(PyInt::imag)
});
}