pyqres.dsl.codegen 源代码

"""
Code Generator for Quantum-Resource-Estimator DSL (composite-only).

Generates Python Composite subclasses from YAML composite operation definitions.
Handles loops, conditionals, and computed parameters.
"""

from typing import Dict, List, Any, Optional, Set
from dataclasses import dataclass, field
from pathlib import Path


[文档] @dataclass class GeneratedClass: """Container for a generated Python class.""" name: str base_class: str # Composite, StandardComposite, AbstractComposite imports: List[str] class_code: str dependencies: List[str] # Other operations referenced
[文档] class CodeGenerator: """ Generate Python Composite subclasses from YAML composite definitions. Handles: - $name variable substitution - Controllers (all_ones, nonzero, bit, value) - Dagger modifiers - Temporary registers lifecycle - Computed parameters - traverse_override for special patterns - Loops (loop, loop_reverse) - Conditionals (if) - Comments """ STANDARD_IMPORTS = [ "from ..core.operation import {base_class}", "from ..core.registry import OperationRegistry", "from ..core.utils import merge_controllers", "import math", ]
[文档] def generate(self, definition: Dict[str, Any]) -> GeneratedClass: """Generate a Python class from a YAML composite definition.""" return self._generate_composite(definition)
def _generate_composite(self, defn: Dict[str, Any]) -> GeneratedClass: """Generate a Composite subclass for a composed operation.""" name = defn["name"] has_custom_sum = defn.get("sum_t_count_formula") == "custom" base_class = "AbstractComposite" if has_custom_sum else "StandardComposite" # Store param type map for for_each range() vs direct-iteration decisions self._param_type_map = {p["name"]: p.get("type", "int") for p in defn.get("params", [])} # Store declared operation params for symbol lookup # When op name matches a declared 'type: operation' param, generate self.{name}() instead of OperationRegistry.get_class() self._declared_op_params = { p["name"] for p in defn.get("params", []) if isinstance(p, dict) and p.get("type") == "operation" } imports = self._generate_imports(base_class, defn) # Extract dependencies from impl dependencies = [] for call in defn.get("impl", []): if "op" in call: dependencies.append(call["op"]) # Build class code lines = [] lines.append(f"class {name}({base_class}):") # Self-conjugate attribute if defn.get("self_conjugate", False): lines.append(" __self_conjugate__ = True") # Docstring desc = defn.get("description", f"{name} composite operation") lines.append(f' """{desc}"""') # __init__ with program_list init_code = self._generate_init(defn, base_class) lines.extend(init_code) # Check if impl has complex structures (loops/conditionals) impl = defn.get("impl", []) has_complex_impl = any(self._is_complex_impl_item(item) for item in impl) # For complex implementations, generate the execute method if has_complex_impl: body_code = self._generate_class_body(defn) lines.extend(body_code) # Optional: control_children override control_override = defn.get("control_override") if control_override: control_code = self._generate_control_override(control_override) lines.extend(control_code) class_code = "\n".join(lines) return GeneratedClass( name=name, base_class=base_class, imports=imports, class_code=class_code, dependencies=dependencies ) def _generate_imports(self, base_class: str, defn: Dict[str, Any] = None) -> List[str]: """Generate import statements for the generated file. Args: base_class: Base class (StandardComposite, AbstractComposite) defn: YAML definition dict (used to detect needed imports) """ imports = [imp.format(base_class=base_class) for imp in self.STANDARD_IMPORTS] # Add pysparq import if any param uses qram type if defn: for param in defn.get("params", []): if isinstance(param, dict) and param.get("type") == "qram": if "import pysparq as ps" not in imports: imports.append("import pysparq as ps") break # Add callable imports based on referenced function names for param in defn.get("params", []): if isinstance(param, dict) and param.get("type") == "callable": func_name = param.get("name", "") if func_name in ("make_func", "make_func_inv"): if "from ..algorithms.state_prep import make_func, make_func_inv" not in imports: imports.append("from ..algorithms.state_prep import make_func, make_func_inv") # Extract imports from python: blocks in impl (including nested loop.body) seen = set(imports) for item in self._flatten_impl_items(defn.get("impl", [])): if isinstance(item, dict) and "python" in item: for line in item["python"].strip().split("\n"): stripped = line.strip() if stripped.startswith("from ") or stripped.startswith("import "): if stripped not in seen: imports.append(stripped) seen.add(stripped) return imports def _generate_init(self, defn: Dict[str, Any], base_class: str = "StandardComposite") -> List[str]: """Generate __init__ for a composite operation.""" lines = [] qregs = defn.get("qregs", []) params = defn.get("params", []) temp_regs = defn.get("temp_regs", []) computed_params = defn.get("computed_params", []) has_params = len(params) > 0 has_temp = len(temp_regs) > 0 # Check if impl has any loops or conditionals (requires special handling) impl = defn.get("impl", []) has_complex_impl = any(self._is_complex_impl_item(item) for item in impl) # Build signature sig_parts = ["reg_list"] if has_params: sig_parts.append("param_list=None") if has_temp: temp_default = "[" + ", ".join( f"('{r['name']}', {r['size']})" for r in temp_regs ) + "]" sig_parts.append(f"temp_reg_list={temp_default}") sig_parts.append("operations=None") lines.append(f" def __init__(self, {', '.join(sig_parts)}):") # Handle optional param_list if has_params: lines.append(" if param_list is None:") lines.append(" param_list = []") # super().__init__ call using base_class directly super_args = ["self", "reg_list=reg_list"] if has_params: super_args.append("param_list=param_list") if has_temp: super_args.append("temp_reg_list=temp_reg_list") super_args.append("operations=operations") lines.append(f" {base_class}.__init__({', '.join(super_args)})") # Register attributes for i, qr in enumerate(qregs): lines.append(f" self.{qr['name']} = reg_list[{i}]") # Param attributes (skip operation params — those are stored from operations list instead) if has_params: for i, p in enumerate(params): if isinstance(p, dict) and p.get("type") == "operation": continue # stored from operations list below lines.append(f" self.{p['name']} = param_list[{i}]") # Computed params - use self. prefix for self references # Only replace exact word matches for parameter names (not substrings) param_names_set = {p["name"] for p in params} computed_names_set = {cp["name"] for cp in computed_params} import re for idx, cp in enumerate(computed_params): formula = cp.get("formula", "") # Replace param names first for pname in param_names_set: pattern = r'\b' + re.escape(pname) + r'\b' formula = re.sub(pattern, f"self.{pname}", formula) # Replace earlier computed param names (only those defined before this one) for prev_cp in computed_params[:idx]: cpname = prev_cp["name"] pattern = r'\b' + re.escape(cpname) + r'\b' formula = re.sub(pattern, f"self.{cpname}", formula) lines.append(f" self.{cp['name']} = {formula.strip()}") # Temp register attributes (stored as dict for easy lookup) if has_temp: lines.append(" # Store temp registers as instance attributes") lines.append(" self._temp_reg_dict = {}") for temp_reg in temp_regs: name = temp_reg['name'] size = temp_reg['size'] lines.append(f" self._temp_reg_dict['{name}'] = ('{name}', {size})") lines.append(f" self.{name} = '{name}'") # Store operation params (type: operation) as instance attributes from the operations list op_names = [ p["name"] for p in params if isinstance(p, dict) and p.get("type") == "operation" ] if op_names: lines.append(" if operations is None:") lines.append(" operations = []") for i, name in enumerate(op_names): lines.append(f" self.{name} = operations[{i}] if {i} < len(operations) else None") # Build program_list (only for simple implementations without loops/conditionals) if impl and not has_complex_impl: lines.append(" self.program_list = [") for call in impl: item_code = self._generate_impl_item(call) lines.append(f" {item_code},") lines.append(" ]") lines.append(" self.declare_program_list()") elif has_complex_impl: # For complex implementations, store the impl structure and generate execute method lines.append(" # Complex implementation with loops/conditionals") lines.append(" self._impl_structure = " + self._serialize_impl(impl)) lines.append(" self._build_execute_method()") return lines def _flatten_impl_items(self, items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Recursively flatten impl items, including those nested in loop.body/for_each.body.""" flat = [] for item in items: if isinstance(item, dict): if "loop" in item: flat.extend(self._flatten_impl_items(item.get("loop", {}).get("body", []))) elif "loop_reverse" in item: flat.extend(self._flatten_impl_items(item.get("loop_reverse", {}).get("body", []))) elif "for_each" in item: flat.extend(self._flatten_impl_items(item.get("for_each", {}).get("body", []))) elif "if" in item: flat.extend(self._flatten_impl_items(item.get("if", {}).get("then", []))) flat.extend(self._flatten_impl_items(item.get("if", {}).get("else", []))) else: flat.append(item) return flat def _is_complex_impl_item(self, item: Dict[str, Any]) -> bool: """Check if an impl item requires complex handling (loops, conditionals, etc.).""" return any(k in item for k in ("loop", "loop_reverse", "if", "comment", "for_each", "python")) def _serialize_impl(self, impl: List[Dict[str, Any]]) -> str: """Serialize the impl structure for storage as Python code.""" import json # Convert to JSON-serializable form serializable = self._make_serializable(impl) # Convert JSON booleans to Python booleans for valid Python syntax json_str = json.dumps(serializable) return json_str.replace('true', 'True').replace('false', 'False') def _make_serializable(self, items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Convert impl items to JSON-serializable format.""" result = [] for item in items: if "comment" in item and "op" not in item: result.append({"_type": "comment", "text": item["comment"]}) elif "loop" in item: result.append({"_type": "loop", "iterations": item["loop"]["iterations"], "body": self._make_serializable(item["loop"]["body"])}) elif "loop_reverse" in item: result.append({"_type": "loop_reverse", "iterations": item["loop_reverse"]["iterations"], "body": self._make_serializable(item["loop_reverse"]["body"])}) elif "for_each" in item: result.append({ "_type": "for_each", "var": item["for_each"]["var"], "items": item["for_each"]["items"], "body": self._make_serializable(item["for_each"]["body"]) }) elif "if" in item: if_data = {"_type": "if", "condition": item["if"]["condition"], "body": self._make_serializable(item["if"]["body"])} if "else" in item["if"]: if_data["else"] = self._make_serializable(item["if"]["else"]) if "elif" in item["if"]: if_data["elif"] = [{"condition": e["condition"], "body": self._make_serializable(e["body"])} for e in item["if"]["elif"]] result.append(if_data) elif "python" in item: result.append({"_type": "python", "code": item["python"]}) elif "op" in item: new_item = {"_type": "op", "op": item["op"]} if "qregs" in item: new_item["qregs"] = item["qregs"] if "params" in item: new_item["params"] = item["params"] if "dagger" in item: new_item["dagger"] = item["dagger"] if "controllers" in item: new_item["controllers"] = item["controllers"] result.append(new_item) return result def _generate_class_body(self, defn: Dict[str, Any]) -> List[str]: """Generate the class body with execute method for complex implementations.""" impl = defn.get("impl", []) params = defn.get("params", []) computed_params = defn.get("computed_params", []) param_names = [p["name"] for p in params] + [p["name"] for p in computed_params] lines = [] # Generate _build_execute_method lines.append("") lines.append(" def _build_execute_method(self):") lines.append(" # Build program_list by expanding loops and conditionals") lines.append(" self.program_list = []") for call in impl: # Skip top-level python: blocks that contain only import statements # (those are handled by _generate_imports and don't need to be in the method body) if isinstance(call, dict) and call.get("python") is not None: code = call["python"] non_import_lines = [l.strip() for l in code.strip().split("\n") if l.strip() and not l.strip().startswith(("from ", "import "))] if not non_import_lines: continue self._add_impl_lines(lines, call, " ", depth=1, param_names=param_names) lines.append(" self.declare_program_list()") return lines def _add_impl_lines(self, lines: List[str], call: Dict[str, Any], indent: str, depth: int, param_names: List[str] = None, local_vars: Dict[str, str] = None): """Add implementation lines recursively.""" prefix = indent * depth # Handle standalone comments (not inside an op) if "comment" in call and "op" not in call and len(call) == 1: lines.append(f"{prefix}# {call['comment']}") return # Handle python code blocks - insert code directly if "python" in call: python_code = call["python"] # Handle multi-line code, substituting $self references code_lines = python_code.strip().split("\n") for code_line in code_lines: resolved = self._resolve_python_block_line(code_line, param_names, local_vars) lines.append(f"{prefix}{resolved}") return if "loop" in call: iterations = call["loop"]["iterations"] body = call["loop"]["body"] iter_expr = self._resolve_loop_expr(iterations, param_names, local_vars) lines.append(f"{prefix}for i in range({iter_expr}):") for item in body: self._add_impl_lines(lines, item, indent, depth + 1, param_names, local_vars) return if "loop_reverse" in call: iterations = call["loop_reverse"]["iterations"] body = call["loop_reverse"]["body"] iter_expr = self._resolve_loop_expr(iterations, param_names, local_vars) lines.append(f"{prefix}for i in range({iter_expr} - 1, -1, -1):") for item in body: self._add_impl_lines(lines, item, indent, depth + 1, param_names, local_vars) return if "for_each" in call: for_each_def = call["for_each"] var_name = for_each_def["var"] items = for_each_def["items"] body = for_each_def["body"] items_expr = self._resolve_for_each_items_expr(items, param_names, local_vars) lines.append(f"{prefix}for {var_name} in {items_expr}:") # Create local vars for $var substitution new_local_vars = dict(local_vars or {}) new_local_vars[var_name] = var_name # Use variable directly (no $ prefix in generated code) for item in body: self._add_impl_lines(lines, item, indent, depth + 1, param_names, new_local_vars) return if "if" in call: if_def = call["if"] condition = self._resolve_expr(if_def["condition"], param_names, local_vars) body = if_def["body"] lines.append(f"{prefix}if {condition}:") for item in body: self._add_impl_lines(lines, item, indent, depth + 1, param_names, local_vars) # Handle elif if "elif" in if_def: for elif_def in if_def["elif"]: elif_condition = self._resolve_expr(elif_def["condition"], param_names, local_vars) elif_body = elif_def["body"] lines.append(f"{prefix}elif {elif_condition}:") for item in elif_body: self._add_impl_lines(lines, item, indent, depth + 1, param_names, local_vars) # Handle else if "else" in if_def: else_body = if_def["else"] lines.append(f"{prefix}else:") for item in else_body: self._add_impl_lines(lines, item, indent, depth + 1, param_names, local_vars) return # Regular operation op_name = call.get("op", "") args = [] qregs_refs = call.get("qregs", []) if qregs_refs: qregs_str = ", ".join(f"self.{r}" for r in qregs_refs) args.append(f"reg_list=[{qregs_str}]") params_refs = call.get("params", []) if params_refs: resolved_params = [] for p in params_refs: resolved_params.append(self._resolve_param_ref(p, param_names, local_vars)) params_str = ", ".join(resolved_params) args.append(f"param_list=[{params_str}]") # Symbol lookup: if op_name is a declared 'type: operation' param, use self.{name}() # Otherwise, use OperationRegistry.get_class() for registered operations if self._declared_op_params and op_name in self._declared_op_params: base = f'self.{op_name}({", ".join(args)})' else: base = f'OperationRegistry.get_class("{op_name}")({", ".join(args)})' if call.get("dagger"): base += ".dagger()" controllers = call.get("controllers", {}) if controllers: base += self._generate_controller_chain(controllers) lines.append(f"{prefix}self.program_list.append({base})") def _generate_impl_item(self, call: Dict[str, Any], local_vars: Dict[str, Any] = None) -> str: """Generate a single item in program_list. Args: call: The operation call dictionary local_vars: Local variables (like $iteration) for template substitution """ # Handle comments (skip) if "comment" in call and "op" not in call: return "# " + call["comment"] # Handle loops if "loop" in call: return self._generate_loop(call["loop"], local_vars) if "loop_reverse" in call: return self._generate_loop(call["loop_reverse"], local_vars, reverse=True) # Handle for_each if "for_each" in call: return self._generate_for_each(call["for_each"], local_vars) # Handle conditionals if "if" in call: return self._generate_conditional(call["if"], local_vars) op_name = call.get("op", "") # Build: OperationRegistry.get_class("OpName")(args) args = [] # reg_list qregs_refs = call.get("qregs", []) if qregs_refs: qregs_str = ", ".join(f"self.{self._substitute(r, local_vars)}" for r in qregs_refs) args.append(f"reg_list=[{qregs_str}]") # param_list params_refs = call.get("params", []) if params_refs: params_str = ", ".join(self._resolve_param_ref(p, local_vars=local_vars) for p in params_refs) args.append(f"param_list=[{params_str}]") # temp_out for operations with temporary output registers temp_out = call.get("temp_out") if temp_out: args.append(f"# temp_out: {temp_out}") # Symbol lookup: if op_name is a declared 'type: operation' param, use self.{name}() # Otherwise, use OperationRegistry.get_class() for registered operations if self._declared_op_params and op_name in self._declared_op_params: base = f'self.{op_name}({", ".join(args)})' else: base = f'OperationRegistry.get_class("{op_name}")({", ".join(args)})' # dagger if call.get("dagger"): base += ".dagger()" # controllers controllers = call.get("controllers", {}) if controllers: base += self._generate_controller_chain(controllers) return base def _substitute(self, value: str, local_vars: Dict[str, Any] = None) -> str: """Substitute template variables like $iteration.""" if local_vars and isinstance(value, str): result = value for var_name, var_value in (local_vars or {}).items(): result = result.replace(f"${var_name}", str(var_value)) return result return value def _generate_loop(self, loop_def: Dict[str, Any], local_vars: Dict[str, Any] = None, reverse: bool = False) -> str: """Generate a for loop.""" iterations = loop_def.get("iterations") body = loop_def.get("body", []) # Handle numeric iterations directly (no self. prefix for numbers) if isinstance(iterations, (int, float)): iter_expr = str(iterations) elif isinstance(iterations, str): iter_expr = f"self.{iterations}" else: iter_expr = str(iterations) lines = [] iter_var = "i" if not reverse else "i_reverse" extra_vars = {"iteration": iter_var} lines.append(f"# Loop: {iterations} iterations") if reverse: lines.append(f"for {iter_var} in range({iter_expr} - 1, -1, -1):") else: lines.append(f"for {iter_var} in range({iter_expr}):") # Add body with iteration context indent = " " for item in body: merged_vars = dict(local_vars or {}) merged_vars.update(extra_vars) item_str = self._generate_impl_item(item, merged_vars) if item_str.startswith("#"): lines.append(f"{indent}{item_str}") else: lines.append(f"{indent}{item_str},") return "\n".join(lines) def _generate_conditional(self, cond_def: Dict[str, Any], local_vars: Dict[str, Any] = None) -> str: """Generate an if statement with optional else/elif.""" condition = cond_def.get("condition", "") body = cond_def.get("body", []) lines = [] lines.append(f"# Conditional: {condition}") lines.append(f"if {condition}:") indent = " " for item in body: item_str = self._generate_impl_item(item, local_vars) if item_str.startswith("#"): lines.append(f"{indent}{item_str}") else: lines.append(f"{indent}{item_str},") # Handle elif if "elif" in cond_def: for elif_def in cond_def["elif"]: elif_condition = elif_def.get("condition", "") elif_body = elif_def.get("body", []) lines.append(f"elif {elif_condition}:") for item in elif_body: item_str = self._generate_impl_item(item, local_vars) if item_str.startswith("#"): lines.append(f"{indent}{item_str}") else: lines.append(f"{indent}{item_str},") # Handle else if "else" in cond_def: else_body = cond_def["else"] lines.append("else:") for item in else_body: item_str = self._generate_impl_item(item, local_vars) if item_str.startswith("#"): lines.append(f"{indent}{item_str}") else: lines.append(f"{indent}{item_str},") return "\n".join(lines) def _generate_for_each(self, for_each_def: Dict[str, Any], local_vars: Dict[str, Any] = None) -> str: """Generate a for_each loop.""" var_name = for_each_def.get("var", "item") items = for_each_def.get("items", []) body = for_each_def.get("body", []) items_expr = self._resolve_for_each_items_expr(items, local_vars=local_vars) lines = [] lines.append(f"# For-each loop over {items_expr}") lines.append(f"for {var_name} in {items_expr}:") indent = " " for item in body: # Add iteration variable to local vars for $var substitution merged_vars = dict(local_vars or {}) merged_vars[var_name] = var_name item_str = self._generate_impl_item(item, merged_vars) if item_str.startswith("#"): lines.append(f"{indent}{item_str}") else: lines.append(f"{indent}{item_str},") return "\n".join(lines) def _resolve_param_ref(self, ref: Any, param_names: List[str] = None, local_vars: Dict[str, str] = None) -> str: """Resolve a parameter reference to Python code. Args: ref: Parameter value (could be string, number, list, or template) param_names: List of declared parameter names local_vars: Local variables for $var substitution (for_each iteration vars) """ if isinstance(ref, str): # Handle $var references from for_each - use the variable directly if ref.startswith("$") and local_vars: var_name = ref[1:] # Remove $ prefix if var_name in local_vars: return var_name # Use the iteration variable directly (e.g., "angle" not "self.angle") # Handle self. prefix - keep as-is if ref.startswith("self."): return ref # Template substitution for $name references result = ref if local_vars: for var_name, var_value in local_vars.items(): result = result.replace(f"${var_name}", str(var_value)) # If it's still a $ reference that wasn't resolved, it might be a for_each var if result.startswith("$") and local_vars: var_name = result[1:] if var_name in local_vars: return var_name # Check if it's a Python keyword if result in ("True", "False", "None"): return result # If it looks like a param reference, add self. prefix if param_names and result in param_names: return f"self.{result}" # Default: add self. prefix for backwards compatibility return f"self.{result}" elif isinstance(ref, (int, float)): return str(ref) elif isinstance(ref, list): items = ", ".join(self._resolve_param_ref(item, param_names, local_vars) for item in ref) return f"[{items}]" elif isinstance(ref, dict): # New complex param types ptype = ref.get("type") if ptype == "callable": # {"type": "callable", "name": "make_func"} # Returns the Python symbol name (imported separately by _generate_imports) return ref.get("name", "make_func") elif ptype == "op_instance": # {"type": "op_instance", "name": "GroverOracle", "args": [...]} op_name = ref.get("name", "") args = ref.get("args", []) if args: args_str = ", ".join( self._resolve_param_ref(a, param_names, local_vars) for a in args ) return f'OperationRegistry.get_class("{op_name}")({args_str})' return f'OperationRegistry.get_class("{op_name}")()' elif ptype == "qram": # {"type": "qram", "addr_size": int, "data_size": int, "memory": list} addr = ref.get("addr_size", 0) data = ref.get("data_size", 0) mem = ref.get("memory", []) mem_str = repr(mem) return f"ps.QRAMCircuit_qutrit({addr}, {data}, {mem_str})" elif ptype == "qram_ref": # {"type": "qram_ref", "name": "qram_param"} # Reference to a declared param that holds a QRAM circuit ref_name = ref.get("name", "") if param_names and ref_name in param_names: return f"self.{ref_name}" return ref_name elif ptype == "literal": return repr(ref.get("value")) elif ptype == "expr": return self._resolve_expr(str(ref.get("value", "")), param_names, local_vars) else: return str(ref) else: return str(ref) def _resolve_loop_expr(self, expr: Any, param_names: List[str] = None, local_vars: Dict[str, str] = None) -> str: """Resolve a loop iteration expression for range(...).""" if isinstance(expr, (int, float)): return str(expr) if isinstance(expr, dict) and expr.get("type") == "expr": return self._resolve_expr(str(expr.get("value", "")), param_names, local_vars) if isinstance(expr, str): if param_names and expr in param_names: return f"self.{expr}" return self._resolve_expr(expr, param_names, local_vars) return str(expr) def _resolve_for_each_items_expr(self, items: Any, param_names: List[str] = None, local_vars: Dict[str, str] = None) -> str: """Resolve a for_each items source expression.""" if isinstance(items, list): return repr(items) if isinstance(items, dict): ptype = items.get("type") if ptype == "literal": return repr(items.get("value")) if ptype == "expr": return self._resolve_expr(str(items.get("value", "")), param_names, local_vars) if isinstance(items, str): ptype = self._param_type_map.get(items, "int") if ptype in ("array", "list"): return f"self.{items}" return f"range(self.{items})" return str(items) def _resolve_expr(self, expr: str, param_names: List[str] = None, local_vars: Dict[str, str] = None) -> str: """Resolve a YAML expression into generated Python code.""" import re as _re # If expression already starts with "self.", it's already fully qualified, # so avoid adding another "self." prefix (prevents self.self.x) if expr.lstrip().startswith("self."): return expr result = expr local_names = set(local_vars or {}) if param_names: for pname in param_names: if pname in local_names: continue pattern = r'\b' + _re.escape(pname) + r'\b' result = _re.sub(pattern, f"self.{pname}", result) return result def _generate_controller_chain(self, controllers: Dict[str, Any]) -> str: """Generate controller method chain.""" chain = "" if "all_ones" in controllers: refs = controllers["all_ones"] refs_str = ", ".join(f"self.{r}" for r in refs) chain += f".control_by_all_ones([{refs_str}])" if "nonzero" in controllers: refs = controllers["nonzero"] refs_str = ", ".join(f"self.{r}" for r in refs) chain += f".control_by_nonzero([{refs_str}])" if "bit" in controllers: pairs = controllers["bit"] pairs_str = ", ".join(f"(self.{p[0]}, {p[1]})" for p in pairs) chain += f".control_by_bit([{pairs_str}])" if "value" in controllers: pairs = controllers["value"] pairs_str = ", ".join(f"(self.{p[0]}, {p[1]})" for p in pairs) chain += f".control_by_value([{pairs_str}])" return chain def _generate_control_override(self, override_type: str) -> List[str]: """Generate traverse_children override for special controller propagation patterns.""" lines = [] if override_type == "cnot_swap": # Swap operation's special control propagation # Only the middle CNOT receives controllers lines.append(" def traverse_children(self, visitor, dagger_ctx=False, controllers_ctx=None):") lines.append(" controllers_ctx = controllers_ctx or {}") lines.append(" controllers = merge_controllers(self.controllers, controllers_ctx)") lines.append(" self.program_list[0].traverse(visitor, False, {})") lines.append(" self.program_list[1].traverse(visitor, False, controllers)") lines.append(" self.program_list[2].traverse(visitor, False, {})") else: lines.append(" def traverse_children(self, visitor, dagger_ctx=False, controllers_ctx=None):") lines.append(f" # Control override type: {override_type}") lines.append(f" raise NotImplementedError(\"control override '{override_type}' not implemented\")") return lines def _resolve_python_block_line(self, code_line: str, param_names: List[str] = None, local_vars: Dict[str, str] = None) -> str: """Resolve $self.XXX references in a python block code line. Substitutes bare identifiers that match declared param names with self.XXX. Also handles $var references from for_each loops. """ import re as _re result = code_line # Substitute $var references (for_each loop variables) if local_vars: for var_name, var_value in local_vars.items(): result = result.replace(f"${var_name}", var_value) # Substitute bare param names with self.param_name (only in word boundaries) if param_names: for pname in param_names: # Match $pname as a word boundary token (not inside a string) # Simple approach: replace $pname with self.pname pattern = r'\$' + _re.escape(pname) + r'\b' result = _re.sub(pattern, f"self.{pname}", result) return result
[文档] def generate_file_content(self, gen_class: GeneratedClass) -> str: """Generate complete file content for a generated class.""" lines = [] lines.append("# Generated from YAML definition") lines.append("") for imp in gen_class.imports: lines.append(imp) lines.append("") lines.append(gen_class.class_code) return "\n".join(lines)
[文档] def generate_class(definition: Dict[str, Any]) -> GeneratedClass: """Convenience function to generate a class from definition.""" generator = CodeGenerator() return generator.generate(definition)