cache.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. """
  2. 内容搜索缓存(磁盘持久化)
  3. 搜索结果按 trace_id 隔离,同一 Agent session 内的 CLI 多次调用也能复用。
  4. 文件格式:/tmp/content_cache_{trace_id}.json
  5. """
  6. import json
  7. import os
  8. import time
  9. from pathlib import Path
  10. from typing import Any, Dict, List, Optional
  11. _CACHE_DIR = Path("/tmp")
  12. _CACHE_TTL = 3600 # 1 小时过期
  13. def _cache_path(trace_id: str) -> Path:
  14. safe_id = trace_id.replace("/", "_").replace("..", "_")
  15. return _CACHE_DIR / f"content_cache_{safe_id}.json"
  16. def _load_raw(trace_id: str) -> dict:
  17. p = _cache_path(trace_id)
  18. if not p.exists():
  19. return {}
  20. try:
  21. data = json.loads(p.read_text("utf-8"))
  22. # 检查过期
  23. if time.time() - data.get("_ts", 0) > _CACHE_TTL:
  24. p.unlink(missing_ok=True)
  25. return {}
  26. return data
  27. except Exception:
  28. return {}
  29. def _save_raw(trace_id: str, data: dict) -> None:
  30. data["_ts"] = time.time()
  31. try:
  32. _cache_path(trace_id).write_text(
  33. json.dumps(data, ensure_ascii=False), encoding="utf-8"
  34. )
  35. except Exception:
  36. pass
  37. def save_search_results(
  38. trace_id: str,
  39. platform: str,
  40. keyword: str,
  41. posts: List[Dict[str, Any]],
  42. ) -> None:
  43. """保存搜索结果到磁盘缓存"""
  44. data = _load_raw(trace_id)
  45. # 每个 platform 只保留最近一次搜索
  46. data[f"search:{platform}"] = {
  47. "keyword": keyword,
  48. "posts": posts,
  49. }
  50. _save_raw(trace_id, data)
  51. def get_cached_post(
  52. trace_id: str,
  53. platform: str,
  54. index: int,
  55. ) -> Optional[Dict[str, Any]]:
  56. """按索引从缓存取一条完整记录(1-based)"""
  57. data = _load_raw(trace_id)
  58. entry = data.get(f"search:{platform}")
  59. if not entry:
  60. return None
  61. posts = entry.get("posts", [])
  62. if 1 <= index <= len(posts):
  63. return posts[index - 1]
  64. return None
  65. def get_cached_search_info(trace_id: str, platform: str) -> Optional[Dict[str, Any]]:
  66. """获取缓存的搜索信息(keyword + 总条数),用于错误提示"""
  67. data = _load_raw(trace_id)
  68. entry = data.get(f"search:{platform}")
  69. if not entry:
  70. return None
  71. return {"keyword": entry.get("keyword"), "total": len(entry.get("posts", []))}