convert.go 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972
  1. package service
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "strings"
  6. "github.com/QuantumNous/new-api/common"
  7. "github.com/QuantumNous/new-api/constant"
  8. "github.com/QuantumNous/new-api/dto"
  9. "github.com/QuantumNous/new-api/relay/channel/openrouter"
  10. relaycommon "github.com/QuantumNous/new-api/relay/common"
  11. "github.com/QuantumNous/new-api/relay/reasonmap"
  12. "github.com/samber/lo"
  13. )
  14. func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
  15. openAIRequest := dto.GeneralOpenAIRequest{
  16. Model: claudeRequest.Model,
  17. Temperature: claudeRequest.Temperature,
  18. }
  19. if claudeRequest.MaxTokens != nil {
  20. openAIRequest.MaxTokens = lo.ToPtr(lo.FromPtr(claudeRequest.MaxTokens))
  21. }
  22. if claudeRequest.TopP != nil {
  23. openAIRequest.TopP = lo.ToPtr(lo.FromPtr(claudeRequest.TopP))
  24. }
  25. if claudeRequest.TopK != nil {
  26. openAIRequest.TopK = lo.ToPtr(lo.FromPtr(claudeRequest.TopK))
  27. }
  28. if claudeRequest.Stream != nil {
  29. openAIRequest.Stream = lo.ToPtr(lo.FromPtr(claudeRequest.Stream))
  30. }
  31. isOpenRouter := info.ChannelType == constant.ChannelTypeOpenRouter
  32. if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" {
  33. if isOpenRouter {
  34. reasoning := openrouter.RequestReasoning{
  35. MaxTokens: claudeRequest.Thinking.GetBudgetTokens(),
  36. }
  37. reasoningJSON, err := json.Marshal(reasoning)
  38. if err != nil {
  39. return nil, fmt.Errorf("failed to marshal reasoning: %w", err)
  40. }
  41. openAIRequest.Reasoning = reasoningJSON
  42. } else {
  43. thinkingSuffix := "-thinking"
  44. if strings.HasSuffix(info.OriginModelName, thinkingSuffix) &&
  45. !strings.HasSuffix(openAIRequest.Model, thinkingSuffix) {
  46. openAIRequest.Model = openAIRequest.Model + thinkingSuffix
  47. }
  48. }
  49. }
  50. // Convert stop sequences
  51. if len(claudeRequest.StopSequences) == 1 {
  52. openAIRequest.Stop = claudeRequest.StopSequences[0]
  53. } else if len(claudeRequest.StopSequences) > 1 {
  54. openAIRequest.Stop = claudeRequest.StopSequences
  55. }
  56. // Convert tools
  57. tools, _ := common.Any2Type[[]dto.Tool](claudeRequest.Tools)
  58. openAITools := make([]dto.ToolCallRequest, 0)
  59. for _, claudeTool := range tools {
  60. openAITool := dto.ToolCallRequest{
  61. Type: "function",
  62. Function: dto.FunctionRequest{
  63. Name: claudeTool.Name,
  64. Description: claudeTool.Description,
  65. Parameters: claudeTool.InputSchema,
  66. },
  67. }
  68. openAITools = append(openAITools, openAITool)
  69. }
  70. openAIRequest.Tools = openAITools
  71. // Convert messages
  72. openAIMessages := make([]dto.Message, 0)
  73. // Add system message if present
  74. if claudeRequest.System != nil {
  75. if claudeRequest.IsStringSystem() && claudeRequest.GetStringSystem() != "" {
  76. openAIMessage := dto.Message{
  77. Role: "system",
  78. }
  79. openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
  80. openAIMessages = append(openAIMessages, openAIMessage)
  81. } else {
  82. systems := claudeRequest.ParseSystem()
  83. if len(systems) > 0 {
  84. openAIMessage := dto.Message{
  85. Role: "system",
  86. }
  87. isOpenRouterClaude := isOpenRouter && strings.HasPrefix(info.UpstreamModelName, "anthropic/claude")
  88. if isOpenRouterClaude {
  89. systemMediaMessages := make([]dto.MediaContent, 0, len(systems))
  90. for _, system := range systems {
  91. message := dto.MediaContent{
  92. Type: "text",
  93. Text: system.GetText(),
  94. CacheControl: system.CacheControl,
  95. }
  96. systemMediaMessages = append(systemMediaMessages, message)
  97. }
  98. openAIMessage.SetMediaContent(systemMediaMessages)
  99. } else {
  100. systemStr := ""
  101. for _, system := range systems {
  102. if system.Text != nil {
  103. systemStr += *system.Text
  104. }
  105. }
  106. openAIMessage.SetStringContent(systemStr)
  107. }
  108. openAIMessages = append(openAIMessages, openAIMessage)
  109. }
  110. }
  111. }
  112. for _, claudeMessage := range claudeRequest.Messages {
  113. openAIMessage := dto.Message{
  114. Role: claudeMessage.Role,
  115. }
  116. //log.Printf("claudeMessage.Content: %v", claudeMessage.Content)
  117. if claudeMessage.IsStringContent() {
  118. openAIMessage.SetStringContent(claudeMessage.GetStringContent())
  119. } else {
  120. content, err := claudeMessage.ParseContent()
  121. if err != nil {
  122. return nil, err
  123. }
  124. contents := content
  125. var toolCalls []dto.ToolCallRequest
  126. mediaMessages := make([]dto.MediaContent, 0, len(contents))
  127. for _, mediaMsg := range contents {
  128. switch mediaMsg.Type {
  129. case "text", "input_text":
  130. message := dto.MediaContent{
  131. Type: "text",
  132. Text: mediaMsg.GetText(),
  133. CacheControl: mediaMsg.CacheControl,
  134. }
  135. mediaMessages = append(mediaMessages, message)
  136. case "image":
  137. // Handle image conversion (base64 to URL or keep as is)
  138. imageData := fmt.Sprintf("data:%s;base64,%s", mediaMsg.Source.MediaType, mediaMsg.Source.Data)
  139. //textContent += fmt.Sprintf("[Image: %s]", imageData)
  140. mediaMessage := dto.MediaContent{
  141. Type: "image_url",
  142. ImageUrl: &dto.MessageImageUrl{Url: imageData},
  143. }
  144. mediaMessages = append(mediaMessages, mediaMessage)
  145. case "tool_use":
  146. toolCall := dto.ToolCallRequest{
  147. ID: mediaMsg.Id,
  148. Type: "function",
  149. Function: dto.FunctionRequest{
  150. Name: mediaMsg.Name,
  151. Arguments: toJSONString(mediaMsg.Input),
  152. },
  153. }
  154. toolCalls = append(toolCalls, toolCall)
  155. case "tool_result":
  156. // Add tool result as a separate message
  157. toolName := mediaMsg.Name
  158. if toolName == "" {
  159. toolName = claudeRequest.SearchToolNameByToolCallId(mediaMsg.ToolUseId)
  160. }
  161. oaiToolMessage := dto.Message{
  162. Role: "tool",
  163. Name: &toolName,
  164. ToolCallId: mediaMsg.ToolUseId,
  165. }
  166. //oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text)
  167. if mediaMsg.IsStringContent() {
  168. oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
  169. } else {
  170. mediaContents := mediaMsg.ParseMediaContent()
  171. encodeJson, _ := common.Marshal(mediaContents)
  172. oaiToolMessage.SetStringContent(string(encodeJson))
  173. }
  174. openAIMessages = append(openAIMessages, oaiToolMessage)
  175. }
  176. }
  177. if len(toolCalls) > 0 {
  178. openAIMessage.SetToolCalls(toolCalls)
  179. }
  180. if len(mediaMessages) > 0 && len(toolCalls) == 0 {
  181. openAIMessage.SetMediaContent(mediaMessages)
  182. }
  183. }
  184. if len(openAIMessage.ParseContent()) > 0 || len(openAIMessage.ToolCalls) > 0 {
  185. openAIMessages = append(openAIMessages, openAIMessage)
  186. }
  187. }
  188. openAIRequest.Messages = openAIMessages
  189. return &openAIRequest, nil
  190. }
  191. func generateStopBlock(index int) *dto.ClaudeResponse {
  192. return &dto.ClaudeResponse{
  193. Type: "content_block_stop",
  194. Index: common.GetPointer[int](index),
  195. }
  196. }
  197. func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
  198. if info.ClaudeConvertInfo.Done {
  199. return nil
  200. }
  201. var claudeResponses []*dto.ClaudeResponse
  202. // stopOpenBlocks emits the required content_block_stop event(s) for the currently open block(s)
  203. // according to Anthropic's SSE streaming state machine:
  204. // content_block_start -> content_block_delta* -> content_block_stop (per index).
  205. //
  206. // For text/thinking, there is at most one open block at info.ClaudeConvertInfo.Index.
  207. // For tools, OpenAI tool_calls can stream multiple parallel tool_use blocks (indexed from 0),
  208. // so we may have multiple open blocks and must stop each one explicitly.
  209. stopOpenBlocks := func() {
  210. switch info.ClaudeConvertInfo.LastMessagesType {
  211. case relaycommon.LastMessageTypeText, relaycommon.LastMessageTypeThinking:
  212. claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
  213. case relaycommon.LastMessageTypeTools:
  214. base := info.ClaudeConvertInfo.ToolCallBaseIndex
  215. for offset := 0; offset <= info.ClaudeConvertInfo.ToolCallMaxIndexOffset; offset++ {
  216. claudeResponses = append(claudeResponses, generateStopBlock(base+offset))
  217. }
  218. }
  219. }
  220. // stopOpenBlocksAndAdvance closes the currently open block(s) and advances the content block index
  221. // to the next available slot for subsequent content_block_start events.
  222. //
  223. // This prevents invalid streams where a content_block_delta (e.g. thinking_delta) is emitted for an
  224. // index whose active content_block type is different (the typical cause of "Mismatched content block type").
  225. stopOpenBlocksAndAdvance := func() {
  226. if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeNone {
  227. return
  228. }
  229. stopOpenBlocks()
  230. switch info.ClaudeConvertInfo.LastMessagesType {
  231. case relaycommon.LastMessageTypeTools:
  232. info.ClaudeConvertInfo.Index = info.ClaudeConvertInfo.ToolCallBaseIndex + info.ClaudeConvertInfo.ToolCallMaxIndexOffset + 1
  233. info.ClaudeConvertInfo.ToolCallBaseIndex = 0
  234. info.ClaudeConvertInfo.ToolCallMaxIndexOffset = 0
  235. default:
  236. info.ClaudeConvertInfo.Index++
  237. }
  238. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeNone
  239. }
  240. if info.SendResponseCount == 1 {
  241. msg := &dto.ClaudeMediaMessage{
  242. Id: openAIResponse.Id,
  243. Model: openAIResponse.Model,
  244. Type: "message",
  245. Role: "assistant",
  246. Usage: &dto.ClaudeUsage{
  247. InputTokens: info.GetEstimatePromptTokens(),
  248. OutputTokens: 0,
  249. },
  250. }
  251. msg.SetContent(make([]any, 0))
  252. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  253. Type: "message_start",
  254. Message: msg,
  255. })
  256. //claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  257. // Type: "ping",
  258. //})
  259. if openAIResponse.IsToolCall() {
  260. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
  261. info.ClaudeConvertInfo.ToolCallBaseIndex = 0
  262. info.ClaudeConvertInfo.ToolCallMaxIndexOffset = 0
  263. var toolCall dto.ToolCallResponse
  264. if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.ToolCalls) > 0 {
  265. toolCall = openAIResponse.Choices[0].Delta.ToolCalls[0]
  266. } else {
  267. first := openAIResponse.GetFirstToolCall()
  268. if first != nil {
  269. toolCall = *first
  270. } else {
  271. toolCall = dto.ToolCallResponse{}
  272. }
  273. }
  274. resp := &dto.ClaudeResponse{
  275. Type: "content_block_start",
  276. ContentBlock: &dto.ClaudeMediaMessage{
  277. Id: toolCall.ID,
  278. Type: "tool_use",
  279. Name: toolCall.Function.Name,
  280. Input: map[string]interface{}{},
  281. },
  282. }
  283. resp.SetIndex(0)
  284. claudeResponses = append(claudeResponses, resp)
  285. // 首块包含工具 delta,则追加 input_json_delta
  286. if toolCall.Function.Arguments != "" {
  287. idx := 0
  288. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  289. Index: &idx,
  290. Type: "content_block_delta",
  291. Delta: &dto.ClaudeMediaMessage{
  292. Type: "input_json_delta",
  293. PartialJson: &toolCall.Function.Arguments,
  294. },
  295. })
  296. }
  297. } else {
  298. }
  299. // 判断首个响应是否存在内容(非标准的 OpenAI 响应)
  300. if len(openAIResponse.Choices) > 0 {
  301. reasoning := openAIResponse.Choices[0].Delta.GetReasoningContent()
  302. content := openAIResponse.Choices[0].Delta.GetContentString()
  303. if reasoning != "" {
  304. if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
  305. stopOpenBlocksAndAdvance()
  306. }
  307. idx := info.ClaudeConvertInfo.Index
  308. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  309. Index: &idx,
  310. Type: "content_block_start",
  311. ContentBlock: &dto.ClaudeMediaMessage{
  312. Type: "thinking",
  313. Thinking: common.GetPointer[string](""),
  314. },
  315. })
  316. idx2 := idx
  317. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  318. Index: &idx2,
  319. Type: "content_block_delta",
  320. Delta: &dto.ClaudeMediaMessage{
  321. Type: "thinking_delta",
  322. Thinking: &reasoning,
  323. },
  324. })
  325. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
  326. } else if content != "" {
  327. if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
  328. stopOpenBlocksAndAdvance()
  329. }
  330. idx := info.ClaudeConvertInfo.Index
  331. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  332. Index: &idx,
  333. Type: "content_block_start",
  334. ContentBlock: &dto.ClaudeMediaMessage{
  335. Type: "text",
  336. Text: common.GetPointer[string](""),
  337. },
  338. })
  339. idx2 := idx
  340. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  341. Index: &idx2,
  342. Type: "content_block_delta",
  343. Delta: &dto.ClaudeMediaMessage{
  344. Type: "text_delta",
  345. Text: common.GetPointer[string](content),
  346. },
  347. })
  348. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
  349. }
  350. }
  351. // 如果首块就带 finish_reason,需要立即发送停止块
  352. if len(openAIResponse.Choices) > 0 && openAIResponse.Choices[0].FinishReason != nil && *openAIResponse.Choices[0].FinishReason != "" {
  353. info.FinishReason = *openAIResponse.Choices[0].FinishReason
  354. stopOpenBlocks()
  355. oaiUsage := openAIResponse.Usage
  356. if oaiUsage == nil {
  357. oaiUsage = info.ClaudeConvertInfo.Usage
  358. }
  359. if oaiUsage != nil {
  360. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  361. Type: "message_delta",
  362. Usage: &dto.ClaudeUsage{
  363. InputTokens: oaiUsage.PromptTokens,
  364. OutputTokens: oaiUsage.CompletionTokens,
  365. CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
  366. CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
  367. },
  368. Delta: &dto.ClaudeMediaMessage{
  369. StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
  370. },
  371. })
  372. }
  373. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  374. Type: "message_stop",
  375. })
  376. info.ClaudeConvertInfo.Done = true
  377. }
  378. return claudeResponses
  379. }
  380. if len(openAIResponse.Choices) == 0 {
  381. // no choices
  382. // 可能为非标准的 OpenAI 响应,判断是否已经完成
  383. if info.ClaudeConvertInfo.Done {
  384. stopOpenBlocks()
  385. oaiUsage := info.ClaudeConvertInfo.Usage
  386. if oaiUsage != nil {
  387. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  388. Type: "message_delta",
  389. Usage: &dto.ClaudeUsage{
  390. InputTokens: oaiUsage.PromptTokens,
  391. OutputTokens: oaiUsage.CompletionTokens,
  392. CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
  393. CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
  394. },
  395. Delta: &dto.ClaudeMediaMessage{
  396. StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
  397. },
  398. })
  399. }
  400. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  401. Type: "message_stop",
  402. })
  403. }
  404. return claudeResponses
  405. } else {
  406. chosenChoice := openAIResponse.Choices[0]
  407. doneChunk := chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != ""
  408. if doneChunk {
  409. info.FinishReason = *chosenChoice.FinishReason
  410. }
  411. var claudeResponse dto.ClaudeResponse
  412. var isEmpty bool
  413. claudeResponse.Type = "content_block_delta"
  414. if len(chosenChoice.Delta.ToolCalls) > 0 {
  415. toolCalls := chosenChoice.Delta.ToolCalls
  416. if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
  417. stopOpenBlocksAndAdvance()
  418. info.ClaudeConvertInfo.ToolCallBaseIndex = info.ClaudeConvertInfo.Index
  419. info.ClaudeConvertInfo.ToolCallMaxIndexOffset = 0
  420. }
  421. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
  422. base := info.ClaudeConvertInfo.ToolCallBaseIndex
  423. maxOffset := info.ClaudeConvertInfo.ToolCallMaxIndexOffset
  424. for i, toolCall := range toolCalls {
  425. offset := 0
  426. if toolCall.Index != nil {
  427. offset = *toolCall.Index
  428. } else {
  429. offset = i
  430. }
  431. if offset > maxOffset {
  432. maxOffset = offset
  433. }
  434. blockIndex := base + offset
  435. idx := blockIndex
  436. if toolCall.Function.Name != "" {
  437. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  438. Index: &idx,
  439. Type: "content_block_start",
  440. ContentBlock: &dto.ClaudeMediaMessage{
  441. Id: toolCall.ID,
  442. Type: "tool_use",
  443. Name: toolCall.Function.Name,
  444. Input: map[string]interface{}{},
  445. },
  446. })
  447. }
  448. if len(toolCall.Function.Arguments) > 0 {
  449. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  450. Index: &idx,
  451. Type: "content_block_delta",
  452. Delta: &dto.ClaudeMediaMessage{
  453. Type: "input_json_delta",
  454. PartialJson: &toolCall.Function.Arguments,
  455. },
  456. })
  457. }
  458. }
  459. info.ClaudeConvertInfo.ToolCallMaxIndexOffset = maxOffset
  460. info.ClaudeConvertInfo.Index = base + maxOffset
  461. } else {
  462. reasoning := chosenChoice.Delta.GetReasoningContent()
  463. textContent := chosenChoice.Delta.GetContentString()
  464. if reasoning != "" || textContent != "" {
  465. if reasoning != "" {
  466. if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
  467. stopOpenBlocksAndAdvance()
  468. idx := info.ClaudeConvertInfo.Index
  469. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  470. Index: &idx,
  471. Type: "content_block_start",
  472. ContentBlock: &dto.ClaudeMediaMessage{
  473. Type: "thinking",
  474. Thinking: common.GetPointer[string](""),
  475. },
  476. })
  477. }
  478. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
  479. claudeResponse.Delta = &dto.ClaudeMediaMessage{
  480. Type: "thinking_delta",
  481. Thinking: &reasoning,
  482. }
  483. } else {
  484. if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
  485. stopOpenBlocksAndAdvance()
  486. idx := info.ClaudeConvertInfo.Index
  487. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  488. Index: &idx,
  489. Type: "content_block_start",
  490. ContentBlock: &dto.ClaudeMediaMessage{
  491. Type: "text",
  492. Text: common.GetPointer[string](""),
  493. },
  494. })
  495. }
  496. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
  497. claudeResponse.Delta = &dto.ClaudeMediaMessage{
  498. Type: "text_delta",
  499. Text: common.GetPointer[string](textContent),
  500. }
  501. }
  502. } else {
  503. isEmpty = true
  504. }
  505. }
  506. claudeResponse.Index = common.GetPointer[int](info.ClaudeConvertInfo.Index)
  507. if !isEmpty && claudeResponse.Delta != nil {
  508. claudeResponses = append(claudeResponses, &claudeResponse)
  509. }
  510. if doneChunk || info.ClaudeConvertInfo.Done {
  511. stopOpenBlocks()
  512. oaiUsage := openAIResponse.Usage
  513. if oaiUsage == nil {
  514. oaiUsage = info.ClaudeConvertInfo.Usage
  515. }
  516. if oaiUsage != nil {
  517. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  518. Type: "message_delta",
  519. Usage: &dto.ClaudeUsage{
  520. InputTokens: oaiUsage.PromptTokens,
  521. OutputTokens: oaiUsage.CompletionTokens,
  522. CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
  523. CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
  524. },
  525. Delta: &dto.ClaudeMediaMessage{
  526. StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
  527. },
  528. })
  529. }
  530. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  531. Type: "message_stop",
  532. })
  533. info.ClaudeConvertInfo.Done = true
  534. return claudeResponses
  535. }
  536. }
  537. return claudeResponses
  538. }
  539. func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.ClaudeResponse {
  540. var stopReason string
  541. contents := make([]dto.ClaudeMediaMessage, 0)
  542. claudeResponse := &dto.ClaudeResponse{
  543. Id: openAIResponse.Id,
  544. Type: "message",
  545. Role: "assistant",
  546. Model: openAIResponse.Model,
  547. }
  548. for _, choice := range openAIResponse.Choices {
  549. stopReason = stopReasonOpenAI2Claude(choice.FinishReason)
  550. if choice.FinishReason == "tool_calls" {
  551. for _, toolUse := range choice.Message.ParseToolCalls() {
  552. claudeContent := dto.ClaudeMediaMessage{}
  553. claudeContent.Type = "tool_use"
  554. claudeContent.Id = toolUse.ID
  555. claudeContent.Name = toolUse.Function.Name
  556. var mapParams map[string]interface{}
  557. if err := common.Unmarshal([]byte(toolUse.Function.Arguments), &mapParams); err == nil {
  558. claudeContent.Input = mapParams
  559. } else {
  560. claudeContent.Input = toolUse.Function.Arguments
  561. }
  562. contents = append(contents, claudeContent)
  563. }
  564. } else {
  565. claudeContent := dto.ClaudeMediaMessage{}
  566. claudeContent.Type = "text"
  567. claudeContent.SetText(choice.Message.StringContent())
  568. contents = append(contents, claudeContent)
  569. }
  570. }
  571. claudeResponse.Content = contents
  572. claudeResponse.StopReason = stopReason
  573. claudeResponse.Usage = &dto.ClaudeUsage{
  574. InputTokens: openAIResponse.PromptTokens,
  575. OutputTokens: openAIResponse.CompletionTokens,
  576. }
  577. return claudeResponse
  578. }
  579. func stopReasonOpenAI2Claude(reason string) string {
  580. return reasonmap.OpenAIFinishReasonToClaudeStopReason(reason)
  581. }
  582. func toJSONString(v interface{}) string {
  583. b, err := json.Marshal(v)
  584. if err != nil {
  585. return "{}"
  586. }
  587. return string(b)
  588. }
  589. func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
  590. openaiRequest := &dto.GeneralOpenAIRequest{
  591. Model: info.UpstreamModelName,
  592. Stream: lo.ToPtr(info.IsStream),
  593. }
  594. // 转换 messages
  595. var messages []dto.Message
  596. for _, content := range geminiRequest.Contents {
  597. message := dto.Message{
  598. Role: convertGeminiRoleToOpenAI(content.Role),
  599. }
  600. // 处理 parts
  601. var mediaContents []dto.MediaContent
  602. var toolCalls []dto.ToolCallRequest
  603. for _, part := range content.Parts {
  604. if part.Text != "" {
  605. mediaContent := dto.MediaContent{
  606. Type: "text",
  607. Text: part.Text,
  608. }
  609. mediaContents = append(mediaContents, mediaContent)
  610. } else if part.InlineData != nil {
  611. mediaContent := dto.MediaContent{
  612. Type: "image_url",
  613. ImageUrl: &dto.MessageImageUrl{
  614. Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data),
  615. Detail: "auto",
  616. MimeType: part.InlineData.MimeType,
  617. },
  618. }
  619. mediaContents = append(mediaContents, mediaContent)
  620. } else if part.FileData != nil {
  621. mediaContent := dto.MediaContent{
  622. Type: "image_url",
  623. ImageUrl: &dto.MessageImageUrl{
  624. Url: part.FileData.FileUri,
  625. Detail: "auto",
  626. MimeType: part.FileData.MimeType,
  627. },
  628. }
  629. mediaContents = append(mediaContents, mediaContent)
  630. } else if part.FunctionCall != nil {
  631. // 处理 Gemini 的工具调用
  632. toolCall := dto.ToolCallRequest{
  633. ID: fmt.Sprintf("call_%d", len(toolCalls)+1), // 生成唯一ID
  634. Type: "function",
  635. Function: dto.FunctionRequest{
  636. Name: part.FunctionCall.FunctionName,
  637. Arguments: toJSONString(part.FunctionCall.Arguments),
  638. },
  639. }
  640. toolCalls = append(toolCalls, toolCall)
  641. } else if part.FunctionResponse != nil {
  642. // 处理 Gemini 的工具响应,创建单独的 tool 消息
  643. toolMessage := dto.Message{
  644. Role: "tool",
  645. ToolCallId: fmt.Sprintf("call_%d", len(toolCalls)), // 使用对应的调用ID
  646. }
  647. toolMessage.SetStringContent(toJSONString(part.FunctionResponse.Response))
  648. messages = append(messages, toolMessage)
  649. }
  650. }
  651. // 设置消息内容
  652. if len(toolCalls) > 0 {
  653. // 如果有工具调用,设置工具调用
  654. message.SetToolCalls(toolCalls)
  655. } else if len(mediaContents) == 1 && mediaContents[0].Type == "text" {
  656. // 如果只有一个文本内容,直接设置字符串
  657. message.Content = mediaContents[0].Text
  658. } else if len(mediaContents) > 0 {
  659. // 如果有多个内容或包含媒体,设置为数组
  660. message.SetMediaContent(mediaContents)
  661. }
  662. // 只有当消息有内容或工具调用时才添加
  663. if len(message.ParseContent()) > 0 || len(message.ToolCalls) > 0 {
  664. messages = append(messages, message)
  665. }
  666. }
  667. openaiRequest.Messages = messages
  668. if geminiRequest.GenerationConfig.Temperature != nil {
  669. openaiRequest.Temperature = geminiRequest.GenerationConfig.Temperature
  670. }
  671. if geminiRequest.GenerationConfig.TopP != nil && *geminiRequest.GenerationConfig.TopP > 0 {
  672. openaiRequest.TopP = lo.ToPtr(*geminiRequest.GenerationConfig.TopP)
  673. }
  674. if geminiRequest.GenerationConfig.TopK != nil && *geminiRequest.GenerationConfig.TopK > 0 {
  675. openaiRequest.TopK = lo.ToPtr(int(*geminiRequest.GenerationConfig.TopK))
  676. }
  677. if geminiRequest.GenerationConfig.MaxOutputTokens != nil && *geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
  678. openaiRequest.MaxTokens = lo.ToPtr(*geminiRequest.GenerationConfig.MaxOutputTokens)
  679. }
  680. // gemini stop sequences 最多 5 个,openai stop 最多 4 个
  681. if len(geminiRequest.GenerationConfig.StopSequences) > 0 {
  682. openaiRequest.Stop = geminiRequest.GenerationConfig.StopSequences[:4]
  683. }
  684. if geminiRequest.GenerationConfig.CandidateCount != nil && *geminiRequest.GenerationConfig.CandidateCount > 0 {
  685. openaiRequest.N = lo.ToPtr(*geminiRequest.GenerationConfig.CandidateCount)
  686. }
  687. // 转换工具调用
  688. if len(geminiRequest.GetTools()) > 0 {
  689. var tools []dto.ToolCallRequest
  690. for _, tool := range geminiRequest.GetTools() {
  691. if tool.FunctionDeclarations != nil {
  692. functionDeclarations, err := common.Any2Type[[]dto.FunctionRequest](tool.FunctionDeclarations)
  693. if err != nil {
  694. common.SysError(fmt.Sprintf("failed to parse gemini function declarations: %v (type=%T)", err, tool.FunctionDeclarations))
  695. continue
  696. }
  697. for _, function := range functionDeclarations {
  698. openAITool := dto.ToolCallRequest{
  699. Type: "function",
  700. Function: dto.FunctionRequest{
  701. Name: function.Name,
  702. Description: function.Description,
  703. Parameters: function.Parameters,
  704. },
  705. }
  706. tools = append(tools, openAITool)
  707. }
  708. }
  709. }
  710. if len(tools) > 0 {
  711. openaiRequest.Tools = tools
  712. }
  713. }
  714. // gemini system instructions
  715. if geminiRequest.SystemInstructions != nil {
  716. // 将系统指令作为第一条消息插入
  717. systemMessage := dto.Message{
  718. Role: "system",
  719. Content: extractTextFromGeminiParts(geminiRequest.SystemInstructions.Parts),
  720. }
  721. openaiRequest.Messages = append([]dto.Message{systemMessage}, openaiRequest.Messages...)
  722. }
  723. return openaiRequest, nil
  724. }
  725. func convertGeminiRoleToOpenAI(geminiRole string) string {
  726. switch geminiRole {
  727. case "user":
  728. return "user"
  729. case "model":
  730. return "assistant"
  731. case "function":
  732. return "function"
  733. default:
  734. return "user"
  735. }
  736. }
  737. func extractTextFromGeminiParts(parts []dto.GeminiPart) string {
  738. var texts []string
  739. for _, part := range parts {
  740. if part.Text != "" {
  741. texts = append(texts, part.Text)
  742. }
  743. }
  744. return strings.Join(texts, "\n")
  745. }
  746. // ResponseOpenAI2Gemini 将 OpenAI 响应转换为 Gemini 格式
  747. func ResponseOpenAI2Gemini(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
  748. geminiResponse := &dto.GeminiChatResponse{
  749. Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
  750. UsageMetadata: dto.GeminiUsageMetadata{
  751. PromptTokenCount: openAIResponse.PromptTokens,
  752. CandidatesTokenCount: openAIResponse.CompletionTokens,
  753. TotalTokenCount: openAIResponse.PromptTokens + openAIResponse.CompletionTokens,
  754. },
  755. }
  756. for _, choice := range openAIResponse.Choices {
  757. candidate := dto.GeminiChatCandidate{
  758. Index: int64(choice.Index),
  759. SafetyRatings: []dto.GeminiChatSafetyRating{},
  760. }
  761. // 设置结束原因
  762. var finishReason string
  763. switch choice.FinishReason {
  764. case "stop":
  765. finishReason = "STOP"
  766. case "length":
  767. finishReason = "MAX_TOKENS"
  768. case "content_filter":
  769. finishReason = "SAFETY"
  770. case "tool_calls":
  771. finishReason = "STOP"
  772. default:
  773. finishReason = "STOP"
  774. }
  775. candidate.FinishReason = &finishReason
  776. // 转换消息内容
  777. content := dto.GeminiChatContent{
  778. Role: "model",
  779. Parts: make([]dto.GeminiPart, 0),
  780. }
  781. // 处理工具调用
  782. toolCalls := choice.Message.ParseToolCalls()
  783. if len(toolCalls) > 0 {
  784. for _, toolCall := range toolCalls {
  785. // 解析参数
  786. var args map[string]interface{}
  787. if toolCall.Function.Arguments != "" {
  788. if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
  789. args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
  790. }
  791. } else {
  792. args = make(map[string]interface{})
  793. }
  794. part := dto.GeminiPart{
  795. FunctionCall: &dto.FunctionCall{
  796. FunctionName: toolCall.Function.Name,
  797. Arguments: args,
  798. },
  799. }
  800. content.Parts = append(content.Parts, part)
  801. }
  802. } else {
  803. // 处理文本内容
  804. textContent := choice.Message.StringContent()
  805. if textContent != "" {
  806. part := dto.GeminiPart{
  807. Text: textContent,
  808. }
  809. content.Parts = append(content.Parts, part)
  810. }
  811. }
  812. candidate.Content = content
  813. geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
  814. }
  815. return geminiResponse
  816. }
  817. // StreamResponseOpenAI2Gemini 将 OpenAI 流式响应转换为 Gemini 格式
  818. func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
  819. // 检查是否有实际内容或结束标志
  820. hasContent := false
  821. hasFinishReason := false
  822. for _, choice := range openAIResponse.Choices {
  823. if len(choice.Delta.GetContentString()) > 0 || (choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0) {
  824. hasContent = true
  825. }
  826. if choice.FinishReason != nil {
  827. hasFinishReason = true
  828. }
  829. }
  830. // 如果没有实际内容且没有结束标志,跳过。主要针对 openai 流响应开头的空数据
  831. if !hasContent && !hasFinishReason {
  832. return nil
  833. }
  834. geminiResponse := &dto.GeminiChatResponse{
  835. Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
  836. UsageMetadata: dto.GeminiUsageMetadata{
  837. PromptTokenCount: info.GetEstimatePromptTokens(),
  838. CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
  839. TotalTokenCount: info.GetEstimatePromptTokens(),
  840. },
  841. }
  842. if openAIResponse.Usage != nil {
  843. geminiResponse.UsageMetadata.PromptTokenCount = openAIResponse.Usage.PromptTokens
  844. geminiResponse.UsageMetadata.CandidatesTokenCount = openAIResponse.Usage.CompletionTokens
  845. geminiResponse.UsageMetadata.TotalTokenCount = openAIResponse.Usage.TotalTokens
  846. }
  847. for _, choice := range openAIResponse.Choices {
  848. candidate := dto.GeminiChatCandidate{
  849. Index: int64(choice.Index),
  850. SafetyRatings: []dto.GeminiChatSafetyRating{},
  851. }
  852. // 设置结束原因
  853. if choice.FinishReason != nil {
  854. var finishReason string
  855. switch *choice.FinishReason {
  856. case "stop":
  857. finishReason = "STOP"
  858. case "length":
  859. finishReason = "MAX_TOKENS"
  860. case "content_filter":
  861. finishReason = "SAFETY"
  862. case "tool_calls":
  863. finishReason = "STOP"
  864. default:
  865. finishReason = "STOP"
  866. }
  867. candidate.FinishReason = &finishReason
  868. }
  869. // 转换消息内容
  870. content := dto.GeminiChatContent{
  871. Role: "model",
  872. Parts: make([]dto.GeminiPart, 0),
  873. }
  874. // 处理工具调用
  875. if choice.Delta.ToolCalls != nil {
  876. for _, toolCall := range choice.Delta.ToolCalls {
  877. // 解析参数
  878. var args map[string]interface{}
  879. if toolCall.Function.Arguments != "" {
  880. if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
  881. args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
  882. }
  883. } else {
  884. args = make(map[string]interface{})
  885. }
  886. part := dto.GeminiPart{
  887. FunctionCall: &dto.FunctionCall{
  888. FunctionName: toolCall.Function.Name,
  889. Arguments: args,
  890. },
  891. }
  892. content.Parts = append(content.Parts, part)
  893. }
  894. } else {
  895. // 处理文本内容
  896. textContent := choice.Delta.GetContentString()
  897. if textContent != "" {
  898. part := dto.GeminiPart{
  899. Text: textContent,
  900. }
  901. content.Parts = append(content.Parts, part)
  902. }
  903. }
  904. candidate.Content = content
  905. geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
  906. }
  907. return geminiResponse
  908. }