chat_via_responses.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. package openai
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "strings"
  7. "time"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/dto"
  10. "github.com/QuantumNous/new-api/logger"
  11. relaycommon "github.com/QuantumNous/new-api/relay/common"
  12. "github.com/QuantumNous/new-api/relay/helper"
  13. "github.com/QuantumNous/new-api/service"
  14. "github.com/QuantumNous/new-api/types"
  15. "github.com/gin-gonic/gin"
  16. )
  17. func OaiResponsesToChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  18. if resp == nil || resp.Body == nil {
  19. return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
  20. }
  21. defer service.CloseResponseBodyGracefully(resp)
  22. var responsesResp dto.OpenAIResponsesResponse
  23. body, err := io.ReadAll(resp.Body)
  24. if err != nil {
  25. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  26. }
  27. if err := common.Unmarshal(body, &responsesResp); err != nil {
  28. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  29. }
  30. if oaiError := responsesResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
  31. return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
  32. }
  33. chatId := helper.GetResponseID(c)
  34. chatResp, usage, err := service.ResponsesResponseToChatCompletionsResponse(&responsesResp, chatId)
  35. if err != nil {
  36. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  37. }
  38. if usage == nil || usage.TotalTokens == 0 {
  39. text := service.ExtractOutputTextFromResponses(&responsesResp)
  40. usage = service.ResponseText2Usage(c, text, info.UpstreamModelName, info.GetEstimatePromptTokens())
  41. chatResp.Usage = *usage
  42. }
  43. chatBody, err := common.Marshal(chatResp)
  44. if err != nil {
  45. return nil, types.NewOpenAIError(err, types.ErrorCodeJsonMarshalFailed, http.StatusInternalServerError)
  46. }
  47. service.IOCopyBytesGracefully(c, resp, chatBody)
  48. return usage, nil
  49. }
  50. func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  51. if resp == nil || resp.Body == nil {
  52. return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
  53. }
  54. defer service.CloseResponseBodyGracefully(resp)
  55. responseId := helper.GetResponseID(c)
  56. createAt := time.Now().Unix()
  57. model := info.UpstreamModelName
  58. var (
  59. usage = &dto.Usage{}
  60. outputText strings.Builder
  61. usageText strings.Builder
  62. sentStart bool
  63. sentStop bool
  64. sawToolCall bool
  65. streamErr *types.NewAPIError
  66. )
  67. toolCallIndexByID := make(map[string]int)
  68. toolCallNameByID := make(map[string]string)
  69. toolCallArgsByID := make(map[string]string)
  70. toolCallNameSent := make(map[string]bool)
  71. toolCallCanonicalIDByItemID := make(map[string]string)
  72. sendStartIfNeeded := func() bool {
  73. if sentStart {
  74. return true
  75. }
  76. if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
  77. streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
  78. return false
  79. }
  80. sentStart = true
  81. return true
  82. }
  83. sendToolCallDelta := func(callID string, name string, argsDelta string) bool {
  84. if callID == "" {
  85. return true
  86. }
  87. if outputText.Len() > 0 {
  88. // Prefer streaming assistant text over tool calls to match non-stream behavior.
  89. return true
  90. }
  91. if !sendStartIfNeeded() {
  92. return false
  93. }
  94. idx, ok := toolCallIndexByID[callID]
  95. if !ok {
  96. idx = len(toolCallIndexByID)
  97. toolCallIndexByID[callID] = idx
  98. }
  99. if name != "" {
  100. toolCallNameByID[callID] = name
  101. }
  102. if toolCallNameByID[callID] != "" {
  103. name = toolCallNameByID[callID]
  104. }
  105. tool := dto.ToolCallResponse{
  106. ID: callID,
  107. Type: "function",
  108. Function: dto.FunctionResponse{
  109. Arguments: argsDelta,
  110. },
  111. }
  112. tool.SetIndex(idx)
  113. if name != "" && !toolCallNameSent[callID] {
  114. tool.Function.Name = name
  115. toolCallNameSent[callID] = true
  116. }
  117. chunk := &dto.ChatCompletionsStreamResponse{
  118. Id: responseId,
  119. Object: "chat.completion.chunk",
  120. Created: createAt,
  121. Model: model,
  122. Choices: []dto.ChatCompletionsStreamResponseChoice{
  123. {
  124. Index: 0,
  125. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  126. ToolCalls: []dto.ToolCallResponse{tool},
  127. },
  128. },
  129. },
  130. }
  131. if err := helper.ObjectData(c, chunk); err != nil {
  132. streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
  133. return false
  134. }
  135. sawToolCall = true
  136. // Include tool call data in the local builder for fallback token estimation.
  137. if tool.Function.Name != "" {
  138. usageText.WriteString(tool.Function.Name)
  139. }
  140. if argsDelta != "" {
  141. usageText.WriteString(argsDelta)
  142. }
  143. return true
  144. }
  145. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  146. if streamErr != nil {
  147. return false
  148. }
  149. var streamResp dto.ResponsesStreamResponse
  150. if err := common.UnmarshalJsonStr(data, &streamResp); err != nil {
  151. logger.LogError(c, "failed to unmarshal responses stream event: "+err.Error())
  152. return true
  153. }
  154. switch streamResp.Type {
  155. case "response.created":
  156. if streamResp.Response != nil {
  157. if streamResp.Response.Model != "" {
  158. model = streamResp.Response.Model
  159. }
  160. if streamResp.Response.CreatedAt != 0 {
  161. createAt = int64(streamResp.Response.CreatedAt)
  162. }
  163. }
  164. case "response.output_text.delta":
  165. if !sendStartIfNeeded() {
  166. return false
  167. }
  168. if streamResp.Delta != "" {
  169. outputText.WriteString(streamResp.Delta)
  170. usageText.WriteString(streamResp.Delta)
  171. delta := streamResp.Delta
  172. chunk := &dto.ChatCompletionsStreamResponse{
  173. Id: responseId,
  174. Object: "chat.completion.chunk",
  175. Created: createAt,
  176. Model: model,
  177. Choices: []dto.ChatCompletionsStreamResponseChoice{
  178. {
  179. Index: 0,
  180. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  181. Content: &delta,
  182. },
  183. },
  184. },
  185. }
  186. if err := helper.ObjectData(c, chunk); err != nil {
  187. streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
  188. return false
  189. }
  190. }
  191. case "response.output_item.added", "response.output_item.done":
  192. if streamResp.Item == nil {
  193. break
  194. }
  195. if streamResp.Item.Type != "function_call" {
  196. break
  197. }
  198. itemID := strings.TrimSpace(streamResp.Item.ID)
  199. callID := strings.TrimSpace(streamResp.Item.CallId)
  200. if callID == "" {
  201. callID = itemID
  202. }
  203. if itemID != "" && callID != "" {
  204. toolCallCanonicalIDByItemID[itemID] = callID
  205. }
  206. name := strings.TrimSpace(streamResp.Item.Name)
  207. if name != "" {
  208. toolCallNameByID[callID] = name
  209. }
  210. newArgs := streamResp.Item.Arguments
  211. prevArgs := toolCallArgsByID[callID]
  212. argsDelta := ""
  213. if newArgs != "" {
  214. if strings.HasPrefix(newArgs, prevArgs) {
  215. argsDelta = newArgs[len(prevArgs):]
  216. } else {
  217. argsDelta = newArgs
  218. }
  219. toolCallArgsByID[callID] = newArgs
  220. }
  221. if !sendToolCallDelta(callID, name, argsDelta) {
  222. return false
  223. }
  224. case "response.function_call_arguments.delta":
  225. itemID := strings.TrimSpace(streamResp.ItemID)
  226. callID := toolCallCanonicalIDByItemID[itemID]
  227. if callID == "" {
  228. callID = itemID
  229. }
  230. if callID == "" {
  231. break
  232. }
  233. toolCallArgsByID[callID] += streamResp.Delta
  234. if !sendToolCallDelta(callID, "", streamResp.Delta) {
  235. return false
  236. }
  237. case "response.function_call_arguments.done":
  238. case "response.completed":
  239. if streamResp.Response != nil {
  240. if streamResp.Response.Model != "" {
  241. model = streamResp.Response.Model
  242. }
  243. if streamResp.Response.CreatedAt != 0 {
  244. createAt = int64(streamResp.Response.CreatedAt)
  245. }
  246. if streamResp.Response.Usage != nil {
  247. if streamResp.Response.Usage.InputTokens != 0 {
  248. usage.PromptTokens = streamResp.Response.Usage.InputTokens
  249. usage.InputTokens = streamResp.Response.Usage.InputTokens
  250. }
  251. if streamResp.Response.Usage.OutputTokens != 0 {
  252. usage.CompletionTokens = streamResp.Response.Usage.OutputTokens
  253. usage.OutputTokens = streamResp.Response.Usage.OutputTokens
  254. }
  255. if streamResp.Response.Usage.TotalTokens != 0 {
  256. usage.TotalTokens = streamResp.Response.Usage.TotalTokens
  257. } else {
  258. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  259. }
  260. if streamResp.Response.Usage.InputTokensDetails != nil {
  261. usage.PromptTokensDetails.CachedTokens = streamResp.Response.Usage.InputTokensDetails.CachedTokens
  262. usage.PromptTokensDetails.ImageTokens = streamResp.Response.Usage.InputTokensDetails.ImageTokens
  263. usage.PromptTokensDetails.AudioTokens = streamResp.Response.Usage.InputTokensDetails.AudioTokens
  264. }
  265. if streamResp.Response.Usage.CompletionTokenDetails.ReasoningTokens != 0 {
  266. usage.CompletionTokenDetails.ReasoningTokens = streamResp.Response.Usage.CompletionTokenDetails.ReasoningTokens
  267. }
  268. }
  269. }
  270. if !sendStartIfNeeded() {
  271. return false
  272. }
  273. if !sentStop {
  274. finishReason := "stop"
  275. if sawToolCall && outputText.Len() == 0 {
  276. finishReason = "tool_calls"
  277. }
  278. stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
  279. if err := helper.ObjectData(c, stop); err != nil {
  280. streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
  281. return false
  282. }
  283. sentStop = true
  284. }
  285. case "response.error", "response.failed":
  286. if streamResp.Response != nil {
  287. if oaiErr := streamResp.Response.GetOpenAIError(); oaiErr != nil && oaiErr.Type != "" {
  288. streamErr = types.WithOpenAIError(*oaiErr, http.StatusInternalServerError)
  289. return false
  290. }
  291. }
  292. streamErr = types.NewOpenAIError(fmt.Errorf("responses stream error: %s", streamResp.Type), types.ErrorCodeBadResponse, http.StatusInternalServerError)
  293. return false
  294. default:
  295. }
  296. return true
  297. })
  298. if streamErr != nil {
  299. return nil, streamErr
  300. }
  301. if usage.TotalTokens == 0 {
  302. usage = service.ResponseText2Usage(c, usageText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
  303. }
  304. if !sentStart {
  305. if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
  306. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
  307. }
  308. }
  309. if !sentStop {
  310. finishReason := "stop"
  311. if sawToolCall && outputText.Len() == 0 {
  312. finishReason = "tool_calls"
  313. }
  314. stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
  315. if err := helper.ObjectData(c, stop); err != nil {
  316. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
  317. }
  318. }
  319. if info.ShouldIncludeUsage && usage != nil {
  320. if err := helper.ObjectData(c, helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)); err != nil {
  321. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
  322. }
  323. }
  324. helper.Done(c)
  325. return usage, nil
  326. }