cache.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. """
  2. 内容搜索缓存(磁盘持久化)
  3. 搜索结果按 trace_id 隔离,同一 Agent session 内的 CLI 多次调用也能复用。
  4. 文件格式:<cwd>/.cache/content_search/{trace_id}.json
  5. 锚在调用方 CWD 的 .cache/ 下,每个项目隔离且 gitignore 友好。
  6. """
  7. import json
  8. import os
  9. import time
  10. from pathlib import Path
  11. from typing import Any, Dict, List, Optional
  12. # CWD 锚定 —— 每个调用项目有独立缓存目录,避免 /tmp 跨会话污染
  13. _CACHE_DIR = Path(os.getcwd()) / ".cache" / "content_search"
  14. _CACHE_DIR.mkdir(parents=True, exist_ok=True)
  15. _CACHE_TTL = 3600 # 1 小时过期
  16. def _cache_path(trace_id: str) -> Path:
  17. safe_id = trace_id.replace("/", "_").replace("..", "_")
  18. return _CACHE_DIR / f"{safe_id}.json"
  19. def _load_raw(trace_id: str) -> dict:
  20. p = _cache_path(trace_id)
  21. if not p.exists():
  22. return {}
  23. try:
  24. data = json.loads(p.read_text("utf-8"))
  25. # 检查过期
  26. if time.time() - data.get("_ts", 0) > _CACHE_TTL:
  27. p.unlink(missing_ok=True)
  28. return {}
  29. return data
  30. except Exception:
  31. return {}
  32. def _save_raw(trace_id: str, data: dict) -> None:
  33. data["_ts"] = time.time()
  34. try:
  35. _cache_path(trace_id).write_text(
  36. json.dumps(data, ensure_ascii=False), encoding="utf-8"
  37. )
  38. except Exception:
  39. pass
  40. def save_search_results(
  41. trace_id: str,
  42. platform: str,
  43. keyword: str,
  44. posts: List[Dict[str, Any]],
  45. ) -> None:
  46. """保存搜索结果到磁盘缓存"""
  47. data = _load_raw(trace_id)
  48. # 每个 platform 只保留最近一次搜索
  49. data[f"search:{platform}"] = {
  50. "keyword": keyword,
  51. "posts": posts,
  52. }
  53. _save_raw(trace_id, data)
  54. def get_cached_post(
  55. trace_id: str,
  56. platform: str,
  57. index: int,
  58. ) -> Optional[Dict[str, Any]]:
  59. """按索引从缓存取一条完整记录(1-based)"""
  60. data = _load_raw(trace_id)
  61. entry = data.get(f"search:{platform}")
  62. if not entry:
  63. return None
  64. posts = entry.get("posts", [])
  65. if 1 <= index <= len(posts):
  66. return posts[index - 1]
  67. return None
  68. def get_cached_search_info(trace_id: str, platform: str) -> Optional[Dict[str, Any]]:
  69. """获取缓存的搜索信息(keyword + 总条数),用于错误提示"""
  70. data = _load_raw(trace_id)
  71. entry = data.get(f"search:{platform}")
  72. if not entry:
  73. return None
  74. return {"keyword": entry.get("keyword"), "total": len(entry.get("posts", []))}