diff --git a/extra_tests/snippets/syntax_async.py b/extra_tests/snippets/syntax_async.py index c0a86017a..0b0d3e048 100644 --- a/extra_tests/snippets/syntax_async.py +++ b/extra_tests/snippets/syntax_async.py @@ -1,5 +1,5 @@ import asyncio - +import unittest class ContextManager: async def __aenter__(self): @@ -70,3 +70,44 @@ assert ls == [ "hello3", "hello4", ] + + +class TestAsyncWith(unittest.TestCase): + def testAenterAttributeError1(self): + class LacksAenter(object): + async def __aexit__(self, *exc): + pass + + async def foo(): + async with LacksAenter(): + pass + + with self.assertRaisesRegex(AttributeError, '__aenter__'): + foo().send(None) + + def testAenterAttributeError2(self): + class LacksAenterAndAexit(object): + pass + + async def foo(): + async with LacksAenterAndAexit(): + pass + + with self.assertRaisesRegex(AttributeError, '__aenter__'): + foo().send(None) + + def testAexitAttributeError(self): + class LacksAexit(object): + async def __aenter__(self): + pass + + async def foo(): + async with LacksAexit(): + pass + + with self.assertRaisesRegex(AttributeError, '__aexit__'): + foo().send(None) + + +if __name__ == "__main__": + unittest.main() diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 31579ab0d..6f93c1654 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -788,9 +788,10 @@ impl ExecutingFrame<'_> { } bytecode::Instruction::BeforeAsyncWith => { let mgr = self.pop_value(); + let aenter_res = + vm.call_special_method(mgr.clone(), identifier!(vm, __aenter__), ())?; let aexit = mgr.get_attr(identifier!(vm, __aexit__), vm)?; self.push_value(aexit); - let aenter_res = vm.call_special_method(mgr, identifier!(vm, __aenter__), ())?; self.push_value(aenter_res); Ok(None)