From 7996a10116681e9f85eda03413d5011b805e577f Mon Sep 17 00:00:00 2001 From: Jonathan Rotter Date: Mon, 27 May 2024 02:54:56 -0500 Subject: [PATCH] await in list comprehension (#5334) * check if comprehension element contains await * force execution to pause in async gen --- Lib/test/test_asyncgen.py | 2 - compiler/codegen/src/compile.rs | 184 ++++++++++++++++-- .../snippets/syntax_async_comprehension.py | 171 ++++++++++++++++ 3 files changed, 344 insertions(+), 13 deletions(-) create mode 100644 extra_tests/snippets/syntax_async_comprehension.py diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index 183d887b7..f97316ade 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -1568,8 +1568,6 @@ class AsyncGenAsyncioTest(unittest.TestCase): self.assertIn('unhandled exception during asyncio.run() shutdown', message['message']) - # TODO: RUSTPYTHON; TypeError: object async_generator can't be used in 'await' expression - @unittest.expectedFailure def test_async_gen_expression_01(self): async def arange(n): for i in range(n): diff --git a/compiler/codegen/src/compile.rs b/compiler/codegen/src/compile.rs index 7b912f43d..dc9f8a976 100644 --- a/compiler/codegen/src/compile.rs +++ b/compiler/codegen/src/compile.rs @@ -87,6 +87,14 @@ impl CompileContext { } } +#[derive(Debug, Clone, Copy, PartialEq)] +enum ComprehensionType { + Generator, + List, + Set, + Dict, +} + /// Compile an located_ast::Mod produced from rustpython_parser::parse() pub fn compile_top( ast: &located_ast::Mod, @@ -2431,6 +2439,8 @@ impl Compiler { ); Ok(()) }, + ComprehensionType::List, + Self::contains_await(elt), )?; } Expr::SetComp(located_ast::ExprSetComp { @@ -2452,6 +2462,8 @@ impl Compiler { ); Ok(()) }, + ComprehensionType::Set, + Self::contains_await(elt), )?; } Expr::DictComp(located_ast::ExprDictComp { @@ -2480,19 +2492,28 @@ impl Compiler { Ok(()) }, + ComprehensionType::Dict, + Self::contains_await(key) || Self::contains_await(value), )?; } Expr::GeneratorExp(located_ast::ExprGeneratorExp { elt, generators, .. }) => { - self.compile_comprehension("", None, generators, &|compiler| { - compiler.compile_comprehension_element(elt)?; - compiler.mark_generator(); - emit!(compiler, Instruction::YieldValue); - emit!(compiler, Instruction::Pop); + self.compile_comprehension( + "", + None, + generators, + &|compiler| { + compiler.compile_comprehension_element(elt)?; + compiler.mark_generator(); + emit!(compiler, Instruction::YieldValue); + emit!(compiler, Instruction::Pop); - Ok(()) - })?; + Ok(()) + }, + ComprehensionType::Generator, + Self::contains_await(elt), + )?; } Expr::Starred(_) => { return Err(self.error(CodegenErrorType::InvalidStarExpr)); @@ -2744,9 +2765,35 @@ impl Compiler { init_collection: Option, generators: &[located_ast::Comprehension], compile_element: &dyn Fn(&mut Self) -> CompileResult<()>, + comprehension_type: ComprehensionType, + element_contains_await: bool, ) -> CompileResult<()> { let prev_ctx = self.ctx; - let is_async = generators.iter().any(|g| g.is_async); + let has_an_async_gen = generators.iter().any(|g| g.is_async); + + // async comprehensions are allowed in various contexts: + // - list/set/dict comprehensions in async functions + // - always for generator expressions + // Note: generators have to be treated specially since their async version is a fundamentally + // different type (aiter vs iter) instead of just an awaitable. + + // for if it actually is async, we check if any generator is async or if the element contains await + + // if the element expression contains await, but the context doesn't allow for async, + // then we continue on here with is_async=false and will produce a syntax once the await is hit + + let is_async_list_set_dict_comprehension = comprehension_type + != ComprehensionType::Generator + && (has_an_async_gen || element_contains_await) // does it have to be async? (uses await or async for) + && prev_ctx.func == FunctionContext::AsyncFunction; // is it allowed to be async? (in an async function) + + let is_async_generator_comprehension = comprehension_type == ComprehensionType::Generator + && (has_an_async_gen || element_contains_await); + + // since one is for generators, and one for not generators, they should never both be true + debug_assert!(!(is_async_list_set_dict_comprehension && is_async_generator_comprehension)); + + let is_async = is_async_list_set_dict_comprehension || is_async_generator_comprehension; self.ctx = CompileContext { loop_data: None, @@ -2838,7 +2885,7 @@ impl Compiler { // End of for loop: self.switch_to_block(after_block); - if is_async { + if has_an_async_gen { emit!(self, Instruction::EndAsyncFor); } } @@ -2877,7 +2924,7 @@ impl Compiler { self.compile_expression(&generators[0].iter)?; // Get iterator / turn item into an iterator - if is_async { + if has_an_async_gen { emit!(self, Instruction::GetAIter); } else { emit!(self, Instruction::GetIter); @@ -2885,11 +2932,15 @@ impl Compiler { // Call just created function: emit!(self, Instruction::CallFunctionPositional { nargs: 1 }); - if is_async { + if is_async_list_set_dict_comprehension { + // async, but not a generator and not an async for + // in this case, we end up with an awaitable + // that evaluates to the list/set/dict, so here we add an await emit!(self, Instruction::GetAwaitable); self.emit_load_const(ConstantData::None); emit!(self, Instruction::YieldFrom); } + Ok(()) } @@ -3016,6 +3067,117 @@ impl Compiler { fn mark_generator(&mut self) { self.current_code_info().flags |= bytecode::CodeFlags::IS_GENERATOR } + + /// Whether the expression contains an await expression and + /// thus requires the function to be async. + /// Async with and async for are statements, so I won't check for them here + fn contains_await(expression: &located_ast::Expr) -> bool { + use located_ast::*; + + match &expression { + Expr::Call(ExprCall { + func, + args, + keywords, + .. + }) => { + Self::contains_await(func) + || args.iter().any(Self::contains_await) + || keywords.iter().any(|kw| Self::contains_await(&kw.value)) + } + Expr::BoolOp(ExprBoolOp { values, .. }) => values.iter().any(Self::contains_await), + Expr::BinOp(ExprBinOp { left, right, .. }) => { + Self::contains_await(left) || Self::contains_await(right) + } + Expr::Subscript(ExprSubscript { value, slice, .. }) => { + Self::contains_await(value) || Self::contains_await(slice) + } + Expr::UnaryOp(ExprUnaryOp { operand, .. }) => Self::contains_await(operand), + Expr::Attribute(ExprAttribute { value, .. }) => Self::contains_await(value), + Expr::Compare(ExprCompare { + left, comparators, .. + }) => Self::contains_await(left) || comparators.iter().any(Self::contains_await), + Expr::Constant(ExprConstant { .. }) => false, + Expr::List(ExprList { elts, .. }) => elts.iter().any(Self::contains_await), + Expr::Tuple(ExprTuple { elts, .. }) => elts.iter().any(Self::contains_await), + Expr::Set(ExprSet { elts, .. }) => elts.iter().any(Self::contains_await), + Expr::Dict(ExprDict { keys, values, .. }) => { + keys.iter() + .any(|key| key.as_ref().map_or(false, Self::contains_await)) + || values.iter().any(Self::contains_await) + } + Expr::Slice(ExprSlice { + lower, upper, step, .. + }) => { + lower.as_ref().map_or(false, |l| Self::contains_await(l)) + || upper.as_ref().map_or(false, |u| Self::contains_await(u)) + || step.as_ref().map_or(false, |s| Self::contains_await(s)) + } + Expr::Yield(ExprYield { value, .. }) => { + value.as_ref().map_or(false, |v| Self::contains_await(v)) + } + Expr::Await(ExprAwait { .. }) => true, + Expr::YieldFrom(ExprYieldFrom { value, .. }) => Self::contains_await(value), + Expr::JoinedStr(ExprJoinedStr { values, .. }) => { + values.iter().any(Self::contains_await) + } + Expr::FormattedValue(ExprFormattedValue { + value, + conversion: _, + format_spec, + .. + }) => { + Self::contains_await(value) + || format_spec + .as_ref() + .map_or(false, |fs| Self::contains_await(fs)) + } + Expr::Name(located_ast::ExprName { .. }) => false, + Expr::Lambda(located_ast::ExprLambda { body, .. }) => Self::contains_await(body), + Expr::ListComp(located_ast::ExprListComp { + elt, generators, .. + }) => { + Self::contains_await(elt) + || generators.iter().any(|gen| Self::contains_await(&gen.iter)) + } + Expr::SetComp(located_ast::ExprSetComp { + elt, generators, .. + }) => { + Self::contains_await(elt) + || generators.iter().any(|gen| Self::contains_await(&gen.iter)) + } + Expr::DictComp(located_ast::ExprDictComp { + key, + value, + generators, + .. + }) => { + Self::contains_await(key) + || Self::contains_await(value) + || generators.iter().any(|gen| Self::contains_await(&gen.iter)) + } + Expr::GeneratorExp(located_ast::ExprGeneratorExp { + elt, generators, .. + }) => { + Self::contains_await(elt) + || generators.iter().any(|gen| Self::contains_await(&gen.iter)) + } + Expr::Starred(expr) => Self::contains_await(&expr.value), + Expr::IfExp(located_ast::ExprIfExp { + test, body, orelse, .. + }) => { + Self::contains_await(test) + || Self::contains_await(body) + || Self::contains_await(orelse) + } + + Expr::NamedExpr(located_ast::ExprNamedExpr { + target, + value, + range: _, + }) => Self::contains_await(target) || Self::contains_await(value), + } + } } trait EmitArg { diff --git a/extra_tests/snippets/syntax_async_comprehension.py b/extra_tests/snippets/syntax_async_comprehension.py new file mode 100644 index 000000000..7d1d02672 --- /dev/null +++ b/extra_tests/snippets/syntax_async_comprehension.py @@ -0,0 +1,171 @@ +import asyncio +from types import GeneratorType, AsyncGeneratorType + + +async def f_async(x): + await asyncio.sleep(0.001) + return x + + +def f_iter(): + for i in range(5): + yield i + + +async def f_aiter(): + for i in range(5): + await asyncio.sleep(0.001) + yield i + + +async def run_async(): + # list + x = [i for i in range(5)] + assert isinstance(x, list) + for i, e in enumerate(x): + assert e == i + + x = [await f_async(i) for i in range(5)] + assert isinstance(x, list) + for i, e in enumerate(x): + assert e == i + + x = [e async for e in f_aiter()] + assert isinstance(x, list) + for i, e in enumerate(x): + assert e == i + + x = [await f_async(i) async for i in f_aiter()] + assert isinstance(x, list) + for i, e in enumerate(x): + assert e == i + + # set + x = {i for i in range(5)} + assert isinstance(x, set) + for e in x: + assert e in range(5) + assert x == {0, 1, 2, 3, 4} + + x = {await f_async(i) for i in range(5)} + assert isinstance(x, set) + for e in x: + assert e in range(5) + assert x == {0, 1, 2, 3, 4} + + x = {e async for e in f_aiter()} + assert isinstance(x, set) + for e in x: + assert e in range(5) + assert x == {0, 1, 2, 3, 4} + + x = {await f_async(i) async for i in f_aiter()} + assert isinstance(x, set) + for e in x: + assert e in range(5) + assert x == {0, 1, 2, 3, 4} + + # dict + x = {i: i for i in range(5)} + assert isinstance(x, dict) + for k, v in x.items(): + assert k == v + assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4} + + x = {await f_async(i): i for i in range(5)} + assert isinstance(x, dict) + for k, v in x.items(): + assert k == v + assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4} + + x = {i: await f_async(i) for i in range(5)} + assert isinstance(x, dict) + for k, v in x.items(): + assert k == v + assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4} + + x = {await f_async(i): await f_async(i) for i in range(5)} + assert isinstance(x, dict) + for k, v in x.items(): + assert k == v + assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4} + + x = {i: i async for i in f_aiter()} + assert isinstance(x, dict) + for k, v in x.items(): + assert k == v + assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4} + + x = {await f_async(i): i async for i in f_aiter()} + assert isinstance(x, dict) + for k, v in x.items(): + assert k == v + assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4} + + x = {i: await f_async(i) async for i in f_aiter()} + assert isinstance(x, dict) + for k, v in x.items(): + assert k == v + assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4} + + x = {await f_async(i): await f_async(i) async for i in f_aiter()} + assert isinstance(x, dict) + for k, v in x.items(): + assert k == v + assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4} + + # generator + x = (i for i in range(5)) + assert isinstance(x, GeneratorType) + for i, e in enumerate(x): + assert e == i + + x = (await f_async(i) for i in range(5)) + assert isinstance(x, AsyncGeneratorType) + i = 0 + async for e in x: + assert e == i + i += 1 + + x = (e async for e in f_aiter()) + assert isinstance(x, AsyncGeneratorType) + i = 0 + async for e in x: + assert i == e + i += 1 + + x = (await f_async(i) async for i in f_aiter()) + assert isinstance(x, AsyncGeneratorType) + i = 0 + async for e in x: + assert i == e + i += 1 + + +def run_sync(): + async def test_async_for(x): + i = 0 + async for e in x: + assert e == i + i += 1 + + x = (i for i in range(5)) + assert isinstance(x, GeneratorType) + for i, e in enumerate(x): + assert e == i + + x = (await f_async(i) for i in range(5)) + assert isinstance(x, AsyncGeneratorType) + asyncio.run(test_async_for(x), debug=True) + + x = (e async for e in f_aiter()) + assert isinstance(x, AsyncGeneratorType) + asyncio.run(test_async_for(x), debug=True) + + x = (await f_async(i) async for i in f_aiter()) + assert isinstance(x, AsyncGeneratorType) + asyncio.run(test_async_for(x), debug=True) + + +asyncio.run(run_async(), debug=True) +run_sync()