فهرست منبع

feat: codex oauth proxy

Seefs 2 هفته پیش
والد
کامیت
20c9002fde
4فایلهای تغییر یافته به همراه39 افزوده شده و 9 حذف شده
  1. 3 1
      controller/codex_oauth.go
  2. 2 3
      controller/codex_usage.go
  3. 1 1
      service/codex_credential_refresh.go
  4. 33 4
      service/codex_oauth.go

+ 3 - 1
controller/codex_oauth.go

@@ -145,6 +145,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
 		return
 	}
 
+	channelProxy := ""
 	if channelID > 0 {
 		ch, err := model.GetChannelById(channelID, false)
 		if err != nil {
@@ -159,6 +160,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
 			c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"})
 			return
 		}
+		channelProxy = ch.GetSetting().Proxy
 	}
 
 	session := sessions.Default(c)
@@ -176,7 +178,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
 	ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second)
 	defer cancel()
 
-	tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier)
+	tokenRes, err := service.ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, channelProxy)
 	if err != nil {
 		common.SysError("failed to exchange codex authorization code: " + err.Error())
 		c.JSON(http.StatusOK, gin.H{"success": false, "message": "授权码交换失败,请重试"})

+ 2 - 3
controller/codex_usage.go

@@ -2,7 +2,6 @@ package controller
 
 import (
 	"context"
-	"encoding/json"
 	"fmt"
 	"net/http"
 	"strconv"
@@ -80,7 +79,7 @@ func GetCodexChannelUsage(c *gin.Context) {
 		refreshCtx, refreshCancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
 		defer refreshCancel()
 
-		res, refreshErr := service.RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
+		res, refreshErr := service.RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy)
 		if refreshErr == nil {
 			oauthKey.AccessToken = res.AccessToken
 			oauthKey.RefreshToken = res.RefreshToken
@@ -109,7 +108,7 @@ func GetCodexChannelUsage(c *gin.Context) {
 	}
 
 	var payload any
-	if json.Unmarshal(body, &payload) != nil {
+	if common.Unmarshal(body, &payload) != nil {
 		payload = string(body)
 	}
 

+ 1 - 1
service/codex_credential_refresh.go

@@ -62,7 +62,7 @@ func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts Code
 	refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
 	defer cancel()
 
-	res, err := RefreshCodexOAuthToken(refreshCtx, oauthKey.RefreshToken)
+	res, err := RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy)
 	if err != nil {
 		return nil, nil, err
 	}

+ 33 - 4
service/codex_oauth.go

@@ -12,6 +12,8 @@ import (
 	"net/url"
 	"strings"
 	"time"
+
+	"github.com/QuantumNous/new-api/common"
 )
 
 const (
@@ -38,12 +40,26 @@ type CodexOAuthAuthorizationFlow struct {
 }
 
 func RefreshCodexOAuthToken(ctx context.Context, refreshToken string) (*CodexOAuthTokenResult, error) {
-	client := &http.Client{Timeout: defaultHTTPTimeout}
+	return RefreshCodexOAuthTokenWithProxy(ctx, refreshToken, "")
+}
+
+func RefreshCodexOAuthTokenWithProxy(ctx context.Context, refreshToken string, proxyURL string) (*CodexOAuthTokenResult, error) {
+	client, err := getCodexOAuthHTTPClient(proxyURL)
+	if err != nil {
+		return nil, err
+	}
 	return refreshCodexOAuthToken(ctx, client, codexOAuthTokenURL, codexOAuthClientID, refreshToken)
 }
 
 func ExchangeCodexAuthorizationCode(ctx context.Context, code string, verifier string) (*CodexOAuthTokenResult, error) {
-	client := &http.Client{Timeout: defaultHTTPTimeout}
+	return ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, "")
+}
+
+func ExchangeCodexAuthorizationCodeWithProxy(ctx context.Context, code string, verifier string, proxyURL string) (*CodexOAuthTokenResult, error) {
+	client, err := getCodexOAuthHTTPClient(proxyURL)
+	if err != nil {
+		return nil, err
+	}
 	return exchangeCodexAuthorizationCode(ctx, client, codexOAuthTokenURL, codexOAuthClientID, code, verifier, codexOAuthRedirectURI)
 }
 
@@ -104,7 +120,7 @@ func refreshCodexOAuthToken(
 		ExpiresIn    int    `json:"expires_in"`
 	}
 
-	if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
+	if err := common.DecodeJson(resp.Body, &payload); err != nil {
 		return nil, err
 	}
 	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
@@ -165,7 +181,7 @@ func exchangeCodexAuthorizationCode(
 		RefreshToken string `json:"refresh_token"`
 		ExpiresIn    int    `json:"expires_in"`
 	}
-	if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
+	if err := common.DecodeJson(resp.Body, &payload); err != nil {
 		return nil, err
 	}
 	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
@@ -181,6 +197,19 @@ func exchangeCodexAuthorizationCode(
 	}, nil
 }
 
+func getCodexOAuthHTTPClient(proxyURL string) (*http.Client, error) {
+	baseClient, err := GetHttpClientWithProxy(strings.TrimSpace(proxyURL))
+	if err != nil {
+		return nil, err
+	}
+	if baseClient == nil {
+		return &http.Client{Timeout: defaultHTTPTimeout}, nil
+	}
+	clientCopy := *baseClient
+	clientCopy.Timeout = defaultHTTPTimeout
+	return &clientCopy, nil
+}
+
 func buildCodexAuthorizeURL(state string, challenge string) (string, error) {
 	u, err := url.Parse(codexOAuthAuthorizeURL)
 	if err != nil {