| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286 |
- """
- 灵感点与人设匹配分析 - Agent 框架版
- 基于 how_decode_v1.py 的 Agent 框架实现
- 参考 step1_match_inspiration_to_persona_v11.py 的业务逻辑
- """
- import asyncio
- import json
- import os
- import sys
- from typing import List, Dict
- from agents import trace
- from agents.tracing.create import custom_span
- from lib.my_trace import set_trace_smith as set_trace
- from lib.async_utils import process_tasks_with_semaphore
- from lib.match_analyzer import match_single
- from lib.data_loader import load_persona_data, load_inspiration_list, select_inspiration
- # 模型配置
- MODEL_NAME = "google/gemini-2.5-pro"
- def build_context_str(perspective_name: str, level1_name: str = None) -> str:
- """构建上下文字符串
- Args:
- perspective_name: 视角名称
- level1_name: 一级分类名称(仅在匹配二级分类时提供)
- Returns:
- 上下文字符串
- """
- if level1_name:
- # 匹配二级分类:包含视角和一级分类
- return f"""所属视角: {perspective_name}
- 一级分类: {level1_name}"""
- else:
- # 匹配一级分类:只包含视角
- return f"""所属视角: {perspective_name}"""
- # ========== 核心匹配逻辑 ==========
- async def match_single_task(task: dict, _index: int) -> dict:
- """执行单个匹配任务(异步版本)
- Args:
- task: 匹配任务,包含:
- - 灵感: 灵感点文本
- - 要素: 要素名称
- - 要素类型: "一级分类" 或 "二级分类"
- - 上下文: 上下文字符串
- _index: 任务索引(由 async_utils 传入,此处未使用)
- Returns:
- 匹配结果
- """
- inspiration = task["灵感"]
- element = task["要素"]
- context_str = task["上下文"]
- # 调用通用匹配模块(内部已包含错误处理和 custom_span 追踪)
- # B = 灵感, A = 要素, A_Context = 上下文
- match_result = await match_single(
- b_content=inspiration,
- a_content=element,
- model_name=MODEL_NAME,
- a_context=context_str # 要素的上下文
- )
- # 构建完整结果(通用字段 + 业务信息统一存储在最后)
- full_result = {
- "输入信息": {
- "B": inspiration, # 待匹配:灵感
- "A": element, # 上下文:要素
- "B_Context": "", # B的上下文(暂时为空)
- "A_Context": context_str # A的上下文:所属视角/一级分类
- },
- "匹配结果": match_result, # {"相同部分": {}, "增量部分": {}, "score": 0.0, "score说明": ""}
- "业务信息": { # 业务语义信息(统一存储在最后)
- "灵感": inspiration,
- "匹配要素": element
- }
- }
- return full_result
- # ========== 任务构建 ==========
- def build_match_tasks(
- persona_data: dict,
- inspiration: str,
- max_tasks: int = None
- ) -> List[dict]:
- """构建匹配任务列表
- Args:
- persona_data: 人设数据
- inspiration: 灵感点
- max_tasks: 最大任务数(None 表示不限制)
- Returns:
- 任务列表
- """
- tasks = []
- # 从"灵感点列表"中提取任务
- for perspective in persona_data.get("灵感点列表", []):
- if max_tasks is not None and len(tasks) >= max_tasks:
- break
- perspective_name = perspective.get("视角名称", "")
- for pattern in perspective.get("模式列表", []):
- if max_tasks is not None and len(tasks) >= max_tasks:
- break
- level1_name = pattern.get("分类名称", "")
- # 添加一级分类任务
- context_str = build_context_str(perspective_name)
- tasks.append({
- "灵感": inspiration,
- "要素": level1_name,
- "要素类型": "一级分类",
- "上下文": context_str
- })
- # 添加该一级下的所有二级分类任务
- for level2 in pattern.get("二级细分", []):
- if max_tasks is not None and len(tasks) >= max_tasks:
- break
- level2_name = level2.get("分类名称", "")
- context_str = build_context_str(perspective_name, level1_name)
- tasks.append({
- "灵感": inspiration,
- "要素": level2_name,
- "要素类型": "二级分类",
- "上下文": context_str
- })
- return tasks
- # ========== 核心业务逻辑 ==========
- async def process_inspiration_match(
- persona_data: dict,
- inspiration: str,
- max_tasks: int = None,
- max_concurrent: int = 3,
- current_time: str = None,
- log_url: str = None
- ) -> dict:
- """执行灵感与人设匹配分析(核心业务逻辑)
- Args:
- persona_data: 人设数据字典
- inspiration: 灵感点文本
- max_tasks: 最大任务数(None 表示不限制)
- max_concurrent: 最大并发数
- current_time: 当前时间戳
- log_url: 日志链接
- Returns:
- 匹配结果字典,包含元数据和匹配结果列表
- """
- # 构建匹配任务
- test_tasks = build_match_tasks(persona_data, inspiration, max_tasks)
- print(f"\n开始匹配分析: {inspiration}")
- print(f"任务数: {len(test_tasks)}, 模型: {MODEL_NAME}\n")
- # 使用 custom_span 标识整个匹配流程
- with custom_span(
- name=f"Step1: 灵感与人设匹配 - {inspiration}",
- data={
- "灵感": inspiration,
- "任务总数": len(test_tasks),
- "模型": MODEL_NAME,
- "并发数": max_concurrent,
- "步骤": "字面语义匹配分析"
- }
- ):
- # 异步并发执行匹配(match_single_task 内部已处理所有错误)
- results = await process_tasks_with_semaphore(
- test_tasks,
- match_single_task,
- max_concurrent=max_concurrent,
- show_progress=True
- )
- # 按 score 降序排序
- results.sort(key=lambda x: x.get('匹配结果', {}).get('score', 0), reverse=True)
- # 构建输出结果
- output = {
- "元数据": {
- "current_time": current_time,
- "log_url": log_url,
- "model": MODEL_NAME
- },
- "灵感": inspiration,
- "匹配结果列表": results
- }
- return output
- # ========== 主函数 ==========
- async def main(current_time: str = None, log_url: str = None, force: bool = False):
- """主函数:负责参数解析、文件读取、结果保存
- Args:
- current_time: 当前时间戳(从外部传入)
- log_url: 日志链接(从外部传入)
- force: 是否强制重新执行(跳过已存在文件检查)
- """
- # 解析命令行参数
- # 第一个参数:人设文件夹路径(默认值)
- if len(sys.argv) > 1:
- persona_dir = sys.argv[1]
- else:
- persona_dir = "data/阿里多多酱/out/人设_1110"
- # 第二个参数:灵感索引(数字)或灵感名称(字符串),默认为 0
- inspiration_arg = sys.argv[2] if len(sys.argv) > 2 else "0"
- # 第三个参数:任务数限制,默认为 None(所有任务)
- max_tasks = None if len(sys.argv) > 3 and sys.argv[3] == "all" else (
- int(sys.argv[3]) if len(sys.argv) > 3 else None
- )
- # 第四个参数:force(如果从命令行调用且有该参数,则覆盖函数参数)
- if len(sys.argv) > 4 and sys.argv[4] == "force":
- force = True
- # 加载数据(使用辅助函数,失败时自动退出)
- persona_data = load_persona_data(persona_dir)
- inspiration_list = load_inspiration_list(persona_dir)
- test_inspiration = select_inspiration(inspiration_arg, inspiration_list)
- # 构建输出文件路径
- output_dir = os.path.join(persona_dir, "how", "灵感点", test_inspiration)
- model_name_short = MODEL_NAME.replace("google/", "").replace("/", "_")
- step_name_cn = "灵感人设匹配"
- scope_prefix = f"top{max_tasks}" if max_tasks is not None else "all"
- output_filename = f"{scope_prefix}_step1_{step_name_cn}_{model_name_short}.json"
- output_file = os.path.join(output_dir, output_filename)
- # 检查文件是否已存在
- if not force and os.path.exists(output_file):
- print(f"\n✓ 输出文件已存在,跳过执行: {output_file}")
- print(f"提示: 如需重新执行,请添加 'force' 参数\n")
- return
- # 执行核心业务逻辑
- output = await process_inspiration_match(
- persona_data=persona_data,
- inspiration=test_inspiration,
- max_tasks=max_tasks,
- max_concurrent=5,
- current_time=current_time,
- log_url=log_url
- )
- # 确保目录存在
- os.makedirs(output_dir, exist_ok=True)
- # 保存结果
- with open(output_file, 'w', encoding='utf-8') as f:
- json.dump(output, f, ensure_ascii=False, indent=2)
- print(f"\n完成!结果已保存到: {output_file}")
- if log_url:
- print(f"Trace: {log_url}\n")
- if __name__ == "__main__":
- # 设置 trace
- current_time, log_url = set_trace()
- # 使用 trace 上下文包裹整个执行流程
- with trace("灵感与人设匹配"):
- asyncio.run(main(current_time, log_url))
|