|
|
@@ -9,8 +9,10 @@ import (
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
"one-api/dto"
|
|
|
+ relaycommon "one-api/relay/common"
|
|
|
"one-api/service"
|
|
|
"strings"
|
|
|
+ "time"
|
|
|
)
|
|
|
|
|
|
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
|
|
@@ -56,7 +58,7 @@ func stopReasonCohere2OpenAI(reason string) string {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
+func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
|
|
createdTime := common.GetTimestamp()
|
|
|
usage := &dto.Usage{}
|
|
|
@@ -84,9 +86,14 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
|
|
|
stopChan <- true
|
|
|
}()
|
|
|
service.SetEventStreamHeaders(c)
|
|
|
+ isFirst := true
|
|
|
c.Stream(func(w io.Writer) bool {
|
|
|
select {
|
|
|
case data := <-dataChan:
|
|
|
+ if isFirst {
|
|
|
+ isFirst = false
|
|
|
+ info.FirstResponseTime = time.Now()
|
|
|
+ }
|
|
|
data = strings.TrimSuffix(data, "\r")
|
|
|
var cohereResp CohereResponse
|
|
|
err := json.Unmarshal([]byte(data), &cohereResp)
|
|
|
@@ -98,7 +105,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
|
|
|
openaiResp.Id = responseId
|
|
|
openaiResp.Created = createdTime
|
|
|
openaiResp.Object = "chat.completion.chunk"
|
|
|
- openaiResp.Model = modelName
|
|
|
+ openaiResp.Model = info.UpstreamModelName
|
|
|
if cohereResp.IsFinished {
|
|
|
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
|
|
|
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
|
|
@@ -137,7 +144,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
|
|
|
}
|
|
|
})
|
|
|
if usage.PromptTokens == 0 {
|
|
|
- usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
|
|
|
+ usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
|
}
|
|
|
return nil, usage
|
|
|
}
|