server.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. """
  2. 内容寻找服务
  3. 提供:
  4. 1. API 接口:POST /api/tasks - 触发内容寻找任务
  5. 2. 定时调度:每 10 分钟调用外部 API 获取 query 并执行任务
  6. 3. 并发控制:限制最大并发任务数
  7. """
  8. import asyncio
  9. import logging
  10. import os
  11. from datetime import datetime
  12. from pathlib import Path
  13. from typing import Optional
  14. import sys
  15. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  16. import httpx
  17. from fastapi import FastAPI, HTTPException
  18. from pydantic import BaseModel
  19. from apscheduler.schedulers.asyncio import AsyncIOScheduler
  20. from dotenv import load_dotenv
  21. load_dotenv()
  22. import core
  23. # 配置日志
  24. log_dir = Path(__file__).parent / '.cache'
  25. log_dir.mkdir(exist_ok=True)
  26. logging.basicConfig(
  27. level=logging.INFO,
  28. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  29. handlers=[
  30. logging.FileHandler(log_dir / 'server.log'),
  31. logging.StreamHandler()
  32. ]
  33. )
  34. logger = logging.getLogger(__name__)
  35. # FastAPI 应用
  36. app = FastAPI(
  37. title="内容寻找服务",
  38. version="1.0.0",
  39. description="抖音内容寻找 Agent 服务"
  40. )
  41. # 定时调度器
  42. scheduler = AsyncIOScheduler()
  43. # 并发控制
  44. MAX_CONCURRENT_TASKS = int(os.getenv("MAX_CONCURRENT_TASKS", "3"))
  45. task_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS)
  46. # 统计信息
  47. stats = {
  48. "total_tasks": 0,
  49. "completed_tasks": 0,
  50. "failed_tasks": 0,
  51. "scheduled_tasks": 0
  52. }
  53. # ============ 数据模型 ============
  54. class TaskRequest(BaseModel):
  55. query: Optional[str] = None
  56. class TaskResponse(BaseModel):
  57. trace_id: str
  58. status: str
  59. query: str
  60. message: str
  61. # ============ 核心函数 ============
  62. async def execute_task(query: str, task_type: str = "api"):
  63. """
  64. 执行任务(带并发控制)
  65. Args:
  66. query: 查询内容
  67. task_type: 任务类型("api" 或 "scheduled")
  68. """
  69. async with task_semaphore:
  70. current_concurrent = MAX_CONCURRENT_TASKS - task_semaphore._value + 1
  71. logger.info(f"任务开始 [{task_type}]: query={query[:50]}..., 当前并发={current_concurrent}/{MAX_CONCURRENT_TASKS}")
  72. start_time = datetime.now()
  73. stats["total_tasks"] += 1
  74. if task_type == "scheduled":
  75. stats["scheduled_tasks"] += 1
  76. try:
  77. # 执行 agent(不流式输出)
  78. result = await core.run_agent(query, stream_output=False)
  79. duration = (datetime.now() - start_time).total_seconds()
  80. if result["status"] == "completed":
  81. stats["completed_tasks"] += 1
  82. logger.info(f"任务完成 [{task_type}]: trace_id={result['trace_id']}, 耗时={duration:.1f}s")
  83. else:
  84. stats["failed_tasks"] += 1
  85. logger.error(f"任务失败 [{task_type}]: trace_id={result.get('trace_id')}, 错误={result.get('error')}, 耗时={duration:.1f}s")
  86. except Exception as e:
  87. stats["failed_tasks"] += 1
  88. duration = (datetime.now() - start_time).total_seconds()
  89. logger.error(f"任务异常 [{task_type}]: {e}, 耗时={duration:.1f}s", exc_info=True)
  90. async def scheduled_task():
  91. """
  92. 定时任务:每 10 分钟执行一次
  93. 流程:
  94. 1. 调用外部 API 获取 query
  95. 2. 如果成功,执行任务
  96. 3. 如果失败,跳过本次执行(不使用兜底)
  97. """
  98. logger.info("定时任务触发")
  99. try:
  100. # 1. 调用外部 API 获取 query
  101. query_api = os.getenv("SCHEDULE_QUERY_API")
  102. if not query_api:
  103. logger.warning("未配置 SCHEDULE_QUERY_API,跳过定时任务")
  104. return
  105. api_key = os.getenv("SCHEDULE_QUERY_API_KEY", "")
  106. timeout = float(os.getenv("SCHEDULE_QUERY_API_TIMEOUT", "10.0"))
  107. async with httpx.AsyncClient() as client:
  108. headers = {}
  109. if api_key:
  110. headers["Authorization"] = f"Bearer {api_key}"
  111. logger.info(f"调用外部 API: {query_api}")
  112. response = await client.get(
  113. query_api,
  114. headers=headers,
  115. timeout=timeout
  116. )
  117. response.raise_for_status()
  118. data = response.json()
  119. # 2. 提取 query
  120. query = data.get("query")
  121. if not query:
  122. logger.info("定时任务跳过:外部 API 返回的 query 为空")
  123. return
  124. # 3. 执行任务
  125. logger.info(f"定时任务启动: query={query[:50]}...")
  126. asyncio.create_task(execute_task(query, task_type="scheduled"))
  127. except httpx.HTTPStatusError as e:
  128. logger.error(f"定时任务失败:外部 API 返回错误 {e.response.status_code}: {e.response.text}")
  129. except httpx.RequestError as e:
  130. logger.error(f"定时任务失败:外部 API 请求失败: {e}")
  131. except httpx.TimeoutException:
  132. logger.error(f"定时任务失败:外部 API 请求超时")
  133. except Exception as e:
  134. logger.error(f"定时任务失败:未知错误: {e}", exc_info=True)
  135. # ============ API 接口 ============
  136. @app.post("/api/tasks", response_model=TaskResponse)
  137. async def create_task(request: TaskRequest):
  138. """
  139. 创建内容寻找任务
  140. Args:
  141. request.query: 查询内容(可选,不传则使用默认值)
  142. Returns:
  143. {
  144. "trace_id": "20260317_103046_xyz789",
  145. "status": "started",
  146. "query": "...",
  147. "message": "任务已启动,结果将保存到 .cache/traces/xxx/"
  148. }
  149. """
  150. # 获取 query
  151. query = request.query or core.DEFAULT_QUERY
  152. # 用 Event 等待 trace_id
  153. trace_id_ready = asyncio.Event()
  154. trace_id_holder = {"id": None}
  155. async def run_and_capture():
  156. try:
  157. # 获取第一个 Trace 对象来获取 trace_id
  158. from agent import Trace
  159. async with task_semaphore:
  160. # 重新构建 runner 来获取 trace_id
  161. from agent import AgentRunner, RunConfig, FileSystemTraceStore
  162. from agent.llm import create_openrouter_llm_call
  163. from agent.llm.prompts import SimplePrompt
  164. from agent.tools.builtin.knowledge import KnowledgeConfig
  165. prompt_path = Path(__file__).parent / "content_finder.prompt"
  166. prompt = SimplePrompt(prompt_path)
  167. messages = prompt.build_messages(query=query)
  168. api_key = os.getenv("OPEN_ROUTER_API_KEY")
  169. model_name = prompt.config.get("model", "sonnet-4.6")
  170. model = os.getenv("MODEL", f"anthropic/claude-{model_name}")
  171. temperature = float(prompt.config.get("temperature", 0.3))
  172. max_iterations = int(os.getenv("MAX_ITERATIONS", "30"))
  173. trace_dir = os.getenv("TRACE_DIR", ".cache/traces")
  174. skills_dir = str(Path(__file__).parent / "skills")
  175. Path(trace_dir).mkdir(parents=True, exist_ok=True)
  176. store = FileSystemTraceStore(base_path=trace_dir)
  177. allowed_tools = [
  178. "douyin_search",
  179. "douyin_user_videos",
  180. "get_content_fans_portrait",
  181. "get_account_fans_portrait",
  182. ]
  183. runner = AgentRunner(
  184. llm_call=create_openrouter_llm_call(model=model),
  185. trace_store=store,
  186. skills_dir=skills_dir,
  187. )
  188. config = RunConfig(
  189. name="内容寻找",
  190. model=model,
  191. temperature=temperature,
  192. max_iterations=max_iterations,
  193. tools=allowed_tools,
  194. extra_llm_params={"max_tokens": 8192},
  195. knowledge=KnowledgeConfig(
  196. enable_extraction=True,
  197. enable_completion_extraction=True,
  198. enable_injection=True,
  199. owner="content_finder_agent",
  200. default_tags={"project": "content_finder"},
  201. default_scopes=["com.piaoquantv.supply"],
  202. default_search_types=["tool", "usecase", "definition"],
  203. default_search_owner="content_finder_agent"
  204. )
  205. )
  206. async for item in runner.run(messages=messages, config=config):
  207. if isinstance(item, Trace):
  208. if not trace_id_holder["id"]:
  209. trace_id_holder["id"] = item.trace_id
  210. trace_id_ready.set()
  211. logger.info(f"任务启动 [api]: trace_id={item.trace_id}")
  212. if item.status == "completed":
  213. stats["completed_tasks"] += 1
  214. logger.info(f"任务完成 [api]: trace_id={item.trace_id}")
  215. break
  216. elif item.status == "failed":
  217. stats["failed_tasks"] += 1
  218. logger.error(f"任务失败 [api]: trace_id={item.trace_id}, 错误={item.error_message}")
  219. break
  220. except Exception as e:
  221. stats["failed_tasks"] += 1
  222. logger.error(f"任务异常 [api]: {e}", exc_info=True)
  223. if not trace_id_holder["id"]:
  224. trace_id_holder["id"] = f"error_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
  225. trace_id_ready.set()
  226. # 启动后台任务
  227. stats["total_tasks"] += 1
  228. asyncio.create_task(run_and_capture())
  229. # 等待 trace_id(最多 5 秒)
  230. try:
  231. await asyncio.wait_for(trace_id_ready.wait(), timeout=5.0)
  232. except asyncio.TimeoutError:
  233. logger.error("获取 trace_id 超时")
  234. raise HTTPException(status_code=500, detail="任务启动超时")
  235. trace_id = trace_id_holder["id"]
  236. return TaskResponse(
  237. trace_id=trace_id,
  238. status="started",
  239. query=query,
  240. message=f"任务已启动,结果将保存到 .cache/traces/{trace_id}/"
  241. )
  242. @app.get("/health")
  243. async def health_check():
  244. """健康检查"""
  245. return {
  246. "status": "ok",
  247. "max_concurrent_tasks": MAX_CONCURRENT_TASKS,
  248. "current_tasks": MAX_CONCURRENT_TASKS - task_semaphore._value,
  249. "scheduler_running": scheduler.running,
  250. "stats": stats
  251. }
  252. @app.get("/")
  253. async def root():
  254. """根路径"""
  255. return {
  256. "service": "内容寻找服务",
  257. "version": "1.0.0",
  258. "endpoints": {
  259. "create_task": "POST /api/tasks",
  260. "health": "GET /health"
  261. }
  262. }
  263. # ============ 启动事件 ============
  264. @app.on_event("startup")
  265. async def startup():
  266. """服务启动时初始化"""
  267. logger.info("=" * 60)
  268. logger.info("内容寻找服务启动中...")
  269. logger.info(f"最大并发任务数: {MAX_CONCURRENT_TASKS}")
  270. # 配置定时任务
  271. query_api = os.getenv("SCHEDULE_QUERY_API")
  272. if query_api:
  273. # 每 10 分钟执行一次
  274. scheduler.add_job(scheduled_task, "cron", minute="*/10")
  275. scheduler.start()
  276. logger.info(f"定时任务已启动:每 10 分钟执行一次")
  277. logger.info(f"外部 API: {query_api}")
  278. else:
  279. logger.info("未配置 SCHEDULE_QUERY_API,定时任务未启动")
  280. logger.info("服务启动完成")
  281. logger.info("=" * 60)
  282. @app.on_event("shutdown")
  283. async def shutdown():
  284. """服务关闭时清理"""
  285. logger.info("服务关闭中...")
  286. if scheduler.running:
  287. scheduler.shutdown()
  288. logger.info("服务已关闭")
  289. # ============ 主函数 ============
  290. if __name__ == "__main__":
  291. import uvicorn
  292. port = int(os.getenv("PORT", "8080"))
  293. host = os.getenv("HOST", "0.0.0.0")
  294. logger.info(f"启动服务: http://{host}:{port}")
  295. uvicorn.run(app, host=host, port=port)