|
@@ -2,19 +2,22 @@
|
|
|
搜索评估工具:搜索帖子并评估是否与账号人设匹配,提取关键词并匹配选题点。
|
|
搜索评估工具:搜索帖子并评估是否与账号人设匹配,提取关键词并匹配选题点。
|
|
|
|
|
|
|
|
处理流程:
|
|
处理流程:
|
|
|
-1. 使用 xhs(失败或空则用 zhihu)搜索帖子
|
|
|
|
|
-2. 并发对每篇帖子调用 LLM 判断人设匹配 & 提取关键词
|
|
|
|
|
-3. 对匹配人设的帖子,调用 match_derivation_to_post_points 匹配选题点
|
|
|
|
|
-4. 返回完整评估结果列表
|
|
|
|
|
|
|
+1. 接收 query_list(多个搜索 query),并发处理
|
|
|
|
|
+2. 每个 query:使用 xhs(失败或空则用 zhihu)搜索帖子
|
|
|
|
|
+3. 并发对每篇帖子调用 LLM 判断人设匹配 & 提取关键词
|
|
|
|
|
+4. 对匹配人设的帖子,调用 match_derivation_to_post_points 匹配选题点
|
|
|
|
|
+5. 返回按 query 分组的评估结果字典
|
|
|
|
|
+6. 支持本地文件缓存(.cache/search/{account_name}/{post_id}/)
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
import asyncio
|
|
import asyncio
|
|
|
|
|
+import hashlib
|
|
|
import json
|
|
import json
|
|
|
import logging
|
|
import logging
|
|
|
import re
|
|
import re
|
|
|
import sys
|
|
import sys
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
-from typing import Any, Dict, List, Optional
|
|
|
|
|
|
|
+from typing import Dict, List, Optional
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
@@ -39,6 +42,7 @@ except ImportError:
|
|
|
|
|
|
|
|
_BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
|
|
_BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
|
|
|
_TOOLS_DIR = Path(__file__).resolve().parent
|
|
_TOOLS_DIR = Path(__file__).resolve().parent
|
|
|
|
|
+_CACHE_ROOT = Path(__file__).resolve().parent.parent / ".cache" / "search"
|
|
|
|
|
|
|
|
BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
|
|
BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
|
|
|
DEFAULT_TIMEOUT = 60.0
|
|
DEFAULT_TIMEOUT = 60.0
|
|
@@ -224,89 +228,155 @@ async def _eval_single_post(
|
|
|
return result
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def _cache_key(query: str) -> str:
|
|
|
|
|
+ """将 query 转为安全的文件名:使用 MD5 哈希避免特殊字符问题"""
|
|
|
|
|
+ h = hashlib.md5(query.encode("utf-8")).hexdigest()[:12]
|
|
|
|
|
+ safe = re.sub(r'[^\w\u4e00-\u9fff]+', '_', query)[:60].strip('_')
|
|
|
|
|
+ return f"{safe}_{h}"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _get_cache_path(account_name: str, post_id: str, query: str) -> Path:
|
|
|
|
|
+ return _CACHE_ROOT / account_name / post_id / f"{_cache_key(query)}.json"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _read_cache(account_name: str, post_id: str, query: str) -> Optional[List[dict]]:
|
|
|
|
|
+ """读取缓存,存在且合法则返回帖子列表,否则返回 None"""
|
|
|
|
|
+ path = _get_cache_path(account_name, post_id, query)
|
|
|
|
|
+ if not path.is_file():
|
|
|
|
|
+ return None
|
|
|
|
|
+ try:
|
|
|
|
|
+ with open(path, "r", encoding="utf-8") as f:
|
|
|
|
|
+ data = json.load(f)
|
|
|
|
|
+ if isinstance(data, list):
|
|
|
|
|
+ logger.info("_read_cache: hit cache for query=%s, %d items", query, len(data))
|
|
|
|
|
+ return data
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning("_read_cache: failed to read cache for query=%s: %s", query, e)
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _write_cache(account_name: str, post_id: str, query: str, results: List[dict]) -> None:
|
|
|
|
|
+ """写入缓存"""
|
|
|
|
|
+ path = _get_cache_path(account_name, post_id, query)
|
|
|
|
|
+ try:
|
|
|
|
|
+ path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ with open(path, "w", encoding="utf-8") as f:
|
|
|
|
|
+ json.dump(results, f, ensure_ascii=False, indent=2)
|
|
|
|
|
+ logger.info("_write_cache: wrote cache for query=%s, %d items", query, len(results))
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.warning("_write_cache: failed to write cache for query=%s: %s", query, e)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def _search_and_eval_single_query(
|
|
|
|
|
+ query: str,
|
|
|
|
|
+ system_prompt: str,
|
|
|
|
|
+ account_name: str,
|
|
|
|
|
+ post_id: str,
|
|
|
|
|
+) -> List[dict]:
|
|
|
|
|
+ """处理单个 query 的搜索、评估、匹配流程,支持缓存"""
|
|
|
|
|
+ cached = _read_cache(account_name, post_id, query)
|
|
|
|
|
+ if cached is not None:
|
|
|
|
|
+ return cached
|
|
|
|
|
+
|
|
|
|
|
+ posts = await _search_posts(query)
|
|
|
|
|
+ if not posts:
|
|
|
|
|
+ logger.warning("_search_and_eval_single_query: no posts for query=%s", query)
|
|
|
|
|
+ _write_cache(account_name, post_id, query, [])
|
|
|
|
|
+ return []
|
|
|
|
|
+
|
|
|
|
|
+ logger.info("_search_and_eval_single_query: got %d posts for query=%s", len(posts), query)
|
|
|
|
|
+ tasks = [
|
|
|
|
|
+ _eval_single_post(post, system_prompt, account_name, post_id)
|
|
|
|
|
+ for post in posts
|
|
|
|
|
+ ]
|
|
|
|
|
+ results: List[dict] = await asyncio.gather(*tasks)
|
|
|
|
|
+
|
|
|
|
|
+ _write_cache(account_name, post_id, query, results)
|
|
|
|
|
+ return results
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
@tool(
|
|
@tool(
|
|
|
description=(
|
|
description=(
|
|
|
"搜索帖子并评估是否与账号人设匹配,提取帖子关键词并与帖子选题点进行匹配。"
|
|
"搜索帖子并评估是否与账号人设匹配,提取帖子关键词并与帖子选题点进行匹配。"
|
|
|
- "参数:account_name 账号名称;post_id 帖子ID;query 搜索词。"
|
|
|
|
|
|
|
+ "参数:account_name 账号名称;post_id 帖子ID;query_list 搜索词列表。"
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
async def search_and_eval(
|
|
async def search_and_eval(
|
|
|
account_name: str,
|
|
account_name: str,
|
|
|
post_id: str,
|
|
post_id: str,
|
|
|
- query: str,
|
|
|
|
|
|
|
+ query_list: List[str],
|
|
|
context: Optional[ToolContext] = None,
|
|
context: Optional[ToolContext] = None,
|
|
|
) -> ToolResult:
|
|
) -> ToolResult:
|
|
|
"""
|
|
"""
|
|
|
搜索帖子并评估是否与账号人设匹配,提取关键词并匹配选题点。
|
|
搜索帖子并评估是否与账号人设匹配,提取关键词并匹配选题点。
|
|
|
|
|
+ 支持多个 query 并发处理,结果按 query 分组返回。
|
|
|
|
|
+ 本地文件缓存:.cache/search/{account_name}/{post_id}/ 下每个 query 一个 JSON 文件。
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
account_name: 账号名称,用于读取人设数据和选题点文件
|
|
account_name: 账号名称,用于读取人设数据和选题点文件
|
|
|
post_id: 帖子ID,用于定位选题点匹配文件
|
|
post_id: 帖子ID,用于定位选题点匹配文件
|
|
|
- query: 搜索词
|
|
|
|
|
|
|
+ query_list: 搜索词列表,每个元素为一个 query 字符串
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
- ToolResult,output 为 JSON 格式的帖子评估结果列表,每项包含:
|
|
|
|
|
- - channel_content_id: 帖子ID
|
|
|
|
|
- - title: 标题
|
|
|
|
|
- - body_text: 正文
|
|
|
|
|
- - images: 图集URL列表
|
|
|
|
|
|
|
+ ToolResult,output 为 JSON 格式的按 query 分组的结果字典:
|
|
|
|
|
+ {
|
|
|
|
|
+ "query1": [帖子评估结果列表],
|
|
|
|
|
+ "query2": [帖子评估结果列表],
|
|
|
|
|
+ ...
|
|
|
|
|
+ }
|
|
|
|
|
+ 每个帖子评估结果包含:
|
|
|
|
|
+ - channel_content_id, title, body_text, images
|
|
|
- persona_match_result: 是否与账号人设匹配(bool)
|
|
- persona_match_result: 是否与账号人设匹配(bool)
|
|
|
- post_keywords: 提取的帖子关键词列表
|
|
- post_keywords: 提取的帖子关键词列表
|
|
|
- - point_match_results: 关键词与帖子选题点的匹配结果列表,
|
|
|
|
|
- 每项含「推导选题点」「帖子选题点」「匹配分数」
|
|
|
|
|
|
|
+ - point_match_results: 关键词与帖子选题点的匹配结果列表
|
|
|
"""
|
|
"""
|
|
|
logger.info(
|
|
logger.info(
|
|
|
- "search_and_eval: account_name=%s post_id=%s query=%s",
|
|
|
|
|
|
|
+ "search_and_eval: account_name=%s post_id=%s query_list=%s",
|
|
|
account_name,
|
|
account_name,
|
|
|
post_id,
|
|
post_id,
|
|
|
- query,
|
|
|
|
|
|
|
+ query_list,
|
|
|
)
|
|
)
|
|
|
- try:
|
|
|
|
|
- # 1. 搜索帖子
|
|
|
|
|
- posts = await _search_posts(query)
|
|
|
|
|
- if not posts:
|
|
|
|
|
- logger.warning("search_and_eval: no posts found for query=%s", query)
|
|
|
|
|
- return ToolResult(
|
|
|
|
|
- title=f"搜索评估: {query}",
|
|
|
|
|
- output="[]",
|
|
|
|
|
- long_term_memory=f"search_and_eval: query='{query}', no posts found",
|
|
|
|
|
- )
|
|
|
|
|
|
|
|
|
|
- logger.info("search_and_eval: got %d posts, loading prompt and persona", len(posts))
|
|
|
|
|
- # 2. 构建 system prompt(替换账号人设)
|
|
|
|
|
|
|
+ if not query_list:
|
|
|
|
|
+ return ToolResult(
|
|
|
|
|
+ title="搜索评估: 空 query_list",
|
|
|
|
|
+ output="{}",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
prompt_template = _load_match_and_extract_prompt()
|
|
prompt_template = _load_match_and_extract_prompt()
|
|
|
persona_text = _load_persona_text(account_name)
|
|
persona_text = _load_persona_text(account_name)
|
|
|
system_prompt = prompt_template.replace("{persona}", persona_text)
|
|
system_prompt = prompt_template.replace("{persona}", persona_text)
|
|
|
|
|
|
|
|
- # 3. 并发评估所有帖子
|
|
|
|
|
tasks = [
|
|
tasks = [
|
|
|
- _eval_single_post(post, system_prompt, account_name, post_id)
|
|
|
|
|
- for post in posts
|
|
|
|
|
|
|
+ _search_and_eval_single_query(q, system_prompt, account_name, post_id)
|
|
|
|
|
+ for q in query_list
|
|
|
]
|
|
]
|
|
|
- results: List[dict] = await asyncio.gather(*tasks)
|
|
|
|
|
|
|
+ all_results: List[List[dict]] = await asyncio.gather(*tasks)
|
|
|
|
|
+
|
|
|
|
|
+ grouped: Dict[str, List[dict]] = {}
|
|
|
|
|
+ total_posts = 0
|
|
|
|
|
+ total_matched = 0
|
|
|
|
|
+ for query, results in zip(query_list, all_results):
|
|
|
|
|
+ grouped[query] = results
|
|
|
|
|
+ total_posts += len(results)
|
|
|
|
|
+ total_matched += sum(1 for r in results if r.get("persona_match_result"))
|
|
|
|
|
|
|
|
- matched_count = sum(1 for r in results if r.get("persona_match_result"))
|
|
|
|
|
- error_count = sum(1 for r in results if r.get("error"))
|
|
|
|
|
logger.info(
|
|
logger.info(
|
|
|
- "search_and_eval: done. total=%d persona_matched=%d errors=%d",
|
|
|
|
|
- len(results),
|
|
|
|
|
- matched_count,
|
|
|
|
|
- error_count,
|
|
|
|
|
|
|
+ "search_and_eval: done. queries=%d total_posts=%d persona_matched=%d",
|
|
|
|
|
+ len(query_list),
|
|
|
|
|
+ total_posts,
|
|
|
|
|
+ total_matched,
|
|
|
)
|
|
)
|
|
|
- output = json.dumps(results, ensure_ascii=False, indent=2)
|
|
|
|
|
- logger.info("search_and_eval: output=%s", output)
|
|
|
|
|
|
|
+ output = json.dumps(grouped, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
return ToolResult(
|
|
return ToolResult(
|
|
|
title=(
|
|
title=(
|
|
|
- f"搜索评估: {query} "
|
|
|
|
|
- f"(共 {len(results)} 条,{matched_count} 条匹配人设)"
|
|
|
|
|
|
|
+ f"搜索评估: {len(query_list)} 个 query "
|
|
|
|
|
+ f"(共 {total_posts} 条帖子,{total_matched} 条匹配人设)"
|
|
|
),
|
|
),
|
|
|
output=output,
|
|
output=output,
|
|
|
- long_term_memory=(
|
|
|
|
|
- f"search_and_eval: query='{query}', "
|
|
|
|
|
- f"found {len(results)} posts, {matched_count} matched persona"
|
|
|
|
|
- ),
|
|
|
|
|
- metadata={"items": results},
|
|
|
|
|
|
|
+ metadata={"search_and_eval summary": f"{len(query_list)} queries, found {total_posts} posts, {total_matched} matched persona"},
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
@@ -329,7 +399,7 @@ def main() -> None:
|
|
|
)
|
|
)
|
|
|
account_name = "家有大志"
|
|
account_name = "家有大志"
|
|
|
post_id = "68fb6a5c000000000302e5de"
|
|
post_id = "68fb6a5c000000000302e5de"
|
|
|
- query = "柴犬 鞋子"
|
|
|
|
|
|
|
+ query_list = ["柴犬 鞋子", "柴犬 日常"]
|
|
|
|
|
|
|
|
async def run():
|
|
async def run():
|
|
|
if ToolResult is None:
|
|
if ToolResult is None:
|
|
@@ -338,19 +408,21 @@ def main() -> None:
|
|
|
result = await search_and_eval(
|
|
result = await search_and_eval(
|
|
|
account_name=account_name,
|
|
account_name=account_name,
|
|
|
post_id=post_id,
|
|
post_id=post_id,
|
|
|
- query=query,
|
|
|
|
|
|
|
+ query_list=query_list,
|
|
|
)
|
|
)
|
|
|
if result.error:
|
|
if result.error:
|
|
|
print(f"Error: {result.error}")
|
|
print(f"Error: {result.error}")
|
|
|
else:
|
|
else:
|
|
|
print(result.title)
|
|
print(result.title)
|
|
|
- data = json.loads(result.output)
|
|
|
|
|
- for item in data:
|
|
|
|
|
- print(
|
|
|
|
|
- f" [{item.get('persona_match_result')}] {item.get('title', '')[:30]}"
|
|
|
|
|
- f" | keywords: {item.get('post_keywords')}"
|
|
|
|
|
- f" | matches: {len(item.get('point_match_results', []))}"
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ grouped = json.loads(result.output)
|
|
|
|
|
+ for query, items in grouped.items():
|
|
|
|
|
+ print(f"\n === query: {query} ({len(items)} posts) ===")
|
|
|
|
|
+ for item in items:
|
|
|
|
|
+ print(
|
|
|
|
|
+ f" [{item.get('persona_match_result')}] {item.get('title', '')[:30]}"
|
|
|
|
|
+ f" | keywords: {item.get('post_keywords')}"
|
|
|
|
|
+ f" | matches: {len(item.get('point_match_results', []))}"
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
asyncio.run(run())
|
|
asyncio.run(run())
|
|
|
|
|
|