cache.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. """
  2. 内容搜索缓存(磁盘持久化)
  3. 搜索结果按 trace_id 隔离,同一 Agent session 内的 CLI 多次调用也能复用。
  4. 文件格式:<cwd>/.cache/content_search/{trace_id}.json
  5. 锚在调用方 CWD 的 .cache/ 下,每个项目隔离且 gitignore 友好。
  6. """
  7. import json
  8. import os
  9. import time
  10. from pathlib import Path
  11. from typing import Any, Dict, List, Optional
  12. # CWD 锚定 —— 每个调用项目有独立缓存目录,避免 /tmp 跨会话污染
  13. _CACHE_DIR = Path(os.getcwd()) / ".cache" / "content_search"
  14. _CACHE_DIR.mkdir(parents=True, exist_ok=True)
  15. _CACHE_TTL = 3600 # 1 小时过期
  16. def _cache_path(trace_id: str) -> Path:
  17. safe_id = trace_id.replace("/", "_").replace("..", "_")
  18. return _CACHE_DIR / f"{safe_id}.json"
  19. def _load_raw(trace_id: str) -> dict:
  20. p = _cache_path(trace_id)
  21. if not p.exists():
  22. return {}
  23. try:
  24. data = json.loads(p.read_text("utf-8"))
  25. # 检查过期
  26. if time.time() - data.get("_ts", 0) > _CACHE_TTL:
  27. p.unlink(missing_ok=True)
  28. return {}
  29. return data
  30. except Exception:
  31. return {}
  32. def _save_raw(trace_id: str, data: dict) -> None:
  33. data["_ts"] = time.time()
  34. try:
  35. _cache_path(trace_id).write_text(
  36. json.dumps(data, ensure_ascii=False), encoding="utf-8"
  37. )
  38. except Exception:
  39. pass
  40. def save_search_results(
  41. trace_id: str,
  42. platform: str,
  43. keyword: str,
  44. posts: List[Dict[str, Any]],
  45. ) -> None:
  46. """保存搜索结果到磁盘缓存(保留历史搜索)"""
  47. data = _load_raw(trace_id)
  48. key = f"search:{platform}"
  49. # 初始化或获取现有的历史记录结构
  50. if key not in data or not isinstance(data[key], dict) or "history" not in data[key]:
  51. # 兼容旧格式:如果是旧的单次搜索格式,转换为历史格式
  52. old_data = data.get(key)
  53. data[key] = {"history": [], "latest_index": -1}
  54. if old_data and isinstance(old_data, dict) and "keyword" in old_data:
  55. # 保留旧数据作为第一条历史记录
  56. data[key]["history"].append({
  57. "timestamp": time.time(),
  58. "keyword": old_data["keyword"],
  59. "posts": old_data.get("posts", []),
  60. })
  61. data[key]["latest_index"] = 0
  62. # 添加新的搜索记录
  63. data[key]["history"].append({
  64. "timestamp": time.time(),
  65. "keyword": keyword,
  66. "posts": posts,
  67. })
  68. data[key]["latest_index"] = len(data[key]["history"]) - 1
  69. _save_raw(trace_id, data)
  70. def get_cached_post(
  71. trace_id: str,
  72. platform: str,
  73. index: int,
  74. ) -> Optional[Dict[str, Any]]:
  75. """按索引从缓存取一条完整记录(1-based),默认从最新搜索结果中获取"""
  76. data = _load_raw(trace_id)
  77. entry = data.get(f"search:{platform}")
  78. if not entry:
  79. return None
  80. # 支持新格式(历史列表)和旧格式(单次搜索)
  81. if isinstance(entry, dict) and "history" in entry:
  82. # 新格式:从最新的搜索结果中获取
  83. latest_idx = entry.get("latest_index", -1)
  84. if latest_idx < 0 or latest_idx >= len(entry["history"]):
  85. return None
  86. posts = entry["history"][latest_idx].get("posts", [])
  87. else:
  88. # 旧格式:兼容处理
  89. posts = entry.get("posts", [])
  90. if 1 <= index <= len(posts):
  91. return posts[index - 1]
  92. return None
  93. def update_post_field(
  94. trace_id: str,
  95. platform: str,
  96. content_id: str,
  97. field: str,
  98. value: Any,
  99. ) -> bool:
  100. """按 channel_content_id 定位缓存里的 post,更新其某个字段并写回磁盘。
  101. 返回是否实际写入(False 表示未找到匹配 post 或缓存不存在)。
  102. 在 detail() 异步拉到补充数据后调用,让数据持久化到 cache,
  103. 供后续 extract_sources 等离线流程读取。
  104. """
  105. if not trace_id or not content_id:
  106. return False
  107. data = _load_raw(trace_id)
  108. entry = data.get(f"search:{platform}")
  109. if not isinstance(entry, dict):
  110. return False
  111. if "history" in entry:
  112. histories = entry.get("history", [])
  113. else:
  114. histories = [entry]
  115. cid_str = str(content_id)
  116. updated = False
  117. # Match by channel_content_id (primary, used by X / aigc-channel platforms)
  118. # or video_id (fallback, used by YouTube whose post id field is named differently).
  119. for hist in histories:
  120. for post in hist.get("posts", []) or []:
  121. if not isinstance(post, dict):
  122. continue
  123. if (
  124. str(post.get("channel_content_id", "")) == cid_str
  125. or str(post.get("video_id", "")) == cid_str
  126. ):
  127. post[field] = value
  128. updated = True
  129. if updated:
  130. _save_raw(trace_id, data)
  131. return updated
  132. def get_cached_search_info(trace_id: str, platform: str) -> Optional[Dict[str, Any]]:
  133. """获取缓存的搜索信息(keyword + 总条数),用于错误提示"""
  134. data = _load_raw(trace_id)
  135. entry = data.get(f"search:{platform}")
  136. if not entry:
  137. return None
  138. # 支持新格式(历史列表)和旧格式(单次搜索)
  139. if isinstance(entry, dict) and "history" in entry:
  140. # 新格式:返回最新搜索的信息
  141. latest_idx = entry.get("latest_index", -1)
  142. if latest_idx < 0 or latest_idx >= len(entry["history"]):
  143. return None
  144. latest = entry["history"][latest_idx]
  145. return {"keyword": latest.get("keyword"), "total": len(latest.get("posts", []))}
  146. else:
  147. # 旧格式:兼容处理
  148. return {"keyword": entry.get("keyword"), "total": len(entry.get("posts", []))}