Source code for uniqc.gateway.api.tasks

"""Task listing and management API — /api/tasks."""

from __future__ import annotations

from datetime import datetime, timedelta, timezone
from typing import Any

from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field

from uniqc.backend_adapter.task.store import TERMINAL_STATUSES, TaskInfo, TaskStore
from uniqc.gateway.db.archive_store import ArchiveStore

router = APIRouter()


[docs] class TaskIdsRequest(BaseModel): task_ids: list[str] = Field(default_factory=list)
def _info_to_dict(t: TaskInfo) -> dict[str, Any]: return { "task_id": t.task_id, "backend": t.backend, "status": t.status, "shots": t.shots, "submit_time": t.submit_time, "update_time": t.update_time, "has_result": t.result is not None, "metadata": t.metadata, "archived_at": t.archived_at, }
[docs] @router.get("") def list_tasks( status: str | None = None, backend: str | None = None, limit: int | None = 100, offset: int | None = 0, ) -> dict[str, Any]: """List active (non-archived) tasks. Query params: status: filter by task status (pending/running/success/failed/cancelled) backend: filter by backend name limit: max rows returned (default 100) offset: number of matching rows to skip """ store = TaskStore() tasks = store.list(status=status, backend=backend, limit=limit, offset=offset) return { "tasks": [_info_to_dict(t) for t in tasks], "total": store.count(status=status, backend=backend), "limit": limit, "offset": offset or 0, }
[docs] @router.get("/counts") def task_counts() -> dict[str, int]: """Return count of active tasks grouped by status.""" store = TaskStore() statuses = ["pending", "running", "success", "failed", "cancelled"] result: dict[str, int] = {} for s in statuses: result[s] = store.count(status=s) result["total"] = sum(result.values()) return result
[docs] @router.get("/{task_id}") def get_task(task_id: str) -> dict[str, Any]: """Return full task details including result.""" store = TaskStore() task = store.get(task_id) if task is None: raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") return { "task_id": task.task_id, "backend": task.backend, "status": task.status, "shots": task.shots, "submit_time": task.submit_time, "update_time": task.update_time, "result": task.result, "metadata": task.metadata, "archived_at": task.archived_at, }
[docs] @router.get("/{task_id}/shards") def get_task_shards(task_id: str) -> dict[str, Any]: """Return the shard mapping for a uniqc task id. Each shard records a single submission to the underlying cloud platform; ``circuit_count`` indicates how many circuits the shard covers (>= 1 for native-batch platforms, == 1 otherwise). Result payloads are excluded — fetch ``GET /api/tasks/{task_id}`` for the aggregated result. """ store = TaskStore() if store.get(task_id) is None: raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") shards = store.get_shards(task_id) return { "task_id": task_id, "shard_count": len(shards), "shards": [ { "uniqc_task_id": s.uniqc_task_id, "shard_index": s.shard_index, "platform_task_id": s.platform_task_id, "backend": s.backend, "circuit_count": s.circuit_count, "sub_index_offset": s.sub_index_offset, "status": s.status, "error_message": s.error_message, "submit_time": s.submit_time, "update_time": s.update_time, } for s in shards ], }
[docs] @router.delete("/{task_id}") def delete_task(task_id: str) -> dict[str, str]: """Permanently delete a task.""" store = TaskStore() deleted = store.delete(task_id) if not deleted: raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") return {"deleted": task_id}
[docs] @router.post("/bulk-delete") def bulk_delete_tasks(payload: TaskIdsRequest) -> dict[str, Any]: """Permanently delete active tasks by id.""" store = TaskStore() deleted: list[str] = [] missing: list[str] = [] for task_id in payload.task_ids: if store.delete(task_id): deleted.append(task_id) else: missing.append(task_id) return {"deleted": deleted, "missing": missing, "count": len(deleted)}
[docs] @router.post("/bulk-archive") def bulk_archive_tasks(payload: TaskIdsRequest) -> dict[str, Any]: """Move active tasks into the archive by id.""" archive = ArchiveStore() archived: list[str] = [] missing: list[str] = [] for task_id in payload.task_ids: if archive.archive_task(task_id): archived.append(task_id) else: missing.append(task_id) return {"archived": archived, "missing": missing, "count": len(archived)}
def _parse_time(value: str) -> datetime | None: try: if value.endswith("Z"): value = value[:-1] + "+00:00" parsed = datetime.fromisoformat(value) except (TypeError, ValueError): return None if parsed.tzinfo is None: parsed = parsed.replace(tzinfo=timezone.utc) return parsed.astimezone(timezone.utc)
[docs] @router.post("/archive-expired") def archive_expired_tasks( hours: int = 72, terminal_only: bool = True, ) -> dict[str, Any]: """Archive active tasks older than ``hours``. By default only terminal tasks are archived, so pending/running jobs are not hidden from the active task view. """ cutoff = datetime.now(timezone.utc) - timedelta(hours=hours) store = TaskStore() archive = ArchiveStore() candidates = store.list(limit=None) archived: list[str] = [] skipped_running: list[str] = [] for task in candidates: if terminal_only and task.status not in TERMINAL_STATUSES: skipped_running.append(task.task_id) continue timestamp = _parse_time(task.update_time or task.submit_time) if timestamp is None or timestamp > cutoff: continue if archive.archive_task(task.task_id): archived.append(task.task_id) return { "archived": archived, "count": len(archived), "hours": hours, "terminal_only": terminal_only, "skipped_running": skipped_running, }