Merge pull request #5700 from arihant2math/match-cleanup

Cleanup match statement codegen
This commit is contained in:
Noa
2025-04-18 13:56:57 -05:00
committed by GitHub
6 changed files with 58 additions and 145 deletions

View File

@@ -1944,32 +1944,35 @@ impl Compiler<'_> {
n: Option<&Identifier>,
pc: &mut PatternContext,
) -> CompileResult<()> {
// If no name is provided, simply pop the top of the stack.
if n.is_none() {
emit!(self, Instruction::Pop);
return Ok(());
match n {
// If no name is provided, simply pop the top of the stack.
None => {
emit!(self, Instruction::Pop);
Ok(())
}
Some(name) => {
// Check if the name is forbidden for storing.
if self.forbidden_name(name.as_str(), NameUsage::Store)? {
return Err(self.compile_error_forbidden_name(name.as_str()));
}
// Ensure we don't store the same name twice.
// TODO: maybe pc.stores should be a set?
if pc.stores.contains(&name.to_string()) {
return Err(
self.error(CodegenErrorType::DuplicateStore(name.as_str().to_string()))
);
}
// Calculate how many items to rotate:
let rotations = pc.on_top + pc.stores.len() + 1;
self.pattern_helper_rotate(rotations)?;
// Append the name to the captured stores.
pc.stores.push(name.to_string());
Ok(())
}
}
let name = n.unwrap();
// Check if the name is forbidden for storing.
if self.forbidden_name(name.as_str(), NameUsage::Store)? {
return Err(self.compile_error_forbidden_name(name.as_str()));
}
// Ensure we don't store the same name twice.
if pc.stores.contains(&name.to_string()) {
return Err(self.error(CodegenErrorType::DuplicateStore(name.as_str().to_string())));
}
// Calculate how many items to rotate:
// the count is the number of items to preserve on top plus the current stored names,
// plus one for the new value.
let rotations = pc.on_top + pc.stores.len() + 1;
self.pattern_helper_rotate(rotations)?;
// Append the name to the captured stores.
pc.stores.push(name.to_string());
Ok(())
}
fn pattern_unpack_helper(&mut self, elts: &[Pattern]) -> CompileResult<()> {
@@ -2155,10 +2158,7 @@ impl Compiler<'_> {
for ident in attrs.iter().take(n_attrs).skip(i + 1) {
let other = ident.as_str();
if attr == other {
todo!();
// return Err(self.compiler_error(
// &format!("attribute name repeated in class pattern: {}", attr),
// ));
return Err(self.error(CodegenErrorType::RepeatedAttributePattern));
}
}
}
@@ -2185,16 +2185,6 @@ impl Compiler<'_> {
let nargs = patterns.len();
let n_attrs = kwd_attrs.len();
let nkwd_patterns = kwd_patterns.len();
// Validate that keyword attribute names and patterns match in length.
if n_attrs != nkwd_patterns {
let msg = format!(
"kwd_attrs ({}) / kwd_patterns ({}) length mismatch in class pattern",
n_attrs, nkwd_patterns
);
unreachable!("{}", msg);
}
// Check for too many sub-patterns.
if nargs > u32::MAX as usize || (nargs + n_attrs).saturating_sub(1) > i32::MAX as usize {
@@ -2223,6 +2213,8 @@ impl Compiler<'_> {
});
}
use bytecode::TestOperator::*;
// Emit instructions:
// 1. Load the new tuple of attribute names.
self.emit_load_const(ConstantData::Tuple {
@@ -2235,7 +2227,7 @@ impl Compiler<'_> {
// 4. Load None.
self.emit_load_const(ConstantData::None);
// 5. Compare with IS_OP 1.
emit!(self, Instruction::IsOperation(true));
emit!(self, Instruction::TestOperation { op: IsNot });
// At this point the TOS is a tuple of (nargs + n_attrs) attributes (or None).
pc.on_top += 1;
@@ -2253,20 +2245,12 @@ impl Compiler<'_> {
pc.on_top -= 1;
// Process each sub-pattern.
for i in 0..total {
// Decrement the on_top counter as each sub-pattern is processed.
for subpattern in patterns.iter().chain(kwd_patterns.iter()) {
// Decrement the on_top counter as each sub-pattern is processed
// (on_top should be zero at the end of the algorithm as a sanity check).
pc.on_top -= 1;
let subpattern = if i < nargs {
// Positional sub-pattern.
&patterns[i]
} else {
// Keyword sub-pattern.
&kwd_patterns[i - nargs]
};
if subpattern.is_wildcard() {
// For wildcard patterns, simply pop the top of the stack.
emit!(self, Instruction::Pop);
continue;
}
// Compile the subpattern without irrefutability checks.
self.compile_pattern_subpattern(subpattern, pc)?;
@@ -2351,7 +2335,7 @@ impl Compiler<'_> {
// emit!(self, Instruction::CopyItem { index: 1_u32 });
// self.emit_load_const(ConstantData::None);
// // TODO: should be is
// emit!(self, Instruction::IsOperation(true));
// emit!(self, Instruction::TestOperation::IsNot);
// self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse)?;
// // Unpack the tuple of values.
@@ -2428,15 +2412,16 @@ impl Compiler<'_> {
} else {
let control_vec = control.as_ref().unwrap();
if nstores != control_vec.len() {
todo!();
// return self.compiler_error("alternative patterns bind different names");
return Err(self.error(CodegenErrorType::ConflictingNameBindPattern));
} else if nstores > 0 {
// Check that the names occur in the same order.
for icontrol in (0..nstores).rev() {
let name = &control_vec[icontrol];
// Find the index of `name` in the current stores.
let istores = pc.stores.iter().position(|n| n == name).unwrap();
// .ok_or_else(|| self.compiler_error("alternative patterns bind different names"))?;
let istores =
pc.stores.iter().position(|n| n == name).ok_or_else(|| {
self.error(CodegenErrorType::ConflictingNameBindPattern)
})?;
if icontrol != istores {
// The orders differ; we must reorder.
assert!(istores < icontrol, "expected istores < icontrol");
@@ -2480,14 +2465,14 @@ impl Compiler<'_> {
self.switch_to_block(end);
// Adjust the final captures.
let nstores = control.as_ref().unwrap().len();
let nrots = nstores + 1 + pc.on_top + pc.stores.len();
for i in 0..nstores {
let n_stores = control.as_ref().unwrap().len();
let n_rots = n_stores + 1 + pc.on_top + pc.stores.len();
for i in 0..n_stores {
// Rotate the capture to its proper place.
self.pattern_helper_rotate(nrots)?;
self.pattern_helper_rotate(n_rots)?;
let name = &control.as_ref().unwrap()[i];
// Check for duplicate binding.
if pc.stores.iter().any(|n| n == name) {
if pc.stores.contains(name) {
return Err(self.error(CodegenErrorType::DuplicateStore(name.to_string())));
}
pc.stores.push(name.clone());
@@ -4608,23 +4593,6 @@ for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):
self.assertIs(ex, stop_exc)
else:
self.fail(f'{stop_exc} was suppressed')
"
));
}
#[test]
fn test_match() {
assert_dis_snapshot!(compile_exec(
"\
class Test:
pass
t = Test()
match t:
case Test():
assert True
case _:
assert False
"
));
}

View File

@@ -65,6 +65,8 @@ pub enum CodegenErrorType {
ForbiddenName,
DuplicateStore(String),
UnreachablePattern(PatternUnreachableReason),
RepeatedAttributePattern,
ConflictingNameBindPattern,
NotImplementedYet, // RustPython marker for unimplemented features
}
@@ -119,6 +121,12 @@ impl fmt::Display for CodegenErrorType {
UnreachablePattern(reason) => {
write!(f, "{reason} makes remaining patterns unreachable")
}
RepeatedAttributePattern => {
write!(f, "attribute name repeated in class pattern")
}
ConflictingNameBindPattern => {
write!(f, "alternative patterns bind different names")
}
NotImplementedYet => {
write!(f, "RustPython does not implement this feature yet")
}

View File

@@ -244,6 +244,9 @@ impl CodeInfo {
let instr_display = instr.display(display_arg, self);
eprint!("{instr_display}: {depth} {effect:+} => ");
}
if effect < 0 && depth < effect.unsigned_abs() {
panic!("The stack will underflow at {depth} with {effect} effect on {instr:?}");
}
let new_depth = depth.checked_add_signed(effect).unwrap();
if DEBUG {
eprintln!("{new_depth}");

View File

@@ -1,53 +0,0 @@
---
source: compiler/codegen/src/compile.rs
expression: "compile_exec(\"\\\nclass Test:\n pass\n\nt = Test()\nmatch t:\n case Test():\n assert True\n case _:\n assert False\n\")"
---
2 0 LoadBuildClass
1 LoadConst (<code object Test at ??? file "source_path", line 1>): 1 0 LoadGlobal (0, __name__)
1 StoreLocal (1, __module__)
2 LoadConst ("Test")
3 StoreLocal (2, __qualname__)
4 LoadConst (None)
5 StoreLocal (3, __doc__)
2 6 ReturnConst (None)
2 LoadConst ("Test")
3 MakeFunction (MakeFunctionFlags(0x0))
4 LoadConst ("Test")
5 CallFunctionPositional(2)
6 StoreLocal (0, Test)
4 7 LoadNameAny (0, Test)
8 CallFunctionPositional(0)
9 StoreLocal (1, t)
5 10 LoadNameAny (1, t)
11 CopyItem (1)
6 12 LoadNameAny (0, Test)
13 LoadConst (())
14 MatchClass (0)
15 CopyItem (1)
16 LoadConst (None)
17 IsOperation (true)
18 JumpIfFalse (27)
19 UnpackSequence (0)
20 Pop
7 21 LoadConst (True)
22 JumpIfTrue (26)
23 LoadGlobal (2, AssertionError)
24 CallFunctionPositional(0)
25 Raise (Raise)
>> 26 Jump (35)
>> 27 Pop
28 Pop
9 29 LoadConst (False)
30 JumpIfTrue (34)
31 LoadGlobal (2, AssertionError)
32 CallFunctionPositional(0)
33 Raise (Raise)
>> 34 Jump (35)
>> 35 ReturnConst (None)

View File

@@ -437,9 +437,6 @@ pub enum Instruction {
TestOperation {
op: Arg<TestOperator>,
},
/// If the argument is true, perform IS NOT. Otherwise perform the IS operation.
// TODO: duplication of TestOperator::{Is,IsNot}. Fix later.
IsOperation(Arg<bool>),
CompareOperation {
op: Arg<ComparisonOperator>,
},
@@ -1227,8 +1224,7 @@ impl Instruction {
BinaryOperation { .. }
| BinaryOperationInplace { .. }
| TestOperation { .. }
| CompareOperation { .. }
| IsOperation(..) => -1,
| CompareOperation { .. } => -1,
BinarySubscript => -1,
CopyItem { .. } => 1,
Pop => -1,
@@ -1436,7 +1432,6 @@ impl Instruction {
BinarySubscript => w!(BinarySubscript),
LoadAttr { idx } => w!(LoadAttr, name = idx),
TestOperation { op } => w!(TestOperation, ?op),
IsOperation(neg) => w!(IsOperation, neg),
CompareOperation { op } => w!(CompareOperation, ?op),
CopyItem { index } => w!(CopyItem, index),
Pop => w!(Pop),

View File

@@ -851,14 +851,6 @@ impl ExecutingFrame<'_> {
bytecode::Instruction::UnaryOperation { op } => self.execute_unary_op(vm, op.get(arg)),
bytecode::Instruction::TestOperation { op } => self.execute_test(vm, op.get(arg)),
bytecode::Instruction::CompareOperation { op } => self.execute_compare(vm, op.get(arg)),
bytecode::Instruction::IsOperation(neg) => {
let a = self.pop_value();
let b = self.pop_value();
// xor with neg to invert the result if needed
let result = vm.ctx.new_bool(a.is(b.as_ref()) ^ neg.get(arg));
self.push_value(result.into());
Ok(None)
}
bytecode::Instruction::ReturnValue => {
let value = self.pop_value();
self.unwind_blocks(vm, UnwindReason::Returning { value })