Fix BeforeAsyncWith

This commit is contained in:
Daeun Lee
2022-07-16 10:14:25 +09:00
parent a563d07919
commit c8e599e29d
2 changed files with 44 additions and 2 deletions

View File

@@ -1,5 +1,5 @@
import asyncio
import unittest
class ContextManager:
async def __aenter__(self):
@@ -70,3 +70,44 @@ assert ls == [
"hello3",
"hello4",
]
class TestAsyncWith(unittest.TestCase):
def testAenterAttributeError1(self):
class LacksAenter(object):
async def __aexit__(self, *exc):
pass
async def foo():
async with LacksAenter():
pass
with self.assertRaisesRegex(AttributeError, '__aenter__'):
foo().send(None)
def testAenterAttributeError2(self):
class LacksAenterAndAexit(object):
pass
async def foo():
async with LacksAenterAndAexit():
pass
with self.assertRaisesRegex(AttributeError, '__aenter__'):
foo().send(None)
def testAexitAttributeError(self):
class LacksAexit(object):
async def __aenter__(self):
pass
async def foo():
async with LacksAexit():
pass
with self.assertRaisesRegex(AttributeError, '__aexit__'):
foo().send(None)
if __name__ == "__main__":
unittest.main()

View File

@@ -788,9 +788,10 @@ impl ExecutingFrame<'_> {
}
bytecode::Instruction::BeforeAsyncWith => {
let mgr = self.pop_value();
let aenter_res =
vm.call_special_method(mgr.clone(), identifier!(vm, __aenter__), ())?;
let aexit = mgr.get_attr(identifier!(vm, __aexit__), vm)?;
self.push_value(aexit);
let aenter_res = vm.call_special_method(mgr, identifier!(vm, __aenter__), ())?;
self.push_value(aenter_res);
Ok(None)