Sfoglia il codice sorgente

feat: first response time support gemini and claude

CalciumIon 1 anno fa
parent
commit
f2654692e8

+ 1 - 1
relay/channel/claude/adaptor.go

@@ -65,7 +65,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
-		err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
+		err, usage = claudeStreamHandler(c, resp, info, a.RequestMode)
 	} else {
 		err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName)
 	}

+ 13 - 6
relay/channel/claude/relay-claude.go

@@ -9,8 +9,10 @@ import (
 	"net/http"
 	"one-api/common"
 	"one-api/dto"
+	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"strings"
+	"time"
 )
 
 func stopReasonClaude2OpenAI(reason string) string {
@@ -246,7 +248,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
 	return &fullTextResponse
 }
 
-func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 	var usage *dto.Usage
 	usage = &dto.Usage{}
@@ -278,10 +280,15 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
 		}
 		stopChan <- true
 	}()
+	isFirst := true
 	service.SetEventStreamHeaders(c)
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:
+			if isFirst {
+				isFirst = false
+				info.FirstResponseTime = time.Now()
+			}
 			// some implementations may add \r at the end of data
 			data = strings.TrimSuffix(data, "\r")
 			var claudeResponse ClaudeResponse
@@ -302,7 +309,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
 				if claudeResponse.Type == "message_start" {
 					// message_start, 获取usage
 					responseId = claudeResponse.Message.Id
-					modelName = claudeResponse.Message.Model
+					info.UpstreamModelName = claudeResponse.Message.Model
 					usage.PromptTokens = claudeUsage.InputTokens
 				} else if claudeResponse.Type == "content_block_delta" {
 					responseText += claudeResponse.Delta.Text
@@ -316,7 +323,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
 			//response.Id = responseId
 			response.Id = responseId
 			response.Created = createdTime
-			response.Model = modelName
+			response.Model = info.UpstreamModelName
 
 			jsonStr, err := json.Marshal(response)
 			if err != nil {
@@ -335,13 +342,13 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
 		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 	}
 	if requestMode == RequestModeCompletion {
-		usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
+		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
 		if usage.PromptTokens == 0 {
-			usage.PromptTokens = promptTokens
+			usage.PromptTokens = info.PromptTokens
 		}
 		if usage.CompletionTokens == 0 {
-			usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens)
+			usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
 		}
 	}
 	return nil, usage

+ 18 - 18
relay/channel/gemini/adaptor.go

@@ -20,27 +20,27 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq
 
 // 定义一个映射,存储模型名称和对应的版本
 var modelVersionMap = map[string]string{
-    "gemini-1.5-pro-latest": "v1beta",
-    "gemini-1.5-flash-latest": "v1beta",
-    "gemini-ultra":   "v1beta",
+	"gemini-1.5-pro-latest":   "v1beta",
+	"gemini-1.5-flash-latest": "v1beta",
+	"gemini-ultra":            "v1beta",
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-    // 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
-    version, beta := modelVersionMap[info.UpstreamModelName]
-    if !beta {
-        if info.ApiVersion != "" {
-            version = info.ApiVersion
-        } else {
-            version = "v1"
-        }
-    }
+	// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
+	version, beta := modelVersionMap[info.UpstreamModelName]
+	if !beta {
+		if info.ApiVersion != "" {
+			version = info.ApiVersion
+		} else {
+			version = "v1"
+		}
+	}
 
-    action := "generateContent"
-    if info.IsStream {
-        action = "streamGenerateContent"
-    }
-    return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
+	action := "generateContent"
+	if info.IsStream {
+		action = "streamGenerateContent"
+	}
+	return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
 		var responseText string
-		err, responseText = geminiChatStreamHandler(c, resp)
+		err, responseText = geminiChatStreamHandler(c, resp, info)
 		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	} else {
 		err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)

+ 7 - 1
relay/channel/gemini/relay-gemini.go

@@ -11,6 +11,7 @@ import (
 	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"strings"
+	"time"
 
 	"github.com/gin-gonic/gin"
 )
@@ -160,7 +161,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
 	return &response
 }
 
-func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
+func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
 	responseText := ""
 	dataChan := make(chan string)
 	stopChan := make(chan bool)
@@ -190,10 +191,15 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr
 		}
 		stopChan <- true
 	}()
+	isFirst := true
 	service.SetEventStreamHeaders(c)
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:
+			if isFirst {
+				isFirst = false
+				info.FirstResponseTime = time.Now()
+			}
 			// this is used to prevent annoying \ related format bug
 			data = fmt.Sprintf("{\"content\": \"%s\"}", data)
 			type dummyStruct struct {

+ 16 - 14
relay/common/relay_info.go

@@ -38,24 +38,26 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 	group := c.GetString("group")
 	tokenUnlimited := c.GetBool("token_unlimited_quota")
 	startTime := time.Now()
+	// firstResponseTime = time.Now() - 1 second
 
 	apiType, _ := constant.ChannelType2APIType(channelType)
 
 	info := &RelayInfo{
-		RelayMode:      constant.Path2RelayMode(c.Request.URL.Path),
-		BaseUrl:        c.GetString("base_url"),
-		RequestURLPath: c.Request.URL.String(),
-		ChannelType:    channelType,
-		ChannelId:      channelId,
-		TokenId:        tokenId,
-		UserId:         userId,
-		Group:          group,
-		TokenUnlimited: tokenUnlimited,
-		StartTime:      startTime,
-		ApiType:        apiType,
-		ApiVersion:     c.GetString("api_version"),
-		ApiKey:         strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
-		Organization:   c.GetString("channel_organization"),
+		RelayMode:         constant.Path2RelayMode(c.Request.URL.Path),
+		BaseUrl:           c.GetString("base_url"),
+		RequestURLPath:    c.Request.URL.String(),
+		ChannelType:       channelType,
+		ChannelId:         channelId,
+		TokenId:           tokenId,
+		UserId:            userId,
+		Group:             group,
+		TokenUnlimited:    tokenUnlimited,
+		StartTime:         startTime,
+		FirstResponseTime: startTime.Add(-time.Second),
+		ApiType:           apiType,
+		ApiVersion:        c.GetString("api_version"),
+		ApiKey:            strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
+		Organization:      c.GetString("channel_organization"),
 	}
 	if info.BaseUrl == "" {
 		info.BaseUrl = common.ChannelBaseURLs[channelType]