Browse Source

Merge branch 'main' of https://git.yishihui.com/ai/knowledge-agent

jihuaqiang 1 month ago
parent
commit
ab3ba0b5a7
1 changed files with 38 additions and 15 deletions
  1. 38 15
      agent.py

+ 38 - 15
agent.py

@@ -12,6 +12,7 @@ import time
 from typing import Any, Dict, List, Optional, TypedDict, Annotated
 from typing import Any, Dict, List, Optional, TypedDict, Annotated
 from contextlib import asynccontextmanager
 from contextlib import asynccontextmanager
 import asyncio
 import asyncio
+from utils.mysql_db import MysqlHelper
 
 
 # 保证可以导入本项目模块
 # 保证可以导入本项目模块
 sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 sys.path.append(os.path.dirname(os.path.abspath(__file__)))
@@ -584,21 +585,30 @@ async def process_request_background(request_id: str):
         # 处理失败,更新状态为3
         # 处理失败,更新状态为3
         update_request_status(request_id, 3)
         update_request_status(request_id, 3)
 
 
+
+extraction_requests: set = set()
+
 @app.post("/extract")
 @app.post("/extract")
-async def extract(input: str):
-    """
-    执行Agent处理用户指令
-    
-    Args:
-        input: 包含用户指令的对象
-        
-    Returns:
-        dict: 包含执行结果的字典
-    """
+async def extract(request_id: str, query_word: str):
     try:
     try:
-        result = execute_agent_with_api(input)
-        return {"status": "success", "result": result}
+        # 检查请求是否已经在处理中
+        async with RUNNING_LOCK:
+            if request_id in extraction_requests:
+                return {"status": 1, "request_id": request_id, "message": "请求已在处理中"}
+            extraction_requests.add(request_id)
+        
+        try:
+            # 更新状态为处理中
+            update_extract_status(request_id, 1)
+            # 执行Agent
+            result = execute_agent_with_api(json.dumps({"query_word":query_word, "request_id": request_id}
+        finally:
+            # 无论成功失败,都从运行集合中移除
+            async with RUNNING_LOCK:
+                extraction_requests.discard(request_id)
     except Exception as e:
     except Exception as e:
+        # 发生异常,更新状态为处理失败
+        update_request_status(request_id, 3)
         raise HTTPException(status_code=500, detail=f"执行Agent时出错: {str(e)}")
         raise HTTPException(status_code=500, detail=f"执行Agent时出错: {str(e)}")
 
 
 @app.post("/expand")
 @app.post("/expand")
@@ -626,8 +636,21 @@ async def expand(request: ExpandRequest, background_tasks: BackgroundTasks):
         return {"status": 1, "requestId": requestId, "message": "扩展查询处理已启动"}
         return {"status": 1, "requestId": requestId, "message": "扩展查询处理已启动"}
         
         
     except Exception as e:
     except Exception as e:
-        logger.error(f"启动扩展查询处理失败: requestId={requestId}, error={e}")
-        raise HTTPException(status_code=500, detail=f"启动扩展查询处理时出错: {str(e)}")
+        raise HTTPException(status_code=500, detail=f"执行Agent时出错: {str(e)}")
+def update_extract_status(request_id: str, status: int):
+    try:
+        from utils.mysql_db import MysqlHelper
+        
+        sql = "UPDATE knowledge_request SET extraction_status = %s WHERE request_id = %s"
+        result = MysqlHelper.update_values(sql, (status, request_id))
+        
+        if result is not None:
+            logger.info(f"更新请求状态成功: requestId={request_id}, status={status}")
+        else:
+            logger.error(f"更新请求状态失败: requestId={request_id}, status={status}")
+            
+    except Exception as e:
+        logger.error(f"更新请求状态异常: requestId={request_id}, status={status}, error={e}")
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     # 启动服务
     # 启动服务
@@ -637,4 +660,4 @@ if __name__ == "__main__":
         port=8080,
         port=8080,
         reload=True,  # 开发模式,自动重载
         reload=True,  # 开发模式,自动重载
         log_level="info"
         log_level="info"
-    ) 
+    )