Source code for uniqc.qem.readout_em

"""Unified readout error mitigation.

Provides a single interface for applying readout EM to measurement counts,
automatically dispatching to 1-qubit or 2-qubit calibration as needed.
Internally calls the ``ReadoutCalibrator`` from ``uniqc.calibration.readout``.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np

if TYPE_CHECKING:
    from uniqc.backend_adapter.task.adapters.base import QuantumAdapter

__all__ = ["ReadoutEM"]


def _outcome_to_int(outcome: Any) -> int:
    """Normalize a count/prob dict key to an int outcome.

    Accepts ``int``, ``str`` of ``0``/``1`` characters (interpreted in
    binary), or anything else convertible by ``int()``.
    """
    if isinstance(outcome, int):
        return outcome
    if isinstance(outcome, str):
        s = outcome.strip()
        if s and all(c in "01" for c in s):
            return int(s, 2)
        return int(s)
    return int(outcome)


[docs] class ReadoutEM: """Unified readout error mitigator. This is the primary interface for applying readout EM to measurement results. It wraps a ``ReadoutCalibrator`` and provides mitigation for arbitrary measurement counts. The mitigator automatically selects: - **1-qubit calibration** for single-qubit measurement results - **2-qubit calibration** for two-qubit joint measurement results - **Per-qubit 1-qubit calibration** (sequential approximation) for >2 qubits Args: adapter: A ``QuantumAdapter`` instance for running calibration circuits. max_age_hours: Maximum acceptable age of cached calibration data in hours. cache_dir: Directory for calibration cache. Defaults to ``~/.uniqc/calibration_cache/``. shots: Number of shots per calibration circuit. """ def __init__( self, adapter: QuantumAdapter, max_age_hours: float = 24.0, cache_dir: str | None = None, shots: int = 1000, ) -> None: from uniqc.calibration.readout import ReadoutCalibrator self.adapter = adapter self.max_age_hours = max_age_hours self.cache_dir = cache_dir self.shots = shots self._calibrator = ReadoutCalibrator(adapter=adapter, shots=shots, cache_dir=cache_dir) # Cache of loaded M3Mitigator instances: (qubit_ident) → M3Mitigator self._mitigators: dict[str, Any] = {}
[docs] def apply(self, result: Any, measured_qubits: list[int] | None = None) -> Any: """Apply readout mitigation to a :class:`UnifiedResult`. Pipeline-style API: returns a new :class:`UnifiedResult` ready to be consumed by the rest of uniqc. ``measured_qubits`` defaults to ``list(range(width))`` inferred from the bitstring length. """ from uniqc.backend_adapter.task.result_types import UnifiedResult if not isinstance(result, UnifiedResult): raise TypeError(f"ReadoutEM.apply expects a UnifiedResult; got {type(result).__name__}.") if not result.counts: return result width = max(len(b) for b in result.counts.keys()) if measured_qubits is None: measured_qubits = list(range(width)) counts_int = {int(b, 2): int(c) for b, c in result.counts.items()} mitigated = self.mitigate_counts(counts_int, measured_qubits) new_counts: dict[str, int] = {} for outcome, value in mitigated.items(): new_counts[format(outcome, f"0{width}b")] = int(round(value)) total = sum(new_counts.values()) or 1 new_probs = {k: v / total for k, v in new_counts.items()} return UnifiedResult( counts=new_counts, probabilities=new_probs, shots=result.shots, platform=result.platform, task_id=result.task_id, backend_name=result.backend_name, execution_time=result.execution_time, raw_result=result, error_message=result.error_message, )
[docs] def mitigate_counts( self, counts: dict[int, int], measured_qubits: list[int], ) -> dict[int, float]: """Apply readout EM to measurement counts. Automatically dispatches to the appropriate calibration: - 1 qubit → 1q calibrator - 2 qubits → 2q calibrator - N>2 qubits → sequential per-qubit 1q mitigation Args: counts: Dict mapping outcome (int) → observed count. measured_qubits: List of qubit indices that were measured. The order matters for the bitstring encoding. Returns: Dict mapping outcome → corrected count (float, total preserved). """ n = len(measured_qubits) if n == 1: return self._mitigate_1q(counts, measured_qubits[0]) elif n == 2: return self._mitigate_2q(counts, measured_qubits[0], measured_qubits[1]) else: return self._mitigate_nq(counts, measured_qubits)
[docs] def mitigate_probabilities( self, probs: dict[int, float] | dict[str, float], measured_qubits: list[int], ) -> dict[int, float]: """Apply readout EM to a probability dictionary. Args: probs: Dict mapping outcome → probability. measured_qubits: List of measured qubit indices. Returns: Dict mapping outcome (int) → corrected probability. """ # Normalize string keys to int (bitstrings interpreted in binary) if probs and isinstance(next(iter(probs)), str): probs = {_outcome_to_int(k): v for k, v in probs.items()} n = len(measured_qubits) if n == 1: return self._mitigate_probs_1q(probs, measured_qubits[0]) elif n == 2: return self._mitigate_probs_2q(probs, measured_qubits[0], measured_qubits[1]) else: return self._mitigate_probs_nq(probs, measured_qubits)
# ------------------------------------------------------------------------- # 1q mitigation # ------------------------------------------------------------------------- def _mitigate_1q(self, counts: dict[int, int], qubit: int) -> dict[int, float]: """Apply 1-qubit readout EM to counts.""" mit = self._get_mitigator_1q(qubit) return mit.mitigate_counts(counts) def _mitigate_probs_1q(self, probs: dict[int, float], qubit: int) -> dict[int, float]: """Apply 1-qubit readout EM to probabilities.""" mit = self._get_mitigator_1q(qubit) return mit.mitigate_probabilities(probs) def _get_mitigator_1q(self, qubit: int): """Get or create a cached M3Mitigator for a single qubit.""" key = f"1q_{qubit}" if key not in self._mitigators: from uniqc.qem.m3 import M3Mitigator self._mitigators[key] = M3Mitigator( max_age_hours=self.max_age_hours, backend=getattr(self.adapter, "name", "unknown"), qubit=qubit, cache_dir=self.cache_dir, ) return self._mitigators[key] # ------------------------------------------------------------------------- # 2q mitigation # ------------------------------------------------------------------------- def _mitigate_2q(self, counts: dict[int, int], q0: int, q1: int) -> dict[int, float]: """Apply 2-qubit joint readout EM to counts.""" mit = self._get_mitigator_2q(q0, q1) return mit.mitigate_counts(counts) def _mitigate_probs_2q(self, probs: dict[int, float], q0: int, q1: int) -> dict[int, float]: """Apply 2-qubit joint readout EM to probabilities.""" mit = self._get_mitigator_2q(q0, q1) return mit.mitigate_probabilities(probs) def _get_mitigator_2q(self, q0: int, q1: int): """Get or create a cached M3Mitigator for a qubit pair.""" key = f"2q_{q0}_{q1}" if key not in self._mitigators: from uniqc.qem.m3 import M3Mitigator self._mitigators[key] = M3Mitigator( max_age_hours=self.max_age_hours, backend=getattr(self.adapter, "name", "unknown"), qubit=(q0, q1), cache_dir=self.cache_dir, ) return self._mitigators[key] # ------------------------------------------------------------------------- # Nq mitigation (per-qubit sequential approximation) # ------------------------------------------------------------------------- def _mitigate_nq(self, counts: dict[int, int], qubits: list[int]) -> dict[int, float]: """Apply per-qubit readout EM sequentially for n>2 qubits. This is an approximation: each qubit is corrected independently using its 1-qubit confusion matrix, applied in order. """ result = {_outcome_to_int(k): float(v) for k, v in counts.items()} for bit_position, q in enumerate(qubits): result = self._apply_1q_matrix(result, q, bit_position, len(qubits)) return result def _mitigate_probs_nq(self, probs: dict[int, float], qubits: list[int]) -> dict[int, float]: """Apply per-qubit readout EM sequentially to probabilities.""" result = {_outcome_to_int(k): float(v) for k, v in probs.items()} for bit_position, q in enumerate(qubits): result = self._apply_1q_matrix_probs(result, q, bit_position, len(qubits)) return result def _apply_1q_matrix( self, counts: dict[int, float], qubit: int, bit_position: int | None = None, n_total: int | None = None, ) -> dict[int, float]: """Apply 1-qubit confusion matrix to an N-qubit counts vector. This marginalizes over all other qubits and applies the 1q correction. """ mit = self._get_mitigator_1q(qubit) # Get the 1q confusion matrix (works with both dict and dataclass) cal = mit.calibration_result if hasattr(cal, "confusion_matrix"): cm = cal.confusion_matrix else: cm = cal["confusion_matrix"] C = np.array(cm) # 2x2: [p(meas|prep)] if n_total is None: max_outcome = max(counts.keys()) if counts else 0 n_total = max(1, int(np.ceil(np.log2(max_outcome + 1)))) if bit_position is None: bit_position = qubit # Proper implementation: apply 1q matrix via tensor product with identity return self._tensor_apply(counts, C, n_total, bit_position) def _tensor_apply( self, counts: dict[int, float], C: np.ndarray, n_total: int, target_qubit: int ) -> dict[int, float]: """Apply a 2x2 1q confusion matrix C to the target qubit in an n-qubit system. Uses the tensor product structure: full_matrix = I⊗...⊗C⊗...⊗I. """ n = 2**n_total try: C_inv = np.linalg.inv(C) except np.linalg.LinAlgError: C_inv = np.eye(2) mats = [np.eye(2) if i != target_qubit else C_inv for i in range(n_total - 1, -1, -1)] full_C = mats[0] for m in mats[1:]: full_C = np.kron(full_C, m) n_obs = np.zeros(n) for outcome, cnt in counts.items(): n_obs[_outcome_to_int(outcome)] = float(cnt) n_corr = full_C @ n_obs n_corr = np.clip(n_corr, 0, None) total = n_obs.sum() if total > 0 and n_corr.sum() > 0: n_corr *= total / n_corr.sum() return {int(i): float(v) for i, v in enumerate(n_corr)} def _apply_1q_matrix_probs( self, probs: dict[int, float], qubit: int, bit_position: int | None = None, n_total: int | None = None, ) -> dict[int, float]: """Apply 1-qubit confusion matrix to an N-qubit probability vector.""" counts = {k: float(v) for k, v in probs.items()} corrected = self._apply_1q_matrix(counts, qubit, bit_position, n_total) total = sum(corrected.values()) if total > 0: return {k: v / total for k, v in corrected.items()} return corrected