| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 | import sysimport osfrom typing import Any, Dict, Listimport json# 添加项目根目录到 Python 路径sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))from utils.logging_config import get_loggerfrom utils.mysql_db import MysqlHelperlogger = get_logger('ExpandAgent')try:    from gemini import GeminiProcessor    HAS_GEMINI = Trueexcept ImportError:    HAS_GEMINI = Falsedef _fetch_parsing_data_by_request(request_id: str) -> List[str]:    """从 knowledge_parsing_content 表中根据 request_id 获取 parsing_data 字段"""    sql = "SELECT parsing_data FROM knowledge_parsing_content WHERE request_id = %s ORDER BY id DESC LIMIT 10"    rows = MysqlHelper.get_values(sql, (request_id,)) or []        results = []    for row in rows:        parsing_data = row[0]  # 获取 parsing_data 字段        if parsing_data:            results.append(parsing_data)        print(f"Final results: {len(results)} items")    return resultsdef _build_prompt(data_samples: str, input_query: str) -> str:    """构建用于扩展查询的 Prompt,使用 expansion.md 模板"""        # 读取 expansion.md 模板文件    try:        template_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "prompt", "expansion.md")        with open(template_path, 'r', encoding='utf-8') as f:            template = f.read()                # 替换模板中的占位符        prompt = template.replace("{Original_Query}", input_query or "无")        prompt = prompt.replace("{Content_For_Expansion}", data_samples)                return prompt            except Exception as e:        logger.error(f"读取 expansion.md 模板失败: {e}")        # 如果模板读取失败,使用备用 prompt        return f"""你是一位顶级的知识库专家,精通语义分析、信息检索和搜索优化策略。根据以下原始查询和内容,生成扩展的查询词列表:**原始查询:** {input_query or "无"}**扩展要基于的内容:** {data_samples}请生成一个JSON数组形式的扩展查询词列表,确保每个查询词都具体、可操作、信息量丰富。输出格式:```json[    "查询词1",    "查询词2",    ...]```"""def _run_llm(prompt: str) -> List[str]:    """调用LLM生成扩展查询"""    if not HAS_GEMINI:        return []            try:        processor = GeminiProcessor()        result = processor.process(content=prompt, system_prompt="你是专业的查询扩展助手")        print(f"result: {result}")                # 处理返回结果        if isinstance(result, dict):            # 如果有错误,直接返回空列表            if "error" in result:                logger.error(f"Gemini API 返回错误: {result['error']}")                return []            # 如果结果在 result 字段中            text = result.get("result", "") or result.get("raw_response", "")        else:            text = str(result)                # 清理文本,移除 markdown 代码块标记        if "```json" in text:            # 提取 ```json 和 ``` 之间的内容            start = text.find("```json") + 7            end = text.find("```", start)            if end != -1:                text = text[start:end].strip()        elif "```" in text:            # 提取 ``` 之间的内容            start = text.find("```") + 3            end = text.find("```", start)            if end != -1:                text = text[start:end].strip()                # 尝试解析 JSON        try:            queries = json.loads(text)            if isinstance(queries, list):                # 确保所有元素都是字符串                return [str(q) for q in queries if q]            else:                logger.warning(f"Gemini 返回的不是列表格式: {type(queries)}")                return []        except json.JSONDecodeError as e:            logger.error(f"JSON 解析失败: {e}, 原始文本: {text}")            return []                except Exception as e:        logger.error(f"LLM调用失败: {e}")        return []def _heuristic_expand(input_query: str) -> List[str]:    """启发式扩展(LLM不可用时的fallback)"""    base = input_query.strip()    if not base:        return []            return [        base,        f"{base} 教程",        f"{base} 实战",        f"{base} 入门指南",        f"{base} 高级技巧"    ]def _update_expansion_status(requestId: str, status: int):    sql = "UPDATE knowledge_request SET expansion_status = %s WHERE request_id = %s"    MysqlHelper.update_values(sql, (status, requestId))def execute_expand_agent_with_api(requestId: str, query: str = "") -> Dict[str, Any]:    """对外暴露的API:根据requestId查询数据,生成扩展查询"""    # 获取数据    data_samples = _fetch_parsing_data_by_request(requestId)        # 构建prompt    prompt = _build_prompt(data_samples[0], query)    # 生成扩展查询    expanded = _run_llm(prompt)    if not expanded:        expanded = _heuristic_expand(query)        # 将扩展查询结果插入到 knowledge_expand_content 表    try:        # 先检查是否已存在记录        check_sql = "SELECT id FROM knowledge_expand_content WHERE request_id = %s LIMIT 1"        existing_record = MysqlHelper.get_values(check_sql, (requestId,))                # 将 expanded 列表转换为 JSON 字符串        expand_querys_json = json.dumps(expanded, ensure_ascii=False)        print(f"expand_querys_json: {expand_querys_json}")                if existing_record:            # 记录已存在,执行更新            update_sql = """            UPDATE knowledge_expand_content             SET expand_querys = %s, query = %s, create_time = NOW()            WHERE request_id = %s            """            affected_rows = MysqlHelper.update_values(update_sql, (expand_querys_json, query, requestId))            logger.info(f"扩展查询结果已更新: requestId={requestId}, affected_rows={affected_rows}")        else:            # 记录不存在,执行插入            insert_sql = """            INSERT INTO knowledge_expand_content             (request_id, create_time, expand_querys, query)             VALUES (%s, NOW(), %s, %s)            """            insert_result = MysqlHelper.insert_and_get_id(insert_sql, (requestId, expand_querys_json, query))            logger.info(f"扩展查询结果已插入: requestId={requestId}, insert_id={insert_result}")                # 更新状态为处理完成        _update_expansion_status(requestId, 2)                except Exception as e:        logger.error(f"保存扩展查询结果到数据库时出错: requestId={requestId}, error={e}")        _update_expansion_status(requestId, 3)        return {"status": 1, "requestId": requestId}if __name__ == "__main__":    queries = execute_expand_agent_with_api("REQUEST_001")    print(queries)
 |