Source code for pytest_quantum.assertions.snapshot

"""Snapshot / golden-file assertions for quantum circuits and distributions."""

from __future__ import annotations

import os
from pathlib import Path
from typing import Any, cast

import numpy as np

_SNAPSHOT_DIR = Path(".pytest-quantum-snapshots")


[docs] def assert_unitary_snapshot( circuit: object, name: str, *, update: bool = False, atol: float = 1e-6, allow_global_phase: bool = True, ) -> None: """Assert that a circuit's unitary matches its saved snapshot. First run: saves the unitary to .pytest-quantum-snapshots/<name>.npy. Subsequent runs: loads and compares. Pass --quantum-update-snapshots on the CLI to regenerate all snapshots. Args: circuit: Any supported quantum circuit (Qiskit, Cirq, Braket, PennyLane). name: Unique snapshot name (used as filename, no path separators). update: If True, overwrite existing snapshot. atol: Absolute tolerance (default 1e-6). allow_global_phase: If True (default), ignore global phase differences. Raises: AssertionError: If the unitary differs from the snapshot. TypeError: If the circuit type is not supported. Example:: def test_compiler_stable(compiled_circuit): from pytest_quantum import assert_unitary_snapshot assert_unitary_snapshot(compiled_circuit, "my_compiler_bell") """ from pytest_quantum.converters.to_unitary import to_unitary _update = update or _should_update() path = _path(name, ".npy") actual = to_unitary(circuit) if _update or not path.exists(): _write_snapshot_safe(path, actual) return expected = np.load(str(path)) if actual.shape != expected.shape: raise AssertionError( f"Snapshot '{name}': shape changed {expected.shape}{actual.shape}.\n" f" Run with --quantum-update-snapshots to regenerate." ) if np.allclose(actual, expected, atol=atol): return if allow_global_phase: flat_idx = int(np.argmax(np.abs(expected))) e_val = expected.flat[flat_idx] a_val = actual.flat[flat_idx] if abs(e_val) > 1e-10 and abs(a_val) > 1e-10: phase = a_val / e_val if np.allclose(actual, phase * expected, atol=atol): return max_diff = float(np.max(np.abs(actual - expected))) raise AssertionError( f"Snapshot '{name}': unitary has changed.\n" f" Max |diff|: {max_diff:.2e} (tolerance: {atol:.2e})\n" f" Run with --quantum-update-snapshots to accept the new behaviour." )
[docs] def assert_distribution_snapshot( counts: dict[str, int], name: str, *, update: bool = False, max_tvd: float = 0.05, ) -> None: """Assert that a measurement distribution matches its saved snapshot. Saves the normalised probability distribution derived from counts. Comparison uses Total Variation Distance (TVD). Args: counts: Measurement count dict e.g. {"00": 512, "11": 512}. name: Unique snapshot name. update: If True, overwrite existing snapshot. max_tvd: Maximum allowed TVD from snapshot (default 0.05). Raises: AssertionError: If TVD from snapshot exceeds max_tvd. ValueError: If counts is empty. Example:: def test_distribution_stable(aer_simulator): counts = run_bell(aer_simulator, shots=4000) from pytest_quantum import assert_distribution_snapshot assert_distribution_snapshot(counts, "bell_distribution") """ from pytest_quantum.stats.tests import tvd_from_counts _update = update or _should_update() path = _path(name, "_dist.npy") total = sum(counts.values()) if total == 0: raise ValueError("counts dict is empty — nothing to snapshot") probs = {k: v / total for k, v in counts.items()} if _update or not path.exists(): keys = np.array(list(probs.keys())) vals = np.array(list(probs.values()), dtype=np.float64) _write_snapshot_safe(path, np.array([keys, vals], dtype=object)) return raw = np.load(str(path), allow_pickle=True) snap_keys = list(raw[0]) snap_vals = [float(v) for v in raw[1]] snap_counts = { k: round(v * 10000) for k, v in zip(snap_keys, snap_vals, strict=False) } current_counts = {k: round(v * 10000) for k, v in probs.items()} distance = tvd_from_counts(current_counts, snap_counts) if distance <= max_tvd: return raise AssertionError( f"Snapshot '{name}': distribution has changed.\n" f" TVD from snapshot: {distance:.4f} (max allowed: {max_tvd:.4f})\n" f" Run with --quantum-update-snapshots to accept the new distribution." )
# --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _write_snapshot_safe(path: Path, data: object) -> None: """Write snapshot file with advisory file lock for xdist safety.""" path.parent.mkdir(parents=True, exist_ok=True) lock_path = path.with_suffix(".lock") try: import filelock lock = filelock.FileLock(str(lock_path), timeout=10) with lock: np.save(str(path), cast("Any", data)) except ImportError: # filelock not installed — just write directly (acceptable for single-worker) np.save(str(path), cast("Any", data)) def _path(name: str, suffix: str) -> Path: safe = name.replace("/", "_").replace("\\", "_") return _SNAPSHOT_DIR / f"{safe}{suffix}" def _should_update() -> bool: return os.environ.get("PYTEST_QUANTUM_UPDATE_SNAPSHOTS", "") == "1"