batch_match_analyzer.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. """
  2. 批量匹配分析模块
  3. 分析单个特征与多个特征之间的语义匹配度(批量版本)
  4. 提供接口:
  5. analyze_batch_match(phrase_a, phrase_b_list, model_name) - 批量分析匹配度
  6. 返回格式:
  7. [
  8. {
  9. "特征": "...",
  10. "分数": 0.85,
  11. "说明": "..."
  12. },
  13. ...
  14. ]
  15. """
  16. from typing import List
  17. from agents import Agent, Runner, ModelSettings
  18. from agents.tracing.create import custom_span
  19. from lib.client import get_model
  20. from lib.utils import parse_json_from_text
  21. # ========== System Prompt ==========
  22. BATCH_MATCH_SYSTEM_PROMPT = """
  23. # 任务
  24. 分析单个特征 <A> 与多个特征 <B_List> 之间的语义匹配度。
  25. ## 输入说明
  26. - **<A></A>**: 待分析的特征(必选)
  27. - **<B_List></B_List>**: 多个特征列表(必选)
  28. **重要**:
  29. 1. 必须在同一个评分标准下对所有 B 进行评分,确保分数可比
  30. 2. **优先识别并给出高分**给与 <A> 相似度最高的特征
  31. 3. 严格区分高相似度和低相似度,避免分数过于集中
  32. ---
  33. ## 评分标准(0-1分)
  34. **核心原则**:从 <B_List> 中找出与 <A> 最相似的特征,给予最高分,其他按相似度递减。
  35. - **0.9-1.0**:几乎完全相同(同义词、可互换)
  36. - **0.7-0.9**:非常接近、高度相关(强关联、核心相关)
  37. - **0.5-0.7**:有一定关联(中等相关、间接关联)
  38. - **0.3-0.5**:关系较弱(弱相关、边缘关联)
  39. - **0.0-0.3**:几乎无关或完全无关
  40. **评分策略**:
  41. - 优先识别与 <A> 最相似的特征,给 0.7+ 高分
  42. - 对明显无关的特征,果断给 0.0-0.3 低分
  43. - 合理使用中间分数段,避免过度集中
  44. - 确保分数有梯度,体现明确的相似度差异
  45. ---
  46. ## 输出格式(严格JSON数组)
  47. ```json
  48. [
  49. {
  50. "特征": "第一个B的特征",
  51. "分数": 0.85,
  52. "说明": "简要说明评分依据"
  53. },
  54. {
  55. "特征": "第二个B的特征",
  56. "分数": 0.45,
  57. "说明": "简要说明评分依据"
  58. }
  59. ]
  60. ```
  61. **输出要求**:
  62. 1. 数组长度必须等于 <B_List> 的长度,顺序一一对应
  63. 2. 分数必须是0-1之间的浮点数,保留2位小数
  64. 3. 所有评分必须使用相同的标准,分数之间可比
  65. 4. **必须有明显的分数梯度**,最相似的给高分,不相关的给低分
  66. """.strip()
  67. def create_batch_match_agent(model_name: str) -> Agent:
  68. """创建批量匹配分析的 Agent
  69. Args:
  70. model_name: 模型名称
  71. Returns:
  72. Agent 实例
  73. """
  74. agent = Agent(
  75. name="Batch Match Expert",
  76. instructions=BATCH_MATCH_SYSTEM_PROMPT,
  77. model=get_model(model_name),
  78. model_settings=ModelSettings(
  79. temperature=0.0,
  80. max_tokens=65536,
  81. ),
  82. tools=[],
  83. )
  84. return agent
  85. def clean_json_text(text: str) -> str:
  86. """清理JSON文本中的常见错误
  87. Args:
  88. text: 原始JSON文本
  89. Returns:
  90. 清理后的JSON文本
  91. """
  92. import re
  93. # 1. 移除数组元素之间的异常字符(如 trib{)
  94. # 匹配模式:逗号后面跟着任意非空白字符,直到遇到正常的对象开始 {
  95. text = re.sub(r',\s*[a-zA-Z]+\s*\{', r',\n {', text)
  96. # 2. 移除对象之间的异常字符
  97. text = re.sub(r'\}\s*[a-zA-Z]+\s*\{', r'},\n {', text)
  98. return text
  99. def parse_batch_match_response(response_content: str) -> List[dict]:
  100. """解析批量匹配响应
  101. Args:
  102. response_content: Agent 返回的响应内容
  103. Returns:
  104. 解析后的字典列表
  105. """
  106. try:
  107. # 使用 parse_json_from_text 函数进行健壮的 JSON 解析
  108. result = parse_json_from_text(response_content)
  109. # 如果解析失败(返回空字典),尝试清理后再解析
  110. if not result:
  111. print(f"首次解析失败,尝试清理JSON文本后重新解析...")
  112. cleaned_text = clean_json_text(response_content)
  113. result = parse_json_from_text(cleaned_text)
  114. # 如果清理后仍然失败
  115. if not result:
  116. print(f"清理后仍解析失败: 无法从响应中提取有效JSON")
  117. return [{
  118. "特征": "",
  119. "分数": 0.0,
  120. "说明": "解析失败: 无法从响应中提取有效JSON"
  121. }]
  122. # 确保返回的是列表
  123. if not isinstance(result, list):
  124. return [result]
  125. return result
  126. except Exception as e:
  127. print(f"解析响应失败: {e}")
  128. return [{
  129. "特征": "",
  130. "分数": 0.0,
  131. "说明": f"解析失败: {str(e)}"
  132. }]
  133. async def analyze_batch_match(
  134. phrase_a: str,
  135. phrase_b_list: List[str],
  136. model_name: str = None
  137. ) -> List[dict]:
  138. """批量分析匹配度
  139. Args:
  140. phrase_a: 待分析的特征
  141. phrase_b_list: 多个特征列表
  142. model_name: 使用的模型名称(可选,默认使用 client.py 中的 MODEL_NAME)
  143. Returns:
  144. 匹配结果列表:[{"特征": "...", "分数": 0.85, "说明": "..."}, ...]
  145. """
  146. try:
  147. # 如果未指定模型,使用默认模型
  148. if model_name is None:
  149. from lib.client import MODEL_NAME
  150. model_name = MODEL_NAME
  151. # 创建 Agent
  152. agent = create_batch_match_agent(model_name)
  153. # 构建 B 列表字符串
  154. b_list_str = "\n".join([f"- {b}" for b in phrase_b_list])
  155. # 构建任务描述
  156. task_description = f"""## 本次分析任务
  157. <A>
  158. {phrase_a}
  159. </A>
  160. <B_List>
  161. {b_list_str}
  162. </B_List>
  163. 请分析 <A> 与 <B_List> 中每个特征的匹配度,输出 JSON 数组格式的结果。
  164. 重要:必须使用一致的评分标准!"""
  165. # 构造消息
  166. messages = [{
  167. "role": "user",
  168. "content": [
  169. {
  170. "type": "input_text",
  171. "text": task_description
  172. }
  173. ]
  174. }]
  175. # 使用 custom_span 追踪分析过程
  176. # 截断显示内容,避免 span name 过长
  177. a_short = (phrase_a[:30] + "...") if len(phrase_a) > 30 else phrase_a
  178. with custom_span(
  179. name=f"批量匹配分析: {a_short} vs {len(phrase_b_list)}个特征",
  180. data={
  181. "phrase_a": phrase_a,
  182. "phrase_b_list": phrase_b_list,
  183. "b_count": len(phrase_b_list)
  184. }
  185. ):
  186. # 运行 Agent
  187. result = await Runner.run(agent, input=messages)
  188. # 解析响应
  189. parsed_result = parse_batch_match_response(result.final_output)
  190. # 验证返回的结果数量
  191. if len(parsed_result) != len(phrase_b_list):
  192. print(f"警告: 返回结果数量 ({len(parsed_result)}) 与输入数量 ({len(phrase_b_list)}) 不匹配")
  193. # 补齐或截断
  194. while len(parsed_result) < len(phrase_b_list):
  195. parsed_result.append({
  196. "特征": phrase_b_list[len(parsed_result)],
  197. "分数": 0.0,
  198. "说明": "结果数量不匹配,自动补齐"
  199. })
  200. parsed_result = parsed_result[:len(phrase_b_list)]
  201. return parsed_result
  202. except Exception as e:
  203. # 返回错误信息(为每个 B 创建一个错误条目)
  204. return [{
  205. "特征": b,
  206. "分数": 0.0,
  207. "说明": f"分析过程出错: {str(e)}"
  208. } for b in phrase_b_list]