Explorar o código

支持并发请求

刘立冬 hai 1 semana
pai
achega
d2a2a175ce
Modificáronse 5 ficheiros con 351 adicións e 137 borrados
  1. 141 0
      SETUP_ENV.md
  2. 5 1
      api/config.py
  3. 32 37
      api/pipeline_wrapper.py
  4. 56 0
      api/request_context.py
  5. 117 99
      api/search_service.py

+ 141 - 0
SETUP_ENV.md

@@ -0,0 +1,141 @@
+# 虚拟环境设置指南
+
+## 创建虚拟环境
+
+### 方法1:使用 venv(推荐)
+
+```bash
+# 创建虚拟环境(在当前目录下创建 .venv 文件夹)
+python3 -m venv .venv
+
+# 或者指定名称
+python3 -m venv venv
+```
+
+### 方法2:使用 virtualenv
+
+```bash
+# 先安装 virtualenv(如果未安装)
+pip install virtualenv
+
+# 创建虚拟环境
+virtualenv venv
+```
+
+## 激活虚拟环境
+
+### macOS/Linux
+
+```bash
+# 激活虚拟环境
+source .venv/bin/activate
+
+# 或者如果命名为 venv
+source venv/bin/activate
+```
+
+激活成功后,终端提示符前会显示 `(venv)` 或 `(.venv)`
+
+### Windows
+
+```bash
+# 激活虚拟环境
+.venv\Scripts\activate
+
+# 或者如果命名为 venv
+venv\Scripts\activate
+```
+
+## 停用虚拟环境
+
+```bash
+deactivate
+```
+
+## 完整设置流程
+
+```bash
+# 1. 创建虚拟环境
+python3 -m venv .venv
+
+# 2. 激活虚拟环境
+source .venv/bin/activate
+
+# 3. 升级pip(可选但推荐)
+pip install --upgrade pip
+
+# 4. 安装项目依赖
+pip install -r requirements.txt
+
+# 5. 设置环境变量
+export OPENROUTER_API_KEY='your-api-key-here'
+
+# 6. 验证安装
+python --version
+pip list
+```
+
+## 检查虚拟环境是否激活
+
+```bash
+# 方法1:查看Python路径
+which python
+# 应该显示虚拟环境路径,如:/path/to/project/.venv/bin/python
+
+# 方法2:查看pip路径
+which pip
+# 应该显示虚拟环境路径,如:/path/to/project/.venv/bin/pip
+
+# 方法3:查看环境变量
+echo $VIRTUAL_ENV
+# 应该显示虚拟环境路径
+```
+
+## 常见问题
+
+### 1. 找不到 python3 命令
+
+```bash
+# 尝试使用 python
+python -m venv .venv
+
+# 或者查找Python安装路径
+which python3
+```
+
+### 2. 权限错误
+
+```bash
+# 使用 sudo(不推荐,但有时需要)
+sudo python3 -m venv .venv
+```
+
+### 3. 虚拟环境已存在
+
+```bash
+# 删除旧虚拟环境
+rm -rf .venv
+
+# 重新创建
+python3 -m venv .venv
+```
+
+## 项目特定设置
+
+对于本项目,建议使用 `.venv` 作为虚拟环境名称(已在 .gitignore 中忽略)
+
+```bash
+# 创建虚拟环境
+python3 -m venv .venv
+
+# 激活虚拟环境
+source .venv/bin/activate
+
+# 安装依赖
+pip install -r requirements.txt
+
+# 设置API密钥
+export OPENROUTER_API_KEY='your-api-key-here'
+```
+
+

+ 5 - 1
api/config.py

@@ -14,7 +14,11 @@ class APIConfig:
     # API服务配置
     API_HOST: str = os.getenv("API_HOST", "0.0.0.0")
     API_PORT: int = int(os.getenv("API_PORT", "8001"))
-    
+
+    # 并发控制配置
+    MAX_CONCURRENT_REQUESTS: int = int(os.getenv("MAX_CONCURRENT_REQUESTS", "5"))
+    TEMP_REQUESTS_DIR: str = os.getenv("TEMP_REQUESTS_DIR", "temp_requests")
+
     # Pipeline配置
     OPENROUTER_API_KEY: Optional[str] = os.getenv("OPENROUTER_API_KEY")
     OUTPUT_DIR: str = os.getenv("OUTPUT_DIR", "output_v2")

+ 32 - 37
api/pipeline_wrapper.py

@@ -19,25 +19,30 @@ logger = logging.getLogger(__name__)
 
 class PipelineWrapper:
     """Pipeline包装器,复用阶段3-7"""
-    
-    def __init__(self):
-        """初始化Pipeline包装器"""
-        # 创建临时输出目录
-        self.temp_output_dir = tempfile.mkdtemp(prefix='api_pipeline_')
-        logger.info(f"创建临时输出目录: {self.temp_output_dir}")
-        
-        # 初始化EnhancedSearchV2实例
-        # 注意:how_json_path参数是必需的,但我们不会使用它(因为我们跳过阶段1-2)
-        # 创建一个空的临时文件作为占位符
-        temp_how_file = os.path.join(self.temp_output_dir, 'temp_how.json')
+
+    def __init__(self, output_dir: str, request_id: str = None):
+        """
+        初始化Pipeline包装器
+
+        Args:
+            output_dir: 请求专用的输出目录(由RequestContext提供)
+            request_id: 请求ID(用于日志标识)
+        """
+        self.output_dir = output_dir
+        self.request_id = request_id or "unknown"
+
+        logger.info(f"[{self.request_id}] 初始化Pipeline,输出目录: {output_dir}")
+
+        # 创建占位符how.json(API模式不需要真实文件)
+        temp_how_file = os.path.join(output_dir, 'placeholder_how.json')
         with open(temp_how_file, 'w', encoding='utf-8') as f:
-            import json
             json.dump({'解构结果': {}}, f)
-        
+
+        # 初始化EnhancedSearchV2实例,使用传入的output_dir
         self.pipeline = EnhancedSearchV2(
             how_json_path=temp_how_file,  # 占位符文件,实际不会使用
             openrouter_api_key=APIConfig.OPENROUTER_API_KEY,
-            output_dir=self.temp_output_dir,
+            output_dir=output_dir,  # 使用独立目录
             top_n=10,
             max_total_searches=APIConfig.MAX_TOTAL_SEARCHES,
             search_max_workers=APIConfig.SEARCH_MAX_WORKERS,
@@ -60,22 +65,22 @@ class PipelineWrapper:
             similarity_max_workers=APIConfig.SIMILARITY_MAX_WORKERS,
             similarity_min_similarity=APIConfig.SIMILARITY_MIN_SIMILARITY
         )
-        
-        logger.info("Pipeline包装器初始化完成")
-    
-    async def run_stages_3_to_7(
+
+        logger.info(f"[{self.request_id}] Pipeline包装器初始化完成")
+
+    def run_stages_3_to_7_sync(
         self,
         features_data: List[Dict[str, Any]]
     ) -> Dict[str, Any]:
         """
-        执行阶段3-7的完整流程
-        
+        执行阶段3-7的完整流程(同步版本,用于在线程池中执行)
+
         Args:
             features_data: 阶段2的输出格式数据(candidate_words.json格式)
-        
+
         Returns:
             包含阶段3-7结果的字典
-        
+
         Raises:
             Exception: 当任何阶段执行失败时
         """
@@ -132,19 +137,19 @@ class PipelineWrapper:
             # 阶段7:相似度分析
             logger.info("阶段7:相似度分析...")
             try:
-                # 在异步环境中直接调用run_async而不是run
-                similarity_results = await self.pipeline.similarity_analyzer.run_async(
+                # 同步版本使用run方法
+                similarity_results = self.pipeline.similarity_analyzer.run(
                     deep_results,
-                    output_path=os.path.join(self.temp_output_dir, "similarity_analysis_results.json")
+                    output_path=os.path.join(self.output_dir, "similarity_analysis_results.json")
                 )
             except Exception as e:
                 logger.error(f"阶段7执行失败: {e}", exc_info=True)
                 raise Exception(f"相似度分析失败: {str(e)}")
-            
+
             # 重要:similarity_analyzer.run_async会更新文件中的evaluation_results(添加comprehensive_score)
             # 需要重新加载更新后的文件,因为内存中的evaluation_results变量还没有被更新
             logger.info("重新加载更新后的评估结果(包含comprehensive_score)...")
-            evaluated_results_path = os.path.join(self.temp_output_dir, "evaluated_results.json")
+            evaluated_results_path = os.path.join(self.output_dir, "evaluated_results.json")
             if os.path.exists(evaluated_results_path):
                 with open(evaluated_results_path, 'r', encoding='utf-8') as f:
                     evaluation_results = json.load(f)
@@ -165,14 +170,4 @@ class PipelineWrapper:
         except Exception as e:
             logger.error(f"执行阶段3-7失败: {e}", exc_info=True)
             raise
-    
-    def cleanup(self):
-        """清理临时文件"""
-        try:
-            import shutil
-            if os.path.exists(self.temp_output_dir):
-                shutil.rmtree(self.temp_output_dir)
-                logger.info(f"已清理临时目录: {self.temp_output_dir}")
-        except Exception as e:
-            logger.warning(f"清理临时目录失败: {e}")
 

+ 56 - 0
api/request_context.py

@@ -0,0 +1,56 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+请求上下文管理器
+为每个API请求提供独立的临时工作空间
+"""
+
+import uuid
+import shutil
+from pathlib import Path
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class RequestContextManager:
+    """请求级别的临时工作空间管理器"""
+
+    def __init__(self, base_dir: str = "temp_requests"):
+        """
+        初始化请求上下文管理器
+
+        Args:
+            base_dir: 临时目录基础路径(项目内相对路径)
+        """
+        self.base_dir = Path(base_dir)
+        self.request_id = str(uuid.uuid4())[:8]  # 短UUID,便于日志查看
+        self.work_dir = None
+
+    async def __aenter__(self):
+        """
+        创建请求专用工作目录
+
+        Returns:
+            self: 返回自身,提供request_id和work_dir访问
+        """
+        self.work_dir = self.base_dir / f"request_{self.request_id}"
+        self.work_dir.mkdir(parents=True, exist_ok=True)
+        logger.info(f"[{self.request_id}] 创建工作目录: {self.work_dir}")
+        return self
+
+    async def __aexit__(self, exc_type, exc_val, exc_tb):
+        """
+        清理工作目录
+
+        Args:
+            exc_type: 异常类型
+            exc_val: 异常值
+            exc_tb: 异常追踪
+        """
+        if self.work_dir and self.work_dir.exists():
+            try:
+                shutil.rmtree(self.work_dir)
+                logger.info(f"[{self.request_id}] 已清理工作目录")
+            except Exception as e:
+                logger.warning(f"[{self.request_id}] 清理工作目录失败: {e}")

+ 117 - 99
api/search_service.py

@@ -5,8 +5,10 @@ FastAPI服务主文件
 提供搜索API端点
 """
 
+import asyncio
 import logging
 from typing import List, Dict, Any
+from concurrent.futures import ThreadPoolExecutor
 from fastapi import FastAPI, HTTPException
 from pydantic import BaseModel, Field
 
@@ -15,6 +17,7 @@ from api.data_converter import (
     convert_pipeline_output_to_api_response
 )
 from api.pipeline_wrapper import PipelineWrapper
+from api.request_context import RequestContextManager
 
 logger = logging.getLogger(__name__)
 
@@ -25,8 +28,10 @@ app = FastAPI(
     version="1.0.0"
 )
 
-# 全局Pipeline包装器实例
-pipeline_wrapper: PipelineWrapper = None
+# 并发控制
+REQUEST_SEMAPHORE = None
+MAX_CONCURRENT_REQUESTS = 5
+EXECUTOR = None  # 线程池执行器,用于运行阻塞的pipeline代码
 
 
 # 请求模型
@@ -78,120 +83,132 @@ class SearchResponse(BaseModel):
 
 @app.on_event("startup")
 async def startup_event():
-    """应用启动时初始化Pipeline包装器"""
-    global pipeline_wrapper
-    try:
-        pipeline_wrapper = PipelineWrapper()
-        logger.info("Pipeline包装器初始化成功")
-    except Exception as e:
-        logger.error(f"Pipeline包装器初始化失败: {e}", exc_info=True)
-        raise
+    """应用启动时初始化并发限流器和线程池"""
+    global REQUEST_SEMAPHORE, EXECUTOR
+    REQUEST_SEMAPHORE = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
+    EXECUTOR = ThreadPoolExecutor(max_workers=MAX_CONCURRENT_REQUESTS)
+    logger.info(f"API服务已启动,最大并发请求数: {MAX_CONCURRENT_REQUESTS}")
 
 
 @app.on_event("shutdown")
 async def shutdown_event():
     """应用关闭时清理资源"""
-    global pipeline_wrapper
-    if pipeline_wrapper:
-        try:
-            pipeline_wrapper.cleanup()
-            logger.info("Pipeline包装器清理完成")
-        except Exception as e:
-            logger.warning(f"Pipeline包装器清理失败: {e}")
+    global EXECUTOR
+    if EXECUTOR:
+        EXECUTOR.shutdown(wait=True)
+        logger.info("线程池已关闭")
+    logger.info("API服务关闭")
 
 
 @app.post("/what/search", response_model=SearchResponse)
 async def search(request: SearchRequest):
     """
-    执行搜索和评估
-    
+    执行搜索和评估(支持并发)
+
+    并发控制:最多5个请求同时处理,超出的请求会等待
+
     Args:
         request: 搜索请求
-    
+
     Returns:
         搜索结果响应
-    
+
     Raises:
         HTTPException: 当请求参数无效或处理失败时
     """
-    try:
-        logger.info(f"收到搜索请求: original_target={request.original_target}, "
-                   f"persona_features数量={len(request.persona_features)}, "
-                   f"candidate_words数量={len(request.candidate_words)}")
-        
-        # 验证Pipeline包装器是否已初始化
-        if pipeline_wrapper is None:
-            logger.error("Pipeline包装器未初始化")
-            raise HTTPException(status_code=503, detail="Pipeline包装器未初始化,请稍后重试")
-        
-        # 验证输入参数
-        if not request.original_target or not request.original_target.strip():
-            raise HTTPException(status_code=400, detail="original_target不能为空")
-        
-        if not request.persona_features or len(request.persona_features) == 0:
-            raise HTTPException(status_code=400, detail="persona_features不能为空")
-        
-        if not request.candidate_words or len(request.candidate_words) == 0:
-            raise HTTPException(status_code=400, detail="candidate_words不能为空")
-        
-        # 验证persona_features中的persona_feature_name
-        for idx, pf in enumerate(request.persona_features):
-            if not pf.persona_feature_name or not pf.persona_feature_name.strip():
-                raise HTTPException(
-                    status_code=400,
-                    detail=f"persona_features[{idx}].persona_feature_name不能为空"
+    # 步骤1:获取并发许可(限流)
+    async with REQUEST_SEMAPHORE:
+        # 步骤2:创建请求专用上下文
+        async with RequestContextManager(base_dir="temp_requests") as ctx:
+            try:
+                logger.info(f"[{ctx.request_id}] 收到搜索请求: "
+                           f"original_target={request.original_target}, "
+                           f"persona_features={len(request.persona_features)}, "
+                           f"candidate_words={len(request.candidate_words)}")
+
+                # 验证输入参数
+                if not request.original_target or not request.original_target.strip():
+                    raise HTTPException(status_code=400, detail="original_target不能为空")
+
+                if not request.persona_features or len(request.persona_features) == 0:
+                    raise HTTPException(status_code=400, detail="persona_features不能为空")
+
+                if not request.candidate_words or len(request.candidate_words) == 0:
+                    raise HTTPException(status_code=400, detail="candidate_words不能为空")
+
+                # 验证persona_features中的persona_feature_name
+                for idx, pf in enumerate(request.persona_features):
+                    if not pf.persona_feature_name or not pf.persona_feature_name.strip():
+                        raise HTTPException(
+                            status_code=400,
+                            detail=f"persona_features[{idx}].persona_feature_name不能为空"
+                        )
+
+                # 步骤3:创建请求专用的Pipeline实例
+                pipeline_wrapper = PipelineWrapper(
+                    output_dir=str(ctx.work_dir),
+                    request_id=ctx.request_id
                 )
-        
-        # 步骤1:将API输入转换为pipeline格式
-        logger.info("步骤1:转换API输入格式...")
-        try:
-            features_data = convert_api_input_to_pipeline_format(
-                original_target=request.original_target,
-                persona_features=[pf.dict() for pf in request.persona_features],
-                candidate_words=request.candidate_words
-            )
-        except Exception as e:
-            logger.error(f"API输入格式转换失败: {e}", exc_info=True)
-            raise HTTPException(status_code=400, detail=f"输入格式转换失败: {str(e)}")
-        
-        if not features_data:
-            raise HTTPException(status_code=400, detail="无法构建有效的特征数据,请检查输入参数")
-        
-        # 步骤2:执行阶段3-7
-        logger.info("步骤2:执行阶段3-7...")
-        try:
-            pipeline_output = await pipeline_wrapper.run_stages_3_to_7(features_data)
-        except Exception as e:
-            logger.error(f"阶段3-7执行失败: {e}", exc_info=True)
-            raise HTTPException(status_code=500, detail=f"Pipeline执行失败: {str(e)}")
-        
-        # 验证pipeline输出
-        if not pipeline_output or 'evaluation_results' not in pipeline_output:
-            logger.error("Pipeline输出格式不正确")
-            raise HTTPException(status_code=500, detail="Pipeline输出格式不正确")
-        
-        # 步骤3:将pipeline输出转换为API响应格式
-        logger.info("步骤3:转换API输出格式...")
-        try:
-            response = convert_pipeline_output_to_api_response(
-                pipeline_results=pipeline_output['evaluation_results'],
-                original_target=request.original_target,
-                similarity_results=pipeline_output.get('similarity_results')
-            )
-        except Exception as e:
-            logger.error(f"API输出格式转换失败: {e}", exc_info=True)
-            raise HTTPException(status_code=500, detail=f"输出格式转换失败: {str(e)}")
-        
-        logger.info(f"搜索完成: 找到 {len(response['search_results'])} 个有效结果 "
-                   f"(综合得分P > 0)")
-        
-        return response
-        
-    except HTTPException:
-        raise
-    except Exception as e:
-        logger.error(f"搜索请求处理失败: {e}", exc_info=True)
-        raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")
+
+                # 步骤4:转换API输入格式
+                logger.info(f"[{ctx.request_id}] 转换API输入格式...")
+                try:
+                    features_data = convert_api_input_to_pipeline_format(
+                        original_target=request.original_target,
+                        persona_features=[pf.dict() for pf in request.persona_features],
+                        candidate_words=request.candidate_words
+                    )
+                except Exception as e:
+                    logger.error(f"[{ctx.request_id}] API输入格式转换失败: {e}", exc_info=True)
+                    raise HTTPException(status_code=400, detail=f"输入格式转换失败: {str(e)}")
+
+                if not features_data:
+                    raise HTTPException(status_code=400, detail="无法构建有效的特征数据")
+
+                # 步骤5:执行阶段3-7(在线程池中运行以避免阻塞)
+                logger.info(f"[{ctx.request_id}] 执行阶段3-7...")
+                try:
+                    # 获取当前事件循环
+                    loop = asyncio.get_event_loop()
+
+                    # 在线程池中运行阻塞的pipeline代码
+                    pipeline_output = await loop.run_in_executor(
+                        EXECUTOR,
+                        pipeline_wrapper.run_stages_3_to_7_sync,
+                        features_data
+                    )
+                except Exception as e:
+                    logger.error(f"[{ctx.request_id}] 阶段3-7执行失败: {e}", exc_info=True)
+                    raise HTTPException(status_code=500, detail=f"Pipeline执行失败: {str(e)}")
+
+                # 验证pipeline输出
+                if not pipeline_output or 'evaluation_results' not in pipeline_output:
+                    logger.error(f"[{ctx.request_id}] Pipeline输出格式不正确")
+                    raise HTTPException(status_code=500, detail="Pipeline输出格式不正确")
+
+                # 步骤6:转换API输出格式
+                logger.info(f"[{ctx.request_id}] 转换API输出格式...")
+                try:
+                    response = convert_pipeline_output_to_api_response(
+                        pipeline_results=pipeline_output['evaluation_results'],
+                        original_target=request.original_target,
+                        similarity_results=pipeline_output.get('similarity_results')
+                    )
+                except Exception as e:
+                    logger.error(f"[{ctx.request_id}] API输出格式转换失败: {e}", exc_info=True)
+                    raise HTTPException(status_code=500, detail=f"输出格式转换失败: {str(e)}")
+
+                logger.info(f"[{ctx.request_id}] 搜索完成: "
+                           f"找到 {len(response['search_results'])} 个有效结果")
+
+                # 步骤7:返回结果(工作目录会在退出上下文时自动清理)
+                return response
+
+            except HTTPException:
+                raise
+            except Exception as e:
+                logger.error(f"[{ctx.request_id}] 请求失败: {e}", exc_info=True)
+                raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")
 
 
 @app.get("/health")
@@ -199,6 +216,7 @@ async def health_check():
     """健康检查端点"""
     return {
         "status": "healthy",
-        "pipeline_initialized": pipeline_wrapper is not None
+        "max_concurrent_requests": MAX_CONCURRENT_REQUESTS,
+        "semaphore_initialized": REQUEST_SEMAPHORE is not None
     }