service_account.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. package vertex
  2. import (
  3. "crypto/rsa"
  4. "crypto/x509"
  5. "encoding/json"
  6. "encoding/pem"
  7. "errors"
  8. "github.com/bytedance/gopkg/cache/asynccache"
  9. "github.com/golang-jwt/jwt"
  10. "net/http"
  11. "net/url"
  12. relaycommon "one-api/relay/common"
  13. "strings"
  14. "fmt"
  15. "time"
  16. )
  17. type Credentials struct {
  18. ProjectID string `json:"project_id"`
  19. PrivateKeyID string `json:"private_key_id"`
  20. PrivateKey string `json:"private_key"`
  21. ClientEmail string `json:"client_email"`
  22. ClientID string `json:"client_id"`
  23. }
  24. var Cache = asynccache.NewAsyncCache(asynccache.Options{
  25. RefreshDuration: time.Minute * 35,
  26. EnableExpire: true,
  27. ExpireDuration: time.Minute * 30,
  28. Fetcher: func(key string) (interface{}, error) {
  29. return nil, errors.New("not found")
  30. },
  31. })
  32. func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
  33. cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
  34. val, err := Cache.Get(cacheKey)
  35. if err == nil {
  36. return val.(string), nil
  37. }
  38. signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey)
  39. if err != nil {
  40. return "", fmt.Errorf("failed to create signed JWT: %w", err)
  41. }
  42. newToken, err := exchangeJwtForAccessToken(signedJWT)
  43. if err != nil {
  44. return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
  45. }
  46. if err := Cache.SetDefault(cacheKey, newToken); err {
  47. return newToken, nil
  48. }
  49. return newToken, nil
  50. }
  51. func createSignedJWT(email, privateKeyPEM string) (string, error) {
  52. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "")
  53. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "")
  54. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "")
  55. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "")
  56. privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "")
  57. block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----"))
  58. if block == nil {
  59. return "", fmt.Errorf("failed to parse PEM block containing the private key")
  60. }
  61. privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
  62. if err != nil {
  63. return "", err
  64. }
  65. rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
  66. if !ok {
  67. return "", fmt.Errorf("not an RSA private key")
  68. }
  69. now := time.Now()
  70. claims := jwt.MapClaims{
  71. "iss": email,
  72. "scope": "https://www.googleapis.com/auth/cloud-platform",
  73. "aud": "https://www.googleapis.com/oauth2/v4/token",
  74. "exp": now.Add(time.Minute * 35).Unix(),
  75. "iat": now.Unix(),
  76. }
  77. token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
  78. signedToken, err := token.SignedString(rsaPrivateKey)
  79. if err != nil {
  80. return "", err
  81. }
  82. return signedToken, nil
  83. }
  84. func exchangeJwtForAccessToken(signedJWT string) (string, error) {
  85. authURL := "https://www.googleapis.com/oauth2/v4/token"
  86. data := url.Values{}
  87. data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
  88. data.Set("assertion", signedJWT)
  89. resp, err := http.PostForm(authURL, data)
  90. if err != nil {
  91. return "", err
  92. }
  93. defer resp.Body.Close()
  94. var result map[string]interface{}
  95. if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
  96. return "", err
  97. }
  98. if accessToken, ok := result["access_token"].(string); ok {
  99. return accessToken, nil
  100. }
  101. return "", fmt.Errorf("failed to get access token: %v", result)
  102. }