Source code for uniqc.simulator.torchquantum_simulator

"""TorchQuantum-based simulator with native PyTorch autograd.

This module provides a quantum circuit simulator backed by TorchQuantum,
enabling differentiable statevector simulation where gradients flow through
PyTorch autograd natively (no parameter-shift rule needed).

Unlike BaseSimulator subclasses that consume OriginIR/QASM strings,
this simulator reads Circuit.opcode_list directly.

Note: TorchQuantum uses qubit-0-as-MSB convention (the first dimension in the
state tensor is qubit 0). UnifiedQuantum uses qubit-0-as-LSB convention
(standard in most quantum computing frameworks). This simulator handles the
endianness conversion automatically.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

try:
    import torch

    TORCH_AVAILABLE = True
except ImportError:
    torch = None  # type: ignore
    TORCH_AVAILABLE = False

try:
    if TORCH_AVAILABLE:
        import torchquantum as tq
        import torchquantum.functional as tqf
        from torchquantum.measurement import expval_joint_analytical

        TORCHQUANTUM_AVAILABLE = True
    else:
        TORCHQUANTUM_AVAILABLE = False
except ImportError:
    tq = None  # type: ignore
    tqf = None  # type: ignore
    expval_joint_analytical = None  # type: ignore
    TORCHQUANTUM_AVAILABLE = False

if TYPE_CHECKING:
    from uniqc.circuit_builder.qcircuit import OpCode

__all__ = ["TORCHQUANTUM_AVAILABLE", "TorchQuantumSimulator"]

if TORCHQUANTUM_AVAILABLE:
    # Gate mapping: Uniqc opcode name → (tqf function, is_parametric)
    _GATE_MAP: dict[str, tuple] = {
        "H": (tqf.hadamard, False),
        "X": (tqf.paulix, False),
        "Y": (tqf.pauliy, False),
        "Z": (tqf.pauliz, False),
        "S": (tqf.s, False),
        "SX": (tqf.sx, False),
        "T": (tqf.t, False),
        "I": (tqf.i, False),
        "RX": (tqf.rx, True),
        "RY": (tqf.ry, True),
        "RZ": (tqf.rz, True),
        "U1": (tqf.u1, True),
        "U2": (tqf.u2, True),
        "U3": (tqf.u3, True),
        "CNOT": (tqf.cnot, False),
        "CZ": (tqf.cz, False),
        "SWAP": (tqf.swap, False),
        "ISWAP": (tqf.iswap, False),
        "XX": (tqf.rxx, True),
        "YY": (tqf.ryy, True),
        "ZZ": (tqf.rzz, True),
        "TOFFOLI": (tqf.toffoli, False),
        "CSWAP": (tqf.cswap, False),
    }

    # Dagger-specific overrides
    _DAGGER_MAP: dict[str, tuple] = {
        "S": (tqf.sdg, False),
        "SX": (tqf.sxdg, False),
        "T": (tqf.tdg, False),
    }
else:
    _GATE_MAP = {}
    _DAGGER_MAP = {}


def _require_torchquantum() -> None:
    """Raise a consistent install hint when TorchQuantum backend is unavailable."""
    if TORCHQUANTUM_AVAILABLE:
        return
    raise ImportError(
        "TorchQuantum backend requires PyTorch and a manual TorchQuantum install. "
        "Install with: pip install unified-quantum[pytorch] && "
        'pip install "torchquantum @ '
        'git+https://github.com/Agony5757/torchquantum.git@fix/optional-qiskit-deps"'
    )


def _extract_n_qubits(opcode_list: list[OpCode]) -> int:
    """Determine number of qubits from opcodes."""
    max_q = 0
    for _op_name, qubits, _cbits, _params, _dagger, controls in opcode_list:
        max_q = max(max_q, max(qubits) + 1) if isinstance(qubits, list) else max(max_q, qubits + 1)
        if controls:
            ctrl_list = list(controls) if not isinstance(controls, list) else controls
            max_q = max(max_q, max(ctrl_list) + 1)
    return max_q


def _reverse_bits(statevector: torch.Tensor, n_qubits: int) -> torch.Tensor:
    """Reverse bit order in statevector to convert between endianness conventions.

    TorchQuantum: qubit 0 = MSB (first tensor dimension)
    Standard:     qubit 0 = LSB (last bit in binary index)
    """
    dim = 2**n_qubits
    sv = statevector[:dim]
    indices = torch.arange(dim, device=sv.device)
    # Reverse the bits of each index
    reversed_indices = torch.zeros_like(indices)
    for b in range(n_qubits):
        bit = (indices >> b) & 1
        reversed_indices |= bit << (n_qubits - 1 - b)
    return sv[reversed_indices]


def _reverse_pauli_string(pauli_str: str) -> str:
    """Reverse Pauli string to match TorchQuantum's qubit ordering."""
    return pauli_str[::-1]


[docs] class TorchQuantumSimulator: """TorchQuantum-based simulator with native PyTorch autograd. Operates on Circuit.opcode_list directly (no string serialization). All operations are differentiable through PyTorch autograd. The n_wires parameter is optional — if not set, it is auto-detected from the opcodes. """ def __init__(self, n_wires: int = 0, device: str = "cpu"): _require_torchquantum() self.n_wires = n_wires self.device = device def _resolve_n_wires(self, opcode_list: list[OpCode]) -> int: if self.n_wires > 0: return self.n_wires return _extract_n_qubits(opcode_list) or 1 def _create_qdev(self, n_wires: int, bsz: int = 1) -> tq.QuantumDevice: return tq.QuantumDevice(n_wires=n_wires, bsz=bsz, device=self.device)
[docs] def execute_opcodes( self, opcode_list: list[OpCode], param_overrides: dict[int, torch.Tensor] | None = None, n_qubits: int | None = None, bsz: int = 1, ) -> tq.QuantumDevice: """Execute opcodes on a fresh QuantumDevice. Args: opcode_list: Circuit.opcode_list. param_overrides: Map opcode index → torch.Tensor to inject differentiable parameters. n_qubits: Override number of qubits (auto-detected if None). bsz: Batch size for the QuantumDevice. Returns: The QuantumDevice after executing all gates. """ if n_qubits is None: n_qubits = self._resolve_n_wires(opcode_list) qdev = self._create_qdev(n_qubits, bsz) param_overrides = param_overrides or {} for idx, opcode in enumerate(opcode_list): op_name, qubits, _cbits, params, dagger, controls = opcode # Resolve gate function if dagger and op_name in _DAGGER_MAP: gate_fn, is_parametric = _DAGGER_MAP[op_name] elif op_name in _GATE_MAP: gate_fn, is_parametric = _GATE_MAP[op_name] else: raise NotImplementedError( f"Gate '{op_name}' is not supported by TorchQuantum backend." ) # Resolve wires wires = qubits if isinstance(qubits, list) else [qubits] if controls: wires = list(controls) + wires # Resolve parameters if idx in param_overrides: gate_params = param_overrides[idx] elif is_parametric and params is not None: if isinstance(params, (list, tuple)): raw = [-p for p in params] if dagger else list(params) else: raw = [-params] if dagger else [params] gate_params = torch.tensor( raw, dtype=torch.float32, device=self.device ) else: gate_params = None # Apply gate kwargs: dict = {"wires": wires, "inverse": False} if gate_params is not None: if gate_params.dim() == 0: gate_params = gate_params.unsqueeze(0).unsqueeze(0) elif gate_params.dim() == 1: gate_params = gate_params.unsqueeze(0) if gate_params.shape[0] != bsz: gate_params = gate_params.expand(bsz, -1) kwargs["params"] = gate_params gate_fn(qdev, **kwargs) return qdev
[docs] def simulate_statevector( self, opcode_list: list[OpCode], param_overrides: dict[int, torch.Tensor] | None = None, n_qubits: int | None = None, ) -> torch.Tensor: """Execute circuit and return statevector (LSB convention). Returns: Complex tensor of shape (2^n_qubits,) with the final statevector. """ if n_qubits is None: n_qubits = self._resolve_n_wires(opcode_list) qdev = self.execute_opcodes(opcode_list, param_overrides, n_qubits, bsz=1) sv_tq = qdev.get_states_1d().squeeze(0) return _reverse_bits(sv_tq, n_qubits)
[docs] def expectation( self, opcode_list: list[OpCode], hamiltonian: list[tuple[str, float]], param_overrides: dict[int, torch.Tensor] | None = None, n_qubits: int | None = None, ) -> torch.Tensor: """Compute <psi|H|psi> for a Pauli Hamiltonian. Args: opcode_list: Circuit.opcode_list. hamiltonian: List of (pauli_string, coefficient). param_overrides: Differentiable parameter injection. n_qubits: Override number of qubits. Returns: Scalar tensor with the expectation value (differentiable). """ if n_qubits is None: n_qubits = self._resolve_n_wires(opcode_list) total = torch.tensor(0.0, dtype=torch.float32, device=self.device) for pauli_str, coeff in hamiltonian: if abs(coeff) < 1e-15: continue if all(c == "I" for c in pauli_str): total = total + coeff continue qdev = self.execute_opcodes(opcode_list, param_overrides, n_qubits, bsz=1) # Reverse Pauli string to match TorchQuantum's qubit ordering tq_pauli = _reverse_pauli_string(pauli_str) expval = expval_joint_analytical(qdev, tq_pauli) total = total + coeff * expval.squeeze() return total