"""M3 readout error mitigator.
Provides confusion-matrix-based readout error mitigation with TTL-based
calibration freshness enforcement.
"""
from __future__ import annotations
import pathlib
from datetime import datetime, timezone
from typing import Any
import numpy as np
from uniqc._error_hints import format_enriched_message
from uniqc.exceptions import StaleCalibrationError # noqa: F401 — re-export
__all__ = ["M3Mitigator", "StaleCalibrationError"]
def _get_field(cal: Any, name: str) -> Any:
"""Read ``name`` from ``cal`` whether it is a dataclass or a dict."""
if cal is None:
raise ValueError(format_enriched_message("calibration result is None", "calibration"))
if hasattr(cal, name):
return getattr(cal, name)
return cal[name]
[docs]
class M3Mitigator:
"""M3 (Matrix Misassignment Mitigation) readout error mitigator.
Uses a calibration confusion matrix to correct measurement outcomes via
linear inversion. The calibration data can be provided directly or loaded
from the calibration cache.
Args:
calibration_result: Pre-loaded ``ReadoutCalibrationResult``.
cache_path: Path to a cached calibration JSON file.
max_age_hours: Maximum acceptable age of calibration data in hours.
If the cached data is older, ``StaleCalibrationError`` is raised.
backend: Backend name used for cache lookup.
qubit: Qubit index or pair (for cache lookup).
Raises:
StaleCalibrationError: If calibration data exceeds ``max_age_hours``.
FileNotFoundError: If cache_path does not exist.
"""
def __init__(
self,
calibration_result: Any | None = None,
cache_path: str | pathlib.Path | None = None,
max_age_hours: float = 24.0,
backend: str = "dummy:local:simulator",
qubit: int | tuple[int, int] | None = None,
cache_dir: str | pathlib.Path | None = None,
) -> None:
if calibration_result is not None:
self._cal = calibration_result
calibrated_at = ""
try:
calibrated_at = _get_field(calibration_result, "calibrated_at")
except (KeyError, AttributeError):
calibrated_at = ""
if calibrated_at and max_age_hours is not None:
self._check_age(calibrated_at, max_age_hours)
elif cache_path is not None:
self._cal = self._load_from_path(cache_path, max_age_hours)
else:
self._cal = None
self._backend = backend
self._qubit = qubit
self._max_age_hours = max_age_hours
self._cache_dir = cache_dir
@property
def calibration_result(self) -> Any:
if self._cal is None:
self._cal = self._load_from_cache(
self._backend,
self._qubit,
self._max_age_hours,
self._cache_dir,
)
return self._cal
[docs]
def apply(self, result: Any) -> Any:
"""Apply mitigation to a :class:`UnifiedResult` and return a new one.
This is the recommended pipeline-style API: call ``M3Mitigator(...).apply(result)``
and feed the returned object straight back into any uniqc workflow that
expects a :class:`UnifiedResult`.
Args:
result: A :class:`uniqc.backend_adapter.task.result_types.UnifiedResult`
produced by ``submit_task``/``wait_for_result``/``simulate``.
Returns:
A new ``UnifiedResult`` whose ``counts``/``probabilities`` are
mitigated. ``shots``, ``platform``, ``task_id`` and metadata are
preserved; the original is kept as ``raw_result``.
"""
from uniqc.backend_adapter.task.result_types import UnifiedResult
if not isinstance(result, UnifiedResult):
raise TypeError(
format_enriched_message(
"M3Mitigator.apply expects a UnifiedResult; "
f"got {type(result).__name__}. Use mitigate_counts/mitigate_probabilities "
"for raw dict input.",
"calibration",
)
)
# Convert string bitstrings → int and apply mitigation.
counts_int: dict[int, int] = {}
for bitstring, count in result.counts.items():
counts_int[int(bitstring, 2) if isinstance(bitstring, str) else int(bitstring)] = int(count)
mitigated_int = self.mitigate_counts(counts_int)
# Round to integer counts and convert back to bitstrings of the same width.
if result.counts:
width = max(len(b) for b in result.counts.keys())
else:
width = max(1, int(np.ceil(np.log2(max(1, len(mitigated_int))))))
new_counts: dict[str, int] = {}
for outcome, value in mitigated_int.items():
key = format(outcome, f"0{width}b")
new_counts[key] = int(round(value))
total = sum(new_counts.values()) or 1
new_probs = {k: v / total for k, v in new_counts.items()}
return UnifiedResult(
counts=new_counts,
probabilities=new_probs,
shots=result.shots,
platform=result.platform,
task_id=result.task_id,
backend_name=result.backend_name,
execution_time=result.execution_time,
raw_result=result,
error_message=result.error_message,
)
[docs]
def mitigate_counts(self, counts: dict[int, int]) -> dict[int, float]:
"""Apply M3 mitigation to measurement counts.
Uses linear inversion: ``n_corrected = C⁻¹ · n_obs``.
The result is normalized so the total counts are preserved.
Args:
counts: Dict mapping outcome (int bitstring) → observed count.
Returns:
Dict mapping outcome → corrected count (float, preserved total).
"""
cal = self.calibration_result
C = np.array(_get_field(cal, "confusion_matrix"))
n = len(C)
# Build observed vector
n_obs = np.zeros(n)
for outcome, cnt in counts.items():
n_obs[int(outcome)] = float(cnt)
total = n_obs.sum()
if total == 0:
return {}
# Linear inversion: n_corrected = C^{-1} · n_obs
try:
C_inv = np.linalg.inv(C)
except np.linalg.LinAlgError:
# Singular matrix — fall back to identity
C_inv = np.eye(n)
n_corr = C_inv @ n_obs
# Clip first (push negatives to 0), then rescale so the corrected
# counts sum to the original total. Doing it in the opposite order
# would re-introduce a mismatch with ``total`` whenever clipping
# removed mass.
n_corr = np.clip(n_corr, 0, None)
s = n_corr.sum()
if s > 0:
n_corr = n_corr * (total / s)
return {int(i): float(v) for i, v in enumerate(n_corr)}
[docs]
def mitigate_probabilities(self, probs: dict[str, float] | dict[int, float]) -> dict[int, float]:
"""Apply M3 mitigation to a probability dictionary.
Args:
probs: Dict mapping outcome → probability.
Returns:
Dict mapping outcome (int) → corrected probability.
"""
cal = self.calibration_result
C = np.array(_get_field(cal, "confusion_matrix"))
n = len(C)
p_obs = np.zeros(n)
for outcome, p in probs.items():
p_obs[int(outcome)] = float(p)
try:
C_inv = np.linalg.inv(C)
except np.linalg.LinAlgError:
C_inv = np.eye(n)
p_corr = C_inv @ p_obs
# Renormalize
p_corr = np.clip(p_corr, 0, None)
total = p_corr.sum()
if total > 0:
p_corr /= total
return {int(i): float(v) for i, v in enumerate(p_corr)}
def _load_from_path(self, path: str | pathlib.Path, max_age_hours: float) -> dict[str, Any]:
"""Load and validate calibration from a file path."""
import json
with open(path) as f:
d = json.load(f)
self._check_age(d.get("calibrated_at", ""), max_age_hours)
return d
def _load_from_cache(
self,
backend: str,
qubit: int | tuple | None,
max_age_hours: float,
cache_dir: str | pathlib.Path | None = None,
) -> dict[str, Any]:
"""Find and load the freshest calibration result from cache."""
import json
from uniqc.calibration.results import find_cached_results
if qubit is None:
raise ValueError(format_enriched_message("qubit must be provided when loading from cache", "calibration"))
result_type = "readout_2q" if isinstance(qubit, tuple) else "readout_1q"
paths = find_cached_results(
backend,
result_type,
max_age_hours=max_age_hours,
cache_dir=cache_dir,
)
if not paths:
raise FileNotFoundError(
format_enriched_message(
f"No fresh calibration result found for backend={backend}, "
f"qubit={qubit}, max_age_hours={max_age_hours}. "
f"Run calibration first.",
"calibration",
)
)
matching_paths: list[pathlib.Path] = []
expected_qubit = tuple(qubit) if isinstance(qubit, tuple) else qubit
for path in paths:
try:
with open(path) as f:
data = json.load(f)
except Exception:
continue
cached_qubit = data.get("qubit")
if isinstance(cached_qubit, list):
cached_qubit = tuple(cached_qubit)
if cached_qubit == expected_qubit:
matching_paths.append(path)
if not matching_paths:
raise FileNotFoundError(
format_enriched_message(
f"No fresh calibration result found for backend={backend}, "
f"qubit={qubit}, max_age_hours={max_age_hours}. "
f"Run calibration first.",
"calibration",
)
)
# Use the most recent exact qubit/pair match.
latest = max(matching_paths, key=lambda p: p.stat().st_mtime)
return self._load_from_path(latest, max_age_hours)
def _check_age(self, calibrated_at: str, max_age_hours: float) -> None:
"""Raise StaleCalibrationError if the calibration is too old."""
if not calibrated_at:
return
try:
# Parse ISO-8601
ts = datetime.fromisoformat(calibrated_at.replace("Z", "+00:00"))
now = datetime.now(timezone.utc)
# Ensure tz-aware
if ts.tzinfo is None:
ts = ts.replace(tzinfo=timezone.utc)
age_hours = (now - ts).total_seconds() / 3600
if age_hours > max_age_hours:
raise StaleCalibrationError(
format_enriched_message(
f"Calibration data is {age_hours:.1f} hours old "
f"(max_age_hours={max_age_hours}). "
f"Calibrated at: {calibrated_at}",
"calibration",
)
)
except ValueError:
pass # Can't parse timestamp — skip age check