Просмотр исходного кода

feat: support cohere first response time

CalciumIon 1 год назад
Родитель
Сommit
a7e3168c17
2 измененных файлов с 11 добавлено и 4 удалено
  1. 1 1
      relay/channel/cohere/adaptor.go
  2. 10 3
      relay/channel/cohere/relay-cohere.go

+ 1 - 1
relay/channel/cohere/adaptor.go

@@ -36,7 +36,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
-		err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
+		err, usage = cohereStreamHandler(c, resp, info)
 	} else {
 		err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
 	}

+ 10 - 3
relay/channel/cohere/relay-cohere.go

@@ -9,8 +9,10 @@ import (
 	"net/http"
 	"one-api/common"
 	"one-api/dto"
+	relaycommon "one-api/relay/common"
 	"one-api/service"
 	"strings"
+	"time"
 )
 
 func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
@@ -56,7 +58,7 @@ func stopReasonCohere2OpenAI(reason string) string {
 	}
 }
 
-func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
 	createdTime := common.GetTimestamp()
 	usage := &dto.Usage{}
@@ -84,9 +86,14 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
 		stopChan <- true
 	}()
 	service.SetEventStreamHeaders(c)
+	isFirst := true
 	c.Stream(func(w io.Writer) bool {
 		select {
 		case data := <-dataChan:
+			if isFirst {
+				isFirst = false
+				info.FirstResponseTime = time.Now()
+			}
 			data = strings.TrimSuffix(data, "\r")
 			var cohereResp CohereResponse
 			err := json.Unmarshal([]byte(data), &cohereResp)
@@ -98,7 +105,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
 			openaiResp.Id = responseId
 			openaiResp.Created = createdTime
 			openaiResp.Object = "chat.completion.chunk"
-			openaiResp.Model = modelName
+			openaiResp.Model = info.UpstreamModelName
 			if cohereResp.IsFinished {
 				finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
 				openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
@@ -137,7 +144,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
 		}
 	})
 	if usage.PromptTokens == 0 {
-		usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
+		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	}
 	return nil, usage
 }