Source code for pytest_quantum.assertions.stim_assertions

"""Assertions for Stim stabilizer circuits and quantum error correction."""

from __future__ import annotations

from typing import Any

import numpy as np


[docs] def assert_stim_logical_error_rate_below( circuit: Any, max_error_rate: float, *, shots: int = 10_000, seed: int | None = None, ) -> None: """Assert that a Stim QEC circuit has logical error rate below threshold. The circuit must contain DETECTOR and OBSERVABLE_INCLUDE instructions. Args: circuit: ``stim.Circuit`` with detectors and observables. max_error_rate: Maximum allowed logical error rate. shots: Number of shots (default 10,000). seed: Optional random seed. Raises: AssertionError: If logical error rate exceeds *max_error_rate*. ImportError: If stim is not installed. ValueError: If the circuit has no observables. Example:: import stim from pytest_quantum import assert_stim_logical_error_rate_below c = stim.Circuit.generated( "repetition_code:memory", rounds=3, distance=3, after_clifford_depolarization=0.001, ) assert_stim_logical_error_rate_below(c, max_error_rate=0.05, shots=1000) """ try: import stim # noqa: F401 except ImportError as exc: raise ImportError("stim is required: pip install stim") from exc if circuit.num_observables == 0: raise ValueError( "Circuit has no observables (OBSERVABLE_INCLUDE instructions). " "Add OBSERVABLE_INCLUDE to track logical qubit errors." ) sampler = circuit.compile_detector_sampler(seed=seed) _det_data, obs_data = sampler.sample(shots, separate_observables=True) logical_errors = np.any(obs_data, axis=1) error_rate = float(np.mean(logical_errors)) if error_rate > max_error_rate: raise AssertionError( f"Logical error rate {error_rate:.4f} exceeds threshold " f"{max_error_rate:.4f}\n" f" Shots: {shots}\n" f" Logical errors: {int(np.sum(logical_errors))}/{shots}\n" f" Observables tracked: {circuit.num_observables}" )
[docs] def assert_stim_detector_error_rate_below( circuit: Any, max_error_rate: float, *, shots: int = 10_000, seed: int | None = None, ) -> None: """Assert that the mean detector error rate is below threshold. Useful for verifying that a noise model produces errors at the expected rate. Args: circuit: ``stim.Circuit`` with DETECTOR instructions and noise. max_error_rate: Maximum allowed mean detector error rate. shots: Number of shots (default 10,000). seed: Optional random seed. Raises: AssertionError: If mean detector error rate exceeds *max_error_rate*. ImportError: If stim is not installed. ValueError: If the circuit has no detectors. """ try: import stim # noqa: F401 except ImportError as exc: raise ImportError("stim is required: pip install stim") from exc if circuit.num_detectors == 0: raise ValueError("Circuit has no detectors (DETECTOR instructions).") sampler = circuit.compile_detector_sampler(seed=seed) det_data = sampler.sample(shots) per_detector_rates = np.mean(det_data, axis=0) mean_rate = float(np.mean(per_detector_rates)) if mean_rate > max_error_rate: raise AssertionError( f"Mean detector error rate {mean_rate:.4f} exceeds threshold " f"{max_error_rate:.4f}\n" f" Detectors: {circuit.num_detectors}\n" f" Max single-detector rate: " f"{float(np.max(per_detector_rates)):.4f}" )
[docs] def assert_stabilizer_state( tableau_simulator: Any, expected_stabilizers: list[str], ) -> None: """Assert a Stim TableauSimulator is in the expected stabilizer state. Args: tableau_simulator: ``stim.TableauSimulator`` after running a circuit. expected_stabilizers: List of Pauli strings, e.g. ``["+XX", "+ZZ"]``. Raises: AssertionError: If any stabilizer is not satisfied (expectation ≠ +1). ImportError: If stim is not installed. Example:: import stim from pytest_quantum import assert_stabilizer_state sim = stim.TableauSimulator() sim.h(0) sim.cnot(0, 1) assert_stabilizer_state(sim, ["+XX", "+ZZ"]) """ try: import stim except ImportError as exc: raise ImportError("stim is required: pip install stim") from exc failing = [] for stab_str in expected_stabilizers: p = stim.PauliString(stab_str) expectation = tableau_simulator.peek_observable_expectation(p) if expectation != 1: failing.append(f" {stab_str!r}: expectation = {expectation} (expected +1)") if failing: raise AssertionError( f"Stabilizer state check failed for " f"{len(failing)}/{len(expected_stabilizers)} stabilizers:\n" + "\n".join(failing) )