Merge pull request #5276 from youknowone/async-for-comprehension

Async for comprehension
This commit is contained in:
Jeong, YunWon
2024-04-25 21:00:22 +09:00
committed by GitHub
4 changed files with 147 additions and 83 deletions

View File

@@ -513,16 +513,15 @@ class AsyncGenAsyncioTest(unittest.TestCase):
return self.yielded
self.check_async_iterator_anext(MyAsyncIterWithTypesCoro)
# TODO: RUSTPYTHON: async for gen expression compilation
# def test_async_gen_aiter(self):
# async def gen():
# yield 1
# yield 2
# g = gen()
# async def consume():
# return [i async for i in aiter(g)]
# res = self.loop.run_until_complete(consume())
# self.assertEqual(res, [1, 2])
def test_async_gen_aiter(self):
async def gen():
yield 1
yield 2
g = gen()
async def consume():
return [i async for i in aiter(g)]
res = self.loop.run_until_complete(consume())
self.assertEqual(res, [1, 2])
# TODO: RUSTPYTHON, NameError: name 'aiter' is not defined
@unittest.expectedFailure
@@ -1569,22 +1568,23 @@ class AsyncGenAsyncioTest(unittest.TestCase):
self.assertIn('unhandled exception during asyncio.run() shutdown',
message['message'])
# TODO: RUSTPYTHON: async for gen expression compilation
# def test_async_gen_expression_01(self):
# async def arange(n):
# for i in range(n):
# await asyncio.sleep(0.01)
# yield i
# TODO: RUSTPYTHON; TypeError: object async_generator can't be used in 'await' expression
@unittest.expectedFailure
def test_async_gen_expression_01(self):
async def arange(n):
for i in range(n):
await asyncio.sleep(0.01)
yield i
# def make_arange(n):
# # This syntax is legal starting with Python 3.7
# return (i * 2 async for i in arange(n))
def make_arange(n):
# This syntax is legal starting with Python 3.7
return (i * 2 async for i in arange(n))
# async def run():
# return [i async for i in make_arange(10)]
async def run():
return [i async for i in make_arange(10)]
# res = self.loop.run_until_complete(run())
# self.assertEqual(res, [i * 2 for i in range(10)])
res = self.loop.run_until_complete(run())
self.assertEqual(res, [i * 2 for i in range(10)])
# TODO: RUSTPYTHON: async for gen expression compilation
# def test_async_gen_expression_02(self):

View File

@@ -418,44 +418,46 @@ class GrammarTests(unittest.TestCase):
gns['__annotations__']
# TODO: RUSTPYTHON
# def test_var_annot_custom_maps(self):
# # tests with custom locals() and __annotations__
# ns = {'__annotations__': CNS()}
# exec('X: int; Z: str = "Z"; (w): complex = 1j', ns)
# self.assertEqual(ns['__annotations__']['x'], int)
# self.assertEqual(ns['__annotations__']['z'], str)
# with self.assertRaises(KeyError):
# ns['__annotations__']['w']
# nonloc_ns = {}
# class CNS2:
# def __init__(self):
# self._dct = {}
# def __setitem__(self, item, value):
# nonlocal nonloc_ns
# self._dct[item] = value
# nonloc_ns[item] = value
# def __getitem__(self, item):
# return self._dct[item]
# exec('x: int = 1', {}, CNS2())
# self.assertEqual(nonloc_ns['__annotations__']['x'], int)
@unittest.expectedFailure
def test_var_annot_custom_maps(self):
# tests with custom locals() and __annotations__
ns = {'__annotations__': CNS()}
exec('X: int; Z: str = "Z"; (w): complex = 1j', ns)
self.assertEqual(ns['__annotations__']['x'], int)
self.assertEqual(ns['__annotations__']['z'], str)
with self.assertRaises(KeyError):
ns['__annotations__']['w']
nonloc_ns = {}
class CNS2:
def __init__(self):
self._dct = {}
def __setitem__(self, item, value):
nonlocal nonloc_ns
self._dct[item] = value
nonloc_ns[item] = value
def __getitem__(self, item):
return self._dct[item]
exec('x: int = 1', {}, CNS2())
self.assertEqual(nonloc_ns['__annotations__']['x'], int)
# TODO: RUSTPYTHON
# def test_var_annot_refleak(self):
# # complex case: custom locals plus custom __annotations__
# # this was causing refleak
# cns = CNS()
# nonloc_ns = {'__annotations__': cns}
# class CNS2:
# def __init__(self):
# self._dct = {'__annotations__': cns}
# def __setitem__(self, item, value):
# nonlocal nonloc_ns
# self._dct[item] = value
# nonloc_ns[item] = value
# def __getitem__(self, item):
# return self._dct[item]
# exec('X: str', {}, CNS2())
# self.assertEqual(nonloc_ns['__annotations__']['x'], str)
@unittest.expectedFailure
def test_var_annot_refleak(self):
# complex case: custom locals plus custom __annotations__
# this was causing refleak
cns = CNS()
nonloc_ns = {'__annotations__': cns}
class CNS2:
def __init__(self):
self._dct = {'__annotations__': cns}
def __setitem__(self, item, value):
nonlocal nonloc_ns
self._dct[item] = value
nonloc_ns[item] = value
def __getitem__(self, item):
return self._dct[item]
exec('X: str', {}, CNS2())
self.assertEqual(nonloc_ns['__annotations__']['x'], str)
def test_var_annot_rhs(self):

View File

@@ -2629,24 +2629,30 @@ impl Compiler {
compile_element: &dyn Fn(&mut Self) -> CompileResult<()>,
) -> CompileResult<()> {
let prev_ctx = self.ctx;
let is_async = generators.iter().any(|g| g.is_async);
self.ctx = CompileContext {
loop_data: None,
in_class: prev_ctx.in_class,
func: FunctionContext::Function,
func: if is_async {
FunctionContext::AsyncFunction
} else {
FunctionContext::Function
},
};
// We must have at least one generator:
assert!(!generators.is_empty());
let flags = bytecode::CodeFlags::NEW_LOCALS | bytecode::CodeFlags::IS_OPTIMIZED;
let flags = if is_async {
flags | bytecode::CodeFlags::IS_COROUTINE
} else {
flags
};
// Create magnificent function <listcomp>:
self.push_output(
bytecode::CodeFlags::NEW_LOCALS | bytecode::CodeFlags::IS_OPTIMIZED,
1,
1,
0,
name.to_owned(),
);
self.push_output(flags, 1, 1, 0, name.to_owned());
let arg0 = self.varname(".0")?;
let return_none = init_collection.is_none();
@@ -2657,13 +2663,11 @@ impl Compiler {
let mut loop_labels = vec![];
for generator in generators {
if generator.is_async {
unimplemented!("async for comprehensions");
}
let loop_block = self.new_block();
let after_block = self.new_block();
// emit!(self, Instruction::SetupLoop);
if loop_labels.is_empty() {
// Load iterator onto stack (passed as first argument):
emit!(self, Instruction::LoadFast(arg0));
@@ -2672,20 +2676,36 @@ impl Compiler {
self.compile_expression(&generator.iter)?;
// Get iterator / turn item into an iterator
emit!(self, Instruction::GetIter);
if generator.is_async {
emit!(self, Instruction::GetAIter);
} else {
emit!(self, Instruction::GetIter);
}
}
loop_labels.push((loop_block, after_block));
self.switch_to_block(loop_block);
emit!(
self,
Instruction::ForIter {
target: after_block,
}
);
self.compile_store(&generator.target)?;
if generator.is_async {
emit!(
self,
Instruction::SetupExcept {
handler: after_block,
}
);
emit!(self, Instruction::GetANext);
self.emit_constant(ConstantData::None);
emit!(self, Instruction::YieldFrom);
self.compile_store(&generator.target)?;
emit!(self, Instruction::PopBlock);
} else {
emit!(
self,
Instruction::ForIter {
target: after_block,
}
);
self.compile_store(&generator.target)?;
}
// Now evaluate the ifs:
for if_condition in &generator.ifs {
@@ -2701,6 +2721,9 @@ impl Compiler {
// End of for loop:
self.switch_to_block(after_block);
if is_async {
emit!(self, Instruction::EndAsyncFor);
}
}
if return_none {
@@ -2737,10 +2760,19 @@ impl Compiler {
self.compile_expression(&generators[0].iter)?;
// Get iterator / turn item into an iterator
emit!(self, Instruction::GetIter);
if is_async {
emit!(self, Instruction::GetAIter);
} else {
emit!(self, Instruction::GetIter);
};
// Call just created <listcomp> function:
emit!(self, Instruction::CallFunctionPositional { nargs: 1 });
if is_async {
emit!(self, Instruction::GetAwaitable);
self.emit_constant(ConstantData::None);
emit!(self, Instruction::YieldFrom);
}
Ok(())
}

View File

@@ -351,6 +351,10 @@ impl ExecutingFrame<'_> {
let mut arg_state = bytecode::OpArgState::default();
loop {
let idx = self.lasti() as usize;
// eprintln!(
// "location: {:?} {}",
// self.code.locations[idx], self.code.source_path
// );
self.update_lasti(|i| *i += 1);
let bytecode::CodeUnit { op, arg } = instrs[idx];
let arg = arg_state.extend(arg);
@@ -993,6 +997,9 @@ impl ExecutingFrame<'_> {
Ok(None)
}
bytecode::Instruction::GetANext => {
#[cfg(debug_assertions)] // remove when GetANext is fully implemented
let orig_stack_len = self.state.stack.len();
let aiter = self.top_value();
let awaitable = if aiter.class().is(vm.ctx.types.async_generator) {
vm.call_special_method(aiter, identifier!(vm, __anext__), ())?
@@ -1030,6 +1037,8 @@ impl ExecutingFrame<'_> {
})?
};
self.push_value(awaitable);
#[cfg(debug_assertions)]
debug_assert_eq!(orig_stack_len + 1, self.state.stack.len());
Ok(None)
}
bytecode::Instruction::EndAsyncFor => {
@@ -1238,6 +1247,7 @@ impl ExecutingFrame<'_> {
fn unwind_blocks(&mut self, vm: &VirtualMachine, reason: UnwindReason) -> FrameResult {
// First unwind all existing blocks on the block stack:
while let Some(block) = self.current_block() {
// eprintln!("unwinding block: {:.60?} {:.60?}", block.typ, reason);
match block.typ {
BlockType::Loop => match reason {
UnwindReason::Break { target } => {
@@ -1935,6 +1945,7 @@ impl ExecutingFrame<'_> {
}
fn push_block(&mut self, typ: BlockType) {
// eprintln!("block pushed: {:.60?} {}", typ, self.state.stack.len());
self.state.blocks.push(Block {
typ,
level: self.state.stack.len(),
@@ -1944,6 +1955,12 @@ impl ExecutingFrame<'_> {
#[track_caller]
fn pop_block(&mut self) -> Block {
let block = self.state.blocks.pop().expect("No more blocks to pop!");
// eprintln!(
// "block popped: {:.60?} {} -> {} ",
// block.typ,
// self.state.stack.len(),
// block.level
// );
#[cfg(debug_assertions)]
if self.state.stack.len() < block.level {
dbg!(&self);
@@ -1965,6 +1982,11 @@ impl ExecutingFrame<'_> {
#[inline]
#[track_caller] // not a real track_caller but push_value is not very useful
fn push_value(&mut self, obj: PyObjectRef) {
// eprintln!(
// "push_value {} / len: {} +1",
// obj.class().name(),
// self.state.stack.len()
// );
match self.state.stack.try_push(obj) {
Ok(()) => {}
Err(_e) => self.fatal("tried to push value onto stack but overflowed max_stackdepth"),
@@ -1975,7 +1997,14 @@ impl ExecutingFrame<'_> {
#[track_caller] // not a real track_caller but pop_value is not very useful
fn pop_value(&mut self) -> PyObjectRef {
match self.state.stack.pop() {
Some(x) => x,
Some(x) => {
// eprintln!(
// "pop_value {} / len: {}",
// x.class().name(),
// self.state.stack.len()
// );
x
}
None => self.fatal("tried to pop value but there was nothing on the stack"),
}
}
@@ -2002,6 +2031,7 @@ impl ExecutingFrame<'_> {
}
#[inline]
#[track_caller]
fn nth_value(&self, depth: u32) -> &PyObject {
let stack = &self.state.stack;
&stack[stack.len() - depth as usize - 1]