xueyiming 2 дней назад
Родитель
Сommit
d900116fb4
2 измененных файлов с 205 добавлено и 23 удалено
  1. 56 23
      examples/demand/run.py
  2. 149 0
      examples/demand/web_api.py

+ 56 - 23
examples/demand/run.py

@@ -1,6 +1,5 @@
 """demand 示例的最小可运行入口。"""
 
-import asyncio
 import copy
 import importlib
 import json
@@ -41,7 +40,7 @@ from agent.llm import create_openrouter_llm_call
 from agent.llm.prompts import SimplePrompt
 from agent.trace import FileSystemTraceStore, Message, Trace
 from agent.utils import setup_logging
-from log_capture import build_log, log
+from examples.demand.log_capture import build_log, log
 
 # 导入项目配置
 from examples.demand.config import DEBUG, LOG_FILE, LOG_LEVEL, RUN_CONFIG, TRACE_STORE_PATH
@@ -172,13 +171,26 @@ def _avg_score_for_joined_name(name: str, score_map: dict) -> float:
     return sum(float(score_map.get(part, 0.0)) for part in parts) / len(parts)
 
 
-def _create_demand_task(execution_id: int) -> Optional[int]:
+def _create_demand_task(
+    execution_id: int,
+    name: Optional[str] = None,
+    platform: Optional[str] = None,
+) -> Optional[int]:
     """创建 demand_task 记录,返回任务ID。"""
     try:
+        # 数据库字段 demand_task.name: varchar(32)
+        if name is not None:
+            name = str(name)[:32]
+        # 数据库字段 demand_task.platform: varchar(32)
+        if platform is not None:
+            platform = str(platform)[:32]
+
         task_id = mysql_db.insert(
             "demand_task",
             {
                 "execution_id": execution_id,
+                "name": name,
+                "platform": platform,
                 "status": 0,
                 "log": "",
             },
@@ -280,10 +292,9 @@ def write_demand_items_to_mysql(execution_id: int, merge_level2: str) -> int:
     return len(rows)
 
 
-async def run_once(execution_id, merge_level2) -> str:
-    # task_id = _create_demand_task(execution_id=execution_id)
-    # task_status = 2
+async def run_once(execution_id, merge_level2, task_id: Optional[int] = None) -> str:
     task_log_text = ""
+    task_status = 0
 
     TopicBuildAgentContext.set_execution_id(execution_id)
     prepare(execution_id)
@@ -356,35 +367,57 @@ async def run_once(execution_id, merge_level2) -> str:
 
             # agent 执行完成后:把本地 result JSON 写入 MySQL 表 demand_content
             # element_names -> name(逗号分隔);reason/desc -> ext_data JSON;merge_leve2 -> demand_content.merge_leve2
-            # try:
-            #     write_demand_items_to_mysql(execution_id=execution_id, merge_level2=merge_level2)
-            # except Exception as e:
-            #     log(f"[mysql] 写入 demand_content 异常:{e}")
+            try:
+                write_demand_items_to_mysql(execution_id=execution_id, merge_level2=merge_level2)
+            except Exception as e:
+                log(f"[mysql] 写入 demand_content 异常:{e}")
 
             task_log_text = log_buffer.getvalue()
-            with open(log_file_path, "w", encoding="utf-8") as f:
-                f.write(task_log_text)
-
             task_status = 1
     except Exception as e:
+        if not task_log_text:
+            # 如果异常发生在 build_log 内部,尽量回收已产生的日志
+            try:
+                existing = locals().get("log_buffer")
+                if existing is not None:
+                    task_log_text = existing.getvalue()  # type: ignore[attr-defined]
+            except Exception:
+                pass
         if not task_log_text:
             task_log_text = f"[run] 执行异常: {e}"
         task_status = 2
         raise
-    # finally:
-    # _finish_demand_task(task_id=task_id, status=task_status, task_log=task_log_text)
+    finally:
+        if task_log_text:
+            try:
+                with open(log_file_path, "w", encoding="utf-8") as f:
+                    f.write(task_log_text)
+            except Exception:
+                # 兜底:即使写文件失败,也要确保 MySQL 状态被更新
+                pass
+        _finish_demand_task(task_id=task_id, status=task_status, task_log=task_log_text)
 
     return final_text
 
 
-async def main(cluster_name, platform_type) -> None:
-    execution_id = None
-    if platform_type == "piaoquan":
-        execution_id = piaoquan_prepare(cluster_name)
-    elif platform_type == "changwen":
-        execution_id = changwen_prepare(cluster_name)
-    if execution_id:
-        await run_once(execution_id, cluster_name)
+async def main(
+        cluster_name: str,
+        platform_type: str,
+        execution_id: Optional[int] = None,
+        task_id: Optional[int] = None,
+) -> dict:
+    if execution_id is None:
+        if platform_type == "piaoquan":
+            execution_id = piaoquan_prepare(cluster_name)
+        elif platform_type == "changwen":
+            execution_id = changwen_prepare(cluster_name)
+        else:
+            execution_id = None
+    if not execution_id:
+        return {"execution_id": None, "final_text": ""}
+
+    final_text = await run_once(execution_id, cluster_name, task_id=task_id)
+    return {"execution_id": execution_id, "final_text": final_text}
 
 
 if __name__ == "__main__":

+ 149 - 0
examples/demand/web_api.py

@@ -0,0 +1,149 @@
+"""
+demand Web API(异步任务:发起 -> 立即返回 task_id -> 另一个接口查询状态)
+"""
+
+import asyncio
+import sys
+from pathlib import Path
+from typing import Literal, Optional
+
+from fastapi import FastAPI, HTTPException
+from pydantic import BaseModel
+
+# 添加项目根目录到 Python 路径(与 run.py 保持一致)
+sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+from examples.demand.changwen_prepare import changwen_prepare
+from examples.demand.mysql import mysql_db
+from examples.demand.piaoquan_prepare import piaoquan_prepare
+from examples.demand.run import _create_demand_task, main
+
+app = FastAPI(title="demand web api")
+
+
+class DemandStartRequest(BaseModel):
+    cluster_name: str
+    platform_type: Literal["piaoquan", "changwen"]
+
+
+@app.post("/demand/start")
+async def demand_start(req: DemandStartRequest):
+    # 注意:这里会同步计算 execution_id(prepare 阶段),随后 run_once 放到后台异步执行。
+    if req.platform_type == "piaoquan":
+        execution_id = piaoquan_prepare(req.cluster_name)
+    else:
+        execution_id = changwen_prepare(req.cluster_name)
+
+    if not execution_id:
+        raise HTTPException(status_code=400, detail="获取 execution_id 失败")
+
+    task_name = req.cluster_name[:32] if req.cluster_name else None
+    task_id = _create_demand_task(
+        execution_id=execution_id,
+        name=task_name,
+        platform=req.platform_type,
+    )
+    if not task_id:
+        raise HTTPException(status_code=500, detail="创建 demand_task 失败")
+
+    async def _job():
+        # run_once 内部会在 finally 里把 task 状态写回 MySQL。
+        await main(
+            req.cluster_name,
+            req.platform_type,
+            execution_id=execution_id,
+            task_id=task_id,
+        )
+
+    asyncio.create_task(_job())
+    return {"ok": True, "message": "调用成功", "task_id": task_id, "execution_id": execution_id}
+
+
+@app.get("/demand/task/{task_id}/status")
+def demand_task_status(task_id: int, max_log_chars: int = 2000):
+    row = mysql_db.select_one(
+        "demand_task",
+        columns="id, execution_id, name, platform, status, log",
+        where="id = %s",
+        where_params=(task_id,),
+    )
+    if not row:
+        raise HTTPException(status_code=404, detail="task not found")
+
+    status = int(row.get("status") or 0)
+    status_map = {0: "running", 1: "completed", 2: "failed"}
+
+    log_text = row.get("log") or ""
+    if max_log_chars and isinstance(log_text, str) and len(log_text) > max_log_chars:
+        log_text = log_text[:max_log_chars] + "...(truncated)"
+
+    execution_id = row.get("execution_id")
+    final_text: Optional[str] = None
+    if status == 1 and execution_id:
+        try:
+            result_path = Path(__file__).parent / "output" / str(execution_id) / "result.txt"
+            if result_path.exists():
+                with open(result_path, "r", encoding="utf-8") as f:
+                    final_text = f.read()
+        except Exception:
+            final_text = None
+
+    return {
+        "task_id": task_id,
+        "execution_id": execution_id,
+        "name": row.get("name"),
+        "platform": row.get("platform"),
+        "status": status,
+        "status_text": status_map.get(status, "unknown"),
+        "final_text": final_text,
+        "log": log_text,
+    }
+
+
+@app.get("/demand/tasks")
+def demand_tasks(
+    status: Optional[int] = None,
+    name: Optional[str] = None,
+    platform_type: Optional[str] = None,
+    page: int = 1,
+    page_size: int = 20,
+):
+    where_parts: list[str] = []
+    where_params: list = []
+
+    if status is not None:
+        status_int = int(status)
+        if status_int not in (0, 1, 2):
+            raise HTTPException(status_code=400, detail="status 必须为 0/1/2")
+        where_parts.append("status = %s")
+        where_params.append(status_int)
+
+    if name:
+        name_str = str(name).strip()
+        if name_str:
+            # 支持模糊匹配:根据需求名称字段(varchar(32))
+            where_parts.append("name LIKE %s")
+            where_params.append(f"%{name_str}%")
+
+    if platform_type:
+        platform_str = str(platform_type).strip()
+        if platform_str:
+            where_parts.append("platform = %s")
+            where_params.append(platform_str)
+
+    where = " AND ".join(where_parts)
+    params = tuple(where_params) if where_params else None
+
+    data = mysql_db.paginate(
+        "demand_task",
+        page=page,
+        page_size=page_size,
+        columns="id, execution_id, name, platform, status, create_time, update_time",
+        where=where,
+        where_params=params,
+        order_by="id DESC",
+    )
+
+    # 返回分页结构(data + pagination),便于前端直接展示
+    return data
+