import os import logging from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from dotenv import load_dotenv from ..models.schemas import QuestionRequest, QueryResponse, HealthResponse from ..agent.query_agent import QueryGenerationAgent from ..database.connection import init_database_manager, get_db_manager from ..database.models import QueryTaskDAO, QueryTaskStatus, get_query_task_dao from ..tools.scheduler import create_scheduler, get_scheduler import time # 加载环境变量 load_dotenv() # 配置日志 log_level = os.getenv("LOG_LEVEL", "INFO") log_file = os.getenv("LOG_FILE", "logs/app.log") logging.basicConfig( level=getattr(logging, log_level), format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # 全局Agent实例 agent = None task_dao = None scheduler = None @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理""" global agent, task_dao, scheduler # 启动时初始化数据库 logger.info("正在初始化数据库连接...") try: init_database_manager( host=os.getenv("DB_HOST", "localhost"), port=int(os.getenv("DB_PORT", 3306)), user=os.getenv("DB_USER", "root"), password=os.getenv("DB_PASSWORD", ""), database=os.getenv("DB_NAME", "ai_knowledge"), charset=os.getenv("DB_CHARSET", "utf8") ) task_dao = get_query_task_dao() logger.info("数据库初始化成功") except Exception as e: logger.error(f"数据库初始化失败: {e}") raise # 启动时初始化Agent logger.info("正在初始化查询生成Agent...") try: agent = QueryGenerationAgent( gemini_api_key=os.getenv("GEMINI_API_KEY", ""), model_name=os.getenv("GEMINI_MODEL", "gemini-1.5-pro") ) logger.info("Agent初始化成功") except Exception as e: logger.error(f"Agent初始化失败: {e}") raise # 启动时初始化调度器 logger.info("正在启动任务调度器...") try: scheduler = create_scheduler(agent, task_dao) await scheduler.start() logger.info("任务调度器启动成功") except Exception as e: logger.error(f"任务调度器启动失败: {e}") raise yield # 关闭时清理资源 logger.info("正在关闭服务...") # 停止调度器 if scheduler: try: await scheduler.stop() logger.info("任务调度器已停止") except Exception as e: logger.error(f"停止任务调度器失败: {e}") # 创建FastAPI应用 app = FastAPI( title="知识工具API", description="将问题转换为查询词的服务", version="1.0.0", lifespan=lifespan ) # 添加CORS中间件 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/health", response_model=HealthResponse) async def health_check(): """健康检查接口""" return HealthResponse( status="healthy", message="服务运行正常" ) @app.post("/generate-queries", response_model=QueryResponse) async def generate_queries(request: QuestionRequest): """ 生成查询词接口 Args: request: 包含question字段的请求 Returns: 任务信息(状态为待执行) """ global task_dao if task_dao is None: raise HTTPException(status_code=503, detail="数据库未初始化") try: logger.info(f"收到问题: {request.question}") # 生成任务ID(使用时间戳) task_id = int(time.time() * 1000) # 创建任务记录,状态设置为0(待执行) task_dao.create_task(task_id, request.question, request.knowledgeType) logger.info(f"创建任务: {task_id},状态: 待执行") # 立即返回待执行状态 return QueryResponse( task_id=task_id, queries=[], # 空列表,因为还未处理 original_question=request.question, total_count=0, # 0个查询词,因为还未处理 status=QueryTaskStatus.PENDING # 待执行状态 ) except Exception as e: logger.error(f"创建任务失败: {e}") raise HTTPException(status_code=500, detail=f"创建任务失败: {str(e)}") @app.get("/task/{task_id}") async def get_task(task_id: int): """ 获取任务信息接口 Args: task_id: 任务ID Returns: 任务信息 """ global task_dao if task_dao is None: raise HTTPException(status_code=503, detail="数据库未初始化") try: task = task_dao.get_task(task_id) if task is None: raise HTTPException(status_code=404, detail="任务不存在") return { "task_id": task.task_id, "question": task.question, "queries": task.querys, "status": task.status, "status_text": { 0: "待执行", 1: "执行中", 2: "成功", 3: "失败" }.get(task.status, "未知") } except HTTPException: raise except Exception as e: logger.error(f"获取任务信息失败: {e}") raise HTTPException(status_code=500, detail=f"获取任务信息失败: {str(e)}") @app.get("/tasks") async def get_tasks(status: int = None, limit: int = 100): """ 获取任务列表接口 Args: status: 任务状态筛选(可选) limit: 限制数量 Returns: 任务列表 """ global task_dao if task_dao is None: raise HTTPException(status_code=503, detail="数据库未初始化") try: if status is not None: tasks = task_dao.get_tasks_by_status(status, limit) else: # 获取所有状态的任务(这里简化处理,实际可能需要更复杂的查询) all_tasks = [] for s in [0, 1, 2, 3]: tasks_by_status = task_dao.get_tasks_by_status(s, limit // 4) all_tasks.extend(tasks_by_status) tasks = all_tasks[:limit] return { "tasks": [ { "task_id": task.task_id, "question": task.question, "queries": task.querys, "status": task.status, "status_text": { 0: "待执行", 1: "执行中", 2: "成功", 3: "失败" }.get(task.status, "未知") } for task in tasks ], "total": len(tasks) } except Exception as e: logger.error(f"获取任务列表失败: {e}") raise HTTPException(status_code=500, detail=f"获取任务列表失败: {str(e)}") @app.get("/") async def root(): """根路径""" return { "message": "知识工具API服务", "version": "1.0.0", "docs": "/docs" } if __name__ == "__main__": import uvicorn uvicorn.run( "src.api.main:app", host=os.getenv("HOST", "0.0.0.0"), port=int(os.getenv("PORT", 8079)), reload=os.getenv("DEBUG", "True").lower() == "true", log_level=os.getenv("LOG_LEVEL", "INFO").lower() )