Parcourir la source

feat: Add thinking-to-content conversion for stream responses

1808837298@qq.com il y a 1 an
Parent
commit
115a181db3

+ 3 - 2
constant/channel_setting.go

@@ -1,6 +1,7 @@
 package constant
 package constant
 
 
 var (
 var (
-	ForceFormat        = "force_format" // ForceFormat 强制格式化为OpenAI格式
-	ChanelSettingProxy = "proxy"        // Proxy 代理
+	ForceFormat                     = "force_format"        // ForceFormat 强制格式化为OpenAI格式
+	ChanelSettingProxy              = "proxy"               // Proxy 代理
+	ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
 )
 )

+ 5 - 0
docs/channel/other_setting.md

@@ -10,6 +10,10 @@
     - 用于配置网络代理
     - 用于配置网络代理
     - 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
     - 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
 
 
+3. thinking_to_content
+   - 用于标识是否将思考内容`reasoning_conetnt`转换为`<think>`标签拼接到内容中返回
+   - 类型为布尔值,设置为 true 时启用思考内容转换
+
 --------------------------------------------------------------
 --------------------------------------------------------------
 
 
 ## JSON 格式示例
 ## JSON 格式示例
@@ -19,6 +23,7 @@
 ```json
 ```json
 {
 {
     "force_format": true,
     "force_format": true,
+   "thinking_to_content": true,
     "proxy": "socks5://xxxxxxx"
     "proxy": "socks5://xxxxxxx"
 }
 }
 ```
 ```

+ 18 - 0
dto/openai_response.go

@@ -86,6 +86,10 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) GetReasoningContent() string
 	return *c.ReasoningContent
 	return *c.ReasoningContent
 }
 }
 
 
+func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string) {
+	c.ReasoningContent = &s
+}
+
 type ToolCall struct {
 type ToolCall struct {
 	// Index is not nil only in chat completion chunk object
 	// Index is not nil only in chat completion chunk object
 	Index    *int         `json:"index,omitempty"`
 	Index    *int         `json:"index,omitempty"`
@@ -116,6 +120,20 @@ type ChatCompletionsStreamResponse struct {
 	Usage             *Usage                                `json:"usage"`
 	Usage             *Usage                                `json:"usage"`
 }
 }
 
 
+func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
+	choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
+	copy(choices, c.Choices)
+	return &ChatCompletionsStreamResponse{
+		Id:                c.Id,
+		Object:            c.Object,
+		Created:           c.Created,
+		Model:             c.Model,
+		SystemFingerprint: c.SystemFingerprint,
+		Choices:           choices,
+		Usage:             c.Usage,
+	}
+}
+
 func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
 func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
 	if c.SystemFingerprint == nil {
 	if c.SystemFingerprint == nil {
 		return ""
 		return ""

+ 62 - 17
relay/channel/openai/relay-openai.go

@@ -5,10 +5,6 @@ import (
 	"bytes"
 	"bytes"
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
-	"github.com/bytedance/gopkg/util/gopool"
-	"github.com/gin-gonic/gin"
-	"github.com/gorilla/websocket"
-	"github.com/pkg/errors"
 	"io"
 	"io"
 	"math"
 	"math"
 	"mime/multipart"
 	"mime/multipart"
@@ -23,21 +19,66 @@ import (
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"time"
 	"time"
+
+	"github.com/bytedance/gopkg/util/gopool"
+	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
+	"github.com/pkg/errors"
 )
 )
 
 
-func sendStreamData(c *gin.Context, data string, forceFormat bool) error {
+func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
 	if data == "" {
 	if data == "" {
 		return nil
 		return nil
 	}
 	}
 
 
-	if forceFormat {
-		var lastStreamResponse dto.ChatCompletionsStreamResponse
-		if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
-			return err
+	if !forceFormat && !thinkToContent {
+		return service.StringData(c, data)
+	}
+
+	var lastStreamResponse dto.ChatCompletionsStreamResponse
+	if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
+		return err
+	}
+
+	if !thinkToContent {
+		return service.ObjectData(c, lastStreamResponse)
+	}
+
+	// Handle think to content conversion
+	if info.IsFirstResponse {
+		response := lastStreamResponse.Copy()
+		for i := range response.Choices {
+			response.Choices[i].Delta.SetContentString("<think>\n")
+			response.Choices[i].Delta.SetReasoningContent("")
 		}
 		}
+		service.ObjectData(c, response)
+	}
+
+	if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
 		return service.ObjectData(c, lastStreamResponse)
 		return service.ObjectData(c, lastStreamResponse)
 	}
 	}
-	return service.StringData(c, data)
+
+	// Process each choice
+	for i, choice := range lastStreamResponse.Choices {
+		// Handle transition from thinking to content
+		if len(choice.Delta.GetContentString()) > 0 && !info.SendLastReasoningResponse {
+			response := lastStreamResponse.Copy()
+			for j := range response.Choices {
+				response.Choices[j].Delta.SetContentString("\n</think>")
+				response.Choices[j].Delta.SetReasoningContent("")
+			}
+			info.SendLastReasoningResponse = true
+			service.ObjectData(c, response)
+		}
+
+		// Convert reasoning content to regular content
+		if len(choice.Delta.GetReasoningContent()) > 0 {
+			lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
+			lastStreamResponse.Choices[i].Delta.SetReasoningContent("")
+		}
+	}
+
+	return service.ObjectData(c, lastStreamResponse)
 }
 }
 
 
 func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@@ -56,11 +97,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	var usage = &dto.Usage{}
 	var usage = &dto.Usage{}
 	var streamItems []string // store stream items
 	var streamItems []string // store stream items
 	var forceFormat bool
 	var forceFormat bool
+	var thinkToContent bool
 
 
-	if info.ChannelType == common.ChannelTypeCustom {
-		if forceFmt, ok := info.ChannelSetting["force_format"].(bool); ok {
-			forceFormat = forceFmt
-		}
+	if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
+		forceFormat = forceFmt
+	}
+
+	if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok {
+		thinkToContent = think2Content
 	}
 	}
 
 
 	toolCount := 0
 	toolCount := 0
@@ -84,7 +128,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	)
 	)
 	gopool.Go(func() {
 	gopool.Go(func() {
 		for scanner.Scan() {
 		for scanner.Scan() {
-			info.SetFirstResponseTime()
+			//info.SetFirstResponseTime()
 			ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
 			ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
 			data := scanner.Text()
 			data := scanner.Text()
 			if common.DebugEnabled {
 			if common.DebugEnabled {
@@ -101,10 +145,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 			data = strings.TrimSpace(data)
 			data = strings.TrimSpace(data)
 			if !strings.HasPrefix(data, "[DONE]") {
 			if !strings.HasPrefix(data, "[DONE]") {
 				if lastStreamData != "" {
 				if lastStreamData != "" {
-					err := sendStreamData(c, lastStreamData, forceFormat)
+					err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
 					if err != nil {
 					if err != nil {
 						common.LogError(c, "streaming error: "+err.Error())
 						common.LogError(c, "streaming error: "+err.Error())
 					}
 					}
+					info.SetFirstResponseTime()
 				}
 				}
 				lastStreamData = data
 				lastStreamData = data
 				streamItems = append(streamItems, data)
 				streamItems = append(streamItems, data)
@@ -144,7 +189,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 		}
 		}
 	}
 	}
 	if shouldSendLastResp {
 	if shouldSendLastResp {
-		sendStreamData(c, lastStreamData, forceFormat)
+		sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
 	}
 	}
 
 
 	// 计算token
 	// 计算token

+ 21 - 19
relay/common/relay_info.go

@@ -13,23 +13,24 @@ import (
 )
 )
 
 
 type RelayInfo struct {
 type RelayInfo struct {
-	ChannelType       int
-	ChannelId         int
-	TokenId           int
-	TokenKey          string
-	UserId            int
-	Group             string
-	TokenUnlimited    bool
-	StartTime         time.Time
-	FirstResponseTime time.Time
-	setFirstResponse  bool
-	ApiType           int
-	IsStream          bool
-	IsPlayground      bool
-	UsePrice          bool
-	RelayMode         int
-	UpstreamModelName string
-	OriginModelName   string
+	ChannelType               int
+	ChannelId                 int
+	TokenId                   int
+	TokenKey                  string
+	UserId                    int
+	Group                     string
+	TokenUnlimited            bool
+	StartTime                 time.Time
+	FirstResponseTime         time.Time
+	IsFirstResponse           bool
+	SendLastReasoningResponse bool
+	ApiType                   int
+	IsStream                  bool
+	IsPlayground              bool
+	UsePrice                  bool
+	RelayMode                 int
+	UpstreamModelName         string
+	OriginModelName           string
 	//RecodeModelName      string
 	//RecodeModelName      string
 	RequestURLPath       string
 	RequestURLPath       string
 	ApiVersion           string
 	ApiVersion           string
@@ -88,6 +89,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
 	apiType, _ := relayconstant.ChannelType2APIType(channelType)
 	apiType, _ := relayconstant.ChannelType2APIType(channelType)
 
 
 	info := &RelayInfo{
 	info := &RelayInfo{
+		IsFirstResponse:   true,
 		RelayMode:         relayconstant.Path2RelayMode(c.Request.URL.Path),
 		RelayMode:         relayconstant.Path2RelayMode(c.Request.URL.Path),
 		BaseUrl:           c.GetString("base_url"),
 		BaseUrl:           c.GetString("base_url"),
 		RequestURLPath:    c.Request.URL.String(),
 		RequestURLPath:    c.Request.URL.String(),
@@ -139,9 +141,9 @@ func (info *RelayInfo) SetIsStream(isStream bool) {
 }
 }
 
 
 func (info *RelayInfo) SetFirstResponseTime() {
 func (info *RelayInfo) SetFirstResponseTime() {
-	if !info.setFirstResponse {
+	if info.IsFirstResponse {
 		info.FirstResponseTime = time.Now()
 		info.FirstResponseTime = time.Now()
-		info.setFirstResponse = true
+		info.IsFirstResponse = false
 	}
 	}
 }
 }