main.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. import os
  2. import logging
  3. from fastapi import FastAPI, HTTPException
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from contextlib import asynccontextmanager
  6. from dotenv import load_dotenv
  7. from ..models.schemas import QuestionRequest, QueryResponse, HealthResponse
  8. from ..agent.query_agent import QueryGenerationAgent
  9. from ..database.connection import init_database_manager, get_db_manager
  10. from ..database.models import QueryTaskDAO, QueryTaskStatus, get_query_task_dao
  11. from ..tools.scheduler import create_scheduler, get_scheduler
  12. import time
  13. # 加载环境变量
  14. load_dotenv()
  15. # 配置日志
  16. log_level = os.getenv("LOG_LEVEL", "INFO")
  17. log_file = os.getenv("LOG_FILE", "logs/app.log")
  18. logging.basicConfig(
  19. level=getattr(logging, log_level),
  20. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  21. handlers=[
  22. logging.FileHandler(log_file),
  23. logging.StreamHandler()
  24. ]
  25. )
  26. logger = logging.getLogger(__name__)
  27. # 全局Agent实例
  28. agent = None
  29. task_dao = None
  30. scheduler = None
  31. @asynccontextmanager
  32. async def lifespan(app: FastAPI):
  33. """应用生命周期管理"""
  34. global agent, task_dao, scheduler
  35. # 启动时初始化数据库
  36. logger.info("正在初始化数据库连接...")
  37. try:
  38. init_database_manager(
  39. host=os.getenv("DB_HOST", "localhost"),
  40. port=int(os.getenv("DB_PORT", 3306)),
  41. user=os.getenv("DB_USER", "root"),
  42. password=os.getenv("DB_PASSWORD", ""),
  43. database=os.getenv("DB_NAME", "ai_knowledge"),
  44. charset=os.getenv("DB_CHARSET", "utf8")
  45. )
  46. task_dao = get_query_task_dao()
  47. logger.info("数据库初始化成功")
  48. except Exception as e:
  49. logger.error(f"数据库初始化失败: {e}")
  50. raise
  51. # 启动时初始化Agent
  52. logger.info("正在初始化查询生成Agent...")
  53. try:
  54. agent = QueryGenerationAgent(
  55. gemini_api_key=os.getenv("GEMINI_API_KEY", ""),
  56. model_name=os.getenv("GEMINI_MODEL", "gemini-1.5-pro")
  57. )
  58. logger.info("Agent初始化成功")
  59. except Exception as e:
  60. logger.error(f"Agent初始化失败: {e}")
  61. raise
  62. # 启动时初始化调度器
  63. logger.info("正在启动任务调度器...")
  64. try:
  65. scheduler = create_scheduler(agent, task_dao)
  66. await scheduler.start()
  67. logger.info("任务调度器启动成功")
  68. except Exception as e:
  69. logger.error(f"任务调度器启动失败: {e}")
  70. raise
  71. yield
  72. # 关闭时清理资源
  73. logger.info("正在关闭服务...")
  74. # 停止调度器
  75. if scheduler:
  76. try:
  77. await scheduler.stop()
  78. logger.info("任务调度器已停止")
  79. except Exception as e:
  80. logger.error(f"停止任务调度器失败: {e}")
  81. # 创建FastAPI应用
  82. app = FastAPI(
  83. title="知识工具API",
  84. description="将问题转换为查询词的服务",
  85. version="1.0.0",
  86. lifespan=lifespan
  87. )
  88. # 添加CORS中间件
  89. app.add_middleware(
  90. CORSMiddleware,
  91. allow_origins=["*"],
  92. allow_credentials=True,
  93. allow_methods=["*"],
  94. allow_headers=["*"],
  95. )
  96. @app.get("/health", response_model=HealthResponse)
  97. async def health_check():
  98. """健康检查接口"""
  99. return HealthResponse(
  100. status="healthy",
  101. message="服务运行正常"
  102. )
  103. @app.post("/generate-queries", response_model=QueryResponse)
  104. async def generate_queries(request: QuestionRequest):
  105. """
  106. 生成查询词接口
  107. Args:
  108. request: 包含question字段的请求
  109. Returns:
  110. 任务信息(状态为待执行)
  111. """
  112. global task_dao
  113. if task_dao is None:
  114. raise HTTPException(status_code=503, detail="数据库未初始化")
  115. try:
  116. logger.info(f"收到问题: {request.question}")
  117. # 生成任务ID(使用时间戳)
  118. task_id = int(time.time() * 1000)
  119. # 创建任务记录,状态设置为0(待执行)
  120. task_dao.create_task(task_id, request.question, request.knowledgeType)
  121. logger.info(f"创建任务: {task_id},状态: 待执行")
  122. # 立即返回待执行状态
  123. return QueryResponse(
  124. task_id=task_id,
  125. queries=[], # 空列表,因为还未处理
  126. original_question=request.question,
  127. total_count=0, # 0个查询词,因为还未处理
  128. status=QueryTaskStatus.PENDING # 待执行状态
  129. )
  130. except Exception as e:
  131. logger.error(f"创建任务失败: {e}")
  132. raise HTTPException(status_code=500, detail=f"创建任务失败: {str(e)}")
  133. @app.get("/task/{task_id}")
  134. async def get_task(task_id: int):
  135. """
  136. 获取任务信息接口
  137. Args:
  138. task_id: 任务ID
  139. Returns:
  140. 任务信息
  141. """
  142. global task_dao
  143. if task_dao is None:
  144. raise HTTPException(status_code=503, detail="数据库未初始化")
  145. try:
  146. task = task_dao.get_task(task_id)
  147. if task is None:
  148. raise HTTPException(status_code=404, detail="任务不存在")
  149. return {
  150. "task_id": task.task_id,
  151. "question": task.question,
  152. "queries": task.querys,
  153. "status": task.status,
  154. "status_text": {
  155. 0: "待执行",
  156. 1: "执行中",
  157. 2: "成功",
  158. 3: "失败"
  159. }.get(task.status, "未知")
  160. }
  161. except HTTPException:
  162. raise
  163. except Exception as e:
  164. logger.error(f"获取任务信息失败: {e}")
  165. raise HTTPException(status_code=500, detail=f"获取任务信息失败: {str(e)}")
  166. @app.get("/tasks")
  167. async def get_tasks(status: int = None, limit: int = 100):
  168. """
  169. 获取任务列表接口
  170. Args:
  171. status: 任务状态筛选(可选)
  172. limit: 限制数量
  173. Returns:
  174. 任务列表
  175. """
  176. global task_dao
  177. if task_dao is None:
  178. raise HTTPException(status_code=503, detail="数据库未初始化")
  179. try:
  180. if status is not None:
  181. tasks = task_dao.get_tasks_by_status(status, limit)
  182. else:
  183. # 获取所有状态的任务(这里简化处理,实际可能需要更复杂的查询)
  184. all_tasks = []
  185. for s in [0, 1, 2, 3]:
  186. tasks_by_status = task_dao.get_tasks_by_status(s, limit // 4)
  187. all_tasks.extend(tasks_by_status)
  188. tasks = all_tasks[:limit]
  189. return {
  190. "tasks": [
  191. {
  192. "task_id": task.task_id,
  193. "question": task.question,
  194. "queries": task.querys,
  195. "status": task.status,
  196. "status_text": {
  197. 0: "待执行",
  198. 1: "执行中",
  199. 2: "成功",
  200. 3: "失败"
  201. }.get(task.status, "未知")
  202. }
  203. for task in tasks
  204. ],
  205. "total": len(tasks)
  206. }
  207. except Exception as e:
  208. logger.error(f"获取任务列表失败: {e}")
  209. raise HTTPException(status_code=500, detail=f"获取任务列表失败: {str(e)}")
  210. @app.get("/")
  211. async def root():
  212. """根路径"""
  213. return {
  214. "message": "知识工具API服务",
  215. "version": "1.0.0",
  216. "docs": "/docs"
  217. }
  218. if __name__ == "__main__":
  219. import uvicorn
  220. uvicorn.run(
  221. "src.api.main:app",
  222. host=os.getenv("HOST", "0.0.0.0"),
  223. port=int(os.getenv("PORT", 8079)),
  224. reload=os.getenv("DEBUG", "True").lower() == "true",
  225. log_level=os.getenv("LOG_LEVEL", "INFO").lower()
  226. )