| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- """
- 内容搜索缓存(磁盘持久化)
- 搜索结果按 trace_id 隔离,同一 Agent session 内的 CLI 多次调用也能复用。
- 文件格式:<cwd>/.cache/content_search/{trace_id}.json
- 锚在调用方 CWD 的 .cache/ 下,每个项目隔离且 gitignore 友好。
- """
- import json
- import os
- import time
- from pathlib import Path
- from typing import Any, Dict, List, Optional
- # CWD 锚定 —— 每个调用项目有独立缓存目录,避免 /tmp 跨会话污染
- _CACHE_DIR = Path(os.getcwd()) / ".cache" / "content_search"
- _CACHE_DIR.mkdir(parents=True, exist_ok=True)
- _CACHE_TTL = 3600 # 1 小时过期
- def _cache_path(trace_id: str) -> Path:
- safe_id = trace_id.replace("/", "_").replace("..", "_")
- return _CACHE_DIR / f"{safe_id}.json"
- def _load_raw(trace_id: str) -> dict:
- p = _cache_path(trace_id)
- if not p.exists():
- return {}
- try:
- data = json.loads(p.read_text("utf-8"))
- # 检查过期
- if time.time() - data.get("_ts", 0) > _CACHE_TTL:
- p.unlink(missing_ok=True)
- return {}
- return data
- except Exception:
- return {}
- def _save_raw(trace_id: str, data: dict) -> None:
- data["_ts"] = time.time()
- try:
- _cache_path(trace_id).write_text(
- json.dumps(data, ensure_ascii=False), encoding="utf-8"
- )
- except Exception:
- pass
- def save_search_results(
- trace_id: str,
- platform: str,
- keyword: str,
- posts: List[Dict[str, Any]],
- ) -> None:
- """保存搜索结果到磁盘缓存(保留历史搜索)"""
- data = _load_raw(trace_id)
- key = f"search:{platform}"
- # 初始化或获取现有的历史记录结构
- if key not in data or not isinstance(data[key], dict) or "history" not in data[key]:
- # 兼容旧格式:如果是旧的单次搜索格式,转换为历史格式
- old_data = data.get(key)
- data[key] = {"history": [], "latest_index": -1}
- if old_data and isinstance(old_data, dict) and "keyword" in old_data:
- # 保留旧数据作为第一条历史记录
- data[key]["history"].append({
- "timestamp": time.time(),
- "keyword": old_data["keyword"],
- "posts": old_data.get("posts", []),
- })
- data[key]["latest_index"] = 0
- # 添加新的搜索记录
- data[key]["history"].append({
- "timestamp": time.time(),
- "keyword": keyword,
- "posts": posts,
- })
- data[key]["latest_index"] = len(data[key]["history"]) - 1
- _save_raw(trace_id, data)
- def get_cached_post(
- trace_id: str,
- platform: str,
- index: int,
- ) -> Optional[Dict[str, Any]]:
- """按索引从缓存取一条完整记录(1-based),默认从最新搜索结果中获取"""
- data = _load_raw(trace_id)
- entry = data.get(f"search:{platform}")
- if not entry:
- return None
- # 支持新格式(历史列表)和旧格式(单次搜索)
- if isinstance(entry, dict) and "history" in entry:
- # 新格式:从最新的搜索结果中获取
- latest_idx = entry.get("latest_index", -1)
- if latest_idx < 0 or latest_idx >= len(entry["history"]):
- return None
- posts = entry["history"][latest_idx].get("posts", [])
- else:
- # 旧格式:兼容处理
- posts = entry.get("posts", [])
- if 1 <= index <= len(posts):
- return posts[index - 1]
- return None
- def update_post_field(
- trace_id: str,
- platform: str,
- content_id: str,
- field: str,
- value: Any,
- ) -> bool:
- """按 channel_content_id 定位缓存里的 post,更新其某个字段并写回磁盘。
- 返回是否实际写入(False 表示未找到匹配 post 或缓存不存在)。
- 在 detail() 异步拉到补充数据后调用,让数据持久化到 cache,
- 供后续 extract_sources 等离线流程读取。
- """
- if not trace_id or not content_id:
- return False
- data = _load_raw(trace_id)
- entry = data.get(f"search:{platform}")
- if not isinstance(entry, dict):
- return False
- if "history" in entry:
- histories = entry.get("history", [])
- else:
- histories = [entry]
- cid_str = str(content_id)
- updated = False
- # Match by channel_content_id (primary, used by X / aigc-channel platforms)
- # or video_id (fallback, used by YouTube whose post id field is named differently).
- for hist in histories:
- for post in hist.get("posts", []) or []:
- if not isinstance(post, dict):
- continue
- if (
- str(post.get("channel_content_id", "")) == cid_str
- or str(post.get("video_id", "")) == cid_str
- ):
- post[field] = value
- updated = True
- if updated:
- _save_raw(trace_id, data)
- return updated
- def get_cached_search_info(trace_id: str, platform: str) -> Optional[Dict[str, Any]]:
- """获取缓存的搜索信息(keyword + 总条数),用于错误提示"""
- data = _load_raw(trace_id)
- entry = data.get(f"search:{platform}")
- if not entry:
- return None
- # 支持新格式(历史列表)和旧格式(单次搜索)
- if isinstance(entry, dict) and "history" in entry:
- # 新格式:返回最新搜索的信息
- latest_idx = entry.get("latest_index", -1)
- if latest_idx < 0 or latest_idx >= len(entry["history"]):
- return None
- latest = entry["history"][latest_idx]
- return {"keyword": latest.get("keyword"), "total": len(latest.get("posts", []))}
- else:
- # 旧格式:兼容处理
- return {"keyword": entry.get("keyword"), "total": len(entry.get("posts", []))}
|