PySparQ.pysparq.dynamic_operator.operator_wrapper 源代码

"""
动态算子包装模块

使用 ctypes 调用动态库中的 C++ 算子,创建 Python 代理类。
"""

import ctypes
import os
import weakref
from typing import Any, Callable, List, Tuple, Type

# 从 pysparq 导入基类(可选,用于类型提示)
try:
    from pysparq import SparseState
except ImportError:
    # pysparq 未编译时,定义占位符
[文档] SparseState = None
[文档] class DynamicOperatorError(Exception): """动态算子错误""" pass
[文档] class DynamicOperatorLoadError(DynamicOperatorError): """动态库加载错误""" pass
[文档] class DynamicOperatorFactoryError(DynamicOperatorError): """工厂函数调用错误""" pass
# 存储活跃实例用于清理 _active_instances = {} _instance_counter = [0] def _register_instance(instance): """注册实例用于跟踪""" instance_id = _instance_counter[0] _instance_counter[0] += 1 _active_instances[instance_id] = weakref.ref(instance) return instance_id def _unregister_instance(instance_id): """注销实例""" if instance_id in _active_instances: del _active_instances[instance_id]
[文档] class CppOperatorWrapper: """ C++ 算子包装器 负责加载动态库、调用工厂函数创建/销毁 C++ 算子对象 """ def __init__(self, lib_path: str): """ 初始化包装器 Args: lib_path: 动态库路径 """
[文档] self.lib_path = lib_path
self._handle = None self._create_func = None self._destroy_func = None self._get_name_func = None self._get_base_class_func = None self._apply_func = None self._apply_dag_func = None self._arg_types = []
[文档] def load(self, arg_types: List[str] = None): """ 加载动态库 Args: arg_types: 构造函数参数类型列表 Raises: DynamicOperatorLoadError: 加载失败 """ if not os.path.exists(self.lib_path): raise DynamicOperatorLoadError(f"动态库不存在: {self.lib_path}") try: # 使用 RTLD_GLOBAL 以便解析符号 self._handle = ctypes.CDLL(self.lib_path, mode=ctypes.RTLD_GLOBAL) except OSError as e: raise DynamicOperatorLoadError(f"加载动态库失败: {e}") # 获取工厂函数 try: self._create_func = self._handle.create_operator self._destroy_func = self._handle.destroy_operator self._get_name_func = self._handle.get_operator_name self._get_base_class_func = self._handle.get_base_class # Python 增强函数 - 通过 state._cpp_ptr() 获取 C++ SparseState* 指针 try: self._apply_func = self._handle.apply_operator self._apply_dag_func = self._handle.apply_operator_dag except AttributeError: pass # 旧版本模板没有这些函数 except AttributeError as e: raise DynamicOperatorLoadError(f"找不到必需的工厂函数: {e}") # 设置参数类型 if arg_types: self._arg_types = arg_types self._setup_arg_types()
def _setup_arg_types(self): """设置函数参数类型""" if not self._create_func: return # 根据参数类型设置 argtypes type_mapping = { 'int': ctypes.c_int, 'size_t': ctypes.c_size_t, 'unsigned int': ctypes.c_uint, 'unsigned long': ctypes.c_ulong, 'unsigned long long': ctypes.c_ulonglong, 'long': ctypes.c_long, 'long long': ctypes.c_longlong, 'float': ctypes.c_float, 'double': ctypes.c_double, 'bool': ctypes.c_bool, 'char': ctypes.c_char, 'char*': ctypes.c_char_p, 'const char*': ctypes.c_char_p, } argtypes = [] for arg_type in self._arg_types: ctype = type_mapping.get(arg_type) if ctype is None: # 默认使用 size_t ctype = ctypes.c_size_t argtypes.append(ctype) self._create_func.argtypes = argtypes self._create_func.restype = ctypes.c_void_p self._destroy_func.argtypes = [ctypes.c_void_p] self._destroy_func.restype = None if self._get_name_func: self._get_name_func.argtypes = [] self._get_name_func.restype = ctypes.c_char_p if self._get_base_class_func: self._get_base_class_func.argtypes = [] self._get_base_class_func.restype = ctypes.c_char_p if self._apply_func: # SparseState* 参数:ctypes.c_void_p 传递指针值 self._apply_func.argtypes = [ctypes.c_void_p, ctypes.c_void_p] self._apply_func.restype = None if self._apply_dag_func: self._apply_dag_func.argtypes = [ctypes.c_void_p, ctypes.c_void_p] self._apply_dag_func.restype = None
[文档] def create(self, *args) -> int: """ 创建 C++ 算子实例 Args: *args: 构造函数参数 Returns: C++ 对象地址(作为 Python int) """ if not self._create_func: raise DynamicOperatorFactoryError("工厂函数未加载") try: ptr = self._create_func(*args) return ptr except Exception as e: raise DynamicOperatorFactoryError(f"创建算子失败: {e}")
[文档] def destroy(self, ptr: int): """ 销毁 C++ 算子实例 Args: ptr: C++ 对象地址 """ if self._destroy_func and ptr: try: self._destroy_func(ptr) except Exception: pass # 忽略销毁错误
[文档] def close(self): """ 关闭动态库并释放资源 注意:在 Windows 上,必须确保所有 C++ 对象都已销毁 才能成功删除动态库文件 """ # 清理函数引用,帮助垃圾回收 self._create_func = None self._destroy_func = None self._get_name_func = None self._get_base_class_func = None self._apply_func = None self._apply_dag_func = None # 释放动态库句柄 if self._handle is not None: # 在 Windows 上,需要强制垃圾回收以确保句柄释放 import gc gc.collect() # 删除 handle 引用,让 ctypes 释放库 handle = self._handle self._handle = None # Windows 特定:强制释放库句柄 if os.name == 'nt': try: import ctypes kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) # 获取模块句柄并释放 hmodule = ctypes.c_void_p(handle._handle) if hmodule: kernel32.FreeLibrary(hmodule) except Exception: pass # 忽略释放错误 # 删除 handle 对象 del handle # 再次强制垃圾回收 gc.collect()
[文档] def get_name(self) -> str: """获取算子名称""" if self._get_name_func: result = self._get_name_func() if result: return result.decode('utf-8') return ""
[文档] def get_base_class(self) -> str: """获取基类名称""" if self._get_base_class_func: result = self._get_base_class_func() if result: return result.decode('utf-8') return "BaseOperator"
[文档] def apply(self, ptr: int, state_cpp_ptr: int): """ 应用算子到 SparseState Args: ptr: 算子对象地址 state_cpp_ptr: C++ SparseState* 指针(通过 state._cpp_ptr() 获取) """ if self._apply_func and ptr and state_cpp_ptr: self._apply_func(ptr, state_cpp_ptr)
[文档] def apply_dag(self, ptr: int, state_cpp_ptr: int): """ 应用 dagger 到 SparseState Args: ptr: 算子对象地址 state_cpp_ptr: C++ SparseState* 指针(通过 state._cpp_ptr() 获取) """ if self._apply_dag_func and ptr and state_cpp_ptr: self._apply_dag_func(ptr, state_cpp_ptr)
[文档] def create_operator_class( name: str, lib_path: str, base_class: str = "BaseOperator", constructor_args: List[Tuple[str, str]] = None ) -> Type: """ 创建动态算子 Python 类 Args: name: 算子类名 lib_path: 动态库路径 base_class: 基类名 ("BaseOperator" 或 "SelfAdjointOperator") constructor_args: 构造函数参数列表 [(type, name), ...] Returns: 动态创建的算子类 """ constructor_args = constructor_args or [] # 创建 C++ 包装器 wrapper = CppOperatorWrapper(lib_path) arg_types = [arg[0] for arg in constructor_args] wrapper.load(arg_types) # 验证基类 detected_base = wrapper.get_base_class() if detected_base and detected_base != base_class: import warnings warnings.warn(f"检测到基类为 {detected_base},但指定为 {base_class}") base_class = detected_base def custom_init(self, **kwargs): """ 动态算子构造函数 Args: **kwargs: 构造函数参数(按名称传参) """ # 收集参数值 args = [] for arg_type, arg_name in constructor_args: if arg_name not in kwargs: raise TypeError(f"缺少必需参数: {arg_name}") args.append(kwargs[arg_name]) # 存储参数用于 dag self._args = tuple(args) # 将 wrapper 和 base_class 存储在类上(而非实例上), # 这样所有实例共享同一个 wrapper,避免 __del__ 被调用时误关动态库。 self._wrapper = DynamicOpClass._wrapper self._base_class = DynamicOpClass._base_class self._instance_id = _register_instance(self) # 创建 C++ 算子实例 self._cpp_ptr = self._wrapper.create(*args) def call_method(self, state): """ 调用算子 Args: state: SparseState 对象 Returns: 返回输入状态(支持链式调用) """ if not self._cpp_ptr: raise RuntimeError("算子未初始化或已销毁") # 通过 state._cpp_ptr()(在 pysparq._core.SparseState 中暴露)获取 C++ SparseState* 指针 state_cpp_ptr = state._cpp_ptr() self._wrapper.apply(self._cpp_ptr, state_cpp_ptr) return state def dag_method(self, state): """ 调用 dagger 操作 Args: state: SparseState 对象 Returns: 返回输入状态 """ if not self._cpp_ptr: raise RuntimeError("算子未初始化或已销毁") state_cpp_ptr = state._cpp_ptr() if base_class == "SelfAdjointOperator": # 自伴算子 dagger 等于自身 self._wrapper.apply(self._cpp_ptr, state_cpp_ptr) else: # BaseOperator 使用 dagger 辅助函数 self._wrapper.apply_dag(self._cpp_ptr, state_cpp_ptr) return state def repr_method(self) -> str: """字符串表示""" arg_str = ", ".join(f"{arg_name}={repr(val)}" for (arg_type, arg_name), val in zip( constructor_args, self._args )) return f"{name}({arg_str})" def del_method(self): """析构函数""" if hasattr(self, '_cpp_ptr') and self._cpp_ptr: self._wrapper.destroy(self._cpp_ptr) self._cpp_ptr = 0 if hasattr(self, '_instance_id'): _unregister_instance(self._instance_id) # 注意:不要在这里关闭 wrapper,因为 wrapper 由类级别管理,跨所有实例共享。 # 动态库的关闭由 cleanup_all_instances() 或显式调用负责。 # 创建动态类 DynamicOpClass = type( name, (object,), { '__init__': custom_init, '__call__': call_method, 'dag': dag_method, '__repr__': repr_method, '__del__': del_method, '_is_dynamic_operator': True, '_base_class': base_class, '_lib_path': lib_path, } ) # 将 wrapper 和 base_class 存储在类级别(跨实例共享) DynamicOpClass._wrapper = wrapper DynamicOpClass._base_class = base_class # 添加文档字符串 arg_docs = "\n".join(f" {arg_name} ({arg_type})" for arg_type, arg_name in constructor_args) if constructor_args else " (无)" DynamicOpClass.__doc__ = f""" 动态生成的算子类: {name} 基类: {base_class} 构造函数参数: {arg_docs} 使用示例: >>> op = {name}({', '.join(f"{arg_name}=..." for _, arg_name in constructor_args) if constructor_args else ''}) >>> state = op(state) """ return DynamicOpClass
[文档] def cleanup_all_instances(): """清理所有活跃的动态算子实例""" import gc # 清理仍然存在的实例 for instance_id, ref in list(_active_instances.items()): instance = ref() if instance is not None: try: # 先销毁 C++ 对象 if hasattr(instance, '_cpp_ptr') and instance._cpp_ptr: instance._wrapper.destroy(instance._cpp_ptr) instance._cpp_ptr = 0 # 关闭动态库句柄 if hasattr(instance, '_wrapper'): instance._wrapper.close() except: pass _active_instances.clear() # 强制垃圾回收确保资源释放 gc.collect()