diff --git a/Lib/contextlib.py b/Lib/contextlib.py index 58e9a4988..b831d8916 100644 --- a/Lib/contextlib.py +++ b/Lib/contextlib.py @@ -145,14 +145,17 @@ class _GeneratorContextManager( except StopIteration: return False else: - raise RuntimeError("generator didn't stop") + try: + raise RuntimeError("generator didn't stop") + finally: + self.gen.close() else: if value is None: # Need to force instantiation so we can reliably # tell if we get the same exception back value = typ() try: - self.gen.throw(typ, value, traceback) + self.gen.throw(value) except StopIteration as exc: # Suppress StopIteration *unless* it's the same exception that # was passed to throw(). This prevents a StopIteration @@ -187,7 +190,10 @@ class _GeneratorContextManager( raise exc.__traceback__ = traceback return False - raise RuntimeError("generator didn't stop after throw()") + try: + raise RuntimeError("generator didn't stop after throw()") + finally: + self.gen.close() class _AsyncGeneratorContextManager( _GeneratorContextManagerBase, @@ -212,14 +218,17 @@ class _AsyncGeneratorContextManager( except StopAsyncIteration: return False else: - raise RuntimeError("generator didn't stop") + try: + raise RuntimeError("generator didn't stop") + finally: + await self.gen.aclose() else: if value is None: # Need to force instantiation so we can reliably # tell if we get the same exception back value = typ() try: - await self.gen.athrow(typ, value, traceback) + await self.gen.athrow(value) except StopAsyncIteration as exc: # Suppress StopIteration *unless* it's the same exception that # was passed to throw(). This prevents a StopIteration @@ -254,7 +263,10 @@ class _AsyncGeneratorContextManager( raise exc.__traceback__ = traceback return False - raise RuntimeError("generator didn't stop after athrow()") + try: + raise RuntimeError("generator didn't stop after athrow()") + finally: + await self.gen.aclose() def contextmanager(func): @@ -441,7 +453,16 @@ class suppress(AbstractContextManager): # exactly reproduce the limitations of the CPython interpreter. # # See http://bugs.python.org/issue12029 for more details - return exctype is not None and issubclass(exctype, self._exceptions) + if exctype is None: + return + if issubclass(exctype, self._exceptions): + return True + if issubclass(exctype, BaseExceptionGroup): + match, rest = excinst.split(self._exceptions) + if rest is None: + return True + raise rest + return False class _BaseExitStack: diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py index 71215ecf5..f6478b43e 100644 --- a/Lib/test/test_contextlib.py +++ b/Lib/test/test_contextlib.py @@ -10,6 +10,7 @@ import unittest from contextlib import * # Tests __all__ from test import support from test.support import os_helper +from test.support.testcase import ExceptionIsLikeMixin import weakref @@ -158,9 +159,45 @@ class ContextManagerTestCase(unittest.TestCase): yield ctx = whoo() ctx.__enter__() - self.assertRaises( - RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None - ) + with self.assertRaises(RuntimeError): + ctx.__exit__(TypeError, TypeError("foo"), None) + if support.check_impl_detail(cpython=True): + # The "gen" attribute is an implementation detail. + self.assertFalse(ctx.gen.gi_suspended) + + def test_contextmanager_trap_no_yield(self): + @contextmanager + def whoo(): + if False: + yield + ctx = whoo() + with self.assertRaises(RuntimeError): + ctx.__enter__() + + def test_contextmanager_trap_second_yield(self): + @contextmanager + def whoo(): + yield + yield + ctx = whoo() + ctx.__enter__() + with self.assertRaises(RuntimeError): + ctx.__exit__(None, None, None) + if support.check_impl_detail(cpython=True): + # The "gen" attribute is an implementation detail. + self.assertFalse(ctx.gen.gi_suspended) + + def test_contextmanager_non_normalised(self): + @contextmanager + def whoo(): + try: + yield + except RuntimeError: + raise SyntaxError + ctx = whoo() + ctx.__enter__() + with self.assertRaises(SyntaxError): + ctx.__exit__(RuntimeError, None, None) def test_contextmanager_except(self): state = [] @@ -241,6 +278,23 @@ def woohoo(): self.assertEqual(ex.args[0], 'issue29692:Unchained') self.assertIsNone(ex.__cause__) + def test_contextmanager_wrap_runtimeerror(self): + @contextmanager + def woohoo(): + try: + yield + except Exception as exc: + raise RuntimeError(f'caught {exc}') from exc + with self.assertRaises(RuntimeError): + with woohoo(): + 1 / 0 + # If the context manager wrapped StopIteration in a RuntimeError, + # we also unwrap it, because we can't tell whether the wrapping was + # done by the generator machinery or by the generator itself. + with self.assertRaises(StopIteration): + with woohoo(): + raise StopIteration + def _create_contextmanager_attribs(self): def attribs(**kw): def decorate(func): @@ -252,6 +306,7 @@ def woohoo(): @attribs(foo='bar') def baz(spam): """Whee!""" + yield return baz def test_contextmanager_attribs(self): @@ -308,8 +363,11 @@ def woohoo(): def test_recursive(self): depth = 0 + ncols = 0 @contextmanager def woohoo(): + nonlocal ncols + ncols += 1 nonlocal depth before = depth depth += 1 @@ -323,6 +381,7 @@ def woohoo(): recursive() recursive() + self.assertEqual(ncols, 10) self.assertEqual(depth, 0) @@ -374,12 +433,10 @@ class FileContextTestCase(unittest.TestCase): def testWithOpen(self): tfn = tempfile.mktemp() try: - f = None with open(tfn, "w", encoding="utf-8") as f: self.assertFalse(f.closed) f.write("Booh\n") self.assertTrue(f.closed) - f = None with self.assertRaises(ZeroDivisionError): with open(tfn, "r", encoding="utf-8") as f: self.assertFalse(f.closed) @@ -1160,7 +1217,7 @@ class TestRedirectStderr(TestRedirectStream, unittest.TestCase): orig_stream = "stderr" -class TestSuppress(unittest.TestCase): +class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase): @support.requires_docstrings def test_instance_docs(self): @@ -1214,6 +1271,51 @@ class TestSuppress(unittest.TestCase): 1/0 self.assertTrue(outer_continued) + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_exception_groups(self): + eg_ve = lambda: ExceptionGroup( + "EG with ValueErrors only", + [ValueError("ve1"), ValueError("ve2"), ValueError("ve3")], + ) + eg_all = lambda: ExceptionGroup( + "EG with many types of exceptions", + [ValueError("ve1"), KeyError("ke1"), ValueError("ve2"), KeyError("ke2")], + ) + with suppress(ValueError): + raise eg_ve() + with suppress(ValueError, KeyError): + raise eg_all() + with self.assertRaises(ExceptionGroup) as eg1: + with suppress(ValueError): + raise eg_all() + self.assertExceptionIsLike( + eg1.exception, + ExceptionGroup( + "EG with many types of exceptions", + [KeyError("ke1"), KeyError("ke2")], + ), + ) + + # Check handling of BaseExceptionGroup, using GeneratorExit so that + # we don't accidentally discard a ctrl-c with KeyboardInterrupt. + with suppress(GeneratorExit): + raise BaseExceptionGroup("message", [GeneratorExit()]) + # If we raise a BaseException group, we can still suppress parts + with self.assertRaises(BaseExceptionGroup) as eg1: + with suppress(KeyError): + raise BaseExceptionGroup("message", [GeneratorExit("g"), KeyError("k")]) + self.assertExceptionIsLike( + eg1.exception, BaseExceptionGroup("message", [GeneratorExit("g")]), + ) + # If we suppress all the leaf BaseExceptions, we get a non-base ExceptionGroup + with self.assertRaises(ExceptionGroup) as eg1: + with suppress(GeneratorExit): + raise BaseExceptionGroup("message", [GeneratorExit("g"), KeyError("k")]) + self.assertExceptionIsLike( + eg1.exception, ExceptionGroup("message", [KeyError("k")]), + ) + class TestChdir(unittest.TestCase): def make_relative_path(self, *parts):