api.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. """
  2. Trace RESTful API
  3. 提供 Trace、GoalTree、Message 的查询接口
  4. """
  5. import os
  6. import json
  7. import httpx
  8. from datetime import datetime, timezone
  9. from typing import List, Optional, Dict, Any
  10. from fastapi import APIRouter, HTTPException, Query
  11. from pydantic import BaseModel
  12. from agent.llm.openrouter import openrouter_llm_call
  13. from .protocols import TraceStore
  14. router = APIRouter(prefix="/api/traces", tags=["traces"])
  15. # ===== Response 模型 =====
  16. class TraceListResponse(BaseModel):
  17. """Trace 列表响应"""
  18. traces: List[Dict[str, Any]]
  19. class TraceDetailResponse(BaseModel):
  20. """Trace 详情响应(包含 GoalTree 和 Sub-Traces 元数据)"""
  21. trace: Dict[str, Any]
  22. goal_tree: Optional[Dict[str, Any]] = None
  23. sub_traces: Dict[str, Dict[str, Any]] = {}
  24. class MessagesResponse(BaseModel):
  25. """Messages 响应"""
  26. messages: List[Dict[str, Any]]
  27. # ===== 全局 TraceStore(由 api_server.py 注入)=====
  28. _trace_store: Optional[TraceStore] = None
  29. def set_trace_store(store: TraceStore):
  30. """设置 TraceStore 实例"""
  31. global _trace_store
  32. _trace_store = store
  33. def get_trace_store() -> TraceStore:
  34. """获取 TraceStore 实例"""
  35. if _trace_store is None:
  36. raise RuntimeError("TraceStore not initialized")
  37. return _trace_store
  38. # ===== 路由 =====
  39. @router.get("", response_model=TraceListResponse)
  40. async def list_traces(
  41. mode: Optional[str] = None,
  42. agent_type: Optional[str] = None,
  43. uid: Optional[str] = None,
  44. status: Optional[str] = None,
  45. limit: int = Query(20, le=100)
  46. ):
  47. """
  48. 列出 Traces
  49. Args:
  50. mode: 模式过滤(call/agent)
  51. agent_type: Agent 类型过滤
  52. uid: 用户 ID 过滤
  53. status: 状态过滤(running/completed/failed)
  54. limit: 最大返回数量
  55. """
  56. store = get_trace_store()
  57. traces = await store.list_traces(
  58. mode=mode,
  59. agent_type=agent_type,
  60. uid=uid,
  61. status=status,
  62. limit=limit
  63. )
  64. return TraceListResponse(
  65. traces=[t.to_dict() for t in traces]
  66. )
  67. @router.get("/{trace_id}", response_model=TraceDetailResponse)
  68. async def get_trace(trace_id: str):
  69. """
  70. 获取 Trace 详情
  71. 返回 Trace 元数据、GoalTree、Sub-Traces 元数据(不含 Sub-Trace 内 GoalTree)
  72. Args:
  73. trace_id: Trace ID
  74. """
  75. store = get_trace_store()
  76. # 获取 Trace
  77. trace = await store.get_trace(trace_id)
  78. if not trace:
  79. raise HTTPException(status_code=404, detail="Trace not found")
  80. # 获取 GoalTree
  81. goal_tree = await store.get_goal_tree(trace_id)
  82. # 获取所有 Sub-Traces(通过 parent_trace_id 查询)
  83. sub_traces = {}
  84. all_traces = await store.list_traces(limit=1000) # 获取所有 traces
  85. for t in all_traces:
  86. if t.parent_trace_id == trace_id:
  87. sub_traces[t.trace_id] = t.to_dict()
  88. return TraceDetailResponse(
  89. trace=trace.to_dict(),
  90. goal_tree=goal_tree.to_dict() if goal_tree else None,
  91. sub_traces=sub_traces
  92. )
  93. @router.get("/{trace_id}/messages", response_model=MessagesResponse)
  94. async def get_messages(
  95. trace_id: str,
  96. mode: str = Query("main_path", description="查询模式:main_path(当前主路径消息)或 all(全部消息含所有分支)"),
  97. head: Optional[int] = Query(None, description="主路径的 head sequence(仅 mode=main_path 有效,默认用 trace.head_sequence)"),
  98. goal_id: Optional[str] = Query(None, description="过滤指定 Goal 的消息。使用 '_init' 查询初始阶段(goal_id=None)的消息"),
  99. ):
  100. """
  101. 获取 Messages
  102. Args:
  103. trace_id: Trace ID
  104. mode: 查询模式
  105. - "main_path"(默认): 从 head 沿 parent_sequence 链回溯的主路径消息
  106. - "all": 返回所有消息(包含所有分支)
  107. head: 可选,指定主路径的 head sequence(仅 mode=main_path 有效)
  108. goal_id: 可选,过滤指定 Goal 的消息
  109. - 不指定: 不按 goal 过滤
  110. - "_init" 或 "null": 返回初始阶段(goal_id=None)的消息
  111. - 其他值: 返回指定 Goal 的消息
  112. """
  113. store = get_trace_store()
  114. # 验证 Trace 存在
  115. trace = await store.get_trace(trace_id)
  116. if not trace:
  117. raise HTTPException(status_code=404, detail="Trace not found")
  118. # 获取 Messages
  119. if goal_id and goal_id not in ("_init", "null"):
  120. # 按 Goal 过滤(独立查询)
  121. messages = await store.get_messages_by_goal(trace_id, goal_id)
  122. elif mode == "main_path":
  123. # 主路径模式
  124. head_seq = head if head is not None else trace.head_sequence
  125. if head_seq > 0:
  126. messages = await store.get_main_path_messages(trace_id, head_seq)
  127. else:
  128. messages = []
  129. else:
  130. # all 模式:返回所有消息
  131. messages = await store.get_trace_messages(trace_id)
  132. # goal_id 过滤(_init 表示 goal_id=None 的消息)
  133. if goal_id in ("_init", "null"):
  134. messages = [m for m in messages if m.goal_id is None]
  135. return MessagesResponse(
  136. messages=[m.to_dict() for m in messages]
  137. )
  138. # ===== 知识反馈 =====
  139. class KnowledgeFeedbackItem(BaseModel):
  140. knowledge_id: str
  141. action: str # "confirm" | "override" | "skip"
  142. eval_status: Optional[str] = None # helpful | harmful | unused | irrelevant | neutral
  143. feedback_text: Optional[str] = None
  144. source: Dict[str, Any] = {} # {trace_id, goal_id, sequence, feedback_by, feedback_at}
  145. class KnowledgeFeedbackRequest(BaseModel):
  146. feedback_list: List[KnowledgeFeedbackItem]
  147. @router.get("/{trace_id}/knowledge_log")
  148. async def get_knowledge_log(trace_id: str):
  149. """获取 Trace 的知识注入日志"""
  150. store = get_trace_store()
  151. trace = await store.get_trace(trace_id)
  152. if not trace:
  153. raise HTTPException(status_code=404, detail="Trace not found")
  154. return await store.get_knowledge_log(trace_id)
  155. @router.post("/{trace_id}/knowledge_feedback")
  156. async def submit_knowledge_feedback(trace_id: str, req: KnowledgeFeedbackRequest):
  157. """提交知识使用反馈,同步更新 knowledge_log.json 并转发到 KnowHub"""
  158. store = get_trace_store()
  159. trace = await store.get_trace(trace_id)
  160. if not trace:
  161. raise HTTPException(status_code=404, detail="Trace not found")
  162. knowhub_url = os.getenv("KNOWHUB_API") or os.getenv("KNOWHUB_URL", "http://localhost:9999")
  163. updated_count = 0
  164. now_iso = datetime.now(timezone.utc).isoformat()
  165. async with httpx.AsyncClient(timeout=10.0) as client:
  166. for item in req.feedback_list:
  167. if item.action == "skip":
  168. continue
  169. # 1. 记录到 knowledge_log.json
  170. feedback_record = {
  171. "action": item.action,
  172. "eval_status": item.eval_status,
  173. "feedback_text": item.feedback_text,
  174. "feedback_by": item.source.get("feedback_by", "user"),
  175. "feedback_at": item.source.get("feedback_at", now_iso),
  176. }
  177. await store.update_user_feedback(trace_id, item.knowledge_id, feedback_record)
  178. # 2. 构建 history 条目(含完整溯源信息)
  179. history_entry = {
  180. "source": "user",
  181. "action": item.action,
  182. "eval_status": item.eval_status,
  183. "feedback_by": item.source.get("feedback_by", "user"),
  184. "feedback_at": now_iso,
  185. "trace_id": trace_id,
  186. "goal_id": item.source.get("goal_id"),
  187. "sequence": item.source.get("sequence"),
  188. "feedback_text": item.feedback_text,
  189. }
  190. # 3. 根据 action 和 eval_status 决定调用 KnowHub 的哪个字段
  191. if item.action == "confirm":
  192. payload = {"add_helpful_case": history_entry}
  193. elif item.action == "override":
  194. if item.eval_status == "harmful":
  195. payload = {"add_harmful_case": history_entry}
  196. else:
  197. # helpful / unused / irrelevant / neutral → 记为 helpful_case,history 内保留完整 eval_status
  198. payload = {"add_helpful_case": history_entry}
  199. else:
  200. continue
  201. try:
  202. await client.put(
  203. f"{knowhub_url}/api/knowledge/{item.knowledge_id}",
  204. json=payload
  205. )
  206. updated_count += 1
  207. except Exception as e:
  208. # 记录警告但不中断整体提交
  209. print(f"[KnowledgeFeedback] KnowHub 更新失败 {item.knowledge_id}: {e}")
  210. return {"status": "ok", "updated": updated_count}
  211. @router.post("/extract_comment", status_code=201)
  212. async def extract_comment_proxy(req: Dict[str, Any]):
  213. """调用 LLM 从评论提取结构化知识,再 POST 到远端 KnowHub /api/knowledge"""
  214. comment = (req.get("comment") or "").strip()
  215. if not comment:
  216. raise HTTPException(status_code=400, detail="comment is required")
  217. context = req.get("context") or ""
  218. prompt = f"""你是知识提取专家。根据用户的评论和 Agent 执行上下文,提取一条结构化知识。
  219. 【上下文(Agent 执行内容)】:
  220. {context or "(无上下文)"}
  221. 【用户评论】:
  222. {comment}
  223. 【输出格式】(严格 JSON,不要其他内容):
  224. {{
  225. "task": "任务场景描述(一句话,描述在什么情况下要完成什么目标)",
  226. "content": "核心知识内容(具体可操作的方法、注意事项)"
  227. }}"""
  228. try:
  229. response = await openrouter_llm_call(
  230. messages=[{"role": "user", "content": prompt}],
  231. model="google/gemini-2.5-flash-lite",
  232. )
  233. raw = response.get("content", "").strip()
  234. if "```" in raw:
  235. for part in raw.split("```"):
  236. part = part.strip().lstrip("json").strip()
  237. try:
  238. parsed = json.loads(part)
  239. if "task" in parsed and "content" in parsed:
  240. raw = part
  241. break
  242. except Exception:
  243. continue
  244. extracted = json.loads(raw)
  245. task = extracted.get("task", "").strip()
  246. content = extracted.get("content", "").strip()
  247. if not task or not content:
  248. raise ValueError("missing task or content")
  249. except Exception as e:
  250. raise HTTPException(status_code=500, detail=f"LLM 提取失败: {e}")
  251. knowhub_url = os.getenv("KNOWHUB_API") or os.getenv("KNOWHUB_URL", "http://localhost:9999")
  252. payload = {
  253. "task": task,
  254. "content": content,
  255. "types": req.get("types", ["strategy"]),
  256. "scopes": req.get("scopes", ["org:cybertogether"]),
  257. "owner": req.get("owner", "user"),
  258. "source": req.get("source", {}),
  259. }
  260. async with httpx.AsyncClient(timeout=15.0) as client:
  261. try:
  262. resp = await client.post(f"{knowhub_url}/api/knowledge", json=payload)
  263. resp.raise_for_status()
  264. data = resp.json()
  265. return {"status": "pending", "knowledge_id": data.get("id", ""), "task": task, "content": content}
  266. except Exception as e:
  267. raise HTTPException(status_code=502, detail=f"KnowHub 写入失败: {e}")