"""Circuit structure assertions.
These assertions check static properties of a circuit — depth, gate counts,
qubit width — without executing it. Useful for catching regressions in
compiler output or ensuring a circuit meets hardware constraints.
"""
from __future__ import annotations
from typing import Any, cast
# ---------------------------------------------------------------------------
# Clifford gate sets per framework
# ---------------------------------------------------------------------------
_CLIFFORD_BRAKET = frozenset(
{"H", "X", "Y", "Z", "S", "Si", "CNot", "CZ", "Swap", "CY", "I", "V", "Vi"}
)
_CLIFFORD_PENNYLANE = frozenset(
{
"PauliX",
"PauliY",
"PauliZ",
"Hadamard",
"S",
"SX",
"CNOT",
"CY",
"CZ",
"SWAP",
"ISWAP",
"Identity",
"Adjoint(S)",
"Adjoint(SX)",
# Aliases
"X",
"Y",
"Z",
"H",
}
)
# Clifford gate sets (case-normalised)
_CLIFFORD_QISKIT = frozenset(
{
"h",
"s",
"sdg",
"x",
"y",
"z",
"cx",
"cy",
"cz",
"swap",
"id",
"sx",
"sxdg",
"measure",
"barrier",
"reset",
}
)
_CLIFFORD_CIRQ = frozenset(
{"h", "x", "y", "z", "s", "cnot", "cz", "swap", "i", "measure"}
)
[docs]
def assert_circuit_depth(
circuit: object,
*,
max_depth: int | None = None,
min_depth: int | None = None,
) -> None:
"""Assert that a circuit's depth is within the specified bounds.
At least one of *max_depth* or *min_depth* must be provided.
Supported frameworks: Qiskit, Cirq, Amazon Braket.
Args:
circuit: A quantum circuit from a supported framework.
max_depth: If given, the circuit depth must be ≤ this value.
min_depth: If given, the circuit depth must be ≥ this value.
Raises:
AssertionError: If the depth is outside the specified bounds.
TypeError: If the circuit type is not supported.
ValueError: If neither bound is provided.
Example::
def test_circuit_depth():
from qiskit import QuantumCircuit
qc = QuantumCircuit(2)
qc.h(0)
qc.cx(0, 1)
assert_circuit_depth(qc, max_depth=3)
"""
if max_depth is None and min_depth is None:
raise ValueError("Provide at least one of max_depth or min_depth.")
depth = _get_depth(circuit)
if max_depth is not None and depth > max_depth:
raise AssertionError(f"Circuit depth {depth} exceeds max_depth {max_depth}.")
if min_depth is not None and depth < min_depth:
raise AssertionError(f"Circuit depth {depth} is below min_depth {min_depth}.")
[docs]
def assert_circuit_width(
circuit: object,
expected_qubits: int,
) -> None:
"""Assert that a circuit acts on exactly *expected_qubits* qubits.
Supported frameworks: Qiskit, Cirq, Amazon Braket, PennyLane.
Args:
circuit: A quantum circuit from a supported framework.
expected_qubits: Expected number of qubits.
Raises:
AssertionError: If the qubit count does not match.
TypeError: If the circuit type is not supported.
Example::
def test_circuit_width():
from qiskit import QuantumCircuit
qc = QuantumCircuit(3)
qc.h(0)
qc.cx(0, 1)
qc.cx(1, 2)
assert_circuit_width(qc, expected_qubits=3)
"""
actual = _get_width(circuit)
if actual != expected_qubits:
raise AssertionError(
f"Circuit qubit count mismatch.\n"
f" Expected : {expected_qubits}\n"
f" Actual : {actual}"
)
[docs]
def assert_gate_count(
circuit: object,
gate_name: str,
expected: int,
) -> None:
"""Assert that a circuit contains exactly *expected* occurrences of *gate_name*.
Supported frameworks: Qiskit, Cirq, PennyLane.
Args:
circuit: A quantum circuit from a supported framework.
gate_name: Gate name as a string, e.g. ``"cx"``, ``"h"``, ``"t"``,
``"CNOT"``, ``"Hadamard"``. Case-insensitive for Qiskit;
Cirq and PennyLane match case-insensitively by gate class name.
expected: Expected count.
Raises:
AssertionError: If the actual count differs from *expected*.
NotImplementedError: If the framework is not yet supported.
Example::
def test_t_count():
from qiskit import QuantumCircuit
qc = QuantumCircuit(2)
qc.t(0)
qc.t(1)
qc.cx(0, 1)
assert_gate_count(qc, "t", 2)
assert_gate_count(qc, "cx", 1)
"""
module = type(circuit).__module__
c: Any = circuit
if module.startswith("qiskit"):
ops = c.count_ops()
actual = ops.get(gate_name.lower(), 0)
elif module.startswith("cirq"):
# Count operations by matching str(op.gate) which gives human-readable names
# e.g. cirq.H -> "H", cirq.CNOT -> "CNOT", cirq.CZ -> "CZ"
name_lower = gate_name.lower()
actual = sum(
1
for moment in c
for op in moment.operations
if str(op.gate).lower() == name_lower
)
elif module.startswith("braket"):
# Braket: iterate circuit.instructions, match operator.name case-insensitively
name_lower = gate_name.lower()
actual = sum(
1 for instr in c.instructions if instr.operator.name.lower() == name_lower
)
elif module.startswith("pennylane") or hasattr(circuit, "device"):
# QNode: try to get the tape; if not available, do a dry run first
name_lower = gate_name.lower()
tape = None
try:
tape = c.tape
except AttributeError:
pass
if tape is None:
# Execute the circuit with a dry run to populate the tape
try:
c()
tape = c.tape
except Exception:
tape = None
if tape is None:
raise TypeError(
"PennyLane QNode tape could not be obtained. "
"Ensure the QNode is properly constructed."
)
actual = sum(
1 for op in tape.operations if type(op).__name__.lower() == name_lower
)
elif module.startswith("pytket"):
name_lower = gate_name.lower()
# Try to match by OpType name (case-insensitive)
actual = sum(
1 for cmd in c.get_commands() if cmd.op.type.name.lower() == name_lower
)
else:
raise NotImplementedError(
f"assert_gate_count supports Qiskit, Cirq, Braket, PennyLane, "
f"and Pytket. Got circuit type: {type(circuit).__qualname__!r}."
)
if actual != expected:
raise AssertionError(
f"Gate count mismatch for {gate_name!r}.\n"
f" Expected : {expected}\n"
f" Actual : {actual}"
)
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _get_depth(circuit: object) -> int:
"""Extract depth from any supported circuit type."""
module = type(circuit).__module__
c: Any = circuit
if module.startswith("qiskit"):
return int(c.depth())
if module.startswith("cirq"):
# cirq.Circuit depth = number of non-empty moments
return len(c)
if module.startswith("braket"):
return int(c.depth)
if module.startswith("pennylane") or hasattr(circuit, "device"):
try:
import pennylane as qml
specs = qml.specs(c)()
# Try resources.depth first (newer PennyLane), then fall back to "depth"
if hasattr(specs, "get"):
resources = specs.get("resources", None)
if resources is not None and hasattr(resources, "depth"):
return int(resources.depth)
depth_val = specs.get("depth", None)
if depth_val is not None:
return int(depth_val)
raise TypeError(
"Could not extract depth from qml.specs() output. "
"Upgrade PennyLane to a version that exposes 'resources' or 'depth'."
)
except ImportError as exc:
raise TypeError(
"pennylane is required for PennyLane circuit depth. "
"Install it with: pip install pytest-quantum[pennylane]"
) from exc
if module.startswith("pytket"):
return int(c.depth())
raise TypeError(
f"assert_circuit_depth does not support circuit type "
f"{type(circuit).__qualname__!r}.\n"
"Supported frameworks: Qiskit, Cirq, Amazon Braket, PennyLane."
)
def _get_width(circuit: object) -> int:
"""Extract qubit count from any supported circuit type."""
module = type(circuit).__module__
c: Any = circuit
if module.startswith("qiskit"):
return int(c.num_qubits)
if module.startswith("cirq"):
return len(c.all_qubits())
if module.startswith("braket"):
return int(c.qubit_count)
if module.startswith("pennylane") or hasattr(circuit, "device"):
return len(c.device.wires)
if module.startswith("pytket"):
return int(c.n_qubits)
raise TypeError(
f"assert_circuit_width does not support circuit type "
f"{type(circuit).__qualname__!r}.\n"
"Supported frameworks: Qiskit, Cirq, Amazon Braket, PennyLane."
)
[docs]
def assert_gates_in_basis_set(
circuit: object,
basis_gates: set[str],
*,
case_sensitive: bool = False,
) -> None:
"""Assert every gate in the circuit belongs to the specified basis gate set.
Useful for verifying that a transpiled circuit only uses a target backend's
native gate set (e.g. after ``qiskit.transpile`` with ``basis_gates=[...]``).
Args:
circuit: Qiskit, Cirq, Braket, or Pytket circuit.
basis_gates: Set of allowed gate names.
case_sensitive: If False (default), comparison is case-insensitive.
Raises:
AssertionError: Lists every non-basis gate found.
NotImplementedError: For unsupported frameworks.
Example::
from qiskit import QuantumCircuit, transpile
from qiskit_aer import AerSimulator
from pytest_quantum import assert_gates_in_basis_set
qc = QuantumCircuit(2)
qc.h(0)
qc.cx(0, 1)
transpiled = transpile(qc, basis_gates=["cx", "u3"])
assert_gates_in_basis_set(transpiled, {"cx", "u3"})
"""
module = type(circuit).__module__
c = cast("Any", circuit)
basis = {g.lower() for g in basis_gates} if not case_sensitive else set(basis_gates)
def _normalise(name: str) -> str:
return name if case_sensitive else name.lower()
non_basis: list[str] = []
if module.startswith("qiskit"):
for instr in c.data:
gate_name = instr.operation.name
if _normalise(gate_name) not in basis:
non_basis.append(gate_name)
elif module.startswith("cirq"):
for moment in c:
for op in moment.operations:
gate_name = str(op.gate)
if _normalise(gate_name) not in basis:
non_basis.append(gate_name)
elif module.startswith("braket"):
for instr in c.instructions:
gate_name = type(instr.operator).__name__
if _normalise(gate_name) not in basis:
non_basis.append(gate_name)
elif module.startswith("pytket"):
for cmd in c.get_commands():
gate_name = cmd.op.type.name
if _normalise(gate_name) not in basis:
non_basis.append(gate_name)
else:
raise NotImplementedError(
f"assert_gates_in_basis_set supports Qiskit, Cirq, Braket, Pytket; "
f"got {module!r}"
)
if non_basis:
unique = sorted(set(non_basis))
raise AssertionError(
f"Circuit contains {len(non_basis)} gate(s) not in basis set.\n"
f" Non-basis gates found : {unique}\n"
f" Allowed basis : {sorted(basis_gates)}\n"
f" Hint: use transpile(circuit, basis_gates=[...]) first."
)
[docs]
def assert_circuit_is_clifford(circuit: object) -> None:
"""Assert a circuit uses only Clifford gates (H, S, S†, X, Y, Z, CNOT, CZ, SWAP).
Clifford circuits are classically efficiently simulable.
Supported: Qiskit, Cirq.
Raises:
AssertionError: If non-Clifford gates found.
NotImplementedError: If framework not supported.
Example::
def test_is_clifford():
from qiskit import QuantumCircuit
from pytest_quantum import assert_circuit_is_clifford
qc = QuantumCircuit(2)
qc.h(0)
qc.cx(0, 1)
assert_circuit_is_clifford(qc)
"""
module = type(circuit).__module__
c: Any = circuit
if module.startswith("qiskit"):
ops = c.count_ops()
non_clifford = sorted(g for g in ops if g not in _CLIFFORD_QISKIT)
if non_clifford:
raise AssertionError(
f"Circuit contains non-Clifford gates: {non_clifford}\n"
f" Clifford set: "
f"{sorted(g for g in _CLIFFORD_QISKIT if g not in ('measure', 'barrier', 'reset'))}"
)
return
if module.startswith("cirq"):
non_clifford_cirq: set[str] = set()
for moment in c:
for op in moment.operations:
name = str(op.gate).lower()
if name not in _CLIFFORD_CIRQ:
non_clifford_cirq.add(str(op.gate))
if non_clifford_cirq:
raise AssertionError(
f"Circuit contains non-Clifford gates: {sorted(non_clifford_cirq)}"
)
return
if module.startswith("braket"):
non_clifford = [
type(instr.operator).__name__
for instr in c.instructions
if type(instr.operator).__name__ not in _CLIFFORD_BRAKET
]
if non_clifford:
raise AssertionError(
f"Circuit contains non-Clifford gates: {sorted(set(non_clifford))}. "
f"Clifford set: {sorted(_CLIFFORD_BRAKET)}"
)
return
if module.startswith("pennylane") or hasattr(circuit, "device"):
tape = None
try:
tape = c.tape
except AttributeError:
pass
if tape is None:
try:
c()
tape = c.tape
except Exception:
pass
if tape is None:
raise TypeError("Cannot check Clifford: QNode tape could not be obtained.")
non_clifford = [
op.name for op in tape.operations if op.name not in _CLIFFORD_PENNYLANE
]
if non_clifford:
raise AssertionError(
f"Circuit contains non-Clifford operations: "
f"{sorted(set(non_clifford))}. "
f"Clifford set: {sorted(_CLIFFORD_PENNYLANE)}"
)
return
if module.startswith("pytket"):
try:
from pytket.tableau import UnitaryTableau
UnitaryTableau(c) # raises if circuit contains non-Clifford gates
except ImportError as exc:
raise ImportError("pytket is required: pip install pytket") from exc
except Exception as exc:
raise AssertionError(f"Circuit contains non-Clifford gates: {exc}") from exc
return
raise NotImplementedError(
f"assert_circuit_is_clifford supports Qiskit and Cirq (and also "
f"Braket, PennyLane, Pytket). Got: {type(circuit).__qualname__!r}"
)
[docs]
def assert_no_mid_circuit_measurement(circuit: object) -> None:
"""Assert a circuit has no mid-circuit measurements (all measurements are terminal).
Mid-circuit measurements (measurements followed by further gate operations)
are not supported on all hardware backends. This assertion verifies that
all measurements occur after all gate operations — i.e., measurements only
appear in the final layer.
Supported frameworks: Qiskit, Cirq.
Args:
circuit: A quantum circuit from a supported framework.
Raises:
AssertionError: If mid-circuit measurements are detected.
NotImplementedError: If framework is not supported.
Example::
from qiskit import QuantumCircuit
from pytest_quantum import assert_no_mid_circuit_measurement
qc = QuantumCircuit(2, 2)
qc.h(0)
qc.cx(0, 1)
qc.measure_all()
assert_no_mid_circuit_measurement(qc) # passes — measurements are terminal
"""
module = type(circuit).__module__
c: Any = circuit
if module.startswith("qiskit"):
# Build a per-qubit timeline: track the last instruction index for each qubit
# If any gate comes after a measurement on the same qubit, it's mid-circuit
from qiskit.circuit import Measure
qubit_measure_idx: dict[int, int] = {}
violations: list[str] = []
for idx, instr in enumerate(c.data):
qubits = [c.find_bit(q).index for q in instr.qubits]
if isinstance(instr.operation, Measure):
for q in qubits:
qubit_measure_idx[q] = idx
else:
op_name = instr.operation.name
if op_name in ("barrier",):
continue
for q in qubits:
if q in qubit_measure_idx:
violations.append(
f"qubit {q}: gate '{op_name}' at position {idx} "
f"follows measurement at position {qubit_measure_idx[q]}"
)
if violations:
raise AssertionError(
f"Mid-circuit measurements detected ({len(violations)} violation(s)):\n"
+ "\n".join(f" - {v}" for v in violations)
+ "\n Hint: move all measurements to the end of the circuit."
)
return
if module.startswith("cirq"):
import cirq
# In Cirq, measurements in moments before the last non-empty moment are mid-circuit
moments = list(c.moments)
if not moments:
return
# Find all moments with measurements
meas_moment_indices: list[int] = []
gate_moment_indices: list[int] = []
for i, moment in enumerate(moments):
has_meas = any(
isinstance(op.gate, cirq.MeasurementGate) for op in moment.operations
)
has_gate = any(
not isinstance(op.gate, cirq.MeasurementGate)
for op in moment.operations
)
if has_meas:
meas_moment_indices.append(i)
if has_gate:
gate_moment_indices.append(i)
if not meas_moment_indices or not gate_moment_indices:
return
last_gate_moment = max(gate_moment_indices)
mid_meas = [i for i in meas_moment_indices if i < last_gate_moment]
if mid_meas:
raise AssertionError(
f"Mid-circuit measurements detected in Cirq circuit.\n"
f" Measurement moments: {mid_meas}\n"
f" Last gate moment : {last_gate_moment}\n"
f" Hint: reorder operations so all measurements come last."
)
return
raise NotImplementedError(
f"assert_no_mid_circuit_measurement supports Qiskit and Cirq; got {module!r}"
)
[docs]
def assert_has_diagram(circuit: object, expected: str, *, strict: bool = False) -> None:
"""Assert circuit's text representation contains expected pattern.
For Qiskit: uses ``circuit.draw('text')``.
For Cirq: uses ``str(circuit)`` (``circuit.to_text_diagram()``).
Args:
circuit: Any supported framework circuit.
expected: Expected string (exact if *strict* is ``True``, substring
otherwise).
strict: If ``True``, require exact match after stripping leading /
trailing whitespace. If ``False`` (default), just check
that *expected* is a substring of the diagram.
Raises:
AssertionError: If diagram doesn't match.
NotImplementedError: For frameworks without text diagram support.
Example::
from qiskit import QuantumCircuit
from pytest_quantum import assert_has_diagram
qc = QuantumCircuit(1)
qc.h(0)
assert_has_diagram(qc, "H")
"""
module = type(circuit).__module__
c: Any = circuit
if module.startswith("qiskit"):
diagram = str(c.draw("text"))
elif module.startswith("cirq"):
diagram = str(c)
elif module.startswith("pytket"):
try:
diagram = str(c)
except Exception as exc:
raise NotImplementedError("pytket diagram not available") from exc
else:
raise NotImplementedError(
f"assert_has_diagram supports Qiskit and Cirq; got {module!r}"
)
if strict:
if diagram.strip() != expected.strip():
raise AssertionError(
f"Circuit diagram mismatch.\nExpected:\n{expected}\nGot:\n{diagram}"
)
else:
if expected not in diagram:
raise AssertionError(
f"Expected pattern not found in circuit diagram.\n"
f"Pattern: {expected!r}\n"
f"Diagram:\n{diagram}"
)