service_account.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. package vertex
  2. import (
  3. "crypto/rsa"
  4. "crypto/x509"
  5. "encoding/json"
  6. "encoding/pem"
  7. "errors"
  8. "net/http"
  9. "net/url"
  10. "strings"
  11. relaycommon "github.com/QuantumNous/new-api/relay/common"
  12. "github.com/QuantumNous/new-api/service"
  13. "github.com/bytedance/gopkg/cache/asynccache"
  14. "github.com/golang-jwt/jwt/v5"
  15. "fmt"
  16. "time"
  17. )
  18. type Credentials struct {
  19. ProjectID string `json:"project_id"`
  20. PrivateKeyID string `json:"private_key_id"`
  21. PrivateKey string `json:"private_key"`
  22. ClientEmail string `json:"client_email"`
  23. ClientID string `json:"client_id"`
  24. }
  25. var Cache = asynccache.NewAsyncCache(asynccache.Options{
  26. RefreshDuration: time.Minute * 35,
  27. EnableExpire: true,
  28. ExpireDuration: time.Minute * 30,
  29. Fetcher: func(key string) (interface{}, error) {
  30. return nil, errors.New("not found")
  31. },
  32. })
  33. func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
  34. var cacheKey string
  35. if info.ChannelIsMultiKey {
  36. cacheKey = fmt.Sprintf("access-token-%d-%d", info.ChannelId, info.ChannelMultiKeyIndex)
  37. } else {
  38. cacheKey = fmt.Sprintf("access-token-%d", info.ChannelId)
  39. }
  40. val, err := Cache.Get(cacheKey)
  41. if err == nil {
  42. return val.(string), nil
  43. }
  44. signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey)
  45. if err != nil {
  46. return "", fmt.Errorf("failed to create signed JWT: %w", err)
  47. }
  48. newToken, err := exchangeJwtForAccessToken(signedJWT, info)
  49. if err != nil {
  50. return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
  51. }
  52. if err := Cache.SetDefault(cacheKey, newToken); err {
  53. return newToken, nil
  54. }
  55. return newToken, nil
  56. }
  57. func createSignedJWT(email, privateKeyPEM string) (string, error) {
  58. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "")
  59. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "")
  60. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "")
  61. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "")
  62. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "")
  63. block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----"))
  64. if block == nil {
  65. return "", fmt.Errorf("failed to parse PEM block containing the private key")
  66. }
  67. privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
  68. if err != nil {
  69. return "", err
  70. }
  71. rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
  72. if !ok {
  73. return "", fmt.Errorf("not an RSA private key")
  74. }
  75. now := time.Now()
  76. claims := jwt.MapClaims{
  77. "iss": email,
  78. "scope": "https://www.googleapis.com/auth/cloud-platform",
  79. "aud": "https://www.googleapis.com/oauth2/v4/token",
  80. "exp": now.Add(time.Minute * 35).Unix(),
  81. "iat": now.Unix(),
  82. }
  83. token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
  84. signedToken, err := token.SignedString(rsaPrivateKey)
  85. if err != nil {
  86. return "", err
  87. }
  88. return signedToken, nil
  89. }
  90. func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (string, error) {
  91. authURL := "https://www.googleapis.com/oauth2/v4/token"
  92. data := url.Values{}
  93. data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
  94. data.Set("assertion", signedJWT)
  95. var client *http.Client
  96. var err error
  97. if info.ChannelSetting.Proxy != "" {
  98. client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
  99. if err != nil {
  100. return "", fmt.Errorf("new proxy http client failed: %w", err)
  101. }
  102. } else {
  103. client = service.GetHttpClient()
  104. }
  105. resp, err := client.PostForm(authURL, data)
  106. if err != nil {
  107. return "", err
  108. }
  109. defer resp.Body.Close()
  110. var result map[string]interface{}
  111. if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
  112. return "", err
  113. }
  114. if accessToken, ok := result["access_token"].(string); ok {
  115. return accessToken, nil
  116. }
  117. return "", fmt.Errorf("failed to get access token: %v", result)
  118. }
  119. func AcquireAccessToken(creds Credentials, proxy string) (string, error) {
  120. signedJWT, err := createSignedJWT(creds.ClientEmail, creds.PrivateKey)
  121. if err != nil {
  122. return "", fmt.Errorf("failed to create signed JWT: %w", err)
  123. }
  124. return exchangeJwtForAccessTokenWithProxy(signedJWT, proxy)
  125. }
  126. func exchangeJwtForAccessTokenWithProxy(signedJWT string, proxy string) (string, error) {
  127. authURL := "https://www.googleapis.com/oauth2/v4/token"
  128. data := url.Values{}
  129. data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
  130. data.Set("assertion", signedJWT)
  131. var client *http.Client
  132. var err error
  133. if proxy != "" {
  134. client, err = service.NewProxyHttpClient(proxy)
  135. if err != nil {
  136. return "", fmt.Errorf("new proxy http client failed: %w", err)
  137. }
  138. } else {
  139. client = service.GetHttpClient()
  140. }
  141. resp, err := client.PostForm(authURL, data)
  142. if err != nil {
  143. return "", err
  144. }
  145. defer resp.Body.Close()
  146. var result map[string]interface{}
  147. if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
  148. return "", err
  149. }
  150. if accessToken, ok := result["access_token"].(string); ok {
  151. return accessToken, nil
  152. }
  153. return "", fmt.Errorf("failed to get access token: %v", result)
  154. }