| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- """
- 内容搜索缓存(磁盘持久化)
- 搜索结果按 trace_id 隔离,同一 Agent session 内的 CLI 多次调用也能复用。
- 文件格式:<cwd>/.cache/content_search/{trace_id}.json
- 锚在调用方 CWD 的 .cache/ 下,每个项目隔离且 gitignore 友好。
- """
- import json
- import os
- import time
- from pathlib import Path
- from typing import Any, Dict, List, Optional
- # CWD 锚定 —— 每个调用项目有独立缓存目录,避免 /tmp 跨会话污染
- _CACHE_DIR = Path(os.getcwd()) / ".cache" / "content_search"
- _CACHE_DIR.mkdir(parents=True, exist_ok=True)
- _CACHE_TTL = 3600 # 1 小时过期
- def _cache_path(trace_id: str) -> Path:
- safe_id = trace_id.replace("/", "_").replace("..", "_")
- return _CACHE_DIR / f"{safe_id}.json"
- def _load_raw(trace_id: str) -> dict:
- p = _cache_path(trace_id)
- if not p.exists():
- return {}
- try:
- data = json.loads(p.read_text("utf-8"))
- # 检查过期
- if time.time() - data.get("_ts", 0) > _CACHE_TTL:
- p.unlink(missing_ok=True)
- return {}
- return data
- except Exception:
- return {}
- def _save_raw(trace_id: str, data: dict) -> None:
- data["_ts"] = time.time()
- try:
- _cache_path(trace_id).write_text(
- json.dumps(data, ensure_ascii=False), encoding="utf-8"
- )
- except Exception:
- pass
- def save_search_results(
- trace_id: str,
- platform: str,
- keyword: str,
- posts: List[Dict[str, Any]],
- ) -> None:
- """保存搜索结果到磁盘缓存(保留历史搜索)"""
- data = _load_raw(trace_id)
- key = f"search:{platform}"
- # 初始化或获取现有的历史记录结构
- if key not in data or not isinstance(data[key], dict) or "history" not in data[key]:
- # 兼容旧格式:如果是旧的单次搜索格式,转换为历史格式
- old_data = data.get(key)
- data[key] = {"history": [], "latest_index": -1}
- if old_data and isinstance(old_data, dict) and "keyword" in old_data:
- # 保留旧数据作为第一条历史记录
- data[key]["history"].append({
- "timestamp": time.time(),
- "keyword": old_data["keyword"],
- "posts": old_data.get("posts", []),
- })
- data[key]["latest_index"] = 0
- # 添加新的搜索记录
- data[key]["history"].append({
- "timestamp": time.time(),
- "keyword": keyword,
- "posts": posts,
- })
- data[key]["latest_index"] = len(data[key]["history"]) - 1
- _save_raw(trace_id, data)
- def get_cached_post(
- trace_id: str,
- platform: str,
- index: int,
- ) -> Optional[Dict[str, Any]]:
- """按索引从缓存取一条完整记录(1-based),默认从最新搜索结果中获取"""
- data = _load_raw(trace_id)
- entry = data.get(f"search:{platform}")
- if not entry:
- return None
- # 支持新格式(历史列表)和旧格式(单次搜索)
- if isinstance(entry, dict) and "history" in entry:
- # 新格式:从最新的搜索结果中获取
- latest_idx = entry.get("latest_index", -1)
- if latest_idx < 0 or latest_idx >= len(entry["history"]):
- return None
- posts = entry["history"][latest_idx].get("posts", [])
- else:
- # 旧格式:兼容处理
- posts = entry.get("posts", [])
- if 1 <= index <= len(posts):
- return posts[index - 1]
- return None
- def get_cached_search_info(trace_id: str, platform: str) -> Optional[Dict[str, Any]]:
- """获取缓存的搜索信息(keyword + 总条数),用于错误提示"""
- data = _load_raw(trace_id)
- entry = data.get(f"search:{platform}")
- if not entry:
- return None
- # 支持新格式(历史列表)和旧格式(单次搜索)
- if isinstance(entry, dict) and "history" in entry:
- # 新格式:返回最新搜索的信息
- latest_idx = entry.get("latest_index", -1)
- if latest_idx < 0 or latest_idx >= len(entry["history"]):
- return None
- latest = entry["history"][latest_idx]
- return {"keyword": latest.get("keyword"), "total": len(latest.get("posts", []))}
- else:
- # 旧格式:兼容处理
- return {"keyword": entry.get("keyword"), "total": len(entry.get("posts", []))}
|