|
@@ -0,0 +1,276 @@
|
|
|
|
+import os
|
|
|
|
+import logging
|
|
|
|
+from fastapi import FastAPI, HTTPException
|
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
+from contextlib import asynccontextmanager
|
|
|
|
+from dotenv import load_dotenv
|
|
|
|
+
|
|
|
|
+from ..models.schemas import QuestionRequest, QueryResponse, HealthResponse
|
|
|
|
+from ..agent.query_agent import QueryGenerationAgent
|
|
|
|
+from ..database.connection import init_database_manager, get_db_manager
|
|
|
|
+from ..database.models import QueryTaskDAO, QueryTaskStatus, get_query_task_dao
|
|
|
|
+from ..tools.scheduler import create_scheduler, get_scheduler
|
|
|
|
+import time
|
|
|
|
+
|
|
|
|
+# 加载环境变量
|
|
|
|
+load_dotenv()
|
|
|
|
+
|
|
|
|
+# 配置日志
|
|
|
|
+log_level = os.getenv("LOG_LEVEL", "INFO")
|
|
|
|
+log_file = os.getenv("LOG_FILE", "logs/app.log")
|
|
|
|
+
|
|
|
|
+logging.basicConfig(
|
|
|
|
+ level=getattr(logging, log_level),
|
|
|
|
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
|
|
+ handlers=[
|
|
|
|
+ logging.FileHandler(log_file),
|
|
|
|
+ logging.StreamHandler()
|
|
|
|
+ ]
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
+
|
|
|
|
+# 全局Agent实例
|
|
|
|
+agent = None
|
|
|
|
+task_dao = None
|
|
|
|
+scheduler = None
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@asynccontextmanager
|
|
|
|
+async def lifespan(app: FastAPI):
|
|
|
|
+ """应用生命周期管理"""
|
|
|
|
+ global agent, task_dao, scheduler
|
|
|
|
+
|
|
|
|
+ # 启动时初始化数据库
|
|
|
|
+ logger.info("正在初始化数据库连接...")
|
|
|
|
+ try:
|
|
|
|
+ init_database_manager(
|
|
|
|
+ host=os.getenv("DB_HOST", "localhost"),
|
|
|
|
+ port=int(os.getenv("DB_PORT", 3306)),
|
|
|
|
+ user=os.getenv("DB_USER", "root"),
|
|
|
|
+ password=os.getenv("DB_PASSWORD", ""),
|
|
|
|
+ database=os.getenv("DB_NAME", "ai_knowledge"),
|
|
|
|
+ charset=os.getenv("DB_CHARSET", "utf8")
|
|
|
|
+ )
|
|
|
|
+ task_dao = get_query_task_dao()
|
|
|
|
+ logger.info("数据库初始化成功")
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"数据库初始化失败: {e}")
|
|
|
|
+ raise
|
|
|
|
+
|
|
|
|
+ # 启动时初始化Agent
|
|
|
|
+ logger.info("正在初始化查询生成Agent...")
|
|
|
|
+ try:
|
|
|
|
+ agent = QueryGenerationAgent(
|
|
|
|
+ gemini_api_key=os.getenv("GEMINI_API_KEY", ""),
|
|
|
|
+ model_name=os.getenv("GEMINI_MODEL", "gemini-1.5-pro")
|
|
|
|
+ )
|
|
|
|
+ logger.info("Agent初始化成功")
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Agent初始化失败: {e}")
|
|
|
|
+ raise
|
|
|
|
+
|
|
|
|
+ # 启动时初始化调度器
|
|
|
|
+ logger.info("正在启动任务调度器...")
|
|
|
|
+ try:
|
|
|
|
+ scheduler = create_scheduler(agent, task_dao)
|
|
|
|
+ await scheduler.start()
|
|
|
|
+ logger.info("任务调度器启动成功")
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"任务调度器启动失败: {e}")
|
|
|
|
+ raise
|
|
|
|
+
|
|
|
|
+ yield
|
|
|
|
+
|
|
|
|
+ # 关闭时清理资源
|
|
|
|
+ logger.info("正在关闭服务...")
|
|
|
|
+
|
|
|
|
+ # 停止调度器
|
|
|
|
+ if scheduler:
|
|
|
|
+ try:
|
|
|
|
+ await scheduler.stop()
|
|
|
|
+ logger.info("任务调度器已停止")
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"停止任务调度器失败: {e}")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# 创建FastAPI应用
|
|
|
|
+app = FastAPI(
|
|
|
|
+ title="知识工具API",
|
|
|
|
+ description="将问题转换为查询词的服务",
|
|
|
|
+ version="1.0.0",
|
|
|
|
+ lifespan=lifespan
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+# 添加CORS中间件
|
|
|
|
+app.add_middleware(
|
|
|
|
+ CORSMiddleware,
|
|
|
|
+ allow_origins=["*"],
|
|
|
|
+ allow_credentials=True,
|
|
|
|
+ allow_methods=["*"],
|
|
|
|
+ allow_headers=["*"],
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.get("/health", response_model=HealthResponse)
|
|
|
|
+async def health_check():
|
|
|
|
+ """健康检查接口"""
|
|
|
|
+ return HealthResponse(
|
|
|
|
+ status="healthy",
|
|
|
|
+ message="服务运行正常"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.post("/generate-queries", response_model=QueryResponse)
|
|
|
|
+async def generate_queries(request: QuestionRequest):
|
|
|
|
+ """
|
|
|
|
+ 生成查询词接口
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ request: 包含question字段的请求
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ 任务信息(状态为待执行)
|
|
|
|
+ """
|
|
|
|
+ global task_dao
|
|
|
|
+
|
|
|
|
+ if task_dao is None:
|
|
|
|
+ raise HTTPException(status_code=503, detail="数据库未初始化")
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ logger.info(f"收到问题: {request.question}")
|
|
|
|
+
|
|
|
|
+ # 生成任务ID(使用时间戳)
|
|
|
|
+ task_id = int(time.time() * 1000)
|
|
|
|
+
|
|
|
|
+ # 创建任务记录,状态设置为0(待执行)
|
|
|
|
+ task_dao.create_task(task_id, request.question)
|
|
|
|
+ logger.info(f"创建任务: {task_id},状态: 待执行")
|
|
|
|
+
|
|
|
|
+ # 立即返回待执行状态
|
|
|
|
+ return QueryResponse(
|
|
|
|
+ task_id=task_id,
|
|
|
|
+ queries=[], # 空列表,因为还未处理
|
|
|
|
+ original_question=request.question,
|
|
|
|
+ total_count=0, # 0个查询词,因为还未处理
|
|
|
|
+ status=QueryTaskStatus.PENDING # 待执行状态
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"创建任务失败: {e}")
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"创建任务失败: {str(e)}")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.get("/task/{task_id}")
|
|
|
|
+async def get_task(task_id: int):
|
|
|
|
+ """
|
|
|
|
+ 获取任务信息接口
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ task_id: 任务ID
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ 任务信息
|
|
|
|
+ """
|
|
|
|
+ global task_dao
|
|
|
|
+
|
|
|
|
+ if task_dao is None:
|
|
|
|
+ raise HTTPException(status_code=503, detail="数据库未初始化")
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ task = task_dao.get_task(task_id)
|
|
|
|
+ if task is None:
|
|
|
|
+ raise HTTPException(status_code=404, detail="任务不存在")
|
|
|
|
+
|
|
|
|
+ return {
|
|
|
|
+ "task_id": task.task_id,
|
|
|
|
+ "question": task.question,
|
|
|
|
+ "queries": task.querys,
|
|
|
|
+ "status": task.status,
|
|
|
|
+ "status_text": {
|
|
|
|
+ 0: "待执行",
|
|
|
|
+ 1: "执行中",
|
|
|
|
+ 2: "成功",
|
|
|
|
+ 3: "失败"
|
|
|
|
+ }.get(task.status, "未知")
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ except HTTPException:
|
|
|
|
+ raise
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"获取任务信息失败: {e}")
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"获取任务信息失败: {str(e)}")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.get("/tasks")
|
|
|
|
+async def get_tasks(status: int = None, limit: int = 100):
|
|
|
|
+ """
|
|
|
|
+ 获取任务列表接口
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ status: 任务状态筛选(可选)
|
|
|
|
+ limit: 限制数量
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ 任务列表
|
|
|
|
+ """
|
|
|
|
+ global task_dao
|
|
|
|
+
|
|
|
|
+ if task_dao is None:
|
|
|
|
+ raise HTTPException(status_code=503, detail="数据库未初始化")
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ if status is not None:
|
|
|
|
+ tasks = task_dao.get_tasks_by_status(status, limit)
|
|
|
|
+ else:
|
|
|
|
+ # 获取所有状态的任务(这里简化处理,实际可能需要更复杂的查询)
|
|
|
|
+ all_tasks = []
|
|
|
|
+ for s in [0, 1, 2, 3]:
|
|
|
|
+ tasks_by_status = task_dao.get_tasks_by_status(s, limit // 4)
|
|
|
|
+ all_tasks.extend(tasks_by_status)
|
|
|
|
+ tasks = all_tasks[:limit]
|
|
|
|
+
|
|
|
|
+ return {
|
|
|
|
+ "tasks": [
|
|
|
|
+ {
|
|
|
|
+ "task_id": task.task_id,
|
|
|
|
+ "question": task.question,
|
|
|
|
+ "queries": task.querys,
|
|
|
|
+ "status": task.status,
|
|
|
|
+ "status_text": {
|
|
|
|
+ 0: "待执行",
|
|
|
|
+ 1: "执行中",
|
|
|
|
+ 2: "成功",
|
|
|
|
+ 3: "失败"
|
|
|
|
+ }.get(task.status, "未知")
|
|
|
|
+ }
|
|
|
|
+ for task in tasks
|
|
|
|
+ ],
|
|
|
|
+ "total": len(tasks)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"获取任务列表失败: {e}")
|
|
|
|
+ raise HTTPException(status_code=500, detail=f"获取任务列表失败: {str(e)}")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@app.get("/")
|
|
|
|
+async def root():
|
|
|
|
+ """根路径"""
|
|
|
|
+ return {
|
|
|
|
+ "message": "知识工具API服务",
|
|
|
|
+ "version": "1.0.0",
|
|
|
|
+ "docs": "/docs"
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
+ import uvicorn
|
|
|
|
+
|
|
|
|
+ uvicorn.run(
|
|
|
|
+ "src.api.main:app",
|
|
|
|
+ host=os.getenv("HOST", "0.0.0.0"),
|
|
|
|
+ port=int(os.getenv("PORT", 8079)),
|
|
|
|
+ reload=os.getenv("DEBUG", "True").lower() == "true",
|
|
|
|
+ log_level=os.getenv("LOG_LEVEL", "INFO").lower()
|
|
|
|
+ )
|