service_account.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. package vertex
  2. import (
  3. "crypto/rsa"
  4. "crypto/x509"
  5. "encoding/json"
  6. "encoding/pem"
  7. "errors"
  8. "github.com/bytedance/gopkg/cache/asynccache"
  9. "github.com/golang-jwt/jwt"
  10. "net/http"
  11. "net/url"
  12. relaycommon "one-api/relay/common"
  13. "one-api/service"
  14. "strings"
  15. "fmt"
  16. "time"
  17. )
  18. type Credentials struct {
  19. ProjectID string `json:"project_id"`
  20. PrivateKeyID string `json:"private_key_id"`
  21. PrivateKey string `json:"private_key"`
  22. ClientEmail string `json:"client_email"`
  23. ClientID string `json:"client_id"`
  24. }
  25. var Cache = asynccache.NewAsyncCache(asynccache.Options{
  26. RefreshDuration: time.Minute * 35,
  27. EnableExpire: true,
  28. ExpireDuration: time.Minute * 30,
  29. Fetcher: func(key string) (interface{}, error) {
  30. return nil, errors.New("not found")
  31. },
  32. })
  33. func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
  34. cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
  35. val, err := Cache.Get(cacheKey)
  36. if err == nil {
  37. return val.(string), nil
  38. }
  39. signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey)
  40. if err != nil {
  41. return "", fmt.Errorf("failed to create signed JWT: %w", err)
  42. }
  43. newToken, err := exchangeJwtForAccessToken(signedJWT, info)
  44. if err != nil {
  45. return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
  46. }
  47. if err := Cache.SetDefault(cacheKey, newToken); err {
  48. return newToken, nil
  49. }
  50. return newToken, nil
  51. }
  52. func createSignedJWT(email, privateKeyPEM string) (string, error) {
  53. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "")
  54. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "")
  55. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "")
  56. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "")
  57. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "")
  58. block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----"))
  59. if block == nil {
  60. return "", fmt.Errorf("failed to parse PEM block containing the private key")
  61. }
  62. privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
  63. if err != nil {
  64. return "", err
  65. }
  66. rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
  67. if !ok {
  68. return "", fmt.Errorf("not an RSA private key")
  69. }
  70. now := time.Now()
  71. claims := jwt.MapClaims{
  72. "iss": email,
  73. "scope": "https://www.googleapis.com/auth/cloud-platform",
  74. "aud": "https://www.googleapis.com/oauth2/v4/token",
  75. "exp": now.Add(time.Minute * 35).Unix(),
  76. "iat": now.Unix(),
  77. }
  78. token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
  79. signedToken, err := token.SignedString(rsaPrivateKey)
  80. if err != nil {
  81. return "", err
  82. }
  83. return signedToken, nil
  84. }
  85. func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (string, error) {
  86. authURL := "https://www.googleapis.com/oauth2/v4/token"
  87. data := url.Values{}
  88. data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
  89. data.Set("assertion", signedJWT)
  90. var client *http.Client
  91. var err error
  92. if info.ChannelSetting.Proxy != "" {
  93. client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
  94. if err != nil {
  95. return "", fmt.Errorf("new proxy http client failed: %w", err)
  96. }
  97. } else {
  98. client = service.GetHttpClient()
  99. }
  100. resp, err := client.PostForm(authURL, data)
  101. if err != nil {
  102. return "", err
  103. }
  104. defer resp.Body.Close()
  105. var result map[string]interface{}
  106. if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
  107. return "", err
  108. }
  109. if accessToken, ok := result["access_token"].(string); ok {
  110. return accessToken, nil
  111. }
  112. return "", fmt.Errorf("failed to get access token: %v", result)
  113. }