await in list comprehension (#5334)

* check if comprehension element contains await

* force execution to pause in async gen
This commit is contained in:
Jonathan Rotter
2024-05-27 02:54:56 -05:00
committed by GitHub
parent db4562f67d
commit 7996a10116
3 changed files with 344 additions and 13 deletions

View File

@@ -1568,8 +1568,6 @@ class AsyncGenAsyncioTest(unittest.TestCase):
self.assertIn('unhandled exception during asyncio.run() shutdown',
message['message'])
# TODO: RUSTPYTHON; TypeError: object async_generator can't be used in 'await' expression
@unittest.expectedFailure
def test_async_gen_expression_01(self):
async def arange(n):
for i in range(n):

View File

@@ -87,6 +87,14 @@ impl CompileContext {
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum ComprehensionType {
Generator,
List,
Set,
Dict,
}
/// Compile an located_ast::Mod produced from rustpython_parser::parse()
pub fn compile_top(
ast: &located_ast::Mod,
@@ -2431,6 +2439,8 @@ impl Compiler {
);
Ok(())
},
ComprehensionType::List,
Self::contains_await(elt),
)?;
}
Expr::SetComp(located_ast::ExprSetComp {
@@ -2452,6 +2462,8 @@ impl Compiler {
);
Ok(())
},
ComprehensionType::Set,
Self::contains_await(elt),
)?;
}
Expr::DictComp(located_ast::ExprDictComp {
@@ -2480,19 +2492,28 @@ impl Compiler {
Ok(())
},
ComprehensionType::Dict,
Self::contains_await(key) || Self::contains_await(value),
)?;
}
Expr::GeneratorExp(located_ast::ExprGeneratorExp {
elt, generators, ..
}) => {
self.compile_comprehension("<genexpr>", None, generators, &|compiler| {
compiler.compile_comprehension_element(elt)?;
compiler.mark_generator();
emit!(compiler, Instruction::YieldValue);
emit!(compiler, Instruction::Pop);
self.compile_comprehension(
"<genexpr>",
None,
generators,
&|compiler| {
compiler.compile_comprehension_element(elt)?;
compiler.mark_generator();
emit!(compiler, Instruction::YieldValue);
emit!(compiler, Instruction::Pop);
Ok(())
})?;
Ok(())
},
ComprehensionType::Generator,
Self::contains_await(elt),
)?;
}
Expr::Starred(_) => {
return Err(self.error(CodegenErrorType::InvalidStarExpr));
@@ -2744,9 +2765,35 @@ impl Compiler {
init_collection: Option<Instruction>,
generators: &[located_ast::Comprehension],
compile_element: &dyn Fn(&mut Self) -> CompileResult<()>,
comprehension_type: ComprehensionType,
element_contains_await: bool,
) -> CompileResult<()> {
let prev_ctx = self.ctx;
let is_async = generators.iter().any(|g| g.is_async);
let has_an_async_gen = generators.iter().any(|g| g.is_async);
// async comprehensions are allowed in various contexts:
// - list/set/dict comprehensions in async functions
// - always for generator expressions
// Note: generators have to be treated specially since their async version is a fundamentally
// different type (aiter vs iter) instead of just an awaitable.
// for if it actually is async, we check if any generator is async or if the element contains await
// if the element expression contains await, but the context doesn't allow for async,
// then we continue on here with is_async=false and will produce a syntax once the await is hit
let is_async_list_set_dict_comprehension = comprehension_type
!= ComprehensionType::Generator
&& (has_an_async_gen || element_contains_await) // does it have to be async? (uses await or async for)
&& prev_ctx.func == FunctionContext::AsyncFunction; // is it allowed to be async? (in an async function)
let is_async_generator_comprehension = comprehension_type == ComprehensionType::Generator
&& (has_an_async_gen || element_contains_await);
// since one is for generators, and one for not generators, they should never both be true
debug_assert!(!(is_async_list_set_dict_comprehension && is_async_generator_comprehension));
let is_async = is_async_list_set_dict_comprehension || is_async_generator_comprehension;
self.ctx = CompileContext {
loop_data: None,
@@ -2838,7 +2885,7 @@ impl Compiler {
// End of for loop:
self.switch_to_block(after_block);
if is_async {
if has_an_async_gen {
emit!(self, Instruction::EndAsyncFor);
}
}
@@ -2877,7 +2924,7 @@ impl Compiler {
self.compile_expression(&generators[0].iter)?;
// Get iterator / turn item into an iterator
if is_async {
if has_an_async_gen {
emit!(self, Instruction::GetAIter);
} else {
emit!(self, Instruction::GetIter);
@@ -2885,11 +2932,15 @@ impl Compiler {
// Call just created <listcomp> function:
emit!(self, Instruction::CallFunctionPositional { nargs: 1 });
if is_async {
if is_async_list_set_dict_comprehension {
// async, but not a generator and not an async for
// in this case, we end up with an awaitable
// that evaluates to the list/set/dict, so here we add an await
emit!(self, Instruction::GetAwaitable);
self.emit_load_const(ConstantData::None);
emit!(self, Instruction::YieldFrom);
}
Ok(())
}
@@ -3016,6 +3067,117 @@ impl Compiler {
fn mark_generator(&mut self) {
self.current_code_info().flags |= bytecode::CodeFlags::IS_GENERATOR
}
/// Whether the expression contains an await expression and
/// thus requires the function to be async.
/// Async with and async for are statements, so I won't check for them here
fn contains_await(expression: &located_ast::Expr) -> bool {
use located_ast::*;
match &expression {
Expr::Call(ExprCall {
func,
args,
keywords,
..
}) => {
Self::contains_await(func)
|| args.iter().any(Self::contains_await)
|| keywords.iter().any(|kw| Self::contains_await(&kw.value))
}
Expr::BoolOp(ExprBoolOp { values, .. }) => values.iter().any(Self::contains_await),
Expr::BinOp(ExprBinOp { left, right, .. }) => {
Self::contains_await(left) || Self::contains_await(right)
}
Expr::Subscript(ExprSubscript { value, slice, .. }) => {
Self::contains_await(value) || Self::contains_await(slice)
}
Expr::UnaryOp(ExprUnaryOp { operand, .. }) => Self::contains_await(operand),
Expr::Attribute(ExprAttribute { value, .. }) => Self::contains_await(value),
Expr::Compare(ExprCompare {
left, comparators, ..
}) => Self::contains_await(left) || comparators.iter().any(Self::contains_await),
Expr::Constant(ExprConstant { .. }) => false,
Expr::List(ExprList { elts, .. }) => elts.iter().any(Self::contains_await),
Expr::Tuple(ExprTuple { elts, .. }) => elts.iter().any(Self::contains_await),
Expr::Set(ExprSet { elts, .. }) => elts.iter().any(Self::contains_await),
Expr::Dict(ExprDict { keys, values, .. }) => {
keys.iter()
.any(|key| key.as_ref().map_or(false, Self::contains_await))
|| values.iter().any(Self::contains_await)
}
Expr::Slice(ExprSlice {
lower, upper, step, ..
}) => {
lower.as_ref().map_or(false, |l| Self::contains_await(l))
|| upper.as_ref().map_or(false, |u| Self::contains_await(u))
|| step.as_ref().map_or(false, |s| Self::contains_await(s))
}
Expr::Yield(ExprYield { value, .. }) => {
value.as_ref().map_or(false, |v| Self::contains_await(v))
}
Expr::Await(ExprAwait { .. }) => true,
Expr::YieldFrom(ExprYieldFrom { value, .. }) => Self::contains_await(value),
Expr::JoinedStr(ExprJoinedStr { values, .. }) => {
values.iter().any(Self::contains_await)
}
Expr::FormattedValue(ExprFormattedValue {
value,
conversion: _,
format_spec,
..
}) => {
Self::contains_await(value)
|| format_spec
.as_ref()
.map_or(false, |fs| Self::contains_await(fs))
}
Expr::Name(located_ast::ExprName { .. }) => false,
Expr::Lambda(located_ast::ExprLambda { body, .. }) => Self::contains_await(body),
Expr::ListComp(located_ast::ExprListComp {
elt, generators, ..
}) => {
Self::contains_await(elt)
|| generators.iter().any(|gen| Self::contains_await(&gen.iter))
}
Expr::SetComp(located_ast::ExprSetComp {
elt, generators, ..
}) => {
Self::contains_await(elt)
|| generators.iter().any(|gen| Self::contains_await(&gen.iter))
}
Expr::DictComp(located_ast::ExprDictComp {
key,
value,
generators,
..
}) => {
Self::contains_await(key)
|| Self::contains_await(value)
|| generators.iter().any(|gen| Self::contains_await(&gen.iter))
}
Expr::GeneratorExp(located_ast::ExprGeneratorExp {
elt, generators, ..
}) => {
Self::contains_await(elt)
|| generators.iter().any(|gen| Self::contains_await(&gen.iter))
}
Expr::Starred(expr) => Self::contains_await(&expr.value),
Expr::IfExp(located_ast::ExprIfExp {
test, body, orelse, ..
}) => {
Self::contains_await(test)
|| Self::contains_await(body)
|| Self::contains_await(orelse)
}
Expr::NamedExpr(located_ast::ExprNamedExpr {
target,
value,
range: _,
}) => Self::contains_await(target) || Self::contains_await(value),
}
}
}
trait EmitArg<Arg: OpArgType> {

View File

@@ -0,0 +1,171 @@
import asyncio
from types import GeneratorType, AsyncGeneratorType
async def f_async(x):
await asyncio.sleep(0.001)
return x
def f_iter():
for i in range(5):
yield i
async def f_aiter():
for i in range(5):
await asyncio.sleep(0.001)
yield i
async def run_async():
# list
x = [i for i in range(5)]
assert isinstance(x, list)
for i, e in enumerate(x):
assert e == i
x = [await f_async(i) for i in range(5)]
assert isinstance(x, list)
for i, e in enumerate(x):
assert e == i
x = [e async for e in f_aiter()]
assert isinstance(x, list)
for i, e in enumerate(x):
assert e == i
x = [await f_async(i) async for i in f_aiter()]
assert isinstance(x, list)
for i, e in enumerate(x):
assert e == i
# set
x = {i for i in range(5)}
assert isinstance(x, set)
for e in x:
assert e in range(5)
assert x == {0, 1, 2, 3, 4}
x = {await f_async(i) for i in range(5)}
assert isinstance(x, set)
for e in x:
assert e in range(5)
assert x == {0, 1, 2, 3, 4}
x = {e async for e in f_aiter()}
assert isinstance(x, set)
for e in x:
assert e in range(5)
assert x == {0, 1, 2, 3, 4}
x = {await f_async(i) async for i in f_aiter()}
assert isinstance(x, set)
for e in x:
assert e in range(5)
assert x == {0, 1, 2, 3, 4}
# dict
x = {i: i for i in range(5)}
assert isinstance(x, dict)
for k, v in x.items():
assert k == v
assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}
x = {await f_async(i): i for i in range(5)}
assert isinstance(x, dict)
for k, v in x.items():
assert k == v
assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}
x = {i: await f_async(i) for i in range(5)}
assert isinstance(x, dict)
for k, v in x.items():
assert k == v
assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}
x = {await f_async(i): await f_async(i) for i in range(5)}
assert isinstance(x, dict)
for k, v in x.items():
assert k == v
assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}
x = {i: i async for i in f_aiter()}
assert isinstance(x, dict)
for k, v in x.items():
assert k == v
assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}
x = {await f_async(i): i async for i in f_aiter()}
assert isinstance(x, dict)
for k, v in x.items():
assert k == v
assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}
x = {i: await f_async(i) async for i in f_aiter()}
assert isinstance(x, dict)
for k, v in x.items():
assert k == v
assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}
x = {await f_async(i): await f_async(i) async for i in f_aiter()}
assert isinstance(x, dict)
for k, v in x.items():
assert k == v
assert x == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}
# generator
x = (i for i in range(5))
assert isinstance(x, GeneratorType)
for i, e in enumerate(x):
assert e == i
x = (await f_async(i) for i in range(5))
assert isinstance(x, AsyncGeneratorType)
i = 0
async for e in x:
assert e == i
i += 1
x = (e async for e in f_aiter())
assert isinstance(x, AsyncGeneratorType)
i = 0
async for e in x:
assert i == e
i += 1
x = (await f_async(i) async for i in f_aiter())
assert isinstance(x, AsyncGeneratorType)
i = 0
async for e in x:
assert i == e
i += 1
def run_sync():
async def test_async_for(x):
i = 0
async for e in x:
assert e == i
i += 1
x = (i for i in range(5))
assert isinstance(x, GeneratorType)
for i, e in enumerate(x):
assert e == i
x = (await f_async(i) for i in range(5))
assert isinstance(x, AsyncGeneratorType)
asyncio.run(test_async_for(x), debug=True)
x = (e async for e in f_aiter())
assert isinstance(x, AsyncGeneratorType)
asyncio.run(test_async_for(x), debug=True)
x = (await f_async(i) async for i in f_aiter())
assert isinstance(x, AsyncGeneratorType)
asyncio.run(test_async_for(x), debug=True)
asyncio.run(run_async(), debug=True)
run_sync()