"""Intermediate primitives for QEC-compiler integration.
These Primitive subclasses map directly to QEC-compiler AbstractGate names.
Each implements:
- to_abstract_gates(): returns AbstractGate list for QEC lowering
- pyqsparse_object(): decomposed PySparQ simulation for verification
- t_count(): resource estimation
"""
from __future__ import annotations
from typing import Any
from ..core.operation import Primitive
from ..core.utils import get_control_qubit_count, merge_controllers, reg_sz, mcx_t_count
from ..core.simulator import PyQSparseOperationWrapper
def _lazy_abstract_gate(name: str, qubits: tuple[int, ...], params: tuple[float, ...] = ()):
from qec_compiler.ir import AbstractGate
return AbstractGate(name=name, qubits=qubits, params=params)
class _IndexedIRGate(Primitive):
"""Base for QEC-Compiler IR gates addressed by register-local qubit index."""
__abstract__ = True
__self_conjugate__ = False
gate_name = ""
qubit_count = 0
def __init__(self, reg_list=None, param_list=None):
super().__init__(reg_list=reg_list, param_list=param_list or [])
self.register = reg_list[0] if reg_list else None
self.bit_indices = tuple(int(q) for q in self.param_list[:self.qubit_count])
self.gate_params = tuple(float(p) for p in self.param_list[self.qubit_count:])
def pyqsparse_object(self, dagger_ctx=False, controllers_ctx=None):
raise NotImplementedError(
f"{type(self).__name__} is a QEC IR primitive. "
"Add a PySparQ reference before using it for simulation."
)
def t_count(self, dagger_ctx=False, controllers_ctx=None):
if self.gate_name in {"H", "X", "Y", "Z", "CNOT", "SWAP"}:
return 0
raise NotImplementedError(
f"{type(self).__name__} has no pyqres-local T-count model. "
"Lower it through QEC-Compiler for resource accounting."
)
def to_abstract_gates(self, qubit_map):
if self.register not in qubit_map:
raise ValueError(f"Register {self.register!r} is not allocated")
reg_qubits = qubit_map[self.register]
qubits = tuple(reg_qubits[index] for index in self.bit_indices)
return [_lazy_abstract_gate(self.gate_name, qubits, self.gate_params)]
class H(_IndexedIRGate):
"""Hadamard gate. Maps directly to AbstractGate('H')."""
__self_conjugate__ = True
gate_name = "H"
qubit_count = 1
class Z(_IndexedIRGate):
"""Pauli-Z gate. Maps directly to AbstractGate('Z')."""
__self_conjugate__ = True
gate_name = "Z"
qubit_count = 1
class SWAP(_IndexedIRGate):
"""Two-qubit SWAP gate. Maps directly to AbstractGate('SWAP')."""
__self_conjugate__ = True
gate_name = "SWAP"
qubit_count = 2
class CPHASE(_IndexedIRGate):
"""Controlled phase rotation. param_list: [control, target, theta]."""
gate_name = "CPHASE"
qubit_count = 2
class RX(_IndexedIRGate):
"""X-axis rotation. param_list: [target, theta]."""
gate_name = "RX"
qubit_count = 1
class RY(_IndexedIRGate):
"""Y-axis rotation. param_list: [target, theta]."""
gate_name = "RY"
qubit_count = 1
class RZ(_IndexedIRGate):
"""Z-axis rotation. param_list: [target, theta]."""
gate_name = "RZ"
qubit_count = 1
class CCX(_IndexedIRGate):
"""Toffoli gate. Maps directly to AbstractGate('CCX')."""
__self_conjugate__ = True
gate_name = "CCX"
qubit_count = 3
[文档]
class MCX(Primitive):
"""Multi-controlled X gate. Maps directly to AbstractGate('MCX')."""
[文档]
def __init__(self, reg_list=None, param_list=None):
super().__init__(reg_list=reg_list, param_list=param_list)
self.control_regs = reg_list[:-1] if reg_list else []
self.target_reg = reg_list[-1] if reg_list else None
[文档]
def pyqsparse_object(self, dagger_ctx=False, controllers_ctx=None):
import pysparq as ps
controllers_ctx = merge_controllers(self.controllers, controllers_ctx or {})
obj = PyQSparseOperationWrapper(ps.FlipBools(self.target_reg))
obj.set_controller(controllers_ctx)
return obj
[文档]
def t_count(self, dagger_ctx=False, controllers_ctx=None):
ncontrols = get_control_qubit_count(
merge_controllers(self.controllers, controllers_ctx or {}))
n_controls_intrinsic = sum(reg_sz(r) for r in self.control_regs)
return mcx_t_count(n_controls_intrinsic + ncontrols)
[文档]
def to_abstract_gates(self, qubit_map):
ctrl_qubits = []
for reg in self.control_regs:
ctrl_qubits.extend(qubit_map[reg])
tgt_qubits = qubit_map[self.target_reg]
return [_lazy_abstract_gate("MCX", tuple(ctrl_qubits) + (tgt_qubits[0],))]
[文档]
class ADD(Primitive):
"""N-bit ripple-carry adder. Maps to AbstractGate('ADD', (n_bits,))."""
[文档]
def __init__(self, reg_list=None, param_list=None):
super().__init__(reg_list=reg_list, param_list=param_list)
self.input_reg1 = reg_list[0] if reg_list else None
self.input_reg2 = reg_list[1] if len(reg_list) > 1 else None
self.n_bits = param_list[0] if param_list else reg_sz(self.input_reg1) if self.input_reg1 else 1
[文档]
def pyqsparse_object(self, dagger_ctx=False, controllers_ctx=None):
import pysparq as ps
controllers_ctx = merge_controllers(self.controllers, controllers_ctx or {})
obj = PyQSparseOperationWrapper(
ps.Add_UInt_UInt_InPlace(self.input_reg1, self.input_reg2))
obj.set_dagger(dagger_ctx)
obj.set_controller(controllers_ctx)
return obj
[文档]
def t_count(self, dagger_ctx=False, controllers_ctx=None):
ncontrols = get_control_qubit_count(
merge_controllers(self.controllers, controllers_ctx or {}))
n = self.n_bits
return 2 * (n - 1) * mcx_t_count(ncontrols + 2)
[文档]
def to_abstract_gates(self, qubit_map):
a_qubits = qubit_map[self.input_reg1]
b_qubits = qubit_map[self.input_reg2]
return [_lazy_abstract_gate("ADD", tuple(a_qubits) + tuple(b_qubits), (self.n_bits,))]
[文档]
class PLUS_ONE(Primitive):
"""Increment circuit. Maps to AbstractGate('PLUS_ONE', (n_bits,))."""
[文档]
def __init__(self, reg_list=None, param_list=None):
super().__init__(reg_list=reg_list, param_list=param_list)
self.main_reg = reg_list[0] if reg_list else None
self.overflow_reg = reg_list[1] if len(reg_list) > 1 else None
self.n_bits = param_list[0] if param_list else reg_sz(self.main_reg) if self.main_reg else 1
__self_conjugate__ = False
[文档]
def pyqsparse_object(self, dagger_ctx=False, controllers_ctx=None):
import pysparq as ps
controllers_ctx = merge_controllers(self.controllers, controllers_ctx or {})
if self.overflow_reg:
obj = PyQSparseOperationWrapper(
ps.PlusOneAndOverflow(self.main_reg, self.overflow_reg))
else:
obj = PyQSparseOperationWrapper(
ps.PlusOneAndOverflow(self.main_reg, "_overflow"))
obj.set_dagger(dagger_ctx ^ self.dagger_flag)
obj.set_controller(controllers_ctx)
return obj
[文档]
def t_count(self, dagger_ctx=False, controllers_ctx=None):
ncontrols = get_control_qubit_count(
merge_controllers(self.controllers, controllers_ctx or {}))
n = self.n_bits
return 4 * n + ncontrols * 4
[文档]
def to_abstract_gates(self, qubit_map):
qubits = list(qubit_map[self.main_reg])
if self.overflow_reg and self.overflow_reg in qubit_map:
qubits.extend(qubit_map[self.overflow_reg])
return [_lazy_abstract_gate("PLUS_ONE", tuple(qubits), (self.n_bits,))]
[文档]
class REFLECT(Primitive):
"""Multi-controlled Z (reflection). Maps to AbstractGate('REFLECT', (n_bits,))."""
__self_conjugate__ = True
[文档]
def __init__(self, reg_list=None, param_list=None):
super().__init__(reg_list=reg_list, param_list=param_list)
self.target_regs = reg_list
[文档]
def pyqsparse_object(self, dagger_ctx=False, controllers_ctx=None):
import pysparq as ps
controllers_ctx = merge_controllers(self.controllers, controllers_ctx or {})
inverse = param_list[0] if (param_list := self.param_list) else True
obj = PyQSparseOperationWrapper(
ps.Reflection_Bool(self.target_regs, inverse))
obj.set_controller(controllers_ctx)
return obj
[文档]
def t_count(self, dagger_ctx=False, controllers_ctx=None):
ncontrols = get_control_qubit_count(
merge_controllers(self.controllers, controllers_ctx or {}))
n = sum(reg_sz(r) for r in self.target_regs)
return mcx_t_count(n + ncontrols)
[文档]
def to_abstract_gates(self, qubit_map):
qubits = []
for reg in self.target_regs:
qubits.extend(qubit_map[reg])
n_bits = len(qubits)
return [_lazy_abstract_gate("REFLECT", tuple(qubits), (n_bits,))]
[文档]
class MOD_ADD(Primitive):
"""Modular addition a+b mod N. Maps to AbstractGate('MOD_ADD', (modulus,))."""
[文档]
def __init__(self, reg_list=None, param_list=None):
super().__init__(reg_list=reg_list, param_list=param_list)
self.a_reg = reg_list[0] if reg_list else None
self.b_reg = reg_list[1] if len(reg_list) > 1 else None
self.modulus = param_list[0] if param_list else 2
[文档]
def pyqsparse_object(self, dagger_ctx=False, controllers_ctx=None):
raise NotImplementedError(
"MOD_ADD has no matching PySparQ reference primitive yet. "
"Using Add_UInt_UInt_InPlace would violate the modular-add contract."
)
[文档]
def t_count(self, dagger_ctx=False, controllers_ctx=None):
n = reg_sz(self.a_reg) if self.a_reg else 1
return 4 * n * 2 # Approximate: add + compare + conditional sub
[文档]
def to_abstract_gates(self, qubit_map):
qubits = list(qubit_map[self.a_reg]) + list(qubit_map[self.b_reg])
return [_lazy_abstract_gate("MOD_ADD", tuple(qubits), (self.modulus,))]
[文档]
class MOD_MUL(Primitive):
"""Modular multiplication a*c mod N. Maps to AbstractGate('MOD_MUL', (multiplier, modulus))."""
[文档]
def __init__(self, reg_list=None, param_list=None):
super().__init__(reg_list=reg_list, param_list=param_list)
self.reg = reg_list[0] if reg_list else None
self.multiplier = param_list[0] if param_list else 1
self.modulus = param_list[1] if len(param_list) > 1 else 2
[文档]
def pyqsparse_object(self, dagger_ctx=False, controllers_ctx=None):
import pysparq as ps
import math
controllers_ctx = merge_controllers(self.controllers, controllers_ctx or {})
multiplier = int(self.multiplier)
modulus = int(self.modulus)
if math.gcd(multiplier, modulus) != 1:
raise ValueError(
f"MOD_MUL requires multiplier coprime to modulus, got "
f"multiplier={multiplier}, modulus={modulus}"
)
if dagger_ctx ^ self.dagger_flag:
multiplier = pow(multiplier, -1, modulus)
op_cls = getattr(ps, "Mod_Mult_UInt_ConstUInt_InPlace", None)
if op_cls is None:
op_cls = getattr(ps, "Mod_Mult_UInt_ConstUInt")
# PySparQ's primitive computes reg *= a^(2^x) mod N. Use x=0 so
# the multiplier is exactly the intermediate-layer constant c.
obj = PyQSparseOperationWrapper(
op_cls(self.reg, multiplier, 0, modulus))
obj.set_controller(controllers_ctx)
return obj
[文档]
def t_count(self, dagger_ctx=False, controllers_ctx=None):
n = reg_sz(self.reg) if self.reg else 1
ncontrols = get_control_qubit_count(
merge_controllers(self.controllers, controllers_ctx or {}))
return 4 * n * mcx_t_count(ncontrols + 2)
[文档]
def to_abstract_gates(self, qubit_map):
qubits = list(qubit_map[self.reg])
return [_lazy_abstract_gate("MOD_MUL", tuple(qubits),
(self.multiplier, self.modulus))]
class CMUL_MOD_N(Primitive):
"""Controlled modular multiplication.
``param_list`` shape: ``[qubits, multiplier, modulus]`` where ``qubits``
is ``[control, *work_register]`` in indices relative to ``reg_list[0]``.
Maps directly to QEC-Compiler ``AbstractGate('CMUL_MOD_N')``.
"""
def __init__(self, reg_list=None, param_list=None):
super().__init__(reg_list=reg_list, param_list=param_list or [])
self.register = reg_list[0] if reg_list else None
self.bit_indices = tuple(int(q) for q in self.param_list[0])
self.multiplier = float(self.param_list[1])
self.modulus = float(self.param_list[2])
def pyqsparse_object(self, dagger_ctx=False, controllers_ctx=None):
raise NotImplementedError(
"CMUL_MOD_N has no pyqres-local PySparQ reference implementation yet."
)
def t_count(self, dagger_ctx=False, controllers_ctx=None):
raise NotImplementedError(
"CMUL_MOD_N resource accounting is owned by QEC-Compiler lowering."
)
def to_abstract_gates(self, qubit_map):
reg_qubits = qubit_map[self.register]
qubits = tuple(reg_qubits[index] for index in self.bit_indices)
return [_lazy_abstract_gate("CMUL_MOD_N", qubits, (self.multiplier, self.modulus))]