cache.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. """
  2. 内容搜索缓存(磁盘持久化)
  3. 搜索结果按 trace_id 隔离,同一 Agent session 内的 CLI 多次调用也能复用。
  4. 文件格式:<cwd>/.cache/content_search/{trace_id}.json
  5. 锚在调用方 CWD 的 .cache/ 下,每个项目隔离且 gitignore 友好。
  6. """
  7. import json
  8. import os
  9. import time
  10. from pathlib import Path
  11. from typing import Any, Dict, List, Optional
  12. # CWD 锚定 —— 每个调用项目有独立缓存目录,避免 /tmp 跨会话污染
  13. _CACHE_DIR = Path(os.getcwd()) / ".cache" / "content_search"
  14. _CACHE_DIR.mkdir(parents=True, exist_ok=True)
  15. _CACHE_TTL = 3600 # 1 小时过期
  16. def _cache_path(trace_id: str) -> Path:
  17. safe_id = trace_id.replace("/", "_").replace("..", "_")
  18. return _CACHE_DIR / f"{safe_id}.json"
  19. def _load_raw(trace_id: str) -> dict:
  20. p = _cache_path(trace_id)
  21. if not p.exists():
  22. return {}
  23. try:
  24. data = json.loads(p.read_text("utf-8"))
  25. # 检查过期
  26. if time.time() - data.get("_ts", 0) > _CACHE_TTL:
  27. p.unlink(missing_ok=True)
  28. return {}
  29. return data
  30. except Exception:
  31. return {}
  32. def _save_raw(trace_id: str, data: dict) -> None:
  33. data["_ts"] = time.time()
  34. try:
  35. _cache_path(trace_id).write_text(
  36. json.dumps(data, ensure_ascii=False), encoding="utf-8"
  37. )
  38. except Exception:
  39. pass
  40. def save_search_results(
  41. trace_id: str,
  42. platform: str,
  43. keyword: str,
  44. posts: List[Dict[str, Any]],
  45. ) -> None:
  46. """保存搜索结果到磁盘缓存(保留历史搜索)"""
  47. data = _load_raw(trace_id)
  48. key = f"search:{platform}"
  49. # 初始化或获取现有的历史记录结构
  50. if key not in data or not isinstance(data[key], dict) or "history" not in data[key]:
  51. # 兼容旧格式:如果是旧的单次搜索格式,转换为历史格式
  52. old_data = data.get(key)
  53. data[key] = {"history": [], "latest_index": -1}
  54. if old_data and isinstance(old_data, dict) and "keyword" in old_data:
  55. # 保留旧数据作为第一条历史记录
  56. data[key]["history"].append({
  57. "timestamp": time.time(),
  58. "keyword": old_data["keyword"],
  59. "posts": old_data.get("posts", []),
  60. })
  61. data[key]["latest_index"] = 0
  62. # 添加新的搜索记录
  63. data[key]["history"].append({
  64. "timestamp": time.time(),
  65. "keyword": keyword,
  66. "posts": posts,
  67. })
  68. data[key]["latest_index"] = len(data[key]["history"]) - 1
  69. _save_raw(trace_id, data)
  70. def get_cached_post(
  71. trace_id: str,
  72. platform: str,
  73. index: int,
  74. ) -> Optional[Dict[str, Any]]:
  75. """按索引从缓存取一条完整记录(1-based),默认从最新搜索结果中获取"""
  76. data = _load_raw(trace_id)
  77. entry = data.get(f"search:{platform}")
  78. if not entry:
  79. return None
  80. # 支持新格式(历史列表)和旧格式(单次搜索)
  81. if isinstance(entry, dict) and "history" in entry:
  82. # 新格式:从最新的搜索结果中获取
  83. latest_idx = entry.get("latest_index", -1)
  84. if latest_idx < 0 or latest_idx >= len(entry["history"]):
  85. return None
  86. posts = entry["history"][latest_idx].get("posts", [])
  87. else:
  88. # 旧格式:兼容处理
  89. posts = entry.get("posts", [])
  90. if 1 <= index <= len(posts):
  91. return posts[index - 1]
  92. return None
  93. def get_cached_search_info(trace_id: str, platform: str) -> Optional[Dict[str, Any]]:
  94. """获取缓存的搜索信息(keyword + 总条数),用于错误提示"""
  95. data = _load_raw(trace_id)
  96. entry = data.get(f"search:{platform}")
  97. if not entry:
  98. return None
  99. # 支持新格式(历史列表)和旧格式(单次搜索)
  100. if isinstance(entry, dict) and "history" in entry:
  101. # 新格式:返回最新搜索的信息
  102. latest_idx = entry.get("latest_index", -1)
  103. if latest_idx < 0 or latest_idx >= len(entry["history"]):
  104. return None
  105. latest = entry["history"][latest_idx]
  106. return {"keyword": latest.get("keyword"), "total": len(latest.get("posts", []))}
  107. else:
  108. # 旧格式:兼容处理
  109. return {"keyword": entry.get("keyword"), "total": len(entry.get("posts", []))}