Sfoglia il codice sorgente

feat: implement OpenAI responses handling and streaming support with built-in tool tracking

CaIon 10 mesi fa
parent
commit
18b3300ff1

+ 8 - 3
dto/openai_response.go

@@ -238,8 +238,12 @@ type ResponsesOutputContent struct {
 }
 
 const (
-	BuildInTools_WebSearch  = "web_search_preview"
-	BuildInTools_FileSearch = "file_search"
+	BuildInToolWebSearchPreview = "web_search_preview"
+	BuildInToolFileSearch       = "file_search"
+)
+
+const (
+	BuildInCallWebSearchCall = "web_search_call"
 )
 
 const (
@@ -250,6 +254,7 @@ const (
 // ResponsesStreamResponse 用于处理 /v1/responses 流式响应
 type ResponsesStreamResponse struct {
 	Type     string                   `json:"type"`
-	Response *OpenAIResponsesResponse `json:"response"`
+	Response *OpenAIResponsesResponse `json:"response,omitempty"`
 	Delta    string                   `json:"delta,omitempty"`
+	Item     *ResponsesOutput         `json:"item,omitempty"`
 }

+ 1 - 1
relay/channel/openai/adaptor.go

@@ -429,7 +429,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		if info.IsStream {
 			err, usage = OaiResponsesStreamHandler(c, resp, info)
 		} else {
-			err, usage = OpenaiResponsesHandler(c, resp, info)
+			err, usage = OaiResponsesHandler(c, resp, info)
 		}
 	default:
 		if info.IsStream {

+ 7 - 0
relay/channel/openai/helper.go

@@ -187,3 +187,10 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
 		}
 	}
 }
+
+func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
+	if data == "" {
+		return
+	}
+	helper.ResponseChunkData(c, streamResponse, data)
+}

+ 0 - 99
relay/channel/openai/relay-openai.go

@@ -644,102 +644,3 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
 	}
 	return nil, &usageResp.Usage
 }
-
-func OpenaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	// read response body
-	var responsesResponse dto.OpenAIResponsesResponse
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = common.DecodeJson(responseBody, &responsesResponse)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
-	}
-	if responsesResponse.Error != nil {
-		return &dto.OpenAIErrorWithStatusCode{
-			Error: dto.OpenAIError{
-				Message: responsesResponse.Error.Message,
-				Type:    "openai_error",
-				Code:    responsesResponse.Error.Code,
-			},
-			StatusCode: resp.StatusCode,
-		}, nil
-	}
-
-	// reset response body
-	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
-	// We shouldn't set the header before we parse the response body, because the parse part may fail.
-	// And then we will have to send an error response, but in this case, the header has already been set.
-	// So the httpClient will be confused by the response.
-	// For example, Postman will report error, and we cannot check the response at all.
-	for k, v := range resp.Header {
-		c.Writer.Header().Set(k, v[0])
-	}
-	c.Writer.WriteHeader(resp.StatusCode)
-	// copy response body
-	_, err = io.Copy(c.Writer, resp.Body)
-	if err != nil {
-		common.SysError("error copying response body: " + err.Error())
-	}
-	resp.Body.Close()
-	// compute usage
-	usage := dto.Usage{}
-	usage.PromptTokens = responsesResponse.Usage.InputTokens
-	usage.CompletionTokens = responsesResponse.Usage.OutputTokens
-	usage.TotalTokens = responsesResponse.Usage.TotalTokens
-	return nil, &usage
-}
-
-func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	if resp == nil || resp.Body == nil {
-		common.LogError(c, "invalid response or response body")
-		return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
-	}
-
-	var usage = &dto.Usage{}
-	var responseTextBuilder strings.Builder
-
-	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
-
-		// 检查当前数据是否包含 completed 状态和 usage 信息
-		var streamResponse dto.ResponsesStreamResponse
-		if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
-			sendResponsesStreamData(c, streamResponse, data)
-			switch streamResponse.Type {
-			case "response.completed":
-				usage.PromptTokens = streamResponse.Response.Usage.InputTokens
-				usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
-				usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
-			case "response.output_text.delta":
-				// 处理输出文本
-				responseTextBuilder.WriteString(streamResponse.Delta)
-
-			}
-		}
-		return true
-	})
-
-	if usage.CompletionTokens == 0 {
-		// 计算输出文本的 token 数量
-		tempStr := responseTextBuilder.String()
-		if len(tempStr) > 0 {
-			// 非正常结束,使用输出文本的 token 数量
-			completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
-			usage.CompletionTokens = completionTokens
-		}
-	}
-
-	return nil, usage
-}
-
-func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
-	if data == "" {
-		return
-	}
-	helper.ResponseChunkData(c, streamResponse, data)
-}

+ 114 - 0
relay/channel/openai/relay_responses.go

@@ -0,0 +1,114 @@
+package openai
+
+import (
+	"bytes"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
+	"one-api/service"
+	"strings"
+)
+
+func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	// read response body
+	var responsesResponse dto.OpenAIResponsesResponse
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = common.DecodeJson(responseBody, &responsesResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	if responsesResponse.Error != nil {
+		return &dto.OpenAIErrorWithStatusCode{
+			Error: dto.OpenAIError{
+				Message: responsesResponse.Error.Message,
+				Type:    "openai_error",
+				Code:    responsesResponse.Error.Code,
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+
+	// reset response body
+	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+	// We shouldn't set the header before we parse the response body, because the parse part may fail.
+	// And then we will have to send an error response, but in this case, the header has already been set.
+	// So the httpClient will be confused by the response.
+	// For example, Postman will report error, and we cannot check the response at all.
+	for k, v := range resp.Header {
+		c.Writer.Header().Set(k, v[0])
+	}
+	c.Writer.WriteHeader(resp.StatusCode)
+	// copy response body
+	_, err = io.Copy(c.Writer, resp.Body)
+	if err != nil {
+		common.SysError("error copying response body: " + err.Error())
+	}
+	resp.Body.Close()
+	// compute usage
+	usage := dto.Usage{}
+	usage.PromptTokens = responsesResponse.Usage.InputTokens
+	usage.CompletionTokens = responsesResponse.Usage.OutputTokens
+	usage.TotalTokens = responsesResponse.Usage.TotalTokens
+	return nil, &usage
+}
+
+func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	if resp == nil || resp.Body == nil {
+		common.LogError(c, "invalid response or response body")
+		return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
+	}
+
+	var usage = &dto.Usage{}
+	var responseTextBuilder strings.Builder
+
+	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+
+		// 检查当前数据是否包含 completed 状态和 usage 信息
+		var streamResponse dto.ResponsesStreamResponse
+		if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
+			sendResponsesStreamData(c, streamResponse, data)
+			switch streamResponse.Type {
+			case "response.completed":
+				usage.PromptTokens = streamResponse.Response.Usage.InputTokens
+				usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
+				usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
+			case "response.output_text.delta":
+				// 处理输出文本
+				responseTextBuilder.WriteString(streamResponse.Delta)
+			case dto.ResponsesOutputTypeItemDone:
+				// 函数调用处理
+				if streamResponse.Item != nil {
+					switch streamResponse.Item.Type {
+					case dto.BuildInCallWebSearchCall:
+						info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++
+					}
+				}
+			}
+		}
+		return true
+	})
+
+	if usage.CompletionTokens == 0 {
+		// 计算输出文本的 token 数量
+		tempStr := responseTextBuilder.String()
+		if len(tempStr) > 0 {
+			// 非正常结束,使用输出文本的 token 数量
+			completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
+			usage.CompletionTokens = completionTokens
+		}
+	}
+
+	return nil, usage
+}

+ 1 - 1
relay/channel/vertex/adaptor.go

@@ -11,8 +11,8 @@ import (
 	"one-api/relay/channel/claude"
 	"one-api/relay/channel/gemini"
 	"one-api/relay/channel/openai"
-	"one-api/setting/model_setting"
 	relaycommon "one-api/relay/common"
+	"one-api/setting/model_setting"
 	"strings"
 
 	"github.com/gin-gonic/gin"

+ 37 - 0
relay/common/relay_info.go

@@ -36,6 +36,7 @@ type ClaudeConvertInfo struct {
 const (
 	RelayFormatOpenAI = "openai"
 	RelayFormatClaude = "claude"
+	RelayFormatGemini = "gemini"
 )
 
 type RerankerInfo struct {
@@ -43,6 +44,16 @@ type RerankerInfo struct {
 	ReturnDocuments bool
 }
 
+type BuildInToolInfo struct {
+	ToolName          string
+	CallCount         int
+	SearchContextSize string
+}
+
+type ResponsesUsageInfo struct {
+	BuiltInTools map[string]*BuildInToolInfo
+}
+
 type RelayInfo struct {
 	ChannelType       int
 	ChannelId         int
@@ -90,6 +101,7 @@ type RelayInfo struct {
 	ThinkingContentInfo
 	*ClaudeConvertInfo
 	*RerankerInfo
+	*ResponsesUsageInfo
 }
 
 // 定义支持流式选项的通道类型
@@ -134,6 +146,31 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
 	return info
 }
 
+func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
+	info := GenRelayInfo(c)
+	info.RelayMode = relayconstant.RelayModeResponses
+	info.ResponsesUsageInfo = &ResponsesUsageInfo{
+		BuiltInTools: make(map[string]*BuildInToolInfo),
+	}
+	if len(req.Tools) > 0 {
+		for _, tool := range req.Tools {
+			info.ResponsesUsageInfo.BuiltInTools[tool.Type] = &BuildInToolInfo{
+				ToolName:  tool.Type,
+				CallCount: 0,
+			}
+			switch tool.Type {
+			case dto.BuildInToolWebSearchPreview:
+				if tool.SearchContextSize == "" {
+					tool.SearchContextSize = "medium"
+				}
+				info.ResponsesUsageInfo.BuiltInTools[tool.Type].SearchContextSize = tool.SearchContextSize
+			}
+		}
+	}
+	info.IsStream = req.Stream
+	return info
+}
+
 func GenRelayInfo(c *gin.Context) *RelayInfo {
 	channelType := c.GetInt("channel_type")
 	channelId := c.GetInt("channel_id")

+ 5 - 5
relay/relay-responses.go

@@ -19,7 +19,7 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-func getAndValidateResponsesRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) (*dto.OpenAIResponsesRequest, error) {
+func getAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
 	request := &dto.OpenAIResponsesRequest{}
 	err := common.UnmarshalBodyReusable(c, request)
 	if err != nil {
@@ -31,13 +31,11 @@ func getAndValidateResponsesRequest(c *gin.Context, relayInfo *relaycommon.Relay
 	if len(request.Input) == 0 {
 		return nil, errors.New("input is required")
 	}
-	relayInfo.IsStream = request.Stream
 	return request, nil
 
 }
 
 func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) ([]string, error) {
-
 	sensitiveWords, err := service.CheckSensitiveInput(textRequest.Input)
 	return sensitiveWords, err
 }
@@ -49,12 +47,14 @@ func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo
 }
 
 func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
-	relayInfo := relaycommon.GenRelayInfo(c)
-	req, err := getAndValidateResponsesRequest(c, relayInfo)
+	req, err := getAndValidateResponsesRequest(c)
 	if err != nil {
 		common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
 		return service.OpenAIErrorWrapperLocal(err, "invalid_responses_request", http.StatusBadRequest)
 	}
+
+	relayInfo := relaycommon.GenRelayInfoResponses(c, req)
+
 	if setting.ShouldCheckPromptSensitive() {
 		sensitiveWords, err := checkInputSensitive(req, relayInfo)
 		if err != nil {