import sys import os from typing import Any, Dict, List import json # 添加项目根目录到 Python 路径 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from utils.logging_config import get_logger from utils.mysql_db import MysqlHelper logger = get_logger('ExpandAgent') try: from gemini import GeminiProcessor HAS_GEMINI = True except ImportError: HAS_GEMINI = False def _fetch_parsing_data_by_request(request_id: str) -> List[str]: """从 knowledge_extraction_content 表中根据 request_id 获取 data 字段""" sql = "SELECT data, content_id, parsing_id FROM knowledge_extraction_content WHERE request_id = %s" rows = MysqlHelper.get_values(sql, (request_id,)) or [] results = [] for row in rows: data = row[0] # 获取 data 字段 content_id = row[1] # 获取 content_id 字段 parsing_id = row[2] # 获取 parsing_id 字段 if data: results.append({"data": data, "content_id": content_id, "parsing_id": parsing_id}) print(f"Final results: {len(results)} items") return results def _build_prompt(data_samples: str, input_query: str) -> str: """构建用于扩展查询的 Prompt,使用 expansion.md 模板""" # 读取 expansion.md 模板文件 try: template_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "prompt", "expansion.md") with open(template_path, 'r', encoding='utf-8') as f: template = f.read() # 替换模板中的占位符 prompt = template.replace("{Original_Query}", input_query or "无") prompt = prompt.replace("{Content_For_Expansion}", data_samples) return prompt except Exception as e: logger.error(f"读取 expansion.md 模板失败: {e}") # 如果模板读取失败,使用备用 prompt return f"""你是一位顶级的知识库专家,精通语义分析、信息检索和搜索优化策略。 根据以下原始查询和内容,生成扩展的查询词列表: **原始查询:** {input_query or "无"} **扩展要基于的内容:** {data_samples} 请生成一个JSON数组形式的扩展查询词列表,确保每个查询词都具体、可操作、信息量丰富。 输出格式: ```json [ "查询词1", "查询词2", ... ] ```""" def _run_llm(prompt: str) -> List[str]: """调用LLM生成扩展查询""" if not HAS_GEMINI: return [] try: processor = GeminiProcessor() result = processor.process(content=prompt, system_prompt="你是专业的查询扩展助手") print(f"result: {result}") # 处理返回结果 if isinstance(result, dict): # 如果有错误,直接返回空列表 if "error" in result: logger.error(f"Gemini API 返回错误: {result['error']}") return [] # 如果结果在 result 字段中 text = result.get("result", "") or result.get("raw_response", "") else: text = str(result) # 清理文本,移除 markdown 代码块标记 if "```json" in text: # 提取 ```json 和 ``` 之间的内容 start = text.find("```json") + 7 end = text.find("```", start) if end != -1: text = text[start:end].strip() elif "```" in text: # 提取 ``` 之间的内容 start = text.find("```") + 3 end = text.find("```", start) if end != -1: text = text[start:end].strip() # 尝试解析 JSON try: queries = json.loads(text) if isinstance(queries, list): # 确保所有元素都是字符串 return [str(q) for q in queries if q] else: logger.warning(f"Gemini 返回的不是列表格式: {type(queries)}") return [] except json.JSONDecodeError as e: logger.error(f"JSON 解析失败: {e}, 原始文本: {text}") # 尝试修复常见的JSON格式错误 try: # 检查是否是对象格式的数组(缺少方括号) if text.strip().startswith('{') and text.strip().endswith('}'): # 尝试将对象转换为数组 # 如果内容是逗号分隔的字符串,尝试解析 content = text.strip()[1:-1] # 去掉首尾的大括号 # 按逗号分割,但要注意字符串内的逗号 items = [] current_item = "" in_quotes = False quote_char = None for char in content: if char in ['"', "'"] and not in_quotes: in_quotes = True quote_char = char current_item += char elif char == quote_char and in_quotes: in_quotes = False quote_char = None current_item += char elif char == ',' and not in_quotes: if current_item.strip(): items.append(current_item.strip()) current_item = "" else: current_item += char # 添加最后一个项目 if current_item.strip(): items.append(current_item.strip()) # 清理项目(去掉引号) cleaned_items = [] for item in items: item = item.strip() if item.startswith('"') and item.endswith('"'): item = item[1:-1] elif item.startswith("'") and item.endswith("'"): item = item[1:-1] cleaned_items.append(item) if cleaned_items: logger.info(f"成功修复JSON格式,提取到 {len(cleaned_items)} 个项目") return cleaned_items except Exception as fix_error: logger.error(f"JSON修复尝试失败: {fix_error}") return [] except Exception as e: logger.error(f"LLM调用失败: {e}") return [] def _update_expansion_status(requestId: str, status: int): sql = "UPDATE knowledge_request SET expansion_status = %s WHERE request_id = %s" MysqlHelper.update_values(sql, (status, requestId)) def execute_expand_agent_with_api(requestId: str, query: str = "") -> Dict[str, Any]: """对外暴露的API:根据requestId查询数据,生成扩展查询""" # 获取数据(可能多条) data_samples = _fetch_parsing_data_by_request(requestId) try: total = 0 success = 0 if not data_samples: # 即使没有数据,也基于 query 生成一次兜底扩展 # prompt = _build_prompt("", query) # expanded = _run_llm(prompt) # if not expanded: # expanded = _heuristic_expand(query) # expand_querys_json = json.dumps(expanded, ensure_ascii=False) # insert_sql = """ # INSERT INTO knowledge_expand_content # (request_id, create_time, expand_querys, query) # VALUES (%s, NOW(), %s, %s) # """ # MysqlHelper.insert_and_get_id(insert_sql, (requestId, expand_querys_json, query)) # total = 1 # success = 1 if expanded else 0 logger.info(f"没有数据,不进行扩展,直接返回: requestId={requestId}, total={total}, success={success}") else: # 针对每条 parsing_data 分别生成与入库 insert_sql = """ INSERT INTO knowledge_expand_content (request_id, create_time, expand_querys, query, content_id, parsing_id) VALUES (%s, NOW(), %s, %s, %s, %s) """ for sample in data_samples: total += 1 if not sample: continue prompt = _build_prompt(sample["data"], query) expanded = _run_llm(prompt) try: expand_querys_json = json.dumps(expanded, ensure_ascii=False) MysqlHelper.insert_and_get_id(insert_sql, (requestId, expand_querys_json, query, sample["content_id"], sample["parsing_id"])) success += 1 except Exception as ie: logger.error(f"单条扩展结果入库失败: requestId={requestId}, error={ie}") # 更新状态为处理完成 _update_expansion_status(requestId, 2) logger.info(f"扩展完成: requestId={requestId}, total={total}, success={success}") except Exception as e: logger.error(f"保存扩展查询结果到数据库时出错: requestId={requestId}, error={e}") _update_expansion_status(requestId, 3) return {"status": 1, "requestId": requestId} if __name__ == "__main__": queries = execute_expand_agent_with_api("REQUEST_001") print(queries)