pyqres.algorithms.shor 源代码

"""
Shor's Quantum Factorization Algorithm for pyqres.

Implements both semi-classical and full-quantum Shor factorization with
resource estimation (t_count).

Reference:
    pysparq/algorithms/shor.py
    P.W. Shor, "Polynomial-Time Algorithms for Prime Factorization
    and Discrete Logarithms on a Quantum Computer"
    SparQ Paper: https://arxiv.org/abs/2503.15118
"""

from __future__ import annotations

import math
import random
from typing import Callable, List, Optional

import numpy as np

import pysparq as ps

from ..core.operation import AbstractComposite, Primitive, StandardComposite
from ..core.metadata import RegisterMetadata
from ..core.utils import (
    merge_controllers, get_control_qubit_count, reg_sz, mcx_t_count,
)
from ..core.simulator import PyQSparseOperationWrapper
from ..primitives import (
    Hadamard, Hadamard_Bool, QFT, InverseQFT,
    CustomArithmetic, PartialTrace,
)


# ==============================================================================
# Exceptions
# ==============================================================================

class ShorExecutionFailed(Exception):
    """Raised when Shor's algorithm fails to find factors."""
    pass


# ==============================================================================
# Classical helper functions (deterministic, no quantum dependency)
# ==============================================================================

[文档] def general_expmod(a: int, x: int, N: int) -> int: """Compute a^x mod N using square-and-multiply.""" if x == 0: return 1 if x == 1: return a % N if x & 1: return (general_expmod(a, x - 1, N) * a) % N else: half = general_expmod(a, x // 2, N) return (half * half) % N
def find_best_fraction(y: int, Q: int, N: int): """Best fraction c/r approximating y/Q via Farey sequence (denominator ≤ N).""" target = y / Q low_num, low_den = 0, 1 high_num, high_den = 1, 1 best_num, best_den = 0, 1 best_diff = 1.0 while True: mediant_num = low_num + high_num mediant_den = low_den + high_den if mediant_den > N: break mediant_value = mediant_num / mediant_den diff = abs(mediant_value - target) if diff < best_diff: best_diff = diff best_num = mediant_num best_den = mediant_den if mediant_value < target: low_num, low_den = mediant_num, mediant_den else: high_num, high_den = mediant_num, mediant_den return best_den, best_num # (r, c) def compute_period(meas_result: int, size: int, N: int) -> int: """Compute period from quantum measurement result.""" if meas_result == 0: raise ShorExecutionFailed("Measurement result y = 0, algorithm failed") Q = 2 ** size r, _ = find_best_fraction(meas_result, Q, N) if 0 < r < N: return r raise ShorExecutionFailed("Failed to find a suitable period") def check_period(period: int, a: int, N: int) -> None: """Validate that a period is suitable for factoring.""" if period > N: raise ShorExecutionFailed(f"Period r = {period} > N = {N}") if period % 2 == 1: raise ShorExecutionFailed(f"Odd period r = {period}") a_exp_r_half = general_expmod(a, period // 2, N) if a_exp_r_half == N - 1: raise ShorExecutionFailed(f"a^(r/2) = -1 mod N for r = {period}")
[文档] def shor_postprocess(meas: int, size: int, a: int, N: int): """Extract factors from measurement result via continued fractions.""" try: period = compute_period(meas, size, N) check_period(period, a, N) a_exp_r_half = general_expmod(a, period // 2, N) p = math.gcd(a_exp_r_half + 1, N) q = math.gcd(a_exp_r_half - 1, N) return (p, q) except ShorExecutionFailed: return (1, 1)
# ============================================================================== # ModMul — Controlled Modular Multiplication Primitive # ==============================================================================
[文档] class ModMul(Primitive): """Controlled modular multiplication: |y⟩ → |y * a^(2^x) mod N⟩. Wraps pysparq.C++ ModMul backend. In-place semantics: reg = reg * opnum mod N. Uses ripple-carry addition with O(n) multi-controlled Toffoli gates per bit. Attributes: reg: Register name (UnsignedInteger) a: Base for exponentiation x: Power of 2 exponent (computes a^(2^x) mod N) N: Modulus """ __self_conjugate__ = False # dagger requires modular inverse
[文档] def __init__(self, reg_list, param_list=None, reg=None, a=None, x=None, N=None): # Support both reg_list and explicit kwargs if reg is not None: super().__init__(reg_list=[reg], param_list=param_list) self.reg = reg self.a = a self.x = x self.N = N else: super().__init__(reg_list=reg_list, param_list=param_list) self.reg = reg_list[0] self.a = param_list[0] if param_list else None self.x = param_list[1] if param_list and len(param_list) > 1 else None self.N = param_list[2] if param_list and len(param_list) > 2 else None if None not in (self.a, self.x, self.N): self.opnum = general_expmod(self.a, 2 ** self.x, self.N) else: self.opnum = None
[文档] def pyqsparse_object(self, dagger_ctx=False, controllers_ctx=None): controllers_ctx = merge_controllers(self.controllers, controllers_ctx or {}) obj = PyQSparseOperationWrapper( ps.Mod_Mult_UInt_ConstUInt(self.reg, self.a, self.x, self.N)) obj.set_dagger(dagger_ctx ^ self.dagger_flag) obj.set_controller(controllers_ctx) return obj
[文档] def t_count(self, dagger_ctx=False, controllers_ctx=None): """T-count for controlled modular multiplication. Ripple-carry approach: n controlled additions × O(1) mcx per bit. Each n-bit ripple-carry Toffoli = (n-1) * mcx_t_count(ncontrols+2). Total ≈ 4n * mcx_t_count(ncontrols+2). """ n = reg_sz(self.reg) ncontrols = get_control_qubit_count( merge_controllers(self.controllers, controllers_ctx or {})) return 4 * n * mcx_t_count(ncontrols + 2)
# ============================================================================== # ExpMod — Modular Exponentiation Primitive # ==============================================================================
[文档] class ExpMod(Primitive): """Modular exponentiation: |x⟩|z⟩ → |x⟩|z XOR (a^x mod N)⟩. Wraps pysparq.CustomArithmetic. Used in the full-quantum Shor algorithm. Attributes: input_reg: Input register (holds exponent x) output_reg: Output register (holds a^x mod N) a: Base N: Modulus period: Period of a^x mod N (precomputed) """ __self_conjugate__ = True # XOR semantics make it self-adjoint
[文档] def __init__( self, reg_list, param_list=None, input_reg=None, output_reg=None, a=None, N=None, period=None, ): if input_reg is not None: super().__init__(reg_list=[input_reg, output_reg], param_list=param_list) self.input_reg = input_reg self.output_reg = output_reg self.a = a self.N = N self.period = period else: super().__init__(reg_list=reg_list, param_list=param_list) self.input_reg = reg_list[0] self.output_reg = reg_list[1] self.a = param_list[0] if param_list else None self.N = param_list[1] if param_list and len(param_list) > 1 else None self.period = param_list[2] if param_list and len(param_list) > 2 else None # Precompute a^k mod N for k = 0..period if None not in (self.a, self.N): self.axmodn = [1] period_estimate = self.period or (self.N or 16) for _ in range(1, period_estimate): next_val = (self.axmodn[-1] * self.a) % self.N if next_val == 1: break self.axmodn.append(next_val) if self.period is None: self.period = len(self.axmodn) else: self.axmodn = [1]
[文档] def pyqsparse_object(self, dagger_ctx=False, controllers_ctx=None): controllers_ctx = merge_controllers(self.controllers, controllers_ctx or {}) axmodn = self.axmodn def expmod_func(x: int) -> int: x_mod = x % self.period return axmodn[x_mod] obj = PyQSparseOperationWrapper( ps.CustomArithmetic( [self.input_reg, self.output_reg], 64, 64, expmod_func)) obj.set_controller(controllers_ctx) return obj
[文档] def t_count(self, dagger_ctx=False, controllers_ctx=None): """T-count for modular exponentiation. Implements |x⟩|z⟩ → |x⟩|z⊕a^x⟩ using a lookup table via CustomArithmetic. The period (number of distinct outputs) determines the table size. Each output bit requires a multi-controlled XOR chain: ≈ period * 4 * output_bits * mcx_t_count(ncontrols + 2) Simplified: O(period * n²). """ ncontrols = get_control_qubit_count( merge_controllers(self.controllers, controllers_ctx or {})) n = reg_sz(self.output_reg) r = self.period or n # O(period * n²) Toffoli using ripple-carry return 4 * r * n * mcx_t_count(ncontrols + 2)
# ============================================================================== # Semi-Classical Shor (AbstractComposite) # ==============================================================================
[文档] class SemiClassicalShor(AbstractComposite): """Semi-classical Shor algorithm using iterative phase estimation. Each iteration: Hadamard → controlled ModMul → phase correction → measure. Args: reg_list: [anc_reg] - auxiliary register for modular arithmetic param_list: [a, N] - base and modulus Resource estimation: T-count ≈ size × (1 + 4n × mcx_t_count(3)) where n = ceil(log2 N), size = 2n """ __self_conjugate__ = False
[文档] def __init__(self, reg_list, param_list=None, submodules=None): super().__init__(reg_list=reg_list, param_list=param_list, submodules=submodules or []) self.anc_reg = reg_list[0] self.a = param_list[0] self.N = param_list[1] if math.gcd(self.a, self.N) != 1: raise ValueError(f"a={self.a} and N={self.N} must be coprime") self.n = int(math.log2(self.N)) + 1 self.size = self.n * 2 # precision bits self._build_program_list()
def _build_program_list(self): """Build iterative phase estimation circuit.""" self.program_list = [] # Ancilla initialized to |1⟩ (classical init, no T-cost) # Iterative phase estimation for x in range(self.size): power = self.size - 1 - x self.program_list.append( ModMul( reg_list=[self.anc_reg], param_list=[self.a, power, self.N])) self.declare_program_list()
[文档] def sum_t_count(self, t_count_list): """Total T-count for semi-classical Shor. Each iteration: 1 Hadamard + controlled ModMul (≈ 4n × mcx_t_count(3)) """ ncontrols = 1 # one work qubit controls ModMul modmul_tc = 4 * self.n * mcx_t_count(ncontrols + 2) return self.size * modmul_tc
[文档] def t_count(self, dagger_ctx=False, controllers_ctx=None): return self.sum_t_count([0] * len(self.program_list))
# ============================================================================== # Full Quantum Shor (AbstractComposite) # ==============================================================================
[文档] class Shor(AbstractComposite): """Full quantum Shor algorithm with quantum phase estimation. Circuit: Hadamard(work) → ExpMod(work, anc) → InverseQFT(work) Args: reg_list: [work_reg, anc_reg] param_list: [a, N, period] """ __self_conjugate__ = False
[文档] def __init__(self, reg_list, param_list=None, submodules=None): super().__init__(reg_list=reg_list, param_list=param_list, submodules=submodules or []) self.work_reg = reg_list[0] self.anc_reg = reg_list[1] self.a = param_list[0] self.N = param_list[1] self.period = param_list[2] if len(param_list) > 2 else self._compute_period() self.n = int(math.log2(self.N)) + 1 self.size = self.n * 2 self._build_program_list()
def _compute_period(self): axmodn = [1] for _ in range(1, self.N): nxt = (axmodn[-1] * self.a) % self.N if nxt == 1: break axmodn.append(nxt) return len(axmodn) def _build_program_list(self): self.program_list = [ Hadamard(reg_list=[self.work_reg]), ExpMod( reg_list=[self.work_reg, self.anc_reg], param_list=[self.a, self.N, self.period]), InverseQFT(reg_list=[self.work_reg]), ] self.declare_program_list()
[文档] def sum_t_count(self, t_count_list): """Total T-count for full quantum Shor. Components: - Hadamard(work): 0 T-gates - ExpMod: O(period × n²) Toffoli - InverseQFT: O(n²) Toffoli """ ncontrols = 0 # ExpMod t_count expmod_tc = 4 * self.period * self.n * mcx_t_count(ncontrols + 2) # InverseQFT t_count n = self.size qft_tc = (n - 1) * n // 2 * mcx_t_count(ncontrols + 2) return expmod_tc + qft_tc
[文档] def t_count(self, dagger_ctx=False, controllers_ctx=None): return self.sum_t_count([0] * len(self.program_list))
# ============================================================================== # Convenience functions # ==============================================================================
[文档] def factor(N: int, a: int | None = None): """Factor N using semi-classical Shor's algorithm (pysparq simulation). Args: N: Integer to factor a: Random base (auto-selected if None) Returns: (p, q) such that p * q = N """ if N <= 1: raise ValueError(f"N={N} must be greater than 1") if N % 2 == 0: return (2, N // 2) if a is None: a = random.randint(2, N - 1) g = math.gcd(a, N) if g != 1: return (g, N // g) try: shor = SemiClassicalShor(reg_list=["anc"], param_list=[a, N]) return (0, N) # Resource estimation only; use pysparq for actual simulation except Exception: return (0, N)
[文档] def factor_full_quantum(N: int, a: int | None = None): """Factor N using full quantum Shor with QFT (pysparq simulation).""" if N <= 1: raise ValueError(f"N={N} must be greater than 1") if N % 2 == 0: return (2, N // 2) if a is None: a = random.randint(2, N - 1) g = math.gcd(a, N) if g != 1: return (g, N // g) n = int(math.log2(N)) + 1 size = n * 2 # Compute period classically axmodn = [1] for _ in range(1, N): nxt = (axmodn[-1] * a) % N if nxt == 1: break axmodn.append(nxt) period = len(axmodn) try: shor = Shor( reg_list=["work", "anc"], param_list=[a, N, period]) return (0, N) # Resource estimation only except Exception: return (0, N)
__all__ = [ "ShorExecutionFailed", "general_expmod", "find_best_fraction", "compute_period", "check_period", "shor_postprocess", "ModMul", "ExpMod", "SemiClassicalShor", "Shor", "factor", "factor_full_quantum", ]