cache_manager.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. """
  2. 缓存管理模块
  3. 提供统一的缓存读写接口,支持基于问题的分级缓存
  4. """
  5. import os
  6. import json
  7. import hashlib
  8. from typing import Any, Optional
  9. from loguru import logger
  10. class CacheManager:
  11. """缓存管理器"""
  12. def __init__(self, base_cache_dir: str = None):
  13. """
  14. 初始化缓存管理器
  15. Args:
  16. base_cache_dir: 缓存根目录,默认为当前目录下的 .cache
  17. """
  18. if base_cache_dir is None:
  19. current_dir = os.path.dirname(os.path.abspath(__file__))
  20. base_cache_dir = os.path.join(current_dir, '.cache')
  21. self.base_cache_dir = base_cache_dir
  22. os.makedirs(base_cache_dir, exist_ok=True)
  23. logger.info(f"缓存管理器初始化,缓存目录: {base_cache_dir}")
  24. def _get_question_hash(self, question: str) -> str:
  25. """
  26. 获取问题的hash值,用作文件夹名
  27. Args:
  28. question: 问题文本
  29. Returns:
  30. str: hash值(MD5的前12位)
  31. """
  32. return hashlib.md5(question.encode('utf-8')).hexdigest()[:12]
  33. def _get_cache_path(self, question: str, cache_type: str, filename: str) -> str:
  34. """
  35. 获取缓存文件的完整路径
  36. Args:
  37. question: 问题文本
  38. cache_type: 缓存类型(如 'function_knowledge', 'llm_search', 'multi_search')
  39. filename: 缓存文件名
  40. Returns:
  41. str: 缓存文件完整路径
  42. """
  43. question_hash = self._get_question_hash(question)
  44. cache_dir = os.path.join(self.base_cache_dir, question_hash, cache_type)
  45. os.makedirs(cache_dir, exist_ok=True)
  46. # 同时保存原始问题文本以便查看
  47. question_file = os.path.join(self.base_cache_dir, question_hash, 'question.txt')
  48. if not os.path.exists(question_file):
  49. with open(question_file, 'w', encoding='utf-8') as f:
  50. f.write(question)
  51. return os.path.join(cache_dir, filename)
  52. def get(self, question: str, cache_type: str, filename: str) -> Optional[Any]:
  53. """
  54. 读取缓存
  55. Args:
  56. question: 问题文本
  57. cache_type: 缓存类型
  58. filename: 缓存文件名
  59. Returns:
  60. 缓存内容,如果缓存不存在则返回 None
  61. """
  62. cache_path = self._get_cache_path(question, cache_type, filename)
  63. if not os.path.exists(cache_path):
  64. logger.debug(f"缓存未命中: {cache_type}/{filename}")
  65. return None
  66. try:
  67. with open(cache_path, 'r', encoding='utf-8') as f:
  68. content = f.read()
  69. # 尝试解析为JSON
  70. if filename.endswith('.json'):
  71. content = json.loads(content)
  72. logger.info(f"✓ 缓存命中: {cache_type}/{filename}")
  73. return content
  74. except Exception as e:
  75. logger.error(f"读取缓存失败 {cache_type}/{filename}: {e}")
  76. return None
  77. def set(self, question: str, cache_type: str, filename: str, content: Any) -> bool:
  78. """
  79. 写入缓存
  80. Args:
  81. question: 问题文本
  82. cache_type: 缓存类型
  83. filename: 缓存文件名
  84. content: 缓存内容
  85. Returns:
  86. bool: 是否写入成功
  87. """
  88. cache_path = self._get_cache_path(question, cache_type, filename)
  89. try:
  90. # 如果是字典或列表,转换为JSON
  91. if isinstance(content, (dict, list)):
  92. content = json.dumps(content, ensure_ascii=False, indent=2)
  93. with open(cache_path, 'w', encoding='utf-8') as f:
  94. f.write(str(content))
  95. logger.debug(f"缓存已保存: {cache_type}/{filename}")
  96. return True
  97. except Exception as e:
  98. logger.error(f"写入缓存失败 {cache_type}/{filename}: {e}")
  99. return False
  100. def clear(self, question: str = None):
  101. """
  102. 清除缓存
  103. Args:
  104. question: 如果指定,只清除该问题的缓存;否则清除所有缓存
  105. """
  106. if question:
  107. question_hash = self._get_question_hash(question)
  108. cache_dir = os.path.join(self.base_cache_dir, question_hash)
  109. if os.path.exists(cache_dir):
  110. import shutil
  111. shutil.rmtree(cache_dir)
  112. logger.info(f"已清除问题缓存: {question[:30]}...")
  113. else:
  114. import shutil
  115. if os.path.exists(self.base_cache_dir):
  116. shutil.rmtree(self.base_cache_dir)
  117. os.makedirs(self.base_cache_dir)
  118. logger.info("已清除所有缓存")