Source code for pytest_quantum.assertions.primitives

"""Assertions for Qiskit Primitives (StatevectorSampler / StatevectorEstimator)."""

from __future__ import annotations

from typing import Any

import numpy as np


[docs] def assert_sampler_distribution( result: Any, expected_probs: dict[str, float], *, pub_idx: int = 0, significance: float = 0.05, ) -> None: """Assert a Qiskit Sampler result matches expected probability distribution. Uses chi-square goodness-of-fit (same as assert_measurement_distribution). Args: result: PrimitiveResult from StatevectorSampler.run(). expected_probs: Expected probability distribution e.g. {"00": 0.5, "11": 0.5}. pub_idx: Which pub result to check (default 0). significance: p-value threshold (default 0.05). Raises: AssertionError: If distribution doesn't match. Example:: def test_sampler_bell(qiskit_sampler): from qiskit.circuit import QuantumCircuit from pytest_quantum import assert_sampler_distribution qc = QuantumCircuit(2, 2) qc.h(0) qc.cx(0, 1) qc.measure([0, 1], [0, 1]) result = qiskit_sampler.run([(qc,)]).result() assert_sampler_distribution(result, {"00": 0.5, "11": 0.5}) """ from pytest_quantum.assertions.distributions import assert_measurement_distribution counts = _extract_sampler_counts(result, pub_idx) assert_measurement_distribution(counts, expected_probs, significance=significance)
[docs] def assert_estimator_close( result: Any, expected: float, *, atol: float = 0.1, pub_idx: int = 0, ) -> None: """Assert a Qiskit Estimator result is close to the expected value. Args: result: PrimitiveResult from StatevectorEstimator.run(). expected: Expected expectation value. atol: Absolute tolerance (default 0.1). pub_idx: Which pub result to check (default 0). Raises: AssertionError: If |actual - expected| > atol. Example:: def test_estimator_z(qiskit_estimator): from qiskit.circuit import QuantumCircuit from qiskit.quantum_info import SparsePauliOp from pytest_quantum import assert_estimator_close qc = QuantumCircuit(1) # |0>, <Z> = 1.0 obs = SparsePauliOp("Z") result = qiskit_estimator.run([(qc, obs)]).result() assert_estimator_close(result, expected=1.0, atol=0.01) """ from pytest_quantum.assertions.observables import assert_expectation_value_close actual = _extract_estimator_value(result, pub_idx) assert_expectation_value_close(actual, expected, atol=atol)
# --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _extract_sampler_counts(result: Any, pub_idx: int) -> dict[str, int]: try: pub_result = result[pub_idx] data = pub_result.data creg_name = next(k for k in vars(data) if not k.startswith("_")) bit_array = getattr(data, creg_name) return dict(bit_array.get_counts()) except Exception: pass if hasattr(result, "quasi_dists"): qd = result.quasi_dists[pub_idx] meta = result.metadata[pub_idx] if hasattr(result, "metadata") else {} shots = meta.get("shots", 1024) n = meta.get( "num_qubits", max(len(format(k, "b")) for k in qd if k >= 0) if qd else 1, ) return { format(k, f"0{n}b"): int(v * shots) for k, v in qd.items() if v > 0 and k >= 0 } raise TypeError( f"Cannot extract counts from {type(result).__qualname__!r}. " "Expected PrimitiveResult from StatevectorSampler." ) def _extract_estimator_value(result: Any, pub_idx: int) -> float: try: pub_result = result[pub_idx] return float(np.asarray(pub_result.data.evs).flat[0]) except Exception: pass if hasattr(result, "values"): return float(np.asarray(result.values)[pub_idx]) raise TypeError( f"Cannot extract expectation value from {type(result).__qualname__!r}. " "Expected PrimitiveResult from StatevectorEstimator." )