|
@@ -5,10 +5,7 @@
|
|
|
import json
|
|
|
from typing import Dict, Optional, List
|
|
|
|
|
|
-from sqlalchemy.testing.plugin.plugin_base import logging
|
|
|
-
|
|
|
from pqai_agent import chat_service
|
|
|
-from pqai_agent import configs
|
|
|
from pqai_agent.prompt_templates import USER_PROFILE_EXTRACT_PROMPT, USER_PROFILE_EXTRACT_PROMPT_V2
|
|
|
from openai import OpenAI
|
|
|
from pqai_agent.logging_service import logger
|
|
@@ -27,12 +24,17 @@ class UserProfileExtractor:
|
|
|
"interaction_frequency",
|
|
|
"flexible_params"
|
|
|
]
|
|
|
- def __init__(self):
|
|
|
- self.llm_client = OpenAI(
|
|
|
- api_key=chat_service.VOLCENGINE_API_TOKEN,
|
|
|
- base_url=chat_service.VOLCENGINE_BASE_URL
|
|
|
- )
|
|
|
- self.model_name = chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
|
|
|
+ def __init__(self, model_name=None, llm_client=None):
|
|
|
+ if not llm_client:
|
|
|
+ self.llm_client = OpenAI(
|
|
|
+ api_key=chat_service.VOLCENGINE_API_TOKEN,
|
|
|
+ base_url=chat_service.VOLCENGINE_BASE_URL
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.llm_client = llm_client
|
|
|
+ if not model_name:
|
|
|
+ model_name = chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
|
|
|
+ self.model_name = model_name
|
|
|
|
|
|
@staticmethod
|
|
|
def get_extraction_function() -> Dict:
|