Source code for time_machine

from __future__ import annotations

import datetime as dt
import functools
import inspect
import os
import sys
import time as time_module
import uuid
from collections.abc import Awaitable, Callable, Generator
from collections.abc import Generator as TypingGenerator
from enum import Enum
from time import gmtime as orig_gmtime
from time import struct_time
from types import TracebackType
from typing import Any, TypeAlias, TypeVar, cast, overload
from unittest import TestCase
from zoneinfo import ZoneInfo

import _time_machine

if sys.version_info >= (3, 11):
    from typing import assert_never
else:

    def assert_never(_: Any) -> None:  # pragma: no cover
        pass


# time.clock_gettime and time.CLOCK_REALTIME not always available
# e.g. on builds against old macOS = official Python.org installer
try:
    from time import CLOCK_REALTIME
except ImportError:
    # Dummy value that won't compare equal to any value
    CLOCK_REALTIME = sys.maxsize  # type: ignore[misc]

try:
    from time import tzset

    HAVE_TZSET = True
except ImportError:  # pragma: no cover
    # Windows
    HAVE_TZSET = False

try:
    from dateutil.parser import parse as parse_datetime

    HAVE_DATEUTIL = True
except ImportError:  # pragma: no cover
    HAVE_DATEUTIL = False

try:
    import pytest
except ImportError:  # pragma: no cover
    HAVE_PYTEST = False
else:
    HAVE_PYTEST = True

NANOSECONDS_PER_SECOND = 1_000_000_000

# Windows' time epoch is not unix epoch but in 1601. This constant helps us
# translate to it.
_system_epoch = orig_gmtime(0)
SYSTEM_EPOCH_TIMESTAMP_NS = int(
    dt.datetime(
        _system_epoch.tm_year,
        _system_epoch.tm_mon,
        _system_epoch.tm_mday,
        _system_epoch.tm_hour,
        _system_epoch.tm_min,
        _system_epoch.tm_sec,
        tzinfo=dt.timezone.utc,
    ).timestamp()
    * NANOSECONDS_PER_SECOND
)

DestinationBaseType: TypeAlias = (
    int | float | dt.datetime | dt.timedelta | dt.date | str
)
DestinationType: TypeAlias = (
    DestinationBaseType
    | Callable[[], DestinationBaseType]
    | TypingGenerator[DestinationBaseType, None, None]
)

_F = TypeVar("_F", bound=Callable[..., Any])
_AF = TypeVar("_AF", bound=Callable[..., Awaitable[Any]])
TestCaseType = TypeVar("TestCaseType", bound=type[TestCase])

# copied from typeshed:
_TimeTuple = tuple[int, int, int, int, int, int, int, int, int]


[docs] class NaiveMode(Enum): MIXED = 1 UTC = 2 LOCAL = 3 ERROR = 4
naive_mode = NaiveMode.MIXED def extract_timestamp_tzname( destination: DestinationType, ) -> tuple[float, str | None]: dest: DestinationBaseType if isinstance(destination, Generator): dest = next(destination) elif callable(destination): dest = destination() else: dest = destination timestamp: float tzname: str | None = None if isinstance(dest, int): timestamp = float(dest) elif isinstance(dest, float): timestamp = dest elif isinstance(dest, dt.datetime): if isinstance(dest.tzinfo, ZoneInfo): tzname = dest.tzinfo.key elif dest.tzinfo == dt.timezone.utc: tzname = "UTC" elif dest.tzinfo is None: if naive_mode == NaiveMode.MIXED or naive_mode == NaiveMode.UTC: dest = dest.replace(tzinfo=dt.timezone.utc) elif naive_mode == NaiveMode.LOCAL: pass elif naive_mode == NaiveMode.ERROR: raise RuntimeError( "Naive datetime provided while time_machine.naive_mode is set to ERROR. " "Please provide a timezone-aware datetime." ) else: # pragma: no cover assert_never(naive_mode) timestamp = dest.timestamp() elif isinstance(dest, dt.timedelta): timestamp = time_module.time() + dest.total_seconds() elif isinstance(dest, dt.date): if naive_mode == NaiveMode.MIXED or naive_mode == NaiveMode.UTC: timestamp = dt.datetime.combine( dest, dt.time(0, 0), tzinfo=dt.timezone.utc ).timestamp() elif naive_mode == NaiveMode.LOCAL: timestamp = dt.datetime.combine(dest, dt.time(0, 0)).timestamp() elif naive_mode == NaiveMode.ERROR: raise RuntimeError( "date object provided while time_machine.naive_mode is set to ERROR. " "Please provide a timezone-aware datetime." ) else: # pragma: no cover assert_never(naive_mode) elif isinstance(dest, str): try: parsed = dt.datetime.fromisoformat(dest) except ValueError as exc: if HAVE_DATEUTIL: try: parsed = parse_datetime(dest) except ValueError as dateutil_exc: raise dateutil_exc from None else: raise exc if parsed.tzinfo is None: if naive_mode == NaiveMode.MIXED: # Keep as naive, for backwards compatibility pass elif naive_mode == NaiveMode.UTC: parsed = parsed.replace(tzinfo=dt.timezone.utc) elif naive_mode == NaiveMode.LOCAL: pass elif naive_mode == NaiveMode.ERROR: raise RuntimeError( "Naive datetime string provided while time_machine.naive_mode is set to ERROR. " "Please provide a timezone-aware datetime string." ) else: # pragma: no cover assert_never(naive_mode) timestamp = parsed.timestamp() else: raise TypeError(f"Unsupported destination {dest!r}") return timestamp, tzname
[docs] class Traveller: def __init__( self, destination_timestamp: float, destination_tzname: str | None, tick: bool, ) -> None: self._destination_timestamp_ns = int( destination_timestamp * NANOSECONDS_PER_SECOND ) self._destination_tzname = destination_tzname self._tick = tick self._requested = False def time(self) -> float: return self.time_ns() / NANOSECONDS_PER_SECOND def time_ns(self) -> int: if not self._tick: return self._destination_timestamp_ns base = SYSTEM_EPOCH_TIMESTAMP_NS + self._destination_timestamp_ns now_ns: int = _time_machine.original_time_ns() if not self._requested: self._requested = True self._real_start_timestamp_ns = now_ns return base return base + (now_ns - self._real_start_timestamp_ns)
[docs] def shift(self, delta: dt.timedelta | int | float) -> None: if isinstance(delta, dt.timedelta): total_seconds = delta.total_seconds() elif isinstance(delta, (int, float)): total_seconds = delta else: raise TypeError(f"Unsupported type for delta argument: {delta!r}") self._destination_timestamp_ns += int(total_seconds * NANOSECONDS_PER_SECOND)
[docs] def move_to( self, destination: DestinationType, tick: bool | None = None, ) -> None: self._stop() timestamp, self._destination_tzname = extract_timestamp_tzname(destination) self._destination_timestamp_ns = int(timestamp * NANOSECONDS_PER_SECOND) self._requested = False self._start() if tick is not None: self._tick = tick
def _start(self) -> None: if HAVE_TZSET and self._destination_tzname is not None: self._orig_tz = os.environ.get("TZ") os.environ["TZ"] = self._destination_tzname tzset() def _stop(self) -> None: if HAVE_TZSET and self._destination_tzname is not None: if self._orig_tz is None: del os.environ["TZ"] else: os.environ["TZ"] = self._orig_tz tzset()
traveller_stack: list[Traveller] = [] original_uuid_generate_time_safe = None original_uuid_uuid_create = None
[docs] class travel: def __init__(self, destination: DestinationType, *, tick: bool = True) -> None: self.destination_timestamp, self.destination_tzname = extract_timestamp_tzname( destination ) self.tick = tick
[docs] def start(self) -> Traveller: if "freezegun" in sys.modules and dt.datetime.__name__ == "FakeDatetime": raise RuntimeError("time-machine cannot start when freezegun is active.") if not traveller_stack: _time_machine.patch() # During time travel, patch the uuid module's time-based generation function to # None, which makes it use time.time(). Otherwise it makes a system call to # find the current datetime. The time it finds is stored in generated UUID1 # values. global original_uuid_generate_time_safe global original_uuid_uuid_create original_uuid_generate_time_safe = uuid._generate_time_safe # type: ignore[attr-defined] original_uuid_uuid_create = uuid._UuidCreate # type: ignore[attr-defined] uuid._generate_time_safe = None # type: ignore[attr-defined] uuid._UuidCreate = None # type: ignore[attr-defined] traveller = Traveller( destination_timestamp=self.destination_timestamp, destination_tzname=self.destination_tzname, tick=self.tick, ) traveller_stack.append(traveller) traveller._start() return traveller
[docs] def stop(self) -> None: traveller_stack.pop()._stop() if not traveller_stack: _time_machine.unpatch() global original_uuid_generate_time_safe global original_uuid_uuid_create uuid._generate_time_safe = original_uuid_generate_time_safe # type: ignore[attr-defined] uuid._UuidCreate = original_uuid_uuid_create # type: ignore[attr-defined] original_uuid_generate_time_safe = None original_uuid_uuid_create = None
def __enter__(self) -> Traveller: return self.start() def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: self.stop() async def __aenter__(self) -> Traveller: return self.start() async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: self.stop() @overload def __call__(self, wrapped: TestCaseType) -> TestCaseType: # pragma: no cover ... @overload def __call__(self, wrapped: _AF) -> _AF: # pragma: no cover ... @overload def __call__(self, wrapped: _F) -> _F: # pragma: no cover ... # 'Any' below is workaround for Mypy error: # Overloaded function implementation does not accept all possible arguments # of signature def __call__( self, wrapped: TestCaseType | _AF | _F | Any ) -> TestCaseType | _AF | _F | Any: if isinstance(wrapped, type): # Class decorator if not issubclass(wrapped, TestCase): raise TypeError("Can only decorate unittest.TestCase subclasses.") # Modify the setUpClass method orig_setUpClass = wrapped.setUpClass.__func__ # type: ignore[attr-defined] @functools.wraps(orig_setUpClass) def setUpClass(cls: type[TestCase]) -> None: self.__enter__() try: orig_setUpClass(cls) except Exception: self.__exit__(*sys.exc_info()) raise wrapped.setUpClass = classmethod(setUpClass) # type: ignore[assignment] orig_tearDownClass = ( wrapped.tearDownClass.__func__ # type: ignore[attr-defined] ) @functools.wraps(orig_tearDownClass) def tearDownClass(cls: type[TestCase]) -> None: orig_tearDownClass(cls) self.__exit__(None, None, None) wrapped.tearDownClass = classmethod( # type: ignore[assignment] tearDownClass ) return cast(TestCaseType, wrapped) elif inspect.iscoroutinefunction(wrapped): @functools.wraps(wrapped) async def wrapper(*args: Any, **kwargs: Any) -> Any: with self: return await wrapped(*args, **kwargs) return cast(_AF, wrapper) else: assert callable(wrapped) @functools.wraps(wrapped) def wrapper(*args: Any, **kwargs: Any) -> Any: with self: return wrapped(*args, **kwargs) return cast(_F, wrapper)
# datetime module def now(tz: dt.tzinfo | None = None) -> dt.datetime: return dt.datetime.fromtimestamp(time(), tz) def utcnow() -> dt.datetime: return dt.datetime.fromtimestamp(time(), dt.timezone.utc).replace(tzinfo=None) # time module def clock_gettime(clk_id: int) -> float: if clk_id != CLOCK_REALTIME: result: float = _time_machine.original_clock_gettime(clk_id) return result return time() def clock_gettime_ns(clk_id: int) -> int: if clk_id != CLOCK_REALTIME: result: int = _time_machine.original_clock_gettime_ns(clk_id) return result return time_ns() def gmtime(secs: float | None = None) -> struct_time: result: struct_time if secs is not None: result = _time_machine.original_gmtime(secs) else: result = _time_machine.original_gmtime(traveller_stack[-1].time()) return result def localtime(secs: float | None = None) -> struct_time: result: struct_time if secs is not None: result = _time_machine.original_localtime(secs) else: result = _time_machine.original_localtime(traveller_stack[-1].time()) return result def strftime(format: str, t: _TimeTuple | struct_time | None = None) -> str: result: str if t is not None: result = _time_machine.original_strftime(format, t) else: result = _time_machine.original_strftime(format, localtime()) return result def time() -> float: return traveller_stack[-1].time() def time_ns() -> int: return traveller_stack[-1].time_ns() # pytest plugin if HAVE_PYTEST: # pragma: no branch def pytest_collection_modifyitems(items: list[pytest.Item]) -> None: """ Add the fixture to any tests with the marker. """ for item in items: if item.get_closest_marker("time_machine"): item.fixturenames.insert(0, "time_machine") # type: ignore[attr-defined] def pytest_configure(config: pytest.Config) -> None: """ Register the marker. """ config.addinivalue_line( "markers", "time_machine(...): set the time with time-machine" ) class TimeMachineFixture: traveller: travel | None traveller_obj: Traveller | None def __init__(self) -> None: self.traveller = None self.traveller_obj = None def move_to( self, destination: DestinationType, tick: bool | None = None, ) -> None: if self.traveller is None: if tick is None: tick = True self.traveller = travel(destination, tick=tick) self.traveller_obj = self.traveller.start() else: assert self.traveller_obj is not None self.traveller_obj.move_to(destination, tick=tick) def shift(self, delta: dt.timedelta | int | float) -> None: if self.traveller is None: raise RuntimeError( "Initialize time_machine with move_to() before using shift()." ) assert self.traveller_obj is not None self.traveller_obj.shift(delta=delta) def stop(self) -> None: if self.traveller is not None: self.traveller.stop() @pytest.fixture(name="time_machine") def time_machine_fixture( request: pytest.FixtureRequest, ) -> TypingGenerator[TimeMachineFixture, None, None]: fixture = TimeMachineFixture() marker = request.node.get_closest_marker("time_machine") if marker: fixture.move_to(*marker.args, **marker.kwargs) yield fixture fixture.stop() # escape hatch class _EscapeHatchDatetimeDatetime: def now(self, tz: dt.tzinfo | None = None) -> dt.datetime: result: dt.datetime = _time_machine.original_now(tz) return result def utcnow(self) -> dt.datetime: result: dt.datetime = _time_machine.original_utcnow() return result class _EscapeHatchDatetime: def __init__(self) -> None: self.datetime = _EscapeHatchDatetimeDatetime() class _EscapeHatchTime: def clock_gettime(self, clk_id: int) -> float: result: float = _time_machine.original_clock_gettime(clk_id) return result def clock_gettime_ns(self, clk_id: int) -> int: result: int = _time_machine.original_clock_gettime_ns(clk_id) return result def gmtime(self, secs: float | None = None) -> struct_time: result: struct_time = _time_machine.original_gmtime(secs) return result def localtime(self, secs: float | None = None) -> struct_time: result: struct_time = _time_machine.original_localtime(secs) return result def strftime(self, format: str, t: _TimeTuple | struct_time | None = None) -> str: result: str if t is not None: result = _time_machine.original_strftime(format, t) else: result = _time_machine.original_strftime(format) return result def time(self) -> float: result: float = _time_machine.original_time() return result def time_ns(self) -> int: result: int = _time_machine.original_time_ns() return result class _EscapeHatch: def __init__(self) -> None: self.datetime = _EscapeHatchDatetime() self.time = _EscapeHatchTime() def is_travelling(self) -> bool: return bool(traveller_stack) escape_hatch = _EscapeHatch()