Sfoglia il codice sorgente

Update chat_service: add CrossAccountJWTOAuth. fuck coze

StrayWarrior 2 settimane fa
parent
commit
7ba178e911
2 ha cambiato i file con 38 aggiunte e 8 eliminazioni
  1. 3 1
      agent_service.py
  2. 35 7
      chat_service.py

+ 3 - 1
agent_service.py

@@ -55,7 +55,9 @@ class AgentService:
         self.model_name = chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
         coze_config = configs.get()['chat_api']['coze']
         coze_oauth_app = CozeChat.get_oauth_app(
-            coze_config['oauth_client_id'], coze_config['private_key_path'], str(coze_config['public_key_id']))
+            coze_config['oauth_client_id'], coze_config['private_key_path'], str(coze_config['public_key_id']),
+            account_id=coze_config.get('account_id', None)
+        )
         self.coze_client = CozeChat(
             base_url=chat_service.COZE_CN_BASE_URL,
             auth_app=coze_oauth_app

+ 35 - 7
chat_service.py

@@ -8,6 +8,7 @@ import threading
 from typing import List, Dict, Optional
 from enum import Enum, auto
 import logging
+import cozepy
 from cozepy import Coze, TokenAuth, Message, ChatStatus, MessageType, JWTOAuthApp, JWTAuth
 import time
 
@@ -22,6 +23,24 @@ class ChatServiceType(Enum):
     OPENAI_COMPATIBLE = auto
     COZE_CHAT = auto()
 
+class CrossAccountJWTOAuthApp(JWTOAuthApp):
+    def __init__(self, account_id: str, client_id: str, private_key: str, public_key_id: str, base_url):
+        self.account_id = account_id
+        super().__init__(client_id, private_key, public_key_id, base_url)
+
+    def get_access_token(
+            self, ttl: int = 900, scope: Optional[cozepy.Scope] = None, session_name: Optional[str] = None
+    ) -> cozepy.OAuthToken:
+        jwt_token = self._gen_jwt(self._public_key_id, self._private_key, 3600, session_name)
+        url = f"{self._base_url}/api/permission/oauth2/account/{self.account_id}/token"
+        headers = {"Authorization": f"Bearer {jwt_token}"}
+        body = {
+            "duration_seconds": ttl,
+            "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
+            "scope": scope.model_dump() if scope else None,
+        }
+        return self._requester.request("post", url, False, cozepy.OAuthToken, headers=headers, body=body)
+
 class CozeChat:
     def __init__(self, base_url: str, auth_token: Optional[str] = None, auth_app: Optional[JWTOAuthApp] = None):
         if not auth_token and not auth_app:
@@ -59,17 +78,26 @@ class CozeChat:
                 self.auth_app.get_access_token(ttl=12*3600)
 
     @staticmethod
-    def get_oauth_app(client_id, private_key_path, public_key_id, base_url=None) -> JWTOAuthApp:
+    def get_oauth_app(client_id, private_key_path, public_key_id, base_url=None, account_id=None) -> JWTOAuthApp:
         if not base_url:
             base_url = COZE_CN_BASE_URL
         with open(private_key_path, "r") as f:
             private_key = f.read()
-        jwt_oauth_app = JWTOAuthApp(
-            client_id=str(client_id),
-            private_key=private_key,
-            public_key_id=public_key_id,
-            base_url=base_url,
-        )
+        if not account_id:
+            jwt_oauth_app = JWTOAuthApp(
+                client_id=str(client_id),
+                private_key=private_key,
+                public_key_id=public_key_id,
+                base_url=base_url,
+            )
+        else:
+            jwt_oauth_app = CrossAccountJWTOAuthApp(
+                account_id=account_id,
+                client_id=str(client_id),
+                private_key=private_key,
+                public_key_id=public_key_id,
+                base_url=base_url,
+            )
         return jwt_oauth_app
 
 if __name__ == '__main__':