Source code for pytest_quantum.assertions.states

"""Statevector-level assertions for quantum tests.

Use these when you have the full output statevector from a simulator (Aer
statevector mode, Cirq simulator, or the Graphix backend) and want to verify
the quantum state directly — more informative than comparing shot distributions.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from pytest_quantum.stats.tests import fidelity as _fidelity

if TYPE_CHECKING:
    from numpy.typing import NDArray


[docs] def assert_state_fidelity_above( actual: NDArray[np.complex128], target: NDArray[np.complex128], threshold: float = 0.99, ) -> None: """Assert that two pure quantum states have fidelity at or above *threshold*. Fidelity :math:`F = |\\langle\\text{actual}|\\text{target}\\rangle|^2` equals 1.0 for identical states (up to global phase) and 0.0 for orthogonal states. This is the primary assertion for MBQC / Graphix tests where the circuit does not have a fixed unitary representation. Args: actual: Simulated output statevector, any shape (will be flattened). target: Ideal target statevector, same number of elements. threshold: Minimum acceptable fidelity (default ``0.99``). Raises: AssertionError: If ``fidelity(actual, target) < threshold``. ValueError: If the arrays have incompatible sizes. Example:: import numpy as np from pytest_quantum import assert_state_fidelity_above BELL = np.array([1, 0, 0, 1], dtype=complex) / np.sqrt(2) def test_bell_graphix(graphix_backend): from graphix.transpiler import Circuit circuit = Circuit(2) circuit.h(0) circuit.cnot(0, 1) pattern = circuit.transpile().pattern output = graphix_backend.run_pattern(pattern) assert_state_fidelity_above(output, BELL, threshold=0.999) """ f = _fidelity(actual, target) if f < threshold: raise AssertionError( f"State fidelity too low.\n" f" |⟨actual|target⟩|² = {f:.6f}\n" f" Required ≥ {threshold}\n" f" Shortfall = {threshold - f:.2e}" )
[docs] def assert_normalized(statevector: object, *, atol: float = 1e-6) -> None: """Assert statevector has unit norm: ||ψ||₂ = 1. A common bug in manual statevector construction is forgetting to normalize. Args: statevector: Complex array-like of any shape (flattened internally). atol: Absolute tolerance from 1.0 (default 1e-6). Raises: AssertionError: If ||sv||₂ is not within atol of 1.0, showing the actual norm. Example:: >>> import numpy as np >>> sv = np.array([1, 0, 0, 0], dtype=complex) # |00> >>> assert_normalized(sv) # passes >>> sv_bad = np.array([1, 1], dtype=complex) # NOT normalized >>> assert_normalized(sv_bad) # fails: norm = 1.4142 """ sv = np.asarray(statevector, dtype=np.complex128).flatten() norm = float(np.linalg.norm(sv)) if abs(norm - 1.0) > atol: raise AssertionError( f"Statevector is not normalized.\n" f" Norm: {norm:.6f} (expected 1.0, tolerance: {atol:.2e})\n" f" Deviation: {abs(norm - 1.0):.6f}\n" f" Hint: divide by np.linalg.norm(sv) to normalize." )
[docs] def assert_states_close( actual: NDArray[np.complex128], target: NDArray[np.complex128], *, atol: float = 1e-6, ) -> None: """Assert that two statevectors are element-wise close, up to global phase. Stricter than :func:`assert_state_fidelity_above` — use for exact simulator-to-simulator comparisons where you want bit-for-bit agreement. Args: actual: Simulated statevector (will be flattened and normalised). target: Ideal statevector (will be flattened and normalised). atol: Absolute tolerance per element (default ``1e-6``). Raises: AssertionError: If any element differs by more than *atol* after removing the global phase. Example:: def test_plus_state(aer_statevector_simulator): from qiskit import QuantumCircuit, transpile qc = QuantumCircuit(1) qc.h(0) qc.save_statevector() qc_t = transpile(qc, aer_statevector_simulator) sv = aer_statevector_simulator.run(qc_t).result().get_statevector() PLUS = np.array([1, 1]) / np.sqrt(2) assert_states_close(sv.data, PLUS) """ a = np.asarray(actual, dtype=np.complex128).flatten() t = np.asarray(target, dtype=np.complex128).flatten() if a.size != t.size: raise AssertionError( f"Statevector size mismatch: actual has {a.size} elements, " f"target has {t.size}." ) # Normalise a = a / np.linalg.norm(a) t = t / np.linalg.norm(t) # Remove global phase: align the largest-magnitude element of `a` to `t` idx = int(np.argmax(np.abs(t))) if abs(t[idx]) > 1e-10 and abs(a[idx]) > 1e-10: phase = a[idx] / t[idx] a = a / phase if not np.allclose(a, t, atol=atol): max_diff = float(np.max(np.abs(a - t))) raise AssertionError( f"Statevectors are not close (after global-phase alignment).\n" f" Max |difference|: {max_diff:.2e} (tolerance: {atol:.2e})" )