Source code for uniqc.task.store

"""SQLite-backed task storage for UnifiedQuantum.

This module is the single source of truth for task persistence. Both
:mod:`uniqc.task_manager` and :mod:`uniqc.task.persistence` delegate to
:class:`TaskStore`; there is no JSON/JSONL fallback path.

Storage layout::

    ~/.uniqc/cache/
        tasks.sqlite          # single database file

Schema versioning (best practices)
----------------------------------

The database uses SQLite's built-in metadata slots in the file header
rather than a custom bookkeeping table:

- ``PRAGMA application_id`` holds :data:`APPLICATION_ID` (the 4 bytes
  ``"UNIC"``). On open we refuse to touch a file whose ``application_id``
  is set to something else; a zero value means a fresh DB we can stamp.
- ``PRAGMA user_version`` holds the integer schema version. It is bumped
  by each migration and used to decide what else to run.

:data:`MIGRATIONS` is an ordered list of ``(target_version, migrate_fn)``
tuples. To evolve the schema:

1. Append a new tuple ``(N, _apply_vN)`` where ``N = CURRENT_SCHEMA_VERSION + 1``.
2. Bump :data:`CURRENT_SCHEMA_VERSION` to ``N``.
3. Implement ``_apply_vN(conn)`` (DDL or DML on the given connection).

On open, :class:`TaskStore` iterates the list and runs each outstanding
migration in its own transaction; ``user_version`` is stamped as part of
the same transaction so a crashed migration rolls back the version bump
too.
"""

from __future__ import annotations

__all__ = [
    "APPLICATION_ID",
    "CURRENT_SCHEMA_VERSION",
    "DEFAULT_CACHE_DIR",
    "MIGRATIONS",
    "TERMINAL_STATUSES",
    "TaskInfo",
    "TaskStatus",
    "TaskStore",
]

import json
import sqlite3
import threading
from contextlib import contextmanager
from dataclasses import asdict, dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Iterator

# ---------------------------------------------------------------------------
# Paths & constants
# ---------------------------------------------------------------------------

DEFAULT_CACHE_DIR: Path = Path.home() / ".uniqc" / "cache"
DB_FILENAME: str = "tasks.sqlite"

# Four ASCII bytes 'UNIC' (short for uniqc -- can't fit all five letters in
# 32 bits), stored in the SQLite header so we can tell "this is our cache"
# from "this is somebody else's random tasks.sqlite".
APPLICATION_ID: int = 0x554E4943  # b"UNIC"

CURRENT_SCHEMA_VERSION: int = 1

TERMINAL_STATUSES: tuple[str, ...] = ("success", "failed", "cancelled")


# ---------------------------------------------------------------------------
# Task types
# ---------------------------------------------------------------------------

[docs] class TaskStatus(str, Enum): """Enumeration of task statuses.""" PENDING = "pending" RUNNING = "running" SUCCESS = "success" FAILED = "failed" CANCELLED = "cancelled"
[docs] @dataclass class TaskInfo: """Information about a submitted task. Attributes: task_id: Unique identifier for the task. backend: The backend where the task was submitted. status: Current status of the task. result: Task result (if completed). shots: Number of shots requested. submit_time: ISO format timestamp of submission. update_time: ISO format timestamp of last status update. metadata: Additional metadata about the task. """ task_id: str backend: str status: str = TaskStatus.PENDING result: dict | None = None shots: int = 1000 submit_time: str = field(default_factory=lambda: datetime.now().isoformat()) update_time: str = field(default_factory=lambda: datetime.now().isoformat()) metadata: dict = field(default_factory=dict)
[docs] def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" return asdict(self)
[docs] @classmethod def from_dict(cls, data: dict[str, Any]) -> "TaskInfo": """Create from dictionary.""" return cls(**data)
# --------------------------------------------------------------------------- # Schema DDL + migrations # --------------------------------------------------------------------------- _TASKS_DDL = """ CREATE TABLE IF NOT EXISTS tasks ( task_id TEXT PRIMARY KEY, backend TEXT NOT NULL, status TEXT NOT NULL, shots INTEGER NOT NULL DEFAULT 0, submit_time TEXT NOT NULL, update_time TEXT NOT NULL, result_json TEXT, metadata_json TEXT NOT NULL DEFAULT '{}' ) """ _INDEX_DDL: tuple[str, ...] = ( "CREATE INDEX IF NOT EXISTS tasks_backend_idx ON tasks(backend)", "CREATE INDEX IF NOT EXISTS tasks_status_idx ON tasks(status)", "CREATE INDEX IF NOT EXISTS tasks_submit_time_idx ON tasks(submit_time)", ) def _apply_v1(conn: sqlite3.Connection) -> None: """Initial schema: ``tasks`` table + indices.""" conn.execute(_TASKS_DDL) for ddl in _INDEX_DDL: conn.execute(ddl) MigrateFn = Callable[[sqlite3.Connection], None] # Ordered list of ``(target_version, migrate_fn)``. Append future migrations # here; never reorder or delete past entries. MIGRATIONS: list[tuple[int, MigrateFn]] = [ (1, _apply_v1), ] # --------------------------------------------------------------------------- # Header metadata + status helpers # --------------------------------------------------------------------------- def _normalize_status(status: Any) -> str: """Return the canonical string value for a status. Accepts either a plain string or a ``TaskStatus`` enum member. The enum inherits from ``str`` but its ``__str__`` returns the repr form (``"TaskStatus.RUNNING"``), which is not what we want in the DB. ``status.value`` yields the literal. """ if isinstance(status, TaskStatus): return status.value return str(status) def _get_user_version(conn: sqlite3.Connection) -> int: return int(conn.execute("PRAGMA user_version").fetchone()[0]) def _set_user_version(conn: sqlite3.Connection, version: int) -> None: # PRAGMA values cannot be parameter-bound, so interpolate an int. conn.execute(f"PRAGMA user_version = {int(version)}") def _get_application_id(conn: sqlite3.Connection) -> int: return int(conn.execute("PRAGMA application_id").fetchone()[0]) def _set_application_id(conn: sqlite3.Connection, app_id: int) -> None: conn.execute(f"PRAGMA application_id = {int(app_id)}") # --------------------------------------------------------------------------- # Row <-> TaskInfo # --------------------------------------------------------------------------- def _row_to_info(row: sqlite3.Row) -> TaskInfo: result = json.loads(row["result_json"]) if row["result_json"] else None metadata = json.loads(row["metadata_json"]) if row["metadata_json"] else {} return TaskInfo( task_id=row["task_id"], backend=row["backend"], status=row["status"], shots=int(row["shots"]), submit_time=row["submit_time"], update_time=row["update_time"], result=result, metadata=metadata, ) # --------------------------------------------------------------------------- # TaskStore # ---------------------------------------------------------------------------
[docs] class TaskStore: """SQLite-backed storage for task records. ``TaskStore`` is safe to construct multiple times against the same cache directory; all operations use short-lived connections guarded by a per-instance lock. Args: cache_dir: Directory that holds ``tasks.sqlite``. Defaults to ``~/.uniqc/cache/``. The directory is created if missing. """ def __init__(self, cache_dir: Path | str | None = None) -> None: self.cache_dir: Path = Path(cache_dir) if cache_dir else DEFAULT_CACHE_DIR self.cache_dir.mkdir(parents=True, exist_ok=True) self.db_path: Path = self.cache_dir / DB_FILENAME self._lock = threading.RLock() self._init_schema() # -- connection / transaction helpers ----------------------------------- def _connect(self) -> sqlite3.Connection: conn = sqlite3.connect(str(self.db_path)) conn.row_factory = sqlite3.Row # WAL gives us concurrent readers and resilience across crashes. conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA foreign_keys=ON") return conn @contextmanager def _tx(self) -> Iterator[sqlite3.Connection]: """Yield a connection inside a committed transaction.""" with self._lock: conn = self._connect() try: yield conn conn.commit() finally: conn.close() # -- schema init + migrations ------------------------------------------- def _init_schema(self) -> None: # Open one connection for the whole init dance. PRAGMAs on journal # mode etc. need to be outside a transaction. with self._lock: conn = self._connect() try: self._stamp_or_check_application_id(conn) current = _get_user_version(conn) for target, migrate_fn in MIGRATIONS: if current < target: with conn: # BEGIN / COMMIT (or ROLLBACK on error) migrate_fn(conn) _set_user_version(conn, target) current = target finally: conn.close() @staticmethod def _stamp_or_check_application_id(conn: sqlite3.Connection) -> None: """Stamp ``APPLICATION_ID`` on a fresh DB; refuse foreign ones.""" existing = _get_application_id(conn) if existing == 0: _set_application_id(conn, APPLICATION_ID) elif existing != APPLICATION_ID: raise RuntimeError( "SQLite file has application_id=" f"0x{existing:08x}, expected 0x{APPLICATION_ID:08x} (UNIC). " "Refusing to operate on an unknown database." ) # -- CRUD ---------------------------------------------------------------
[docs] def save(self, task_info: TaskInfo) -> None: """Insert or update a task record (by ``task_id``). Also stamps ``task_info.update_time`` with the current timestamp so callers see the same value that was persisted. """ now = datetime.now().isoformat() task_info.update_time = now result_json = json.dumps(task_info.result) if task_info.result is not None else None metadata_json = json.dumps(task_info.metadata or {}, ensure_ascii=False) # ``TaskStatus`` inherits from ``str`` so passing it through binds the # underlying enum value (e.g. "running"); calling ``str()`` instead # would store the ``"TaskStatus.RUNNING"`` repr, which is wrong. params = { "task_id": task_info.task_id, "backend": task_info.backend, "status": _normalize_status(task_info.status), "shots": int(task_info.shots), "submit_time": task_info.submit_time, "update_time": now, "result_json": result_json, "metadata_json": metadata_json, } with self._tx() as conn: conn.execute( """ INSERT INTO tasks (task_id, backend, status, shots, submit_time, update_time, result_json, metadata_json) VALUES (:task_id, :backend, :status, :shots, :submit_time, :update_time, :result_json, :metadata_json) ON CONFLICT(task_id) DO UPDATE SET backend = excluded.backend, status = excluded.status, shots = excluded.shots, update_time = excluded.update_time, result_json = excluded.result_json, metadata_json = excluded.metadata_json """, params, )
[docs] def get(self, task_id: str) -> TaskInfo | None: """Return a task by id, or ``None`` if missing.""" with self._tx() as conn: row = conn.execute( "SELECT * FROM tasks WHERE task_id = ?", (task_id,) ).fetchone() return _row_to_info(row) if row is not None else None
[docs] def list( self, *, status: str | None = None, backend: str | None = None, limit: int | None = None, ) -> list[TaskInfo]: """List tasks, newest first (by ``submit_time``).""" conditions: list[str] = [] params: list[Any] = [] if status is not None: conditions.append("status = ?") params.append(_normalize_status(status)) if backend is not None: conditions.append("backend = ?") params.append(backend) sql = "SELECT * FROM tasks" if conditions: sql += " WHERE " + " AND ".join(conditions) sql += " ORDER BY submit_time DESC, rowid DESC" if limit is not None: sql += " LIMIT ?" params.append(int(limit)) with self._tx() as conn: rows = conn.execute(sql, params).fetchall() return [_row_to_info(row) for row in rows]
[docs] def count(self, *, status: str | None = None, backend: str | None = None) -> int: conditions: list[str] = [] params: list[Any] = [] if status is not None: conditions.append("status = ?") params.append(_normalize_status(status)) if backend is not None: conditions.append("backend = ?") params.append(backend) sql = "SELECT COUNT(*) AS c FROM tasks" if conditions: sql += " WHERE " + " AND ".join(conditions) with self._tx() as conn: row = conn.execute(sql, params).fetchone() return int(row["c"]) if row is not None else 0
[docs] def delete(self, task_id: str) -> bool: """Remove a single task. Returns ``True`` if it existed.""" with self._tx() as conn: cur = conn.execute("DELETE FROM tasks WHERE task_id = ?", (task_id,)) return cur.rowcount > 0
[docs] def clear_completed( self, terminal_statuses: tuple[str, ...] = TERMINAL_STATUSES ) -> int: """Delete tasks whose status is in ``terminal_statuses``.""" placeholders = ",".join("?" for _ in terminal_statuses) with self._tx() as conn: cur = conn.execute( f"DELETE FROM tasks WHERE status IN ({placeholders})", [_normalize_status(s) for s in terminal_statuses], ) return cur.rowcount
[docs] def clear_all(self) -> None: """Delete the on-disk SQLite file (and its WAL/SHM siblings). Matches the pre-SQLite ``clear_cache`` semantics: the cache file disappears. A subsequent operation re-initialises the schema. """ with self._lock: for path in ( self.db_path, self.db_path.with_name(self.db_path.name + "-wal"), self.db_path.with_name(self.db_path.name + "-shm"), ): if path.exists(): try: path.unlink() except OSError: pass