use std::cell::Cell; use std::mem::size_of; use std::ops::Deref; use super::objbyteinner::{ ByteInnerExpandtabsOptions, ByteInnerFindOptions, ByteInnerNewOptions, ByteInnerPaddingOptions, ByteInnerPosition, ByteInnerSplitOptions, ByteInnerSplitlinesOptions, ByteInnerTranslateOptions, PyByteInner, }; use super::objint::PyIntRef; use super::objiter; use super::objslice::PySliceRef; use super::objstr::{PyString, PyStringRef}; use super::objtuple::PyTupleRef; use super::objtype::PyClassRef; use crate::cformat::CFormatString; use crate::function::OptionalArg; use crate::obj::objstr::do_cformat_string; use crate::pyhash; use crate::pyobject::{ Either, IntoPyObject, PyArithmaticValue::{self, *}, PyClassImpl, PyComparisonValue, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; use std::str::FromStr; /// "bytes(iterable_of_ints) -> bytes\n\ /// bytes(string, encoding[, errors]) -> bytes\n\ /// bytes(bytes_or_buffer) -> immutable copy of bytes_or_buffer\n\ /// bytes(int) -> bytes object of size given by the parameter initialized with null bytes\n\ /// bytes() -> empty bytes object\n\nConstruct an immutable array of bytes from:\n \ /// - an iterable yielding integers in range(256)\n \ /// - a text string encoded using the specified encoding\n \ /// - any object implementing the buffer API.\n \ /// - an integer"; #[pyclass(name = "bytes")] #[derive(Clone, Debug)] pub struct PyBytes { inner: PyByteInner, } pub type PyBytesRef = PyRef; impl PyBytes { pub fn new(elements: Vec) -> Self { PyBytes { inner: PyByteInner { elements }, } } pub fn get_value(&self) -> &[u8] { &self.inner.elements } } impl From> for PyBytes { fn from(elements: Vec) -> PyBytes { PyBytes::new(elements) } } impl IntoPyObject for Vec { fn into_pyobject(self, vm: &VirtualMachine) -> PyResult { Ok(vm.ctx.new_bytes(self)) } } impl Deref for PyBytes { type Target = [u8]; fn deref(&self) -> &[u8] { &self.inner.elements } } impl PyValue for PyBytes { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.bytes_type() } } pub(crate) fn init(context: &PyContext) { PyBytes::extend_class(context, &context.types.bytes_type); let bytes_type = &context.types.bytes_type; extend_class!(context, bytes_type, { "maketrans" => context.new_method(PyByteInner::maketrans), }); PyBytesIterator::extend_class(context, &context.types.bytesiterator_type); } #[pyimpl] impl PyBytes { #[pyslot] fn tp_new( cls: PyClassRef, options: ByteInnerNewOptions, vm: &VirtualMachine, ) -> PyResult { PyBytes { inner: options.get_value(vm)?, } .into_ref_with_type(vm, cls) } #[pymethod(name = "__repr__")] fn repr(&self, vm: &VirtualMachine) -> PyResult { Ok(vm.new_str(format!("b'{}'", self.inner.repr()?))) } #[pymethod(name = "__len__")] fn len(&self, _vm: &VirtualMachine) -> usize { self.inner.len() } #[pymethod(name = "__eq__")] fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { self.inner.eq(other, vm) } #[pymethod(name = "__ge__")] fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { self.inner.ge(other, vm) } #[pymethod(name = "__le__")] fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { self.inner.le(other, vm) } #[pymethod(name = "__gt__")] fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { self.inner.gt(other, vm) } #[pymethod(name = "__lt__")] fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { self.inner.lt(other, vm) } #[pymethod(name = "__hash__")] fn hash(&self, _vm: &VirtualMachine) -> pyhash::PyHash { self.inner.hash() } #[pymethod(name = "__iter__")] fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyBytesIterator { PyBytesIterator { position: Cell::new(0), bytes: zelf, } } #[pymethod(name = "__sizeof__")] fn sizeof(&self, _vm: &VirtualMachine) -> PyResult { Ok(size_of::() + self.inner.elements.len() * size_of::()) } #[pymethod(name = "__add__")] fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyArithmaticValue { if let Ok(other) = PyByteInner::try_from_object(vm, other) { Implemented(self.inner.add(other).into()) } else { NotImplemented } } #[pymethod(name = "__contains__")] fn contains( &self, needle: Either, vm: &VirtualMachine, ) -> PyResult { self.inner.contains(needle, vm) } #[pymethod(name = "__getitem__")] fn getitem(&self, needle: Either, vm: &VirtualMachine) -> PyResult { self.inner.getitem(needle, vm) } #[pymethod(name = "isalnum")] fn isalnum(&self, _vm: &VirtualMachine) -> bool { self.inner.isalnum() } #[pymethod(name = "isalpha")] fn isalpha(&self, _vm: &VirtualMachine) -> bool { self.inner.isalpha() } #[pymethod(name = "isascii")] fn isascii(&self, _vm: &VirtualMachine) -> bool { self.inner.isascii() } #[pymethod(name = "isdigit")] fn isdigit(&self, _vm: &VirtualMachine) -> bool { self.inner.isdigit() } #[pymethod(name = "islower")] fn islower(&self, _vm: &VirtualMachine) -> bool { self.inner.islower() } #[pymethod(name = "isspace")] fn isspace(&self, _vm: &VirtualMachine) -> bool { self.inner.isspace() } #[pymethod(name = "isupper")] fn isupper(&self, _vm: &VirtualMachine) -> bool { self.inner.isupper() } #[pymethod(name = "istitle")] fn istitle(&self, _vm: &VirtualMachine) -> bool { self.inner.istitle() } #[pymethod(name = "lower")] fn lower(&self, _vm: &VirtualMachine) -> PyBytes { self.inner.lower().into() } #[pymethod(name = "upper")] fn upper(&self, _vm: &VirtualMachine) -> PyBytes { self.inner.upper().into() } #[pymethod(name = "capitalize")] fn capitalize(&self, _vm: &VirtualMachine) -> PyBytes { self.inner.capitalize().into() } #[pymethod(name = "swapcase")] fn swapcase(&self, _vm: &VirtualMachine) -> PyBytes { self.inner.swapcase().into() } #[pymethod(name = "hex")] fn hex(&self, _vm: &VirtualMachine) -> String { self.inner.hex() } #[pymethod] fn fromhex(string: PyStringRef, vm: &VirtualMachine) -> PyResult { Ok(PyByteInner::fromhex(string.as_str(), vm)?.into()) } #[pymethod(name = "center")] fn center(&self, options: ByteInnerPaddingOptions, vm: &VirtualMachine) -> PyResult { Ok(self.inner.center(options, vm)?.into()) } #[pymethod(name = "ljust")] fn ljust(&self, options: ByteInnerPaddingOptions, vm: &VirtualMachine) -> PyResult { Ok(self.inner.ljust(options, vm)?.into()) } #[pymethod(name = "rjust")] fn rjust(&self, options: ByteInnerPaddingOptions, vm: &VirtualMachine) -> PyResult { Ok(self.inner.rjust(options, vm)?.into()) } #[pymethod(name = "count")] fn count(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { self.inner.count(options, vm) } #[pymethod(name = "join")] fn join(&self, iter: PyIterable, vm: &VirtualMachine) -> PyResult { Ok(self.inner.join(iter, vm)?.into()) } #[pymethod(name = "endswith")] fn endswith( &self, suffix: Either, start: OptionalArg, end: OptionalArg, vm: &VirtualMachine, ) -> PyResult { self.inner.startsendswith(suffix, start, end, true, vm) } #[pymethod(name = "startswith")] fn startswith( &self, prefix: Either, start: OptionalArg, end: OptionalArg, vm: &VirtualMachine, ) -> PyResult { self.inner.startsendswith(prefix, start, end, false, vm) } #[pymethod(name = "find")] fn find(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { self.inner.find(options, false, vm) } #[pymethod(name = "index")] fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { let res = self.inner.find(options, false, vm)?; if res == -1 { return Err(vm.new_value_error("substring not found".to_string())); } Ok(res) } #[pymethod(name = "rfind")] fn rfind(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { self.inner.find(options, true, vm) } #[pymethod(name = "rindex")] fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { let res = self.inner.find(options, true, vm)?; if res == -1 { return Err(vm.new_value_error("substring not found".to_string())); } Ok(res) } #[pymethod(name = "translate")] fn translate( &self, options: ByteInnerTranslateOptions, vm: &VirtualMachine, ) -> PyResult { Ok(self.inner.translate(options, vm)?.into()) } #[pymethod(name = "strip")] fn strip(&self, chars: OptionalArg, _vm: &VirtualMachine) -> PyResult { Ok(self.inner.strip(chars, ByteInnerPosition::All)?.into()) } #[pymethod(name = "lstrip")] fn lstrip(&self, chars: OptionalArg, _vm: &VirtualMachine) -> PyResult { Ok(self.inner.strip(chars, ByteInnerPosition::Left)?.into()) } #[pymethod(name = "rstrip")] fn rstrip(&self, chars: OptionalArg, _vm: &VirtualMachine) -> PyResult { Ok(self.inner.strip(chars, ByteInnerPosition::Right)?.into()) } #[pymethod(name = "split")] fn split(&self, options: ByteInnerSplitOptions, vm: &VirtualMachine) -> PyResult { let as_bytes = self .inner .split(options, false)? .iter() .map(|x| vm.ctx.new_bytes(x.to_vec())) .collect::>(); Ok(vm.ctx.new_list(as_bytes)) } #[pymethod(name = "rsplit")] fn rsplit(&self, options: ByteInnerSplitOptions, vm: &VirtualMachine) -> PyResult { let as_bytes = self .inner .split(options, true)? .iter() .map(|x| vm.ctx.new_bytes(x.to_vec())) .collect::>(); Ok(vm.ctx.new_list(as_bytes)) } #[pymethod(name = "partition")] fn partition(&self, sep: PyObjectRef, vm: &VirtualMachine) -> PyResult { let sepa = PyByteInner::try_from_object(vm, sep.clone())?; let (left, right) = self.inner.partition(&sepa, false)?; Ok(vm .ctx .new_tuple(vec![vm.ctx.new_bytes(left), sep, vm.ctx.new_bytes(right)])) } #[pymethod(name = "rpartition")] fn rpartition(&self, sep: PyObjectRef, vm: &VirtualMachine) -> PyResult { let sepa = PyByteInner::try_from_object(vm, sep.clone())?; let (left, right) = self.inner.partition(&sepa, true)?; Ok(vm .ctx .new_tuple(vec![vm.ctx.new_bytes(left), sep, vm.ctx.new_bytes(right)])) } #[pymethod(name = "expandtabs")] fn expandtabs(&self, options: ByteInnerExpandtabsOptions, _vm: &VirtualMachine) -> PyBytes { self.inner.expandtabs(options).into() } #[pymethod(name = "splitlines")] fn splitlines(&self, options: ByteInnerSplitlinesOptions, vm: &VirtualMachine) -> PyResult { let as_bytes = self .inner .splitlines(options) .iter() .map(|x| vm.ctx.new_bytes(x.to_vec())) .collect::>(); Ok(vm.ctx.new_list(as_bytes)) } #[pymethod(name = "zfill")] fn zfill(&self, width: PyIntRef, _vm: &VirtualMachine) -> PyBytes { self.inner.zfill(width).into() } #[pymethod(name = "replace")] fn replace( &self, old: PyByteInner, new: PyByteInner, count: OptionalArg, _vm: &VirtualMachine, ) -> PyResult { Ok(self.inner.replace(old, new, count)?.into()) } #[pymethod(name = "title")] fn title(&self, _vm: &VirtualMachine) -> PyBytes { self.inner.title().into() } #[pymethod(name = "__mul__")] fn repeat(&self, n: isize, _vm: &VirtualMachine) -> PyBytes { self.inner.repeat(n).into() } #[pymethod(name = "__rmul__")] fn rmul(&self, n: isize, vm: &VirtualMachine) -> PyBytes { self.repeat(n, vm) } fn do_cformat( &self, vm: &VirtualMachine, format_string: CFormatString, values_obj: PyObjectRef, ) -> PyResult { let final_string = do_cformat_string(vm, format_string, values_obj)?; Ok(vm .ctx .new_bytes(final_string.as_str().as_bytes().to_owned())) } #[pymethod(name = "__mod__")] fn modulo(&self, values: PyObjectRef, vm: &VirtualMachine) -> PyResult { let format_string_text = std::str::from_utf8(&self.inner.elements).unwrap(); let format_string = CFormatString::from_str(format_string_text) .map_err(|err| vm.new_value_error(err.to_string()))?; self.do_cformat(vm, format_string, values.clone()) } #[pymethod(name = "__rmod__")] fn rmod(&self, _values: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { vm.ctx.not_implemented() } /// Return a string decoded from the given bytes. /// Default encoding is 'utf-8'. /// Default errors is 'strict', meaning that encoding errors raise a UnicodeError. /// Other possible values are 'ignore', 'replace' /// For a list of possible encodings, /// see https://docs.python.org/3/library/codecs.html#standard-encodings /// currently, only 'utf-8' and 'ascii' emplemented #[pymethod(name = "decode")] fn decode( zelf: PyRef, encoding: OptionalArg, errors: OptionalArg, vm: &VirtualMachine, ) -> PyResult { let encoding = encoding.into_option(); vm.decode(zelf.into_object(), encoding.clone(), errors.into_option())? .downcast::() .map_err(|obj| { vm.new_type_error(format!( "'{}' decoder returned '{}' instead of 'str'; use codecs.encode() to \ encode arbitrary types", encoding.as_ref().map_or("utf-8", |s| s.as_str()), obj.class().name, )) }) } } #[pyclass] #[derive(Debug)] pub struct PyBytesIterator { position: Cell, bytes: PyBytesRef, } impl PyValue for PyBytesIterator { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.bytesiterator_type() } } #[pyimpl] impl PyBytesIterator { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { if self.position.get() < self.bytes.inner.len() { let ret = self.bytes[self.position.get()]; self.position.set(self.position.get() + 1); Ok(ret) } else { Err(objiter::new_stop_iteration(vm)) } } #[pymethod(name = "__iter__")] fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { zelf } }