Browse Source

fix: 添加对Thinking字段的处理逻辑,确保推理内容正确传递

somnifex 5 months ago
parent
commit
62549717e0
1 changed files with 21 additions and 10 deletions
  1. 21 10
      relay/channel/ollama/stream.go

+ 21 - 10
relay/channel/ollama/stream.go

@@ -26,6 +26,7 @@ type ollamaChatStreamChunk struct {
     Message *struct {
         Role      string `json:"role"`
         Content   string `json:"content"`
+        Thinking  json.RawMessage `json:"thinking"`
         ToolCalls []struct {
             Function struct {
                 Name      string      `json:"name"`
@@ -41,7 +42,6 @@ type ollamaChatStreamChunk struct {
     LoadDuration  int64  `json:"load_duration"`
     PromptEvalCount int  `json:"prompt_eval_count"`
     EvalCount       int  `json:"eval_count"`
-    // generate mode may use these
     PromptEvalDuration int64 `json:"prompt_eval_duration"`
     EvalDuration       int64 `json:"eval_duration"`
 }
@@ -95,13 +95,18 @@ func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
                 } },
             }
             if content != "" { delta.Choices[0].Delta.SetContentString(content) }
+            if chunk.Message != nil && len(chunk.Message.Thinking) > 0 {
+                raw := strings.TrimSpace(string(chunk.Message.Thinking))
+                if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) }
+            }
             // tool calls
             if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 {
                 delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls))
                 for _, tc := range chunk.Message.ToolCalls {
                     // arguments -> string
                     argBytes, _ := json.Marshal(tc.Function.Arguments)
-                    tr := dto.ToolCallResponse{ID:"", Type:nil, Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
+                    toolId := fmt.Sprintf("call_%d", toolCallIndex)
+                    tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}}
                     tr.SetIndex(toolCallIndex)
                     toolCallIndex++
                     delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr)
@@ -115,8 +120,8 @@ func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
         usage.PromptTokens = chunk.PromptEvalCount
         usage.CompletionTokens = chunk.EvalCount
         usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
-        finishReason := chunk.DoneReason
-        if finishReason == "" { finishReason = "stop" }
+    finishReason := chunk.DoneReason
+    if finishReason == "" { finishReason = "stop" }
         // emit stop delta
         if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil {
             if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) }
@@ -144,6 +149,7 @@ func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
     lines := strings.Split(raw, "\n")
     var (
         aggContent strings.Builder
+        reasoningBuilder strings.Builder
         lastChunk ollamaChatStreamChunk
         parsedAny bool
     )
@@ -157,18 +163,21 @@ func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
         }
         parsedAny = true
         lastChunk = ck
-        if !ck.Done { 
-            if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
-        } else { 
-            if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
+        if ck.Message != nil && len(ck.Message.Thinking) > 0 {
+            raw := strings.TrimSpace(string(ck.Message.Thinking))
+            if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) }
         }
+        if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) }
     }
 
     if !parsedAny {
         var single ollamaChatStreamChunk
         if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) }
         lastChunk = single
-        if single.Message != nil { aggContent.WriteString(single.Message.Content) } else { aggContent.WriteString(single.Response) }
+        if single.Message != nil {
+            if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } }
+            aggContent.WriteString(single.Message.Content)
+        } else { aggContent.WriteString(single.Response) }
     }
 
     model := lastChunk.Model
@@ -179,6 +188,8 @@ func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
     finishReason := lastChunk.DoneReason
     if finishReason == "" { finishReason = "stop" }
 
+    msg := dto.Message{Role: "assistant", Content: contentPtr(content)}
+    if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = &rc }
     full := dto.OpenAITextResponse{
         Id:      common.GetUUID(),
         Model:   model,
@@ -186,7 +197,7 @@ func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
         Created: created,
         Choices: []dto.OpenAITextResponseChoice{ {
             Index: 0,
-            Message: dto.Message{Role: "assistant", Content: contentPtr(content)},
+            Message: msg,
             FinishReason: finishReason,
         } },
         Usage: *usage,