diff --git a/tests/snippets/bytes.py b/tests/snippets/bytes.py index ec8001213..755707a76 100644 --- a/tests/snippets/bytes.py +++ b/tests/snippets/bytes.py @@ -162,7 +162,7 @@ with assertRaises(TypeError): with assertRaises(TypeError): b"b".center(2, b"ba") b"kok".center(5, bytearray(b"x")) -b"kok".center(-5,) +b"kok".center(-5) # count assert b"azeazerazeazopia".count(b"aze") == 3 @@ -191,3 +191,19 @@ assert ( ) with assertRaises(TypeError): b"".join((b"km", "kl")) + + +# endswith startswith +assert b"abcde".endswith(b"de") +assert b"abcde".endswith(b"") +assert not b"abcde".endswith(b"zx") +assert b"abcde".endswith(b"bc", 0, 3) +assert not b"abcde".endswith(b"bc", 2, 3) +assert b"abcde".endswith((b"c", b"de")) + +assert b"abcde".startswith(b"ab") +assert b"abcde".startswith(b"") +assert not b"abcde".startswith(b"zx") +assert b"abcde".startswith(b"cd", 2) +assert not b"abcde".startswith(b"cd", 1, 4) +assert b"abcde".startswith((b"a", b"bc")) diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index 26f33c2e6..cc4111a8b 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -12,7 +12,8 @@ use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use super::objint; -use super::objsequence::PySliceableSequence; +use super::objtype; +use super::objsequence::{PySliceableSequence, is_valid_slice_arg}; use crate::obj::objint::PyInt; use num_integer::Integer; use num_traits::ToPrimitive; @@ -21,6 +22,7 @@ use super::objbytearray::{get_value as get_value_bytearray, PyByteArray}; use super::objbytes::PyBytes; use super::objmemory::PyMemoryView; use super::objnone::PyNone; +use super::objsequence; #[derive(Debug, Default, Clone)] pub struct PyByteInner { @@ -495,6 +497,61 @@ impl PyByteInner { Ok(vm.ctx.new_bytes(refs)) } + + pub fn startsendswith( + &self, + arg: PyObjectRef, + start: OptionalArg, + end: OptionalArg, + endswith: bool, // true for endswith, false for startswith + vm: &VirtualMachine, + ) -> PyResult { + let suff = if objtype::isinstance(&arg, &vm.ctx.tuple_type()) { + let mut flatten = vec![]; + for v in objsequence::get_elements(&arg).to_vec() { + match try_as_bytes_like(&v) { + None => { + return Err(vm.new_type_error(format!( + "a bytes-like object is required, not {}", + &v.class().name, + ))); + } + Some(value) => flatten.extend(value), + } + } + flatten + } else { + match try_as_bytes_like(&arg) { + Some(value) => value, + None => { + return Err(vm.new_type_error(format!( + "endswith first arg must be bytes or a tuple of bytes, not {}", + arg + ))); + } + } + }; + + if suff.is_empty() { + return Ok(vm.new_bool(true)); + } + let range = self.elements.get_slice_range( + &is_valid_slice_arg(start, vm)?, + &is_valid_slice_arg(end, vm)?, + ); + + if range.end - range.start < suff.len() { + return Ok(vm.new_bool(false)); + } + + let offset = if endswith { + (range.end - suff.len())..range.end + } else { + 0..suff.len() + }; + + Ok(vm.new_bool(suff.as_slice() == &self.elements.do_slice(range)[offset])) + } } pub fn try_as_byte(obj: &PyObjectRef) -> Option> { diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 851c8d8ab..0c6e5c4b9 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -258,6 +258,28 @@ impl PyBytesRef { fn join(self, iter: PyIterable, vm: &VirtualMachine) -> PyResult { self.inner.join(iter, vm) } + + #[pymethod(name = "endswith")] + fn endswith( + self, + suffix: PyObjectRef, + start: OptionalArg, + end: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + self.inner.startsendswith(suffix, start, end, true, vm) + } + + #[pymethod(name = "startswith")] + fn startswith( + self, + suffix: PyObjectRef, + start: OptionalArg, + end: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + self.inner.startsendswith(suffix, start, end, false, vm) + } } #[derive(Debug)] diff --git a/vm/src/obj/objsequence.rs b/vm/src/obj/objsequence.rs index 5594ac858..af2034370 100644 --- a/vm/src/obj/objsequence.rs +++ b/vm/src/obj/objsequence.rs @@ -1,3 +1,5 @@ +use crate::function::OptionalArg; +use crate::obj::objnone::PyNone; use std::cell::RefCell; use std::marker::Sized; use std::ops::{Deref, DerefMut, Range}; @@ -371,3 +373,20 @@ pub fn get_mut_elements<'a>(obj: &'a PyObjectRef) -> impl DerefMut, + vm: &VirtualMachine, +) -> Result, PyObjectRef> { + if let OptionalArg::Present(value) = arg { + match_class!(value, + i @ PyInt => Ok(Some(i.as_bigint().clone())), + _obj @ PyNone => Ok(None), + _=> {return Err(vm.new_type_error("slice indices must be integers or None or have an __index__ method".to_string()));} + // TODO: check for an __index__ method + ) + } else { + Ok(None) + } +}