Support chained comparisons

This commit is contained in:
Adrian Wielgosik
2019-04-02 20:50:14 +02:00
parent a77b3489b7
commit 2fb3fc92ec
7 changed files with 141 additions and 43 deletions

View File

@@ -840,6 +840,79 @@ impl Compiler {
Ok(())
}
fn compile_chained_comparison(
&mut self,
vals: &[ast::Expression],
ops: &[ast::Comparison],
) -> Result<(), CompileError> {
assert!(ops.len() > 0);
assert!(vals.len() == ops.len() + 1);
let to_operator = |op: &ast::Comparison| match op {
ast::Comparison::Equal => bytecode::ComparisonOperator::Equal,
ast::Comparison::NotEqual => bytecode::ComparisonOperator::NotEqual,
ast::Comparison::Less => bytecode::ComparisonOperator::Less,
ast::Comparison::LessOrEqual => bytecode::ComparisonOperator::LessOrEqual,
ast::Comparison::Greater => bytecode::ComparisonOperator::Greater,
ast::Comparison::GreaterOrEqual => bytecode::ComparisonOperator::GreaterOrEqual,
ast::Comparison::In => bytecode::ComparisonOperator::In,
ast::Comparison::NotIn => bytecode::ComparisonOperator::NotIn,
ast::Comparison::Is => bytecode::ComparisonOperator::Is,
ast::Comparison::IsNot => bytecode::ComparisonOperator::IsNot,
};
// a == b == c == d
// compile into (pseudocode):
// result = a == b
// if result:
// result = b == c
// if result:
// result = c == d
// initialize lhs outside of loop
self.compile_expression(&vals[0])?;
let break_label = self.new_label();
let last_label = self.new_label();
// for all comparisons except the last (as the last one doesn't need a conditional jump)
let ops_slice = &ops[0..ops.len()];
let vals_slice = &vals[1..ops.len()];
for (op, val) in ops_slice.iter().zip(vals_slice.iter()) {
self.compile_expression(val)?;
// store rhs for the next comparison in chain
self.emit(Instruction::Duplicate);
self.emit(Instruction::Rotate { amount: 3 });
self.emit(Instruction::CompareOperation {
op: to_operator(op),
});
// if comparison result is false, we break with this value; if true, try the next one.
// (CPython compresses these three opcodes into JUMP_IF_FALSE_OR_POP)
self.emit(Instruction::Duplicate);
self.emit(Instruction::JumpIfFalse {
target: break_label,
});
self.emit(Instruction::Pop);
}
// handle the last comparison
self.compile_expression(vals.last().unwrap())?;
self.emit(Instruction::CompareOperation {
op: to_operator(ops.last().unwrap()),
});
self.emit(Instruction::Jump { target: last_label });
// early exit left us with stack: `rhs, comparison_result`. We need to clean up rhs.
self.set_label(break_label);
self.emit(Instruction::Rotate { amount: 2 });
self.emit(Instruction::Pop);
self.set_label(last_label);
Ok(())
}
fn compile_store(&mut self, target: &ast::Expression) -> Result<(), CompileError> {
match target {
ast::Expression::Identifier { name } => {
@@ -1022,24 +1095,8 @@ impl Compiler {
name: name.to_string(),
});
}
ast::Expression::Compare { a, op, b } => {
self.compile_expression(a)?;
self.compile_expression(b)?;
let i = match op {
ast::Comparison::Equal => bytecode::ComparisonOperator::Equal,
ast::Comparison::NotEqual => bytecode::ComparisonOperator::NotEqual,
ast::Comparison::Less => bytecode::ComparisonOperator::Less,
ast::Comparison::LessOrEqual => bytecode::ComparisonOperator::LessOrEqual,
ast::Comparison::Greater => bytecode::ComparisonOperator::Greater,
ast::Comparison::GreaterOrEqual => bytecode::ComparisonOperator::GreaterOrEqual,
ast::Comparison::In => bytecode::ComparisonOperator::In,
ast::Comparison::NotIn => bytecode::ComparisonOperator::NotIn,
ast::Comparison::Is => bytecode::ComparisonOperator::Is,
ast::Comparison::IsNot => bytecode::ComparisonOperator::IsNot,
};
let i = Instruction::CompareOperation { op: i };
self.emit(i);
ast::Expression::Compare { vals, ops } => {
self.compile_chained_comparison(vals, ops)?;
}
ast::Expression::Number { value } => {
let const_value = match value {

View File

@@ -327,14 +327,14 @@ fn expression_to_ast(vm: &VirtualMachine, expression: &ast::Expression) -> PyObj
node
}
ast::Expression::Compare { a, op, b } => {
ast::Expression::Compare { vals, ops } => {
let node = create_node(vm, "Compare");
let py_a = expression_to_ast(vm, a);
let py_a = expression_to_ast(vm, &vals[0]);
vm.ctx.set_attr(&node, "left", py_a);
// Operator:
let str_op = match op {
let to_operator = |op: &ast::Comparison| match op {
ast::Comparison::Equal => "Eq",
ast::Comparison::NotEqual => "NotEq",
ast::Comparison::Less => "Lt",
@@ -346,10 +346,20 @@ fn expression_to_ast(vm: &VirtualMachine, expression: &ast::Expression) -> PyObj
ast::Comparison::Is => "Is",
ast::Comparison::IsNot => "IsNot",
};
let py_ops = vm.ctx.new_list(vec![vm.ctx.new_str(str_op.to_string())]);
let py_ops = vm.ctx.new_list(
ops.iter()
.map(|x| vm.ctx.new_str(to_operator(x).to_string()))
.collect(),
);
vm.ctx.set_attr(&node, "ops", py_ops);
let py_b = vm.ctx.new_list(vec![expression_to_ast(vm, b)]);
let py_b = vm.ctx.new_list(
vals.iter()
.skip(1)
.map(|x| expression_to_ast(vm, x))
.collect(),
);
vm.ctx.set_attr(&node, "comparators", py_b);
node
}