cache.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. """
  2. 内容搜索缓存(磁盘持久化)
  3. 搜索结果按 trace_id 隔离,同一 Agent session 内的 CLI 多次调用也能复用。
  4. 文件格式:/tmp/content_cache_{trace_id}.json
  5. """
  6. import json
  7. import os
  8. import time
  9. from pathlib import Path
  10. from typing import Any, Dict, List, Optional
  11. import tempfile
  12. _CACHE_DIR = Path(tempfile.gettempdir()) / "agent_content_cache"
  13. _CACHE_DIR.mkdir(parents=True, exist_ok=True)
  14. _CACHE_TTL = 3600 # 1 小时过期
  15. def _cache_path(trace_id: str) -> Path:
  16. safe_id = trace_id.replace("/", "_").replace("..", "_")
  17. return _CACHE_DIR / f"content_cache_{safe_id}.json"
  18. def _load_raw(trace_id: str) -> dict:
  19. p = _cache_path(trace_id)
  20. if not p.exists():
  21. return {}
  22. try:
  23. data = json.loads(p.read_text("utf-8"))
  24. # 检查过期
  25. if time.time() - data.get("_ts", 0) > _CACHE_TTL:
  26. p.unlink(missing_ok=True)
  27. return {}
  28. return data
  29. except Exception:
  30. return {}
  31. def _save_raw(trace_id: str, data: dict) -> None:
  32. data["_ts"] = time.time()
  33. try:
  34. _cache_path(trace_id).write_text(
  35. json.dumps(data, ensure_ascii=False), encoding="utf-8"
  36. )
  37. except Exception:
  38. pass
  39. def save_search_results(
  40. trace_id: str,
  41. platform: str,
  42. keyword: str,
  43. posts: List[Dict[str, Any]],
  44. ) -> None:
  45. """保存搜索结果到磁盘缓存"""
  46. data = _load_raw(trace_id)
  47. # 每个 platform 只保留最近一次搜索
  48. data[f"search:{platform}"] = {
  49. "keyword": keyword,
  50. "posts": posts,
  51. }
  52. _save_raw(trace_id, data)
  53. def get_cached_post(
  54. trace_id: str,
  55. platform: str,
  56. index: int,
  57. ) -> Optional[Dict[str, Any]]:
  58. """按索引从缓存取一条完整记录(1-based)"""
  59. data = _load_raw(trace_id)
  60. entry = data.get(f"search:{platform}")
  61. if not entry:
  62. return None
  63. posts = entry.get("posts", [])
  64. if 1 <= index <= len(posts):
  65. return posts[index - 1]
  66. return None
  67. def get_cached_search_info(trace_id: str, platform: str) -> Optional[Dict[str, Any]]:
  68. """获取缓存的搜索信息(keyword + 总条数),用于错误提示"""
  69. data = _load_raw(trace_id)
  70. entry = data.get(f"search:{platform}")
  71. if not entry:
  72. return None
  73. return {"keyword": entry.get("keyword"), "total": len(entry.get("posts", []))}