Source code for pytest_quantum.assertions.distributions

"""Shot-distribution assertions for quantum tests.

These assertions test the *statistical* output of a circuit — the probability
distribution over measurement outcomes — using principled chi-square tests
rather than ad-hoc tolerances.
"""

from __future__ import annotations

import warnings

import numpy as np

from pytest_quantum.stats.tests import chi_square_test, tvd_from_counts


[docs] def assert_measurement_distribution( counts: dict[str, int], expected_probs: dict[str, float], *, significance: float = 0.05, min_expected_per_bucket: int = 5, ) -> None: """Assert that measured counts match the expected probability distribution. Uses a chi-square goodness-of-fit test — the standard statistical tool for this exact problem. The test fails only when the deviation is statistically significant (p < *significance*), so occasional random fluctuations do not cause false failures. Args: counts: Measured counts dict, e.g. ``{"00": 489, "11": 511}``. Keys are bitstring labels; values are integer counts. expected_probs: Expected probability dict, e.g. ``{"00": 0.5, "11": 0.5}``. Must sum to 1.0 (within 1e-6). Outcomes not present are assumed to have zero expected probability. significance: P-value threshold below which the test fails (default ``0.05``). min_expected_per_bucket: Chi-square requires expected count >= 5 per non-zero cell for valid results. A ``UserWarning`` is raised (but the test does not fail) if this is violated; consider increasing shots. Raises: AssertionError: If ``p_value < significance``, with a per-state breakdown of observed vs expected probabilities. ValueError: If *expected_probs* does not sum to 1.0, or *counts* is empty. Example:: def test_bell_distribution(aer_simulator): from qiskit import QuantumCircuit, transpile from pytest_quantum import assert_measurement_distribution qc = QuantumCircuit(2) qc.h(0) qc.cx(0, 1) qc.measure_all() qc_t = transpile(qc, aer_simulator) counts = aer_simulator.run(qc_t, shots=2000).result().get_counts() assert_measurement_distribution( counts, expected_probs={"00": 0.5, "11": 0.5}, ) """ if not expected_probs: raise ValueError("expected_probs cannot be empty.") prob_sum = sum(expected_probs.values()) if not np.isclose(prob_sum, 1.0, atol=1e-6): raise ValueError( f"expected_probs must sum to 1.0, got {prob_sum:.8f}. " "Normalise your expected probabilities." ) total_shots = sum(counts.values()) if total_shots == 0: raise AssertionError("counts dict is empty — no shots were recorded.") # Warn if any non-zero bucket has too few expected counts for chi-square low_buckets = [ k for k, p in expected_probs.items() if 0 < p * total_shots < min_expected_per_bucket ] if low_buckets: warnings.warn( f"Some buckets have expected count < {min_expected_per_bucket}: " f"{low_buckets}. " f"Chi-square may be unreliable — consider increasing shots or " f"merging low-probability outcomes.", UserWarning, stacklevel=2, ) stat, pvalue = chi_square_test(counts, expected_probs) if pvalue < significance: all_keys = sorted(set(counts) | set(expected_probs)) rows = [] for k in all_keys: obs_p = counts.get(k, 0) / total_shots exp_p = expected_probs.get(k, 0.0) diff = obs_p - exp_p rows.append( f" {k:>12s}: observed {obs_p:.4f} expected {exp_p:.4f} diff {diff:+.4f}" ) table = "\n".join(rows) raise AssertionError( f"Measurement distribution mismatch (chi-square test failed).\n" f" χ² statistic : {stat:.4f}\n" f" p-value : {pvalue:.6f} (threshold: {significance})\n" f" total shots : {total_shots}\n" f"\n" f" Per-state breakdown:\n{table}\n" f"\n" f" Hint: if this test is inherently probabilistic and you see " f"occasional failures, use @pytest.mark.quantum_slow and increase " f"shot count with min_shots(epsilon=0.02) to reduce flakiness." )
[docs] def assert_counts_close( counts_a: dict[str, int], counts_b: dict[str, int], *, max_tvd: float = 0.05, ) -> None: """Assert that two count dictionaries are statistically close. Computes the Total Variation Distance (TVD) between the normalised distributions and fails if it exceeds *max_tvd*. Useful for comparing two backends, or checking that transpilation has not changed a circuit's output distribution. Args: counts_a: First counts dict. counts_b: Second counts dict. max_tvd: Maximum acceptable TVD (default ``0.05``). TVD of 0 means identical distributions; 1 means disjoint support. Raises: AssertionError: If TVD exceeds *max_tvd*. Example:: def test_transpile_preserves_distribution(aer_simulator): from qiskit import QuantumCircuit, transpile qc = QuantumCircuit(2) qc.h(0) qc.cx(0, 1) qc.measure_all() # ideal vs noise-free transpiled qc_t = transpile(qc, aer_simulator, optimization_level=3) counts_ideal = aer_simulator.run(qc, shots=2000).result().get_counts() counts_transpiled = ( aer_simulator.run(qc_t, shots=2000).result().get_counts() ) assert_counts_close(counts_ideal, counts_transpiled, max_tvd=0.05) """ distance = tvd_from_counts(counts_a, counts_b) if distance > max_tvd: raise AssertionError( f"Count distributions differ beyond allowed TVD.\n" f" Total Variation Distance : {distance:.4f}\n" f" Maximum allowed TVD : {max_tvd}\n" f" Excess : {distance - max_tvd:.4f}" )