Source code for pytest_quantum.assertions.information

"""Additional information-theoretic distribution assertions."""

from __future__ import annotations

import numpy as np


[docs] def assert_hellinger_close( counts_a: dict[str, int], counts_b: dict[str, int], *, max_distance: float = 0.1, ) -> None: """Assert Hellinger distance H(p,q) <= max_distance. H(p,q) = (1/sqrt(2)) * ||sqrt(p) - sqrt(q)||_2 Range: [0, 1]. H=0 identical, H=1 disjoint support. More symmetric and bounded than KL divergence. Args: counts_a: First count dictionary. counts_b: Second count dictionary. max_distance: Maximum allowed Hellinger distance (default 0.1). Raises: AssertionError: If H(p,q) > max_distance, with a per-key table. ValueError: If both count dictionaries are empty. """ all_keys = sorted(set(counts_a) | set(counts_b)) if not all_keys: raise ValueError("Both count dictionaries are empty.") total_a = sum(counts_a.values()) or 1 total_b = sum(counts_b.values()) or 1 p = np.array([counts_a.get(k, 0) / total_a for k in all_keys], dtype=np.float64) q = np.array([counts_b.get(k, 0) / total_b for k in all_keys], dtype=np.float64) hellinger = float(np.sqrt(0.5 * np.sum((np.sqrt(p) - np.sqrt(q)) ** 2))) if hellinger > max_distance: # Build per-key table for the error message rows = [ f" {'Outcome':<12} {'p (left)':>10} {'q (right)':>10} {'√p':>10} {'√q':>10}", f" {'─' * 56}", ] for k, pi, qi in zip(all_keys, p, q, strict=True): rows.append( f" {k:<12} {pi:>10.4f} {qi:>10.4f} " f"{float(np.sqrt(pi)):>10.4f} {float(np.sqrt(qi)):>10.4f}" ) table = "\n".join(rows) raise AssertionError( f"Hellinger distance H = {hellinger:.4f} exceeds max_distance = {max_distance:.4f}.\n" f" H(p, q) = {hellinger:.4f} (max allowed: {max_distance:.4f})\n" f"\n" f"{table}\n" f"\n" f" Hint: H=0.0 means identical distributions, H=1.0 means completely different." )
[docs] def assert_kl_divergence_below( counts: dict[str, int], expected_probs: dict[str, float], *, max_kl: float = 0.1, ) -> None: """Assert KL divergence D_KL(observed || expected) <= max_kl. D_KL(P||Q) = sum_x P(x) * log2(P(x) / Q(x)) Note: KL is asymmetric and infinite if Q(x)=0 but P(x)>0. Raises ValueError if expected_probs has zero probability for any observed outcome. Args: counts: Observed count dictionary. expected_probs: Expected probability dictionary. max_kl: Maximum allowed KL divergence in bits (default 0.1). Raises: AssertionError: If D_KL > max_kl. ValueError: If any outcome with non-zero observed count has zero expected probability (KL would be infinite). """ if not counts: return # vacuously true total = sum(counts.values()) if total == 0: return kl = 0.0 for outcome, count in counts.items(): p = count / total if p == 0.0: continue q = expected_probs.get(outcome, 0.0) if q == 0.0: raise ValueError( f"KL divergence is infinite: outcome {outcome!r} has " f"non-zero observed probability ({p:.4f}) but expected_probs " f"assigns it zero probability." ) kl += p * float(np.log2(p / q)) if kl > max_kl: # Build per-outcome contribution table rows = [ f" {'Outcome':<12} {'observed':>10} {'expected':>10} {'contribution':>14}", f" {'─' * 50}", ] for outcome, count in sorted(counts.items()): p_val = count / total if p_val == 0.0: continue q_val = expected_probs.get(outcome, 0.0) if q_val > 0.0: contrib = p_val * float(np.log2(p_val / q_val)) rows.append( f" {outcome:<12} {p_val:>10.4f} {q_val:>10.4f} {contrib:>+14.4f}" ) table = "\n".join(rows) raise AssertionError( f"KL divergence D(observed||expected) = {kl:.4f} exceeds max_kl = {max_kl:.4f}.\n" f" D_KL = {kl:.4f} (max allowed: {max_kl:.4f})\n" f"\n" f"{table}" )
[docs] def assert_cross_entropy_below( counts: dict[str, int], expected_probs: dict[str, float], *, max_ce: float = 1.0, ) -> None: """Assert cross-entropy H(P, Q) = -sum_x P(x) log2 Q(x) <= max_ce. Used in quantum supremacy experiments as XEB (cross-entropy benchmarking). Args: counts: Observed count dictionary. expected_probs: Expected probability dictionary Q. max_ce: Maximum allowed cross-entropy in bits (default 1.0). Raises: AssertionError: If H(P, Q) > max_ce. ValueError: If any observed outcome has zero expected probability. """ if not counts: return total = sum(counts.values()) if total == 0: return cross_entropy = 0.0 for outcome, count in counts.items(): p = count / total if p == 0.0: continue q = expected_probs.get(outcome, 0.0) if q <= 0.0: raise ValueError( f"Cross-entropy is undefined: outcome {outcome!r} has " f"non-zero observed probability ({p:.4f}) but expected_probs " f"assigns it zero/negative probability ({q})." ) cross_entropy += -p * float(np.log2(q)) if cross_entropy > max_ce: raise AssertionError( f"Cross-entropy H(P, Q) = {cross_entropy:.6f} bits exceeds " f"max_ce = {max_ce}.\n" f" H(P,Q) = -∑ P(x) log₂ Q(x) (lower = more similar)" )