diff --git a/tests/snippets/bytes.py b/tests/snippets/bytes.py index 1ad23ea18..d2a60943c 100644 --- a/tests/snippets/bytes.py +++ b/tests/snippets/bytes.py @@ -131,3 +131,30 @@ try: except ValueError as e: str(e) == "non-hexadecimal number found in fromhex() arg at position 1" +# center +assert [b"koki".center(i, b"|") for i in range(3, 10)] == [ + b"koki", + b"koki", + b"|koki", + b"|koki|", + b"||koki|", + b"||koki||", + b"|||koki||", +] + +assert [b"kok".center(i, b"|") for i in range(2, 10)] == [ + b"kok", + b"kok", + b"kok|", + b"|kok|", + b"|kok||", + b"||kok||", + b"||kok|||", + b"|||kok|||", +] +b"kok".center(4) == b" kok" # " test no arg" +with assertRaises(TypeError): + b"b".center(2, "a") +with assertRaises(TypeError): + b"b".center(2, b"ba") +b"kok".center(5, bytearray(b"x")) diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index 31511db11..adcc106bb 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -55,6 +55,29 @@ assert b.rstrip() == ' hallo' c = 'hallo' assert c.capitalize() == 'Hallo' assert c.center(11, '-') == '---hallo---' +assert ["koki".center(i, "|") for i in range(3, 10)] == [ + "koki", + "koki", + "|koki", + "|koki|", + "||koki|", + "||koki||", + "|||koki||", +] + + +assert ["kok".center(i, "|") for i in range(2, 10)] == [ + "kok", + "kok", + "kok|", + "|kok|", + "|kok||", + "||kok||", + "||kok|||", + "|||kok|||", +] + + # assert c.isascii() assert c.index('a') == 1 assert c.rindex('l') == 3 diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index 9bedb0bf0..64ac24a3d 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -1,4 +1,5 @@ use crate::pyobject::PyObjectRef; +use num_bigint::BigInt; use crate::function::OptionalArg; @@ -13,8 +14,12 @@ use std::hash::{Hash, Hasher}; use super::objint; use super::objsequence::PySliceableSequence; use crate::obj::objint::PyInt; +use num_integer::Integer; use num_traits::ToPrimitive; +use super::objbytearray::{get_value as get_value_bytearray, PyByteArray}; +use super::objbytes::PyBytes; + #[derive(Debug, Default, Clone)] pub struct PyByteInner { pub elements: Vec, @@ -355,4 +360,39 @@ impl PyByteInner { .map(|x| x.unwrap()) .collect::>()) } + + pub fn center(&self, width: &BigInt, fillbyte: u8, _vm: &VirtualMachine) -> Vec { + let width = width.to_usize().unwrap(); + + // adjust right et left side + if width <= self.len() { + return self.elements.clone(); + } + let diff: usize = width - self.len(); + let mut ln: usize = diff / 2; + let mut rn: usize = ln; + + if diff.is_odd() && self.len() % 2 == 0 { + ln += 1 + } + + if diff.is_odd() && self.len() % 2 != 0 { + rn += 1 + } + + // merge all + let mut res = vec![fillbyte; ln]; + res.extend_from_slice(&self.elements[..]); + res.extend_from_slice(&vec![fillbyte; rn][..]); + + res + } +} + +pub fn is_byte(obj: &PyObjectRef) -> Option> { + match_class!(obj.clone(), + + i @ PyBytes => Some(i.get_value().to_vec()), + j @ PyByteArray => Some(get_value_bytearray(&j.as_object()).to_vec()), + _ => None) } diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 953e5b2bb..f183738c3 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -7,7 +7,7 @@ use std::ops::Deref; use crate::function::OptionalArg; use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; -use super::objbyteinner::PyByteInner; +use super::objbyteinner::{is_byte, PyByteInner}; use super::objiter; use super::objslice::PySlice; use super::objtype::PyClassRef; @@ -34,6 +34,10 @@ impl PyBytes { inner: PyByteInner { elements }, } } + + pub fn get_value(&self) -> &[u8] { + &self.inner.elements + } } impl Deref for PyBytes { @@ -229,6 +233,42 @@ impl PyBytesRef { obj => Err(vm.new_type_error(format!("fromhex() argument must be str, not {}", obj ))) ) } + + #[pymethod(name = "center")] + fn center( + self, + width: PyObjectRef, + fillbyte: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + let sym = if let OptionalArg::Present(v) = fillbyte { + match is_byte(&v) { + Some(x) => { + if x.len() == 1 { + x[0] + } else { + return Err(vm.new_type_error(format!( + "center() argument 2 must be a byte string of length 1, not {}", + &v + ))); + } + } + None => { + return Err(vm.new_type_error(format!( + "center() argument 2 must be a byte string of length 1, not {}", + &v + ))); + } + } + } else { + 32 // default is space + }; + + match_class!(width, + i @PyInt => Ok(vm.ctx.new_bytes(self.inner.center(i.as_bigint(), sym, vm))), + obj => {Err(vm.new_type_error(format!("{} cannot be interpreted as an integer", obj)))} + ) + } } #[derive(Debug)] diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 62cc923b1..2d5f1f8e1 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -691,8 +691,22 @@ impl PyStringRef { ) -> PyResult { let value = &self.value; let rep_char = Self::get_fill_char(&rep, vm)?; - let left_buff: usize = (len - value.len()) / 2; - let right_buff = len - value.len() - left_buff; + let value_len = self.value.chars().count(); + + if len <= value_len { + return Ok(value.to_string()); + } + let diff: usize = len - value_len; + let mut left_buff: usize = diff / 2; + let mut right_buff: usize = left_buff; + + if diff % 2 != 0 && value_len % 2 == 0 { + left_buff += 1 + } + + if diff % 2 != 0 && value_len % 2 != 0 { + right_buff += 1 + } Ok(format!( "{}{}{}", rep_char.repeat(left_buff),