""" Trace RESTful API 提供 Trace、GoalTree、Message 的查询接口 """ import os import json import httpx from datetime import datetime, timezone from typing import List, Optional, Dict, Any from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel from agent.llm.openrouter import openrouter_llm_call from .protocols import TraceStore router = APIRouter(prefix="/api/traces", tags=["traces"]) # ===== Response 模型 ===== class TraceListResponse(BaseModel): """Trace 列表响应""" traces: List[Dict[str, Any]] class TraceDetailResponse(BaseModel): """Trace 详情响应(包含 GoalTree 和 Sub-Traces 元数据)""" trace: Dict[str, Any] goal_tree: Optional[Dict[str, Any]] = None sub_traces: Dict[str, Dict[str, Any]] = {} class MessagesResponse(BaseModel): """Messages 响应""" messages: List[Dict[str, Any]] # ===== 全局 TraceStore(由 api_server.py 注入)===== _trace_store: Optional[TraceStore] = None def set_trace_store(store: TraceStore): """设置 TraceStore 实例""" global _trace_store _trace_store = store def get_trace_store() -> TraceStore: """获取 TraceStore 实例""" if _trace_store is None: raise RuntimeError("TraceStore not initialized") return _trace_store # ===== 路由 ===== @router.get("", response_model=TraceListResponse) async def list_traces( mode: Optional[str] = None, agent_type: Optional[str] = None, uid: Optional[str] = None, status: Optional[str] = None, limit: int = Query(20, le=100) ): """ 列出 Traces Args: mode: 模式过滤(call/agent) agent_type: Agent 类型过滤 uid: 用户 ID 过滤 status: 状态过滤(running/completed/failed) limit: 最大返回数量 """ store = get_trace_store() traces = await store.list_traces( mode=mode, agent_type=agent_type, uid=uid, status=status, limit=limit ) return TraceListResponse( traces=[t.to_dict() for t in traces] ) @router.get("/{trace_id}", response_model=TraceDetailResponse) async def get_trace(trace_id: str): """ 获取 Trace 详情 返回 Trace 元数据、GoalTree、Sub-Traces 元数据(不含 Sub-Trace 内 GoalTree) Args: trace_id: Trace ID """ store = get_trace_store() # 获取 Trace trace = await store.get_trace(trace_id) if not trace: raise HTTPException(status_code=404, detail="Trace not found") # 获取 GoalTree goal_tree = await store.get_goal_tree(trace_id) # 获取所有 Sub-Traces(通过 parent_trace_id 查询) sub_traces = {} all_traces = await store.list_traces(limit=1000) # 获取所有 traces for t in all_traces: if t.parent_trace_id == trace_id: sub_traces[t.trace_id] = t.to_dict() return TraceDetailResponse( trace=trace.to_dict(), goal_tree=goal_tree.to_dict() if goal_tree else None, sub_traces=sub_traces ) @router.get("/{trace_id}/messages", response_model=MessagesResponse) async def get_messages( trace_id: str, mode: str = Query("main_path", description="查询模式:main_path(当前主路径消息)或 all(全部消息含所有分支)"), head: Optional[int] = Query(None, description="主路径的 head sequence(仅 mode=main_path 有效,默认用 trace.head_sequence)"), goal_id: Optional[str] = Query(None, description="过滤指定 Goal 的消息。使用 '_init' 查询初始阶段(goal_id=None)的消息"), ): """ 获取 Messages Args: trace_id: Trace ID mode: 查询模式 - "main_path"(默认): 从 head 沿 parent_sequence 链回溯的主路径消息 - "all": 返回所有消息(包含所有分支) head: 可选,指定主路径的 head sequence(仅 mode=main_path 有效) goal_id: 可选,过滤指定 Goal 的消息 - 不指定: 不按 goal 过滤 - "_init" 或 "null": 返回初始阶段(goal_id=None)的消息 - 其他值: 返回指定 Goal 的消息 """ store = get_trace_store() # 验证 Trace 存在 trace = await store.get_trace(trace_id) if not trace: raise HTTPException(status_code=404, detail="Trace not found") # 获取 Messages if goal_id and goal_id not in ("_init", "null"): # 按 Goal 过滤(独立查询) messages = await store.get_messages_by_goal(trace_id, goal_id) elif mode == "main_path": # 主路径模式 head_seq = head if head is not None else trace.head_sequence if head_seq > 0: messages = await store.get_main_path_messages(trace_id, head_seq) else: messages = [] else: # all 模式:返回所有消息 messages = await store.get_trace_messages(trace_id) # goal_id 过滤(_init 表示 goal_id=None 的消息) if goal_id in ("_init", "null"): messages = [m for m in messages if m.goal_id is None] return MessagesResponse( messages=[m.to_dict() for m in messages] ) # ===== 知识反馈 ===== class KnowledgeFeedbackItem(BaseModel): knowledge_id: str action: str # "confirm" | "override" | "skip" eval_status: Optional[str] = None # helpful | harmful | unused | irrelevant | neutral feedback_text: Optional[str] = None source: Dict[str, Any] = {} # {trace_id, goal_id, sequence, feedback_by, feedback_at} class KnowledgeFeedbackRequest(BaseModel): feedback_list: List[KnowledgeFeedbackItem] @router.get("/{trace_id}/knowledge_log") async def get_knowledge_log(trace_id: str): """获取 Trace 的知识注入日志""" store = get_trace_store() trace = await store.get_trace(trace_id) if not trace: raise HTTPException(status_code=404, detail="Trace not found") return await store.get_knowledge_log(trace_id) @router.post("/{trace_id}/knowledge_feedback") async def submit_knowledge_feedback(trace_id: str, req: KnowledgeFeedbackRequest): """提交知识使用反馈,同步更新 knowledge_log.json 并转发到 KnowHub""" store = get_trace_store() trace = await store.get_trace(trace_id) if not trace: raise HTTPException(status_code=404, detail="Trace not found") knowhub_url = os.getenv("KNOWHUB_API") or os.getenv("KNOWHUB_URL", "http://localhost:9999") updated_count = 0 now_iso = datetime.now(timezone.utc).isoformat() async with httpx.AsyncClient(timeout=10.0) as client: for item in req.feedback_list: if item.action == "skip": continue # 1. 记录到 knowledge_log.json feedback_record = { "action": item.action, "eval_status": item.eval_status, "feedback_text": item.feedback_text, "feedback_by": item.source.get("feedback_by", "user"), "feedback_at": item.source.get("feedback_at", now_iso), } await store.update_user_feedback(trace_id, item.knowledge_id, feedback_record) # 2. 构建 history 条目(含完整溯源信息) history_entry = { "source": "user", "action": item.action, "eval_status": item.eval_status, "feedback_by": item.source.get("feedback_by", "user"), "feedback_at": now_iso, "trace_id": trace_id, "goal_id": item.source.get("goal_id"), "sequence": item.source.get("sequence"), "feedback_text": item.feedback_text, } # 3. 根据 action 和 eval_status 决定调用 KnowHub 的哪个字段 if item.action == "confirm": payload = {"add_helpful_case": history_entry} elif item.action == "override": if item.eval_status == "harmful": payload = {"add_harmful_case": history_entry} else: # helpful / unused / irrelevant / neutral → 记为 helpful_case,history 内保留完整 eval_status payload = {"add_helpful_case": history_entry} else: continue try: await client.put( f"{knowhub_url}/api/knowledge/{item.knowledge_id}", json=payload ) updated_count += 1 except Exception as e: # 记录警告但不中断整体提交 print(f"[KnowledgeFeedback] KnowHub 更新失败 {item.knowledge_id}: {e}") return {"status": "ok", "updated": updated_count} @router.post("/extract_comment", status_code=201) async def extract_comment_proxy(req: Dict[str, Any]): """调用 LLM 从评论提取结构化知识,再 POST 到远端 KnowHub /api/knowledge""" comment = (req.get("comment") or "").strip() if not comment: raise HTTPException(status_code=400, detail="comment is required") context = req.get("context") or "" prompt = f"""你是知识提取专家。根据用户的评论和 Agent 执行上下文,提取一条结构化知识。 【上下文(Agent 执行内容)】: {context or "(无上下文)"} 【用户评论】: {comment} 【输出格式】(严格 JSON,不要其他内容): {{ "task": "任务场景描述(一句话,描述在什么情况下要完成什么目标)", "content": "核心知识内容(具体可操作的方法、注意事项)" }}""" try: response = await openrouter_llm_call( messages=[{"role": "user", "content": prompt}], model="google/gemini-2.5-flash-lite", ) raw = response.get("content", "").strip() if "```" in raw: for part in raw.split("```"): part = part.strip().lstrip("json").strip() try: parsed = json.loads(part) if "task" in parsed and "content" in parsed: raw = part break except Exception: continue extracted = json.loads(raw) task = extracted.get("task", "").strip() content = extracted.get("content", "").strip() if not task or not content: raise ValueError("missing task or content") except Exception as e: raise HTTPException(status_code=500, detail=f"LLM 提取失败: {e}") knowhub_url = os.getenv("KNOWHUB_API") or os.getenv("KNOWHUB_URL", "http://localhost:9999") payload = { "task": task, "content": content, "types": req.get("types", ["strategy"]), "scopes": req.get("scopes", ["org:cybertogether"]), "owner": req.get("owner", "user"), "source": req.get("source", {}), } async with httpx.AsyncClient(timeout=15.0) as client: try: resp = await client.post(f"{knowhub_url}/api/knowledge", json=payload) resp.raise_for_status() data = resp.json() return {"status": "pending", "knowledge_id": data.get("id", ""), "task": task, "content": content} except Exception as e: raise HTTPException(status_code=502, detail=f"KnowHub 写入失败: {e}")