123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276 |
- import os
- import logging
- from fastapi import FastAPI, HTTPException
- from fastapi.middleware.cors import CORSMiddleware
- from contextlib import asynccontextmanager
- from dotenv import load_dotenv
- from ..models.schemas import QuestionRequest, QueryResponse, HealthResponse
- from ..agent.query_agent import QueryGenerationAgent
- from ..database.connection import init_database_manager, get_db_manager
- from ..database.models import QueryTaskDAO, QueryTaskStatus, get_query_task_dao
- from ..tools.scheduler import create_scheduler, get_scheduler
- import time
- # 加载环境变量
- load_dotenv()
- # 配置日志
- log_level = os.getenv("LOG_LEVEL", "INFO")
- log_file = os.getenv("LOG_FILE", "logs/app.log")
- logging.basicConfig(
- level=getattr(logging, log_level),
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler(log_file),
- logging.StreamHandler()
- ]
- )
- logger = logging.getLogger(__name__)
- # 全局Agent实例
- agent = None
- task_dao = None
- scheduler = None
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- """应用生命周期管理"""
- global agent, task_dao, scheduler
-
- # 启动时初始化数据库
- logger.info("正在初始化数据库连接...")
- try:
- init_database_manager(
- host=os.getenv("DB_HOST", "localhost"),
- port=int(os.getenv("DB_PORT", 3306)),
- user=os.getenv("DB_USER", "root"),
- password=os.getenv("DB_PASSWORD", ""),
- database=os.getenv("DB_NAME", "ai_knowledge"),
- charset=os.getenv("DB_CHARSET", "utf8")
- )
- task_dao = get_query_task_dao()
- logger.info("数据库初始化成功")
- except Exception as e:
- logger.error(f"数据库初始化失败: {e}")
- raise
-
- # 启动时初始化Agent
- logger.info("正在初始化查询生成Agent...")
- try:
- agent = QueryGenerationAgent(
- gemini_api_key=os.getenv("GEMINI_API_KEY", ""),
- model_name=os.getenv("GEMINI_MODEL", "gemini-1.5-pro")
- )
- logger.info("Agent初始化成功")
- except Exception as e:
- logger.error(f"Agent初始化失败: {e}")
- raise
-
- # 启动时初始化调度器
- logger.info("正在启动任务调度器...")
- try:
- scheduler = create_scheduler(agent, task_dao)
- await scheduler.start()
- logger.info("任务调度器启动成功")
- except Exception as e:
- logger.error(f"任务调度器启动失败: {e}")
- raise
-
- yield
-
- # 关闭时清理资源
- logger.info("正在关闭服务...")
-
- # 停止调度器
- if scheduler:
- try:
- await scheduler.stop()
- logger.info("任务调度器已停止")
- except Exception as e:
- logger.error(f"停止任务调度器失败: {e}")
- # 创建FastAPI应用
- app = FastAPI(
- title="知识工具API",
- description="将问题转换为查询词的服务",
- version="1.0.0",
- lifespan=lifespan
- )
- # 添加CORS中间件
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- @app.get("/health", response_model=HealthResponse)
- async def health_check():
- """健康检查接口"""
- return HealthResponse(
- status="healthy",
- message="服务运行正常"
- )
- @app.post("/generate-queries", response_model=QueryResponse)
- async def generate_queries(request: QuestionRequest):
- """
- 生成查询词接口
-
- Args:
- request: 包含question字段的请求
-
- Returns:
- 任务信息(状态为待执行)
- """
- global task_dao
-
- if task_dao is None:
- raise HTTPException(status_code=503, detail="数据库未初始化")
-
- try:
- logger.info(f"收到问题: {request.question}")
-
- # 生成任务ID(使用时间戳)
- task_id = int(time.time() * 1000)
-
- # 创建任务记录,状态设置为0(待执行)
- task_dao.create_task(task_id, request.question, request.knowledgeType)
- logger.info(f"创建任务: {task_id},状态: 待执行")
-
- # 立即返回待执行状态
- return QueryResponse(
- task_id=task_id,
- queries=[], # 空列表,因为还未处理
- original_question=request.question,
- total_count=0, # 0个查询词,因为还未处理
- status=QueryTaskStatus.PENDING # 待执行状态
- )
-
- except Exception as e:
- logger.error(f"创建任务失败: {e}")
- raise HTTPException(status_code=500, detail=f"创建任务失败: {str(e)}")
- @app.get("/task/{task_id}")
- async def get_task(task_id: int):
- """
- 获取任务信息接口
-
- Args:
- task_id: 任务ID
-
- Returns:
- 任务信息
- """
- global task_dao
-
- if task_dao is None:
- raise HTTPException(status_code=503, detail="数据库未初始化")
-
- try:
- task = task_dao.get_task(task_id)
- if task is None:
- raise HTTPException(status_code=404, detail="任务不存在")
-
- return {
- "task_id": task.task_id,
- "question": task.question,
- "queries": task.querys,
- "status": task.status,
- "status_text": {
- 0: "待执行",
- 1: "执行中",
- 2: "成功",
- 3: "失败"
- }.get(task.status, "未知")
- }
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"获取任务信息失败: {e}")
- raise HTTPException(status_code=500, detail=f"获取任务信息失败: {str(e)}")
- @app.get("/tasks")
- async def get_tasks(status: int = None, limit: int = 100):
- """
- 获取任务列表接口
-
- Args:
- status: 任务状态筛选(可选)
- limit: 限制数量
-
- Returns:
- 任务列表
- """
- global task_dao
-
- if task_dao is None:
- raise HTTPException(status_code=503, detail="数据库未初始化")
-
- try:
- if status is not None:
- tasks = task_dao.get_tasks_by_status(status, limit)
- else:
- # 获取所有状态的任务(这里简化处理,实际可能需要更复杂的查询)
- all_tasks = []
- for s in [0, 1, 2, 3]:
- tasks_by_status = task_dao.get_tasks_by_status(s, limit // 4)
- all_tasks.extend(tasks_by_status)
- tasks = all_tasks[:limit]
-
- return {
- "tasks": [
- {
- "task_id": task.task_id,
- "question": task.question,
- "queries": task.querys,
- "status": task.status,
- "status_text": {
- 0: "待执行",
- 1: "执行中",
- 2: "成功",
- 3: "失败"
- }.get(task.status, "未知")
- }
- for task in tasks
- ],
- "total": len(tasks)
- }
-
- except Exception as e:
- logger.error(f"获取任务列表失败: {e}")
- raise HTTPException(status_code=500, detail=f"获取任务列表失败: {str(e)}")
- @app.get("/")
- async def root():
- """根路径"""
- return {
- "message": "知识工具API服务",
- "version": "1.0.0",
- "docs": "/docs"
- }
- if __name__ == "__main__":
- import uvicorn
-
- uvicorn.run(
- "src.api.main:app",
- host=os.getenv("HOST", "0.0.0.0"),
- port=int(os.getenv("PORT", 8079)),
- reload=os.getenv("DEBUG", "True").lower() == "true",
- log_level=os.getenv("LOG_LEVEL", "INFO").lower()
- )
|