chat_service.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. #
  5. import os
  6. import threading
  7. from typing import List, Dict, Optional
  8. from enum import Enum, auto
  9. import logging
  10. from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageType, JWTOAuthApp, JWTAuth
  11. import time
  12. COZE_API_TOKEN = os.getenv("COZE_API_TOKEN")
  13. COZE_CN_BASE_URL = 'https://api.coze.cn'
  14. VOLCENGINE_API_TOKEN = '5e275c38-44fd-415f-abcf-4b59f6377f72'
  15. VOLCENGINE_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3"
  16. VOLCENGINE_MODEL_DEEPSEEK_V3 = "ep-20250213194558-rrmr2"
  17. VOLCENGINE_MODEL_DOUBAO_PRO_1_5 = 'ep-20250307150409-4blz9'
  18. class ChatServiceType(Enum):
  19. OPENAI_COMPATIBLE = auto
  20. COZE_CHAT = auto()
  21. class CozeChat:
  22. def __init__(self, base_url: str, auth_token: Optional[str] = None, auth_app: Optional[JWTOAuthApp] = None):
  23. if not auth_token and not auth_app:
  24. raise ValueError("Either auth_token or auth_app must be provided.")
  25. if auth_token:
  26. self.coze = Coze(auth=TokenAuth(auth_token), base_url=base_url)
  27. else:
  28. self.auth_app = auth_app
  29. oauth_token = auth_app.get_access_token(ttl=12*3600)
  30. self.coze = Coze(auth=JWTAuth(oauth_app=auth_app), base_url=base_url)
  31. self.setup_token_refresh()
  32. def create(self, bot_id: str, user_id: str, messages: List, custom_variables: Dict):
  33. response = self.coze.chat.create_and_poll(
  34. bot_id=bot_id, user_id=user_id, additional_messages=messages,
  35. custom_variables=custom_variables)
  36. logging.debug("Coze response size: {}".format(len(response.messages)))
  37. if response.chat.status != ChatStatus.COMPLETED:
  38. logging.error("Coze chat not completed: {}".format(response.chat.status))
  39. return None
  40. final_response = None
  41. for message in response.messages:
  42. if message.type == MessageType.ANSWER:
  43. final_response = message.content
  44. return final_response
  45. def setup_token_refresh(self):
  46. thread = threading.Thread(target=self.refresh_token_loop)
  47. thread.start()
  48. def refresh_token_loop(self):
  49. while True:
  50. time.sleep(11*3600)
  51. if self.auth_app:
  52. self.auth_app.get_access_token(ttl=12*3600)
  53. @staticmethod
  54. def get_oauth_app(client_id, private_key_path, public_key_id, base_url=None) -> JWTOAuthApp:
  55. if not base_url:
  56. base_url = COZE_CN_BASE_URL
  57. with open(private_key_path, "r") as f:
  58. private_key = f.read()
  59. jwt_oauth_app = JWTOAuthApp(
  60. client_id=str(client_id),
  61. private_key=private_key,
  62. public_key_id=public_key_id,
  63. base_url=base_url,
  64. )
  65. return jwt_oauth_app
  66. if __name__ == '__main__':
  67. # Init the Coze client through the access_token.
  68. coze = Coze(auth=TokenAuth(token=COZE_API_TOKEN), base_url=COZE_CN_BASE_URL)
  69. # Create a bot instance in Coze, copy the last number from the web link as the bot's ID.
  70. bot_id = "7491250992952999973"
  71. # The user id identifies the identity of a user. Developers can use a custom business ID
  72. # or a random string.
  73. user_id = "dev_user"
  74. chat = coze.chat.create_and_poll(
  75. bot_id=bot_id,
  76. user_id=user_id,
  77. additional_messages=[Message.build_user_question_text("钱塘江边 樱花开得不错,推荐一个视频吧")],
  78. custom_variables={
  79. 'agent_name': '芳华',
  80. 'agent_age': '25',
  81. 'agent_region': '北京',
  82. 'name': '李明',
  83. 'preferred_nickname': '李叔',
  84. 'age': '70',
  85. 'last_interaction_interval': '12',
  86. 'current_time_period': '上午',
  87. 'if_first_interaction': 'False',
  88. 'if_active_greeting': 'False'
  89. }
  90. )
  91. for message in chat.messages:
  92. print(message, flush=True)
  93. if chat.chat.status == ChatStatus.COMPLETED:
  94. print("token usage:", chat.chat.usage.token_count)