chat_service.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. #
  5. import os
  6. import threading
  7. from typing import List, Dict, Optional
  8. from enum import Enum, auto
  9. import httpx
  10. from pqai_agent import configs
  11. from pqai_agent.logging import logger
  12. import cozepy
  13. from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageType, JWTOAuthApp, JWTAuth
  14. import time
  15. from openai import OpenAI, AsyncOpenAI, http_client
  16. COZE_API_TOKEN = os.getenv("COZE_API_TOKEN")
  17. COZE_CN_BASE_URL = 'https://api.coze.cn'
  18. VOLCENGINE_API_TOKEN = '5e275c38-44fd-415f-abcf-4b59f6377f72'
  19. VOLCENGINE_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
  20. VOLCENGINE_MODEL_DEEPSEEK_V3 = "deepseek-v3-250324"
  21. VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K = 'doubao-1-5-pro-32k-250115'
  22. VOLCENGINE_MODEL_DOUBAO_PRO_32K = 'doubao-pro-32k-241215'
  23. VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO = 'doubao-1-5-vision-pro-32k-250115'
  24. DEEPSEEK_API_TOKEN = 'sk-67daad8f424f4854bda7f1fed7ef220b'
  25. DEEPSEEK_BASE_URL = 'https://api.deepseek.com/'
  26. DEEPSEEK_CHAT_MODEL = 'deepseek-chat'
  27. VOLCENGINE_BOT_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3/bots"
  28. VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH = "bot-20250427173459-9h2xp"
  29. OPENAI_API_TOKEN = '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'
  30. OPENAI_BASE_URL = 'https://api.openai.com/v1'
  31. OPENAI_MODEL_GPT_4o = 'gpt-4o'
  32. OPENAI_MODEL_GPT_4o_mini = 'gpt-4o-mini'
  33. OPENROUTER_API_TOKEN = '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'
  34. OPENROUTER_BASE_URL = 'https://openrouter.ai/api/v1/'
  35. OPENROUTER_MODEL_CLAUDE_3_7_SONNET = 'anthropic/claude-3.7-sonnet'
  36. OPENROUTER_MODEL_GEMINI_2_5_PRO = 'google/gemini-2.5-pro'
  37. ALIYUN_API_TOKEN = 'sk-47381479425f4485af7673d3d2fd92b6'
  38. ALIYUN_BASE_URL = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
  39. class ChatServiceType(Enum):
  40. OPENAI_COMPATIBLE = auto()
  41. COZE_CHAT = auto()
  42. class ModelPrice:
  43. EXCHANGE_RATE_TO_CNY = {
  44. "USD": 7.2, # Example conversion rate, adjust as needed
  45. }
  46. def __init__(self, input_price: float, output_price: float, currency: str = 'CNY'):
  47. """
  48. :param input_price: input price for per million tokens
  49. :param output_price: output price for per million tokens
  50. """
  51. self.input_price = input_price
  52. self.output_price = output_price
  53. self.currency = currency
  54. def get_total_cost(self, input_tokens: int, output_tokens: int, convert_to_cny: bool = True) -> float:
  55. """
  56. Calculate the total cost based on input and output tokens.
  57. :param input_tokens: Number of input tokens
  58. :param output_tokens: Number of output tokens
  59. :param convert_to_cny: Whether to convert the cost to CNY (default is True)
  60. :return: Total cost in the specified currency
  61. """
  62. total_cost = (self.input_price * input_tokens / 1_000_000) + (self.output_price * output_tokens / 1_000_000)
  63. if convert_to_cny and self.currency != 'CNY':
  64. conversion_rate = self.EXCHANGE_RATE_TO_CNY.get(self.currency, 1.0)
  65. total_cost *= conversion_rate
  66. return total_cost
  67. def get_cny_brief(self) -> str:
  68. input_price = self.input_price * self.EXCHANGE_RATE_TO_CNY.get(self.currency, 1.0)
  69. output_price = self.output_price * self.EXCHANGE_RATE_TO_CNY.get(self.currency, 1.0)
  70. return f"{input_price:.0f}/{output_price:.0f}"
  71. def __repr__(self):
  72. return f"ModelPrice(input_price={self.input_price}, output_price={self.output_price}, currency={self.currency})"
  73. class OpenAICompatible:
  74. volcengine_models = [
  75. VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  76. VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
  77. VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
  78. VOLCENGINE_MODEL_DEEPSEEK_V3
  79. ]
  80. deepseek_models = [
  81. DEEPSEEK_CHAT_MODEL,
  82. ]
  83. openai_models = [
  84. OPENAI_MODEL_GPT_4o_mini,
  85. OPENAI_MODEL_GPT_4o
  86. ]
  87. openrouter_models = [
  88. OPENROUTER_MODEL_CLAUDE_3_7_SONNET,
  89. OPENROUTER_MODEL_GEMINI_2_5_PRO
  90. ]
  91. model_prices = {
  92. VOLCENGINE_MODEL_DEEPSEEK_V3: ModelPrice(input_price=2, output_price=8),
  93. VOLCENGINE_MODEL_DOUBAO_PRO_32K: ModelPrice(input_price=0.8, output_price=2),
  94. VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K: ModelPrice(input_price=0.8, output_price=2),
  95. VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO: ModelPrice(input_price=3, output_price=9),
  96. DEEPSEEK_CHAT_MODEL: ModelPrice(input_price=2, output_price=8),
  97. OPENAI_MODEL_GPT_4o: ModelPrice(input_price=2.5, output_price=10, currency='USD'),
  98. OPENAI_MODEL_GPT_4o_mini: ModelPrice(input_price=0.15, output_price=0.6, currency='USD'),
  99. OPENROUTER_MODEL_CLAUDE_3_7_SONNET: ModelPrice(input_price=3, output_price=15, currency='USD'),
  100. OPENROUTER_MODEL_GEMINI_2_5_PRO: ModelPrice(input_price=1.25, output_price=10, currency='USD'),
  101. }
  102. @staticmethod
  103. def create_client(model_name, **kwargs) -> OpenAI:
  104. if model_name in OpenAICompatible.volcengine_models:
  105. llm_client = OpenAI(api_key=VOLCENGINE_API_TOKEN, base_url=VOLCENGINE_BASE_URL, **kwargs)
  106. elif model_name in OpenAICompatible.deepseek_models:
  107. llm_client = OpenAI(api_key=DEEPSEEK_API_TOKEN, base_url=DEEPSEEK_BASE_URL, **kwargs)
  108. elif model_name in OpenAICompatible.openai_models:
  109. kwargs['http_client'] = OpenAICompatible.create_outside_proxy_http_client()
  110. llm_client = OpenAI(api_key=OPENAI_API_TOKEN, base_url=OPENAI_BASE_URL, **kwargs)
  111. elif model_name in OpenAICompatible.openrouter_models:
  112. # kwargs['http_client'] = OpenAICompatible.create_outside_proxy_http_client()
  113. llm_client = OpenAI(api_key=OPENROUTER_API_TOKEN, base_url=OPENROUTER_BASE_URL, **kwargs)
  114. else:
  115. raise Exception("Unsupported model: %s" % model_name)
  116. return llm_client
  117. @staticmethod
  118. def create_outside_proxy_http_client() -> httpx.Client:
  119. """
  120. Create an HTTP client with outside proxy settings.
  121. :return: Configured httpx.Client instance
  122. """
  123. socks_conf = configs.get().get('system', {}).get('outside_proxy', {}).get('socks5', {})
  124. if socks_conf:
  125. return httpx.Client(
  126. timeout=httpx.Timeout(600, connect=5.0),
  127. proxy=f"socks5://{socks_conf['hostname']}:{socks_conf['port']}"
  128. )
  129. # If no proxy is configured, return a standard client
  130. logger.error("Outside proxy not configured, using default httpx client.")
  131. return httpx.Client(timeout=httpx.Timeout(600, connect=5.0))
  132. @staticmethod
  133. def get_price(model_name: str) -> ModelPrice:
  134. """
  135. Get the price for a given model.
  136. :param model_name: Name of the model
  137. :return: ModelPrice object containing input and output prices
  138. """
  139. if model_name not in OpenAICompatible.model_prices:
  140. raise ValueError(f"Model {model_name} not found in price list.")
  141. return OpenAICompatible.model_prices[model_name]
  142. @staticmethod
  143. def calculate_cost(model_name: str, input_tokens: int, output_tokens: int, convert_to_cny: bool = True) -> float:
  144. """
  145. Calculate the cost for a given model based on input and output tokens.
  146. :param model_name: Name of the model
  147. :param input_tokens: Number of input tokens
  148. :param output_tokens: Number of output tokens
  149. :param convert_to_cny: Whether to convert the cost to CNY (default is True)
  150. :return: Total cost in the model's currency
  151. """
  152. if model_name not in OpenAICompatible.model_prices:
  153. raise ValueError(f"Model {model_name} not found in price list.")
  154. price = OpenAICompatible.model_prices[model_name]
  155. return price.get_total_cost(input_tokens, output_tokens, convert_to_cny)
  156. class CrossAccountJWTOAuthApp(JWTOAuthApp):
  157. def __init__(self, account_id: str, client_id: str, private_key: str, public_key_id: str, base_url):
  158. self.account_id = account_id
  159. super().__init__(client_id, private_key, public_key_id, base_url)
  160. def get_access_token(
  161. self, ttl: int = 900, scope: Optional[cozepy.Scope] = None, session_name: Optional[str] = None
  162. ) -> cozepy.OAuthToken:
  163. jwt_token = self._gen_jwt(self._public_key_id, self._private_key, 3600, session_name)
  164. url = f"{self._base_url}/api/permission/oauth2/account/{self.account_id}/token"
  165. headers = {"Authorization": f"Bearer {jwt_token}"}
  166. body = {
  167. "duration_seconds": ttl,
  168. "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
  169. "scope": scope.model_dump() if scope else None,
  170. }
  171. return self._requester.request("post", url, False, cozepy.OAuthToken, headers=headers, body=body)
  172. class CozeChat:
  173. def __init__(self, base_url: str, auth_token: Optional[str] = None, auth_app: Optional[JWTOAuthApp] = None):
  174. if not auth_token and not auth_app:
  175. raise ValueError("Either auth_token or auth_app must be provided.")
  176. self.thread = None
  177. self.thread_running = False
  178. self.last_token_fresh = 0
  179. if auth_token:
  180. self.coze = Coze(auth=TokenAuth(auth_token), base_url=base_url)
  181. else:
  182. self.auth_app = auth_app
  183. oauth_token = auth_app.get_access_token(ttl=12*3600)
  184. self.last_token_fresh = time.time()
  185. self.coze = Coze(auth=JWTAuth(oauth_app=auth_app), base_url=base_url)
  186. self.setup_token_refresh()
  187. def create(self, bot_id: str, user_id: str, messages: List, custom_variables: Dict):
  188. response = self.coze.chat.create_and_poll(
  189. bot_id=bot_id, user_id=user_id, additional_messages=messages,
  190. custom_variables=custom_variables)
  191. logger.debug("Coze response size: {}".format(len(response.messages)))
  192. if response.chat.status != ChatStatus.COMPLETED:
  193. logger.error("Coze chat not completed: {}".format(response.chat.status))
  194. return None
  195. final_response = None
  196. for message in response.messages:
  197. if message.type == MessageType.ANSWER:
  198. final_response = message.content
  199. return final_response
  200. def setup_token_refresh(self):
  201. self.thread = threading.Thread(target=self.refresh_token_loop)
  202. self.thread.start()
  203. self.thread_running = True
  204. def refresh_token_loop(self):
  205. while self.thread_running:
  206. if time.time() - self.last_token_fresh < 11*3600:
  207. time.sleep(1)
  208. continue
  209. if self.auth_app:
  210. self.auth_app.get_access_token(ttl=12*3600)
  211. self.last_token_fresh = time.time()
  212. def __del__(self):
  213. self.thread_running = False
  214. @staticmethod
  215. def get_oauth_app(client_id, private_key_path, public_key_id, base_url=None, account_id=None) -> JWTOAuthApp:
  216. if not base_url:
  217. base_url = COZE_CN_BASE_URL
  218. with open(private_key_path, "r") as f:
  219. private_key = f.read()
  220. if not account_id:
  221. jwt_oauth_app = JWTOAuthApp(
  222. client_id=str(client_id),
  223. private_key=private_key,
  224. public_key_id=public_key_id,
  225. base_url=base_url,
  226. )
  227. else:
  228. jwt_oauth_app = CrossAccountJWTOAuthApp(
  229. account_id=account_id,
  230. client_id=str(client_id),
  231. private_key=private_key,
  232. public_key_id=public_key_id,
  233. base_url=base_url,
  234. )
  235. return jwt_oauth_app
  236. if __name__ == '__main__':
  237. # Init the Coze client through the access_token.
  238. coze = Coze(auth=TokenAuth(token=COZE_API_TOKEN), base_url=COZE_CN_BASE_URL)
  239. # Create a bot instance in Coze, copy the last number from the web link as the bot's ID.
  240. bot_id = "7491250992952999973"
  241. # The user id identifies the identity of a user. Developers can use a custom business ID
  242. # or a random string.
  243. user_id = "dev_user"
  244. chat = coze.chat.create_and_poll(
  245. bot_id=bot_id,
  246. user_id=user_id,
  247. additional_messages=[Message.build_user_question_text("钱塘江边 樱花开得不错,推荐一个视频吧")],
  248. custom_variables={
  249. 'agent_name': '芳华',
  250. 'agent_age': '25',
  251. 'agent_region': '北京',
  252. 'name': '李明',
  253. 'preferred_nickname': '李叔',
  254. 'age': '70',
  255. 'last_interaction_interval': '12',
  256. 'current_time_period': '上午',
  257. 'if_first_interaction': 'False',
  258. 'if_active_greeting': 'False'
  259. }
  260. )
  261. for message in chat.messages:
  262. print(message, flush=True)
  263. if chat.chat.status == ChatStatus.COMPLETED:
  264. print("token usage:", chat.chat.usage.token_count)