diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index 3ebec9b3d..77ebabaf0 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -36,3 +36,6 @@ b = ' hallo ' assert b.strip() == 'hallo' assert b.lstrip() == 'hallo ' assert b.rstrip() == ' hallo' + +c = 'hallo' +assert c.capitalize() == 'Hallo' diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 009a0d9c5..b3ecde4ff 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -18,6 +18,7 @@ pub fn init(context: &PyContext) { str_type.set_attr("__repr__", context.new_rustfunc(str_repr)); str_type.set_attr("lower", context.new_rustfunc(str_lower)); str_type.set_attr("upper", context.new_rustfunc(str_upper)); + str_type.set_attr("capitalize", context.new_rustfunc(str_capitalize)); str_type.set_attr("split", context.new_rustfunc(str_split)); str_type.set_attr("strip", context.new_rustfunc(str_strip)); str_type.set_attr("lstrip", context.new_rustfunc(str_lstrip)); @@ -141,6 +142,14 @@ fn str_lower(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_str(value)) } +fn str_capitalize(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(s, Some(vm.ctx.str_type()))]); + let value = get_value(&s); + let (first_part, lower_str) = value.split_at(1); + let capitalized = format!("{}{}", first_part.to_uppercase().to_string(), lower_str); + Ok(vm.ctx.new_str(capitalized)) +} + fn str_split(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm,