From 2fb3fc92ecaa13d031bdd3dd2e83ccf778660b46 Mon Sep 17 00:00:00 2001 From: Adrian Wielgosik Date: Tue, 2 Apr 2019 20:50:14 +0200 Subject: [PATCH] Support chained comparisons --- parser/src/ast.rs | 5 +- parser/src/parser.rs | 36 ++++++++------ parser/src/python.lalrpop | 10 +++- tests/snippets/ast_snippet.py | 7 +++ tests/snippets/comparisons.py | 13 +++++ vm/src/compile.rs | 93 ++++++++++++++++++++++++++++------- vm/src/stdlib/ast.rs | 20 ++++++-- 7 files changed, 141 insertions(+), 43 deletions(-) create mode 100644 tests/snippets/comparisons.py diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 257ceadf2..ece0e9798 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -159,9 +159,8 @@ pub enum Expression { value: Box, }, Compare { - a: Box, - op: Comparison, - b: Box, + vals: Vec, + ops: Vec, }, Attribute { value: Box, diff --git a/parser/src/parser.rs b/parser/src/parser.rs index a119740ec..4c37f2842 100644 --- a/parser/src/parser.rs +++ b/parser/src/parser.rs @@ -460,26 +460,30 @@ mod tests { }, ifs: vec![ ast::Expression::Compare { - a: Box::new(ast::Expression::Identifier { - name: "a".to_string() - }), - op: ast::Comparison::Less, - b: Box::new(ast::Expression::Number { - value: ast::Number::Integer { - value: BigInt::from(5) + vals: vec![ + ast::Expression::Identifier { + name: "a".to_string() + }, + ast::Expression::Number { + value: ast::Number::Integer { + value: BigInt::from(5) + } } - }), + ], + ops: vec![ast::Comparison::Less], }, ast::Expression::Compare { - a: Box::new(ast::Expression::Identifier { - name: "a".to_string() - }), - op: ast::Comparison::Greater, - b: Box::new(ast::Expression::Number { - value: ast::Number::Integer { - value: BigInt::from(10) + vals: vec![ + ast::Expression::Identifier { + name: "a".to_string() + }, + ast::Expression::Number { + value: ast::Number::Integer { + value: BigInt::from(10) + } } - }), + ], + ops: vec![ast::Comparison::Greater], }, ], } diff --git a/parser/src/python.lalrpop b/parser/src/python.lalrpop index 83020e926..f4c81f1d8 100644 --- a/parser/src/python.lalrpop +++ b/parser/src/python.lalrpop @@ -680,7 +680,15 @@ NotTest: ast::Expression = { }; Comparison: ast::Expression = { - => ast::Expression::Compare { a: Box::new(e1), op: op, b: Box::new(e2) }, + => { + let mut vals = vec![e]; + let mut ops = vec![]; + for x in comparisons { + ops.push(x.0); + vals.push(x.1); + } + ast::Expression::Compare { vals, ops } + }, => e, }; diff --git a/tests/snippets/ast_snippet.py b/tests/snippets/ast_snippet.py index 43bf74756..dad767b45 100644 --- a/tests/snippets/ast_snippet.py +++ b/tests/snippets/ast_snippet.py @@ -21,3 +21,10 @@ print(foo.body[0].value.func.id) assert foo.body[0].value.func.id == 'print' assert foo.body[0].lineno == 3 assert foo.body[1].lineno == 4 + +n = ast.parse("3 < 4 > 5\n") +assert n.body[0].value.left.n == 3 +assert 'Lt' in str(n.body[0].value.ops[0]) +assert 'Gt' in str(n.body[0].value.ops[1]) +assert n.body[0].value.comparators[0].n == 4 +assert n.body[0].value.comparators[1].n == 5 diff --git a/tests/snippets/comparisons.py b/tests/snippets/comparisons.py new file mode 100644 index 000000000..1a4597ef6 --- /dev/null +++ b/tests/snippets/comparisons.py @@ -0,0 +1,13 @@ + +assert 1 < 2 +assert 1 < 2 < 3 +assert 5 == 5 == 5 +assert (5 == 5) == True +assert 5 == 5 != 4 == 4 > 3 > 2 < 3 <= 3 != 0 == 0 + +assert not 1 > 2 +assert not 5 == 5 == True +assert not 5 == 5 != 5 == 5 +assert not 1 < 2 < 3 > 4 +assert not 1 < 2 > 3 < 4 +assert not 1 > 2 < 3 < 4 diff --git a/vm/src/compile.rs b/vm/src/compile.rs index f79980af9..fe295f424 100644 --- a/vm/src/compile.rs +++ b/vm/src/compile.rs @@ -840,6 +840,79 @@ impl Compiler { Ok(()) } + fn compile_chained_comparison( + &mut self, + vals: &[ast::Expression], + ops: &[ast::Comparison], + ) -> Result<(), CompileError> { + assert!(ops.len() > 0); + assert!(vals.len() == ops.len() + 1); + + let to_operator = |op: &ast::Comparison| match op { + ast::Comparison::Equal => bytecode::ComparisonOperator::Equal, + ast::Comparison::NotEqual => bytecode::ComparisonOperator::NotEqual, + ast::Comparison::Less => bytecode::ComparisonOperator::Less, + ast::Comparison::LessOrEqual => bytecode::ComparisonOperator::LessOrEqual, + ast::Comparison::Greater => bytecode::ComparisonOperator::Greater, + ast::Comparison::GreaterOrEqual => bytecode::ComparisonOperator::GreaterOrEqual, + ast::Comparison::In => bytecode::ComparisonOperator::In, + ast::Comparison::NotIn => bytecode::ComparisonOperator::NotIn, + ast::Comparison::Is => bytecode::ComparisonOperator::Is, + ast::Comparison::IsNot => bytecode::ComparisonOperator::IsNot, + }; + + // a == b == c == d + // compile into (pseudocode): + // result = a == b + // if result: + // result = b == c + // if result: + // result = c == d + + // initialize lhs outside of loop + self.compile_expression(&vals[0])?; + + let break_label = self.new_label(); + let last_label = self.new_label(); + + // for all comparisons except the last (as the last one doesn't need a conditional jump) + let ops_slice = &ops[0..ops.len()]; + let vals_slice = &vals[1..ops.len()]; + for (op, val) in ops_slice.iter().zip(vals_slice.iter()) { + self.compile_expression(val)?; + // store rhs for the next comparison in chain + self.emit(Instruction::Duplicate); + self.emit(Instruction::Rotate { amount: 3 }); + + self.emit(Instruction::CompareOperation { + op: to_operator(op), + }); + + // if comparison result is false, we break with this value; if true, try the next one. + // (CPython compresses these three opcodes into JUMP_IF_FALSE_OR_POP) + self.emit(Instruction::Duplicate); + self.emit(Instruction::JumpIfFalse { + target: break_label, + }); + self.emit(Instruction::Pop); + } + + // handle the last comparison + self.compile_expression(vals.last().unwrap())?; + self.emit(Instruction::CompareOperation { + op: to_operator(ops.last().unwrap()), + }); + self.emit(Instruction::Jump { target: last_label }); + + // early exit left us with stack: `rhs, comparison_result`. We need to clean up rhs. + self.set_label(break_label); + self.emit(Instruction::Rotate { amount: 2 }); + self.emit(Instruction::Pop); + + self.set_label(last_label); + Ok(()) + } + fn compile_store(&mut self, target: &ast::Expression) -> Result<(), CompileError> { match target { ast::Expression::Identifier { name } => { @@ -1022,24 +1095,8 @@ impl Compiler { name: name.to_string(), }); } - ast::Expression::Compare { a, op, b } => { - self.compile_expression(a)?; - self.compile_expression(b)?; - - let i = match op { - ast::Comparison::Equal => bytecode::ComparisonOperator::Equal, - ast::Comparison::NotEqual => bytecode::ComparisonOperator::NotEqual, - ast::Comparison::Less => bytecode::ComparisonOperator::Less, - ast::Comparison::LessOrEqual => bytecode::ComparisonOperator::LessOrEqual, - ast::Comparison::Greater => bytecode::ComparisonOperator::Greater, - ast::Comparison::GreaterOrEqual => bytecode::ComparisonOperator::GreaterOrEqual, - ast::Comparison::In => bytecode::ComparisonOperator::In, - ast::Comparison::NotIn => bytecode::ComparisonOperator::NotIn, - ast::Comparison::Is => bytecode::ComparisonOperator::Is, - ast::Comparison::IsNot => bytecode::ComparisonOperator::IsNot, - }; - let i = Instruction::CompareOperation { op: i }; - self.emit(i); + ast::Expression::Compare { vals, ops } => { + self.compile_chained_comparison(vals, ops)?; } ast::Expression::Number { value } => { let const_value = match value { diff --git a/vm/src/stdlib/ast.rs b/vm/src/stdlib/ast.rs index 0011407ab..91631b0c9 100644 --- a/vm/src/stdlib/ast.rs +++ b/vm/src/stdlib/ast.rs @@ -327,14 +327,14 @@ fn expression_to_ast(vm: &VirtualMachine, expression: &ast::Expression) -> PyObj node } - ast::Expression::Compare { a, op, b } => { + ast::Expression::Compare { vals, ops } => { let node = create_node(vm, "Compare"); - let py_a = expression_to_ast(vm, a); + let py_a = expression_to_ast(vm, &vals[0]); vm.ctx.set_attr(&node, "left", py_a); // Operator: - let str_op = match op { + let to_operator = |op: &ast::Comparison| match op { ast::Comparison::Equal => "Eq", ast::Comparison::NotEqual => "NotEq", ast::Comparison::Less => "Lt", @@ -346,10 +346,20 @@ fn expression_to_ast(vm: &VirtualMachine, expression: &ast::Expression) -> PyObj ast::Comparison::Is => "Is", ast::Comparison::IsNot => "IsNot", }; - let py_ops = vm.ctx.new_list(vec![vm.ctx.new_str(str_op.to_string())]); + let py_ops = vm.ctx.new_list( + ops.iter() + .map(|x| vm.ctx.new_str(to_operator(x).to_string())) + .collect(), + ); + vm.ctx.set_attr(&node, "ops", py_ops); - let py_b = vm.ctx.new_list(vec![expression_to_ast(vm, b)]); + let py_b = vm.ctx.new_list( + vals.iter() + .skip(1) + .map(|x| expression_to_ast(vm, x)) + .collect(), + ); vm.ctx.set_attr(&node, "comparators", py_b); node }