Source code for pytest_quantum.assertions.sweeps

"""Parametrized circuit sweep assertions."""

from __future__ import annotations

import itertools
from typing import TYPE_CHECKING, Any

import numpy as np

if TYPE_CHECKING:
    from collections.abc import Callable

    from numpy.typing import NDArray


[docs] def assert_circuit_sweep( circuit_fn: Callable[..., Any], param_values: dict[str, list[float] | NDArray[np.float64]], expected_fn: Callable[..., NDArray[np.complex128]], *, atol: float = 1e-6, allow_global_phase: bool = True, ) -> None: """Assert a parametrized circuit matches expected unitary for all parameter values. Calls circuit_fn(**params) for every combination of param_values, computes the unitary, and compares to expected_fn(**params). Args: circuit_fn: Callable that accepts keyword parameter arguments and returns a quantum circuit. param_values: Dict mapping parameter name to list/array of values. All combinations are tested (cartesian product). expected_fn: Callable that accepts the same keyword arguments and returns the expected unitary as a numpy array. atol: Absolute tolerance for unitary comparison (default 1e-6). allow_global_phase: If True, ignore global phase differences (default True). Raises: AssertionError: If any parameter combination fails. Example:: import numpy as np from pytest_quantum import assert_circuit_sweep from qiskit import QuantumCircuit def rx_circuit(theta): qc = QuantumCircuit(1) qc.rx(theta, 0) return qc def rx_expected(theta): c, s = np.cos(theta / 2), np.sin(theta / 2) return np.array([[c, -1j * s], [-1j * s, c]]) assert_circuit_sweep( rx_circuit, {"theta": np.linspace(0, 2 * np.pi, 8)}, rx_expected, ) """ from pytest_quantum.converters.to_unitary import to_unitary keys = list(param_values.keys()) value_lists = [list(param_values[k]) for k in keys] failures: list[str] = [] for combo in itertools.product(*value_lists): kwargs = dict(zip(keys, combo, strict=False)) circuit = circuit_fn(**kwargs) actual = to_unitary(circuit) expected = np.asarray(expected_fn(**kwargs), dtype=np.complex128) if actual.shape != expected.shape: failures.append( f" params={kwargs}: shape mismatch {actual.shape} vs {expected.shape}" ) continue if np.allclose(actual, expected, atol=atol): continue if allow_global_phase: flat_idx = int(np.argmax(np.abs(expected))) e_val = expected.flat[flat_idx] a_val = actual.flat[flat_idx] if abs(e_val) > 1e-10 and abs(a_val) > 1e-10: phase = a_val / e_val if np.allclose(actual, phase * expected, atol=atol): continue max_diff = float(np.max(np.abs(actual - expected))) failures.append( f" params={_fmt_params(kwargs)}: max|diff|={max_diff:.2e} (tol={atol:.2e})" ) if failures: raise AssertionError( f"assert_circuit_sweep failed for {len(failures)} / " f"{_count_combos(param_values)} parameter combinations:\n" + "\n".join(failures) )
[docs] def assert_circuit_sweep_states( circuit_fn: Callable[..., Any], initial_state: NDArray[np.complex128], param_values: dict[str, list[float] | NDArray[np.float64]], expected_fn: Callable[..., NDArray[np.complex128]], *, min_fidelity: float = 0.99, ) -> None: """Assert a parametrized circuit produces expected output states for all params. Applies circuit_fn(**params) to initial_state and compares the output statevector fidelity to expected_fn(**params). Args: circuit_fn: Callable(**params) -> circuit. initial_state: Input statevector, shape (2**n,). param_values: Dict of parameter name -> list of values. expected_fn: Callable(**params) -> expected output statevector. min_fidelity: Minimum fidelity |<actual|expected>|^2 (default 0.99). Raises: AssertionError: If any combination fails the fidelity check. Example:: import numpy as np from pytest_quantum import assert_circuit_sweep_states psi0 = np.array([1, 0], dtype=complex) def rz_circuit(phi): from qiskit import QuantumCircuit qc = QuantumCircuit(1) qc.rz(phi, 0) return qc def expected_state(phi): return np.array([np.exp(-1j * phi / 2), 0]) assert_circuit_sweep_states( rz_circuit, psi0, {"phi": [0, np.pi / 2, np.pi]}, expected_state ) """ from pytest_quantum.converters.to_unitary import to_unitary keys = list(param_values.keys()) value_lists = [list(param_values[k]) for k in keys] failures: list[str] = [] for combo in itertools.product(*value_lists): kwargs = dict(zip(keys, combo, strict=False)) circuit = circuit_fn(**kwargs) U = to_unitary(circuit) psi_in = np.asarray(initial_state, dtype=np.complex128) actual = U @ psi_in actual /= np.linalg.norm(actual) expected = np.asarray(expected_fn(**kwargs), dtype=np.complex128) expected /= np.linalg.norm(expected) fidelity = float(abs(np.vdot(actual, expected)) ** 2) if fidelity < min_fidelity: failures.append( f" params={_fmt_params(kwargs)}: fidelity={fidelity:.4f} < {min_fidelity}" ) if failures: raise AssertionError( f"assert_circuit_sweep_states failed for {len(failures)} / " f"{_count_combos(param_values)} parameter combinations:\n" + "\n".join(failures) )
[docs] def assert_parametrized_unitary_continuous( circuit_fn: Callable[..., Any], param_name: str, param_range: tuple[float, float], *, n_samples: int = 20, max_jump: float = 0.5, ) -> None: """Assert that a parametrized circuit's unitary varies continuously. Detects discontinuities (e.g. branch cut issues, phase jumps) by checking that consecutive unitary matrices are close in Frobenius norm. Args: circuit_fn: Callable(param_name=value) -> circuit. param_name: Name of the parameter to sweep. param_range: (start, end) range for the parameter. n_samples: Number of evenly-spaced sample points (default 20). max_jump: Maximum allowed Frobenius norm change between adjacent samples (default 0.5). Raises: AssertionError: If any discontinuity exceeds max_jump. Example:: assert_parametrized_unitary_continuous( lambda theta: rx_circuit(theta), "theta", (0, 2 * np.pi) ) """ from pytest_quantum.converters.to_unitary import to_unitary values = np.linspace(param_range[0], param_range[1], n_samples) unitaries = [] for v in values: circuit = circuit_fn(**{param_name: float(v)}) unitaries.append(to_unitary(circuit)) jumps: list[str] = [] for i in range(1, len(unitaries)): diff = float(np.linalg.norm(unitaries[i] - unitaries[i - 1], "fro")) if diff > max_jump: jumps.append( f" {param_name}={values[i - 1]:.4f}->{values[i]:.4f}: ||DeltaU||_F={diff:.4f}" ) if jumps: raise AssertionError( f"Unitary not continuous at {len(jumps)} transition(s) " f"(max_jump={max_jump}):\n" + "\n".join(jumps) )
def _fmt_params(kwargs: dict[str, float]) -> str: """Format parameter dict for error messages.""" return "{" + ", ".join(f"{k}={v:.4g}" for k, v in kwargs.items()) + "}" def _count_combos( param_values: dict[str, list[float] | NDArray[np.float64]], ) -> int: """Count total number of parameter combinations.""" result = 1 for v in param_values.values(): result *= len(v) return result