unittest’s new context methods in Python 3.11 (with backports)

A testing concerto

Python 3.11 only made one change to unittest, but it’s a good one: context manager methods. These methods can simplify setup and teardown logic in many cases, such as dynamic use of unittest.mock.

In this post we’ll look at a couple fo examples using the new methods, and a backport you can copy-paste into your projects, in plain and Django flavours.

A time-saving example

Take this test case, which uses my time-mocking library time-machine to fix the current microsecond to 0 (for whatever reason):

import datetime as dt
from unittest import TestCase

import time_machine


class ExampleTests(TestCase):
    def setUp(self):
        now = dt.datetime.now().replace(microsecond=0)

        self.traveller = time_machine.travel(now)
        self.traveller.start()
        self.addCleanup(self.traveller.stop)

        super().setUp()

    def test_microsecond(self):
        self.assertEqual(dt.datetime.now().microsecond, 0)

It takes three steps to set up the mocking:

  1. Create a “traveller” by calling time_machine.travel().
  2. Make it mock time by calling its start() start method.
  3. Schedule un-mocking by passing its stop() method to addCleanup().

This is verbose. Plus, it’s all too easy to too easy to forget the cleanup step, leading to test pollution.

You can merge those three steps into one with Python 3.11’s enterContext() method:

import datetime as dt
from unittest import TestCase

import time_machine


class ExampleTests(TestCase):
    def setUp(self):
        now = dt.datetime.now().replace(microsecond=0)

        self.traveller = self.enterContext(time_machine.travel(now))

        super().setUp()

    def test_microsecond(self):
        self.assertEqual(dt.datetime.now().microsecond, 0)

Much tidier.

enterContext() executes three steps:

  1. Call the context manager’s __enter__ method.
  2. Call addCleanup to schedule the context manager’s __exit__ to run at class cleanup.
  3. Return whatever __enter__ returned.

The example works because the time machine traveller class also behaves as a context manager. It has an __enter__() method equivalent to start(), and an __exit__() method equivalent to stop(). Many, maybe most, mocking tools work as context managers, such as unittest.mock patchers and requests_mock.Mocker.

A second class-based example

Here’s another example test case that mocks an environment variable with unittest.mock at the class-level:

import os
from unittest import mock
from unittest import TestCase

from example import get_text_colour


class ExampleTests(TestCase):
    @classmethod
    def setUpClass(cls):
        cls.text_colour = get_text_colour()

        patcher = mock.patch.dict(os.environ, {"TEXT_COLOUR": cls.text_colour})
        patcher.start()
        cls.addClassCleanup(patcher.stop)

        super().setUpClass()

    def test_env_var(self):
        self.assertEqual(os.environ["TEXT_COLOUR"], self.text_colour)

The class-level mocking applies once for the whole test case. This can save time, and also ensures the mocking applies during any test data creation.

You can simplify this test case similarly by using Python 3.11’s new enterClassContext() method:

import os
from unittest import mock
from unittest import TestCase

from example import get_text_colour


class ExampleTests(TestCase):
    @classmethod
    def setUpClass(cls):
        cls.text_colour = get_text_colour()

        cls.enterClassContext(
            mock.patch.dict(os.environ, {"TEXT_COLOUR": cls.text_colour})
        )

        super().setUpClass()

    def test_env_var(self):
        self.assertEqual(os.environ["TEXT_COLOUR"], self.text_colour)

Once more, a couple lines saved.

All the methods

Python 3.11 added four different enter*Context() methods, for different test levels, plus async.

  1. The module-level unittest.enterModuleContext(), which you can use from within a setUpModule() function:

    from unittest import enterModuleContext
    from unittest import TestCase
    
    
    def setUpModule():
        enterModuleContext(...)
    
    
    class ExampleTests(TestCase):
        ...
    
  2. The class-level TestCase.enterClassContext(), which you can use within setUpClass() (and setUpTestData() in Django test cases):

    from unittest import TestCase
    
    
    class ExampleTests(TestCase):
        @classmethod
        def setUpClass(cls):
            cls.enterClassContext(...)
            super().setUpClass()
    
        ...
    
  3. The test-level TestCase.enterContext(), which you can use within setUp() or test methods:

    from unittest import TestCase
    
    
    class ExampleTests(TestCase):
        def test_something(self):
            self.enterContext(...)
    
            ...
    
  4. The async test-level IsolatedAsyncioTestCase.enterAsyncContext(), which you can use within asyncSetUp() or async tests to enter an asynchronous context manager:

    from unittest import IsolatedAsyncioTestCase
    
    
    class ExampleTests(IsolatedAsyncioTestCase):
        async def test_something(self):
            await self.enterAsyncContext(...)
    
            ...
    

That’s the lot.

A backport for plain unittest projects

If you’re not on Python 3.11 yet, you can backport this feature to use it to simplify your tests. Since it builds on the long-existing add*Cleanup() methods, it doesn’t take that much code.

Below is the code to copy to backport the functions on a plain unittest project. Add or merge it into a base test file, from which you import TestCase throughout your project. If you are working on a Django project, see the next section for a more specific backport.

The code is sourced from unittest/case.py (and async_case.py) on the 3.11 branch of CPython. The type hints come from the corresponding files in typeshed. (I couldn’t properly type _enter_context(), due to this Mypy issue.)

The if sys.version_info ... lines mean that when you upgrade Python, the native functions will be used. You can use the pypugrade tool to automatically remove old versioned blocks like these.

from collections.abc import Callable
from contextlib import AbstractAsyncContextManager
from contextlib import AbstractContextManager
import sys
from typing import Any
from typing import TypeVar
import unittest

_T = TypeVar("_T")

if sys.version_info < (3, 11):

    def _enter_context(cm: Any, addcleanup: Callable[..., None]) -> Any:
        # We look up the special methods on the type to match the with
        # statement.
        cls = type(cm)
        try:
            enter = cls.__enter__
            exit = cls.__exit__
        except AttributeError:
            raise TypeError(
                f"'{cls.__module__}.{cls.__qualname__}' object does "
                f"not support the context manager protocol"
            ) from None
        result = enter(cm)
        addcleanup(exit, cm, None, None, None)
        return result


if sys.version_info < (3, 11):

    def enterModuleContext(cm: AbstractContextManager[_T]) -> _T:
        result: _T = _enter_context(cm, unittest.addModuleCleanup)
        return result

else:
    enterModuleContext = unittest.enterModuleContext


class TestCase(unittest.TestCase):
    if sys.version_info < (3, 11):

        def enterContext(self, cm: AbstractContextManager[_T]) -> _T:
            result: _T = _enter_context(cm, self.addCleanup)
            return result

        @classmethod
        def enterClassContext(cls, cm: AbstractContextManager[_T]) -> _T:
            result: _T = _enter_context(cm, cls.addClassCleanup)
            return result


class IsolatedAsyncioTestCase(unittest.IsolatedAsyncioTestCase):
    if sys.version_info < (3, 11):

        async def enterAsyncContext(
            self,
            cm: AbstractAsyncContextManager[_T],
        ) -> _T:
            """Enters the supplied asynchronous context manager.
            If successful, also adds its __aexit__ method as a cleanup
            function and returns the result of the __aenter__ method.
            """
            # We look up the special methods on the type to match the with
            # statement.
            cls = type(cm)
            try:
                enter = cls.__aenter__
                exit = cls.__aexit__
            except AttributeError:
                raise TypeError(
                    f"'{cls.__module__}.{cls.__qualname__}' object does "
                    f"not support the asynchronous context manager protocol"
                ) from None
            result = await enter(cm)
            self.addAsyncCleanup(exit, cm, None, None, None)
            return result

Enjoy.

A backport for Django projects

When testing a Django project, you normally use Django’s test case classes. I recommend adding your own subclasses of these classes in your projects, to allow customization, such as extra assertion methods or a custom test client.

Below is a version of the backport, excluding enterAsyncContext since Django doesn’t have a dedicated async test case. This backport sets up a custom SimpleTestCase and its subclasses. If you already have a custom class, you can merge the changes.

from collections.abc import Callable
from contextlib import AbstractContextManager
import sys
import unittest
from typing import Any
from typing import TypeVar

from django import test


_T = TypeVar("_T")


def _enter_context(cm: Any, addcleanup: Callable[..., None]) -> Any:
    # We look up the special methods on the type to match the with
    # statement.
    cls = type(cm)
    try:
        enter = cls.__enter__
        exit = cls.__exit__
    except AttributeError:
        raise TypeError(
            f"'{cls.__module__}.{cls.__qualname__}' object does "
            f"not support the context manager protocol"
        ) from None
    result = enter(cm)
    addcleanup(exit, cm, None, None, None)
    return result


if sys.version_info < (3, 11):

    def enterModuleContext(cm: AbstractContextManager[_T]) -> _T:
        result: _T = _enter_context(cm, unittest.addModuleCleanup)
        return result

else:
    enterModuleContext = unittest.enterModuleContext


class SimpleTestCase(test.SimpleTestCase):
    if sys.version_info < (3, 11):

        def enterContext(self, cm: AbstractContextManager[_T]) -> _T:
            result: _T = _enter_context(cm, self.addCleanup)
            return result

        @classmethod
        def enterClassContext(cls, cm: AbstractContextManager[_T]) -> _T:
            result: _T = _enter_context(cm, cls.addClassCleanup)
            return result


class TestCase(test.TestCase, SimpleTestCase):
    pass


class TransactionTestCase(test.TransactionTestCase, SimpleTestCase):
    pass


class LiveServerTestCase(test.LiveServerTestCase, SimpleTestCase):
    pass

There you go!

Fin

Thanks to Serhiy Storchaka for contributing this feature, and Andrew Svetlov for reviewing. And thanks to Alex Waygood for contributing the type hints to typeshed, and Jelle Zijlstra for reviewing that.

May your tests be tidier,

—Adam


Read my book Boost Your Git DX to Git better.


Subscribe via RSS, Twitter, Mastodon, or email:

One summary email a week, no spam, I pinky promise.

Related posts:

Tags: ,