瀏覽代碼

Update image_describer: add cache

StrayWarrior 4 周之前
父節點
當前提交
3ce2ae69ab
共有 1 個文件被更改,包括 11 次插入1 次删除
  1. 11 1
      pqai_agent/toolkit/image_describer.py

+ 11 - 1
pqai_agent/toolkit/image_describer.py

@@ -1,3 +1,5 @@
+import diskcache
+
 from pqai_agent import chat_service
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO
 from pqai_agent.logging_service import logger
@@ -6,9 +8,12 @@ from pqai_agent.toolkit.function_tool import FunctionTool
 
 
 class ImageDescriber(BaseToolkit):
-    def __init__(self):
+    def __init__(self, cache_dir: str = None):
         self.model = VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO
         self.llm_client = chat_service.OpenAICompatible.create_client(self.model)
+        if not cache_dir:
+            cache_dir = 'image_descriptions_cache'
+        self.cache = diskcache.Cache(cache_dir, size_limit=100*1024*1024)
         super().__init__()
 
     def analyse_image(self, image_url: str):
@@ -19,6 +24,10 @@ class ImageDescriber(BaseToolkit):
         Returns:
             str: A detailed description of the image.
         """
+        if image_url in self.cache:
+            logger.debug(f"Cache hit for image URL: {image_url}")
+            return self.cache[image_url]
+
         system_prompt = "你是一位图像分析专家。请提供输入图像的详细描述,包括图像中的文本内容(如果存在)"
 
         messages = [
@@ -33,6 +42,7 @@ class ImageDescriber(BaseToolkit):
         response = self.llm_client.chat.completions.create(messages=messages, model=self.model)
         response_content = response.choices[0].message.content
         logger.debug(f"ImageDescriber response: {response_content}")
+        self.cache[image_url] = response_content
         return response_content
 
     def get_tools(self):