"""PySparQ 动态算子扩展模块 - 提供运行时编译和加载自定义 C++ 算子的功能。"""
from typing import List, Tuple, Type, Optional
from .compiler import (
CompilerConfig,
CompilationError,
compile_cpp_code,
compute_code_hash,
find_project_root,
generate_cpp_source,
format_compile_error,
clear_cache,
get_cache_info,
quick_compile,
)
from .operator_wrapper import (
CppOperatorWrapper,
DynamicOperatorError,
DynamicOperatorLoadError,
DynamicOperatorFactoryError,
create_operator_class,
cleanup_all_instances,
)
__all__ = [
# 编译相关
"CompilerConfig",
"CompilationError",
"compile_cpp_code",
"compute_code_hash",
"find_project_root",
"generate_cpp_source",
"format_compile_error",
"clear_cache",
"get_cache_info",
"quick_compile",
# 动态算子相关
"compile_operator",
"CppOperatorWrapper",
"DynamicOperatorError",
"DynamicOperatorLoadError",
"DynamicOperatorFactoryError",
"create_operator_class",
"cleanup_all_instances",
]
__version__ = "0.2.0"
[文档]
def compile_operator(
name: str,
cpp_code: str,
base_class: str = "BaseOperator",
extra_includes: List[str] = None,
extra_libs: List[str] = None,
constructor_args: List[Tuple[str, str]] = None,
cache_dir: Optional[str] = None,
verbose: bool = False,
) -> Type:
"""编译 C++ 代码为动态算子类。
这是一个高级函数,将用户提供的 C++ 代码编译为共享库,
并包装为可直接在 Python 中使用的算子类。动态算子可以
像原生 PySparQ 算子一样应用于 SparseState。
Args:
name: 算子类名。必须是有效的 Python 类名,且必须与 C++ 代码中的类名匹配。
cpp_code: C++ 源代码,仅包含类定义部分。代码必须继承自 BaseOperator
或 SelfAdjointOperator,并实现 operator() 方法。
base_class: 基类名,决定 dagger 行为。可选值:
- "BaseOperator": 一般算子,需手动实现 dag() 方法
- "SelfAdjointOperator": 厄米算子,dag() 自动等于 operator()
默认为 "BaseOperator"。
extra_includes: 额外头文件搜索路径列表。PySparQ 头文件会自动包含。
extra_libs: 额外链接库列表。大多数算子不需要额外库。
constructor_args: 构造函数参数列表,格式为 [(类型, 名称), ...]。
支持的类型: size_t, int, long, double, float, bool, uint64_t。
示例: [("size_t", "reg_id"), ("double", "phase")]
cache_dir: 缓存目录路径。默认使用系统临时目录下的 pysparq_dynamic_ops/。
verbose: 是否输出详细编译日志,用于调试。
Returns:
动态生成的算子类。可通过关键字参数创建实例,如: OpClass(reg_id=0, phase=1.0)
Raises:
CompilationError: C++ 编译失败。错误信息包含详细的编译器输出。
DynamicOperatorLoadError: 动态库加载失败。
ValueError: 参数错误(如空名称、无效基类等)。
Example:
创建一个简单的翻转算子:
>>> from pysparq.dynamic_operator import compile_operator
>>>
>>> cpp_code = '''
... class FlipOp : public SelfAdjointOperator {
... size_t reg_id;
... public:
... FlipOp(size_t r) : reg_id(r) {}
... void operator()(std::vector<System>& state) const override {
... for (auto& s : state) {
... s.get(reg_id).value ^= 1;
... }
... }
... };
... '''
>>>
>>> FlipOp = compile_operator(
... name="FlipOp",
... cpp_code=cpp_code,
... base_class="SelfAdjointOperator",
... constructor_args=[("size_t", "reg_id")]
... )
>>>
>>> # 创建实例
>>> op = FlipOp(reg_id=0)
>>> print(repr(op)) # FlipOp(reg_id=0)
Note:
- 编译的库会基于代码哈希缓存,避免重复编译。
- Windows 上可能存在 ABI 兼容性问题(MSVC vs MinGW)。
- C++ 类名必须与 Python name 参数匹配。
- 算子中的状态访问: s.get(reg_id).value 获取值,s.amplitude 获取振幅。
See Also:
get_cache_info: 查询编译缓存状态。
clear_cache: 清除编译缓存。
CompilerConfig: 高级编译器配置。
"""
if extra_includes is None:
extra_includes = []
if extra_libs is None:
extra_libs = []
if constructor_args is None:
constructor_args = []
# 参数验证
if not name or not isinstance(name, str):
raise ValueError("name 必须是有效的字符串")
if not cpp_code or not isinstance(cpp_code, str):
raise ValueError("cpp_code 必须是有效的 C++ 代码字符串")
valid_base_classes = ["BaseOperator", "SelfAdjointOperator"]
if base_class not in valid_base_classes:
raise ValueError(f"base_class 必须是 {valid_base_classes} 之一")
# 构造构造函数参数字符串
ctor_params = ", ".join(f"{arg_type} {arg_name}" for arg_type, arg_name in constructor_args)
ctor_args = ", ".join(arg_name for _, arg_name in constructor_args)
# 使用 Python 增强模板
config = CompilerConfig(
include_paths=extra_includes,
libraries=extra_libs,
template=CompilerConfig.PYTHON_TEMPLATE.replace("{BASE_CLASS}", base_class),
)
if verbose:
print(f"[compile_operator] 编译算子: {name}")
print(f"[compile_operator] 基类: {base_class}")
print(f"[compile_operator] 参数: {ctor_params}")
# 自动检测项目根目录
project_root = find_project_root()
if project_root is None:
raise RuntimeError(
"无法自动检测项目根目录。请确保 SparQ/ 和 PySparQ/ 目录存在。"
)
# 编译 C++ 代码
lib_path = compile_cpp_code(
cpp_code=cpp_code,
class_name=name,
cache_dir=cache_dir,
ctor_params=ctor_params,
ctor_args=ctor_args,
config=config,
project_root=str(project_root),
verbose=verbose,
)
if verbose:
print(f"[compile_operator] 编译成功: {lib_path}")
print(f"[compile_operator] 创建 Python 类...")
# 创建 Python 类
OpClass = create_operator_class(
name=name,
lib_path=lib_path,
base_class=base_class,
constructor_args=constructor_args,
)
if verbose:
print(f"[compile_operator] 算子类 {name} 已创建")
return OpClass