Browse Source

feat: 添加短语关系分析模块(relation_analyzer)

新增功能:
- 实现7分类关系分析:同义、同级、包含、被包含、部分重叠、相关、无关
- 支持语义接近程度评分(score 0-1)
- 提供可选的上下文参数以辅助理解
- model_name参数支持默认值

包含完整测试套件,14个测试用例100%通过

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
yangxiaohui 2 weeks ago
parent
commit
07f583f6a6
2 changed files with 496 additions and 0 deletions
  1. 287 0
      lib/relation_analyzer.py
  2. 209 0
      test_relation_analyzer.py

+ 287 - 0
lib/relation_analyzer.py

@@ -0,0 +1,287 @@
+"""
+短语关系分析模块
+
+分析两个短语之间的语义关系
+
+提供接口:
+analyze_relation(phrase_a, phrase_b, model_name, context_a="", context_b="") - 分析两个短语的关系
+
+支持可选的 Context 参数:
+- context_a: phrase_a 的补充上下文(帮助理解 phrase_a)
+- context_b: phrase_b 的补充上下文(帮助理解 phrase_b)
+- Context 默认为空,不提供时不会出现在 prompt 中
+
+返回格式:
+{
+    "relation": "same",           # 7种关系之一
+    "score": 0.95,                # 0-1,语义接近程度
+    "explanation": "说明"          # 关系判断的依据
+}
+"""
+import json
+from agents import Agent, Runner, ModelSettings
+from agents.tracing.create import custom_span
+from lib.client import get_model
+
+
+# ========== System Prompt ==========
+RELATION_SYSTEM_PROMPT = """
+# 任务
+分析两个短语 <A> 和 <B> 之间的语义关系。
+
+## 输入说明
+
+- **<A></A>**: 第一个短语(必选)
+- **<B></B>**: 第二个短语(必选)
+- **<A_Context></A_Context>**: A 的补充上下文(可选,帮助理解 A)
+- **<B_Context></B_Context>**: B 的补充上下文(可选,帮助理解 B)
+
+**重要**:关系分析发生在 <A> 和 <B> 之间,Context 仅作为补充理解的辅助信息。
+
+---
+
+## 关系类型(7种)
+
+### 1. same(同义)
+- **定义**:意思完全相同或非常接近,可以互相替换
+- **例子**:
+  - "医生" 和 "大夫" → same
+  - "计算机" 和 "电脑" → same
+  - "快乐" 和 "高兴" → same
+
+### 2. coordinate(同级)
+- **定义**:有共同的上位概念,属于并列关系,通常无交集
+- **例子**:
+  - "轿车" 和 "SUV" → coordinate(都是汽车)
+  - "苹果" 和 "香蕉" → coordinate(都是水果)
+  - "数学" 和 "物理" → coordinate(都是学科)
+
+### 3. contains(包含)
+- **定义**:A 的概念范围包含 B,B 是 A 的子类或特例
+- **例子**:
+  - "水果" contains "苹果"
+  - "汽车" contains "轿车"
+  - "动物" contains "狗"
+
+### 4. contained_by(被包含)
+- **定义**:A 被 B 包含,A 是 B 的子类或特例
+- **例子**:
+  - "苹果" contained_by "水果"
+  - "轿车" contained_by "汽车"
+  - "狗" contained_by "动物"
+
+### 5. overlap(部分重叠)
+- **定义**:两个概念有交集,但互不包含
+- **例子**:
+  - "红苹果" 和 "大苹果" → overlap(有又红又大的苹果)
+  - "亚洲国家" 和 "发展中国家" → overlap(如中国、印度等)
+  - "学生" 和 "运动员" → overlap(有学生运动员)
+
+### 6. related(相关)
+- **定义**:有语义联系,但不属于上述任何层级关系
+- **例子**:
+  - "医生" 和 "医院" → related(工作场所关系)
+  - "阅读" 和 "书籍" → related(动作-对象关系)
+  - "钥匙" 和 "锁" → related(工具-用途关系)
+  - "老师" 和 "学生" → related(角色关系)
+
+### 7. unrelated(无关)
+- **定义**:无明显语义关系
+- **例子**:
+  - "医生" 和 "石头" → unrelated
+  - "苹果" 和 "数学" → unrelated
+
+---
+
+## 评分标准(score: 0-1)
+
+**score 表示两个短语的语义接近程度:**
+
+- **0.9-1.0**:几乎完全相同(完全同义)
+- **0.8-0.9**:非常接近(高度同义、直接包含关系)
+- **0.7-0.8**:比较接近(近义、明确的同级或包含)
+- **0.6-0.7**:有一定接近度(同级但层级稍远、间接包含)
+- **0.5-0.6**:中等程度的关系(中等交集、中度相关)
+- **0.4-0.5**:关系较弱(小交集、弱相关)
+- **0.3-0.4**:关系很弱(勉强算同级、很弱的相关)
+- **0.0-0.3**:几乎无关或完全无关
+
+**不同关系类型的 score 范围参考:**
+- same: 通常 0.7-1.0(完全同义接近1.0,近义0.7-0.8)
+- contains/contained_by: 通常 0.5-0.9(直接包含0.8+,跨层级0.5-0.7)
+- coordinate: 通常 0.3-0.8(同级且上位概念近0.7+,同级但距离远0.3-0.5)
+- overlap: 通常 0.2-0.8(交集大0.6+,交集小0.2-0.4)
+- related: 通常 0.1-0.7(强相关0.5+,弱相关0.1-0.3)
+- unrelated: 通常 0.0-0.2
+
+---
+
+## 判断逻辑(按优先级)
+
+1. **A 和 B 意思相同或非常接近?** → same
+2. **A 包含 B 或 B 包含 A?** → contains 或 contained_by
+3. **A 和 B 有共同上位概念且无交集?** → coordinate
+4. **A 和 B 有交集但互不包含?** → overlap
+5. **A 和 B 有语义联系但不属于上述?** → related
+6. **A 和 B 完全无关?** → unrelated
+
+---
+
+## 输出格式(严格JSON)
+
+```json
+{
+  "relation": "same",
+  "score": 0.95,
+  "explanation": "简要说明为什么是这个关系,以及 score 的依据"
+}
+```
+
+**输出要求**:
+1. 必须严格按照上述JSON格式输出
+2. 所有字段都必须填写
+3. **relation字段**:必须是以下7个值之一:same, coordinate, contains, contained_by, overlap, related, unrelated
+4. **score字段**:必须是0-1之间的浮点数,保留2位小数
+5. **explanation字段**:必须简洁说明关系类型和评分依据(1-2句话)
+""".strip()
+
+
+def create_relation_agent(model_name: str) -> Agent:
+    """创建关系分析的 Agent
+
+    Args:
+        model_name: 模型名称
+
+    Returns:
+        Agent 实例
+    """
+    agent = Agent(
+        name="Phrase Relation Expert",
+        instructions=RELATION_SYSTEM_PROMPT,
+        model=get_model(model_name),
+        model_settings=ModelSettings(
+            temperature=0.0,
+            max_tokens=65536,
+        ),
+        tools=[],
+    )
+
+    return agent
+
+
+def parse_relation_response(response_content: str) -> dict:
+    """解析关系分析响应
+
+    Args:
+        response_content: Agent 返回的响应内容
+
+    Returns:
+        解析后的字典
+    """
+    try:
+        # 如果响应包含在 markdown 代码块中,提取 JSON 部分
+        if "```json" in response_content:
+            json_start = response_content.index("```json") + 7
+            json_end = response_content.index("```", json_start)
+            json_text = response_content[json_start:json_end].strip()
+        elif "```" in response_content:
+            json_start = response_content.index("```") + 3
+            json_end = response_content.index("```", json_start)
+            json_text = response_content[json_start:json_end].strip()
+        else:
+            json_text = response_content.strip()
+
+        return json.loads(json_text)
+    except Exception as e:
+        print(f"解析响应失败: {e}")
+        return {
+            "relation": "unrelated",
+            "score": 0.0,
+            "explanation": f"解析失败: {str(e)}"
+        }
+
+
+async def analyze_relation(
+    phrase_a: str,
+    phrase_b: str,
+    model_name: str = None,
+    context_a: str = "",
+    context_b: str = ""
+) -> dict:
+    """分析两个短语之间的关系
+
+    Args:
+        phrase_a: 第一个短语
+        phrase_b: 第二个短语
+        model_name: 使用的模型名称(可选,默认使用 client.py 中的 MODEL_NAME)
+        context_a: phrase_a 的补充上下文(可选,默认为空)
+        context_b: phrase_b 的补充上下文(可选,默认为空)
+
+    Returns:
+        关系分析结果字典:{"relation": "same", "score": 0.95, "explanation": "..."}
+    """
+    try:
+        # 如果未指定模型,使用默认模型
+        if model_name is None:
+            from lib.client import MODEL_NAME
+            model_name = MODEL_NAME
+
+        # 创建 Agent
+        agent = create_relation_agent(model_name)
+
+        # 构建任务描述
+        a_section = f"<A>\n{phrase_a}\n</A>"
+        if context_a:
+            a_section += f"\n\n<A_Context>\n{context_a}\n</A_Context>"
+
+        b_section = f"<B>\n{phrase_b}\n</B>"
+        if context_b:
+            b_section += f"\n\n<B_Context>\n{context_b}\n</B_Context>"
+
+        task_description = f"""## 本次分析任务
+
+{a_section}
+
+{b_section}
+
+请严格按照系统提示中的要求分析 <A> 和 <B> 之间的语义关系,并输出 JSON 格式的结果。"""
+
+        # 构造消息
+        messages = [{
+            "role": "user",
+            "content": [
+                {
+                    "type": "input_text",
+                    "text": task_description
+                }
+            ]
+        }]
+
+        # 使用 custom_span 追踪分析过程
+        # 截断显示内容,避免 span name 过长
+        a_short = (phrase_a[:30] + "...") if len(phrase_a) > 30 else phrase_a
+        b_short = (phrase_b[:30] + "...") if len(phrase_b) > 30 else phrase_b
+
+        with custom_span(
+            name=f"关系分析: {a_short} <-> {b_short}",
+            data={
+                "phrase_a": phrase_a,
+                "phrase_b": phrase_b,
+                "context_a": context_a if context_a else None,
+                "context_b": context_b if context_b else None,
+            }
+        ):
+            # 运行 Agent
+            result = await Runner.run(agent, input=messages)
+
+        # 解析响应
+        parsed_result = parse_relation_response(result.final_output)
+
+        return parsed_result
+
+    except Exception as e:
+        return {
+            "relation": "unrelated",
+            "score": 0.0,
+            "explanation": f"分析过程出错: {str(e)}"
+        }

+ 209 - 0
test_relation_analyzer.py

@@ -0,0 +1,209 @@
+"""
+测试 relation_analyzer 模块
+"""
+import asyncio
+from lib.relation_analyzer import analyze_relation
+
+
+async def test_all_relations():
+    """测试所有7种关系类型"""
+
+    # 测试用例:每种关系类型的典型例子
+    test_cases = [
+        # 1. same(同义)
+        {
+            "phrase_a": "医生",
+            "phrase_b": "大夫",
+            "expected_relation": "same",
+            "description": "完全同义"
+        },
+        {
+            "phrase_a": "计算机",
+            "phrase_b": "电脑",
+            "expected_relation": "same",
+            "description": "完全同义"
+        },
+
+        # 2. coordinate(同级)
+        {
+            "phrase_a": "轿车",
+            "phrase_b": "SUV",
+            "expected_relation": "coordinate",
+            "description": "都是汽车的子类"
+        },
+        {
+            "phrase_a": "苹果",
+            "phrase_b": "香蕉",
+            "expected_relation": "coordinate",
+            "description": "都是水果"
+        },
+
+        # 3. contains(包含)
+        {
+            "phrase_a": "水果",
+            "phrase_b": "苹果",
+            "expected_relation": "contains",
+            "description": "水果包含苹果"
+        },
+        {
+            "phrase_a": "汽车",
+            "phrase_b": "轿车",
+            "expected_relation": "contains",
+            "description": "汽车包含轿车"
+        },
+
+        # 4. contained_by(被包含)
+        {
+            "phrase_a": "苹果",
+            "phrase_b": "水果",
+            "expected_relation": "contained_by",
+            "description": "苹果被水果包含"
+        },
+        {
+            "phrase_a": "轿车",
+            "phrase_b": "交通工具",
+            "expected_relation": "contained_by",
+            "description": "轿车被交通工具包含"
+        },
+
+        # 5. overlap(部分重叠)
+        {
+            "phrase_a": "红苹果",
+            "phrase_b": "大苹果",
+            "expected_relation": "overlap",
+            "description": "有交集(又红又大的苹果)"
+        },
+        {
+            "phrase_a": "学生",
+            "phrase_b": "运动员",
+            "expected_relation": "overlap",
+            "description": "有交集(学生运动员)"
+        },
+
+        # 6. related(相关)
+        {
+            "phrase_a": "医生",
+            "phrase_b": "医院",
+            "expected_relation": "related",
+            "description": "工作场所关系"
+        },
+        {
+            "phrase_a": "阅读",
+            "phrase_b": "书籍",
+            "expected_relation": "related",
+            "description": "动作-对象关系"
+        },
+
+        # 7. unrelated(无关)
+        {
+            "phrase_a": "医生",
+            "phrase_b": "石头",
+            "expected_relation": "unrelated",
+            "description": "完全无关"
+        },
+        {
+            "phrase_a": "苹果",
+            "phrase_b": "数学",
+            "expected_relation": "unrelated",
+            "description": "完全无关"
+        },
+    ]
+
+    # 模型选择(根据你的配置调整)
+    model_name = "google/gemini-2.5-flash"  # 默认模型
+
+    print(f"=" * 80)
+    print(f"开始测试 relation_analyzer 模块")
+    print(f"使用模型: {model_name}")
+    print(f"测试用例数量: {len(test_cases)}")
+    print(f"=" * 80)
+    print()
+
+    results = []
+
+    for i, test_case in enumerate(test_cases, 1):
+        phrase_a = test_case["phrase_a"]
+        phrase_b = test_case["phrase_b"]
+        expected = test_case["expected_relation"]
+        description = test_case["description"]
+
+        print(f"[{i}/{len(test_cases)}] 测试: \"{phrase_a}\" <-> \"{phrase_b}\"")
+        print(f"     说明: {description}")
+        print(f"     期望关系: {expected}")
+
+        # 调用分析函数
+        result = await analyze_relation(
+            phrase_a=phrase_a,
+            phrase_b=phrase_b,
+            model_name=model_name
+        )
+
+        relation = result.get("relation", "unknown")
+        score = result.get("score", 0.0)
+        explanation = result.get("explanation", "")
+
+        # 判断是否符合预期
+        is_correct = (relation == expected)
+        status = "✓" if is_correct else "✗"
+
+        print(f"     实际关系: {relation} (score: {score:.2f}) {status}")
+        print(f"     解释: {explanation}")
+        print()
+
+        results.append({
+            "test_case": test_case,
+            "result": result,
+            "is_correct": is_correct
+        })
+
+    # 统计结果
+    correct_count = sum(1 for r in results if r["is_correct"])
+    total_count = len(results)
+    accuracy = correct_count / total_count * 100
+
+    print(f"=" * 80)
+    print(f"测试完成")
+    print(f"正确: {correct_count}/{total_count} ({accuracy:.1f}%)")
+    print(f"=" * 80)
+
+    # 显示错误的测试用例
+    errors = [r for r in results if not r["is_correct"]]
+    if errors:
+        print()
+        print("错误的测试用例:")
+        for error in errors:
+            tc = error["test_case"]
+            result = error["result"]
+            print(f"  - \"{tc['phrase_a']}\" <-> \"{tc['phrase_b']}\"")
+            print(f"    期望: {tc['expected_relation']}, 实际: {result['relation']}")
+
+    return results
+
+
+async def test_single_example():
+    """测试单个例子"""
+
+    print("测试单个例子:")
+    print()
+
+    result = await analyze_relation(
+        phrase_a="水果",
+        phrase_b="苹果",
+        model_name="google/gemini-2.5-flash"  # 默认模型
+    )
+
+    print(f"短语A: 水果")
+    print(f"短语B: 苹果")
+    print(f"关系: {result['relation']}")
+    print(f"分数: {result['score']}")
+    print(f"解释: {result['explanation']}")
+
+
+if __name__ == "__main__":
+    # 选择测试方式:
+
+    # 方式1:测试单个例子(快速验证)
+    # asyncio.run(test_single_example())
+
+    # 方式2:测试所有关系类型(完整测试)
+    asyncio.run(test_all_relations())