123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- import json
- from datetime import datetime
- from typing import List, Dict
- from openai import OpenAI
- from pqai_agent import logging_service, chat_service
- from pqai_agent.response_type_detector import ResponseTypeDetector
- from pqai_agent.user_profile_extractor import UserProfileExtractor
- from pqai_agent.dialogue_manager import DialogueManager
- from pqai_agent.mq_message import MessageType
- from pqai_agent.utils.prompt_utils import format_agent_profile
- logger = logging_service.logger
- def compose_openai_chat_messages_no_time(dialogue_history, multimodal=False):
- messages = []
- for entry in dialogue_history:
- role = entry["role"]
- msg_type = entry.get("type", MessageType.TEXT)
- fmt_time = DialogueManager.format_timestamp(entry["timestamp"])
- if msg_type in (MessageType.IMAGE_GW, MessageType.IMAGE_QW, MessageType.GIF):
- if multimodal:
- messages.append(
- {
- "role": role,
- "content": [
- {
- "type": "image_url",
- "image_url": {"url": entry["content"]},
- }
- ],
- }
- )
- else:
- logger.warning("Image in non-multimodal mode")
- messages.append({"role": role, "content": "[图片]"})
- else:
- messages.append({"role": role, "content": f'{entry["content"]}'})
- return messages
- def create_llm_client(model_name):
- volcengine_models = [
- chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
- chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
- chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
- chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
- ]
- deepseek_models = [
- chat_service.DEEPSEEK_CHAT_MODEL,
- ]
- volcengine_bots = [
- chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
- ]
- if model_name in volcengine_models:
- llm_client = OpenAI(
- api_key=chat_service.VOLCENGINE_API_TOKEN,
- base_url=chat_service.VOLCENGINE_BASE_URL,
- )
- elif model_name in volcengine_bots:
- llm_client = OpenAI(
- api_key=chat_service.VOLCENGINE_API_TOKEN,
- base_url=chat_service.VOLCENGINE_BOT_BASE_URL,
- )
- elif model_name in deepseek_models:
- llm_client = OpenAI(
- api_key=chat_service.DEEPSEEK_API_TOKEN,
- base_url=chat_service.DEEPSEEK_BASE_URL,
- )
- else:
- raise Exception("model not supported")
- return llm_client
- def run_openai_chat(messages, model_name, **kwargs):
- llm_client = create_llm_client(model_name)
- response = llm_client.chat.completions.create(
- messages=messages, model=model_name, **kwargs
- )
- logger.debug(response)
- return response
- def run_extractor_prompt(req_data) -> Dict[str, str]:
- prompt = req_data["prompt"]
- user_profile = req_data["user_profile"]
- dialogue_history = req_data["dialogue_history"]
- model_name = req_data["model_name"]
- llm_client = create_llm_client(model_name)
- extractor = UserProfileExtractor(model_name=model_name, llm_client=llm_client)
- profile_to_update = extractor.extract_profile_info_v2(user_profile, dialogue_history, prompt)
- logger.info(profile_to_update)
- if not profile_to_update:
- return {}
- return profile_to_update
- def run_chat_prompt(req_data):
- prompt = req_data["prompt"]
- staff_profile = req_data.get("staff_profile", {})
- user_profile = req_data.get("user_profile", {})
- dialogue_history = req_data.get("dialogue_history", [])
- model_name = req_data["model_name"]
- current_timestamp = req_data["current_timestamp"] / 1000
- prompt_context = {
- 'formatted_staff_profile': format_agent_profile(staff_profile),
- **user_profile
- }
- current_hour = datetime.fromtimestamp(current_timestamp).hour
- prompt_context["last_interaction_interval"] = 0
- prompt_context["current_time_period"] = DialogueManager.get_time_context(
- current_hour
- )
- prompt_context["current_hour"] = current_hour
- prompt_context["if_first_interaction"] = False if dialogue_history else True
- last_message = dialogue_history[-1] if dialogue_history else {"role": "assistant"}
- prompt_context["if_active_greeting"] = (
- False if last_message["role"] == "user" else True
- )
- current_time_str = datetime.fromtimestamp(current_timestamp).strftime(
- "%Y-%m-%d %H:%M:%S"
- )
- system_prompt = {"role": "system", "content": prompt.format(**prompt_context)}
- messages = [system_prompt]
- if req_data["scene"] == "custom_debugging":
- messages.extend(compose_openai_chat_messages_no_time(dialogue_history))
- if "头像" in system_prompt["content"]:
- messages.append(
- {
- "role": "user",
- "content": [
- {
- "type": "image_url",
- "image_url": {"url": user_profile["avatar"]},
- }
- ],
- }
- )
- else:
- messages.extend(
- DialogueManager.compose_chat_messages_openai_compatible(
- dialogue_history, current_time_str
- )
- )
- return run_openai_chat(
- messages, model_name, temperature=1, top_p=0.7, max_tokens=1024
- )
- def run_response_type_prompt(req_data):
- prompt = req_data["prompt"]
- dialogue_history = req_data["dialogue_history"]
- model_name = req_data["model_name"]
- composed_dialogue = ResponseTypeDetector.compose_dialogue(dialogue_history[:-1])
- next_message = DialogueManager.format_dialogue_content(dialogue_history[-1])
- prompt = prompt.format(dialogue_history=composed_dialogue, message=next_message)
- messages = [
- {"role": "system", "content": "你是一个专业的智能助手"},
- {"role": "user", "content": prompt},
- ]
- return run_openai_chat(messages, model_name, temperature=0.2, max_tokens=128)
- def format_dialogue_history(dialogue: List[Dict]) -> str:
- role_map = {'user': '用户', 'assistant': '客服'}
- messages = []
- for msg in dialogue:
- if not msg['content']:
- continue
- if msg['role'] not in role_map:
- continue
- format_dt = datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
- msg_type = MessageType(msg.get('type', MessageType.TEXT.value)).description
- messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
- return '\n'.join(messages)
|