diff --git a/vm/src/anystr.rs b/vm/src/anystr.rs index ba7cd17d2..d7ad454f8 100644 --- a/vm/src/anystr.rs +++ b/vm/src/anystr.rs @@ -65,6 +65,10 @@ impl StartsEndsWithArgs { let range = adjust_indices(self.start, self.end, len); (self.affix, range) } + + pub fn has_subrange(&self) -> bool { + self.start.is_none() && self.end.is_none() + } } fn saturate_to_isize(py_int: PyIntRef) -> isize { @@ -194,6 +198,7 @@ pub trait AnyStr<'s>: 's { fn py_startsendswith( &self, args: StartsEndsWithArgs, + len: usize, func_name: &str, py_type_name: &str, func: F, @@ -203,11 +208,11 @@ pub trait AnyStr<'s>: 's { T: TryFromObject, F: Fn(&Self, &T) -> bool, { - let (affix, value) = if args.start.is_none() && args.end.is_none() { + let (affix, value) = if args.has_subrange() { // If it doesn't have subrange, it uses bytes operation. - (args.affix, self.get_bytes(0..self.bytes_len())) + (args.affix, self.get_bytes(0..len)) } else { - let (affix, range) = args.get_value(self.chars_len()); + let (affix, range) = args.get_value(len); if !range.is_normal() { return Ok(false); } diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 4b0fa0054..2ac4c9fd7 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -432,6 +432,7 @@ impl PyByteArray { fn endswith(&self, options: anystr::StartsEndsWithArgs, vm: &VirtualMachine) -> PyResult { self.borrow_buf().py_startsendswith( options, + self.len(), "endswith", "bytes", |s, x: &PyBytesInner| s.ends_with(&x.elements[..]), @@ -447,6 +448,7 @@ impl PyByteArray { ) -> PyResult { self.borrow_buf().py_startsendswith( options, + self.len(), "startswith", "bytes", |s, x: &PyBytesInner| s.starts_with(&x.elements[..]), diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index 8e6a090e0..4209b2a5a 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -277,6 +277,7 @@ impl PyBytes { fn endswith(&self, options: anystr::StartsEndsWithArgs, vm: &VirtualMachine) -> PyResult { self.inner.elements[..].py_startsendswith( options, + self.len(), "endswith", "bytes", |s, x: &PyBytesInner| s.ends_with(&x.elements[..]), @@ -292,6 +293,7 @@ impl PyBytes { ) -> PyResult { self.inner.elements[..].py_startsendswith( options, + self.len(), "startswith", "bytes", |s, x: &PyBytesInner| s.starts_with(&x.elements[..]), diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index ec3ea2ae7..fd3f45ce7 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -688,8 +688,15 @@ impl PyStr { #[pymethod] fn endswith(&self, args: anystr::StartsEndsWithArgs, vm: &VirtualMachine) -> PyResult { + let len = if args.has_subrange() { + self.byte_len() + } else { + self.char_len() + }; + self.as_str().py_startsendswith( args, + len, "endswith", "str", |s, x: &PyStrRef| s.ends_with(x.as_str()), @@ -699,8 +706,15 @@ impl PyStr { #[pymethod] fn startswith(&self, args: anystr::StartsEndsWithArgs, vm: &VirtualMachine) -> PyResult { + let len = if args.has_subrange() { + self.byte_len() + } else { + self.char_len() + }; + self.as_str().py_startsendswith( args, + len, "startswith", "str", |s, x: &PyStrRef| s.starts_with(x.as_str()),