convert.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. package service
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "one-api/common"
  6. "one-api/dto"
  7. relaycommon "one-api/relay/common"
  8. "strings"
  9. )
  10. func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
  11. openAIRequest := dto.GeneralOpenAIRequest{
  12. Model: claudeRequest.Model,
  13. MaxTokens: claudeRequest.MaxTokens,
  14. Temperature: claudeRequest.Temperature,
  15. TopP: claudeRequest.TopP,
  16. Stream: claudeRequest.Stream,
  17. }
  18. if claudeRequest.Thinking != nil {
  19. if strings.HasSuffix(info.OriginModelName, "-thinking") &&
  20. !strings.HasSuffix(claudeRequest.Model, "-thinking") {
  21. openAIRequest.Model = openAIRequest.Model + "-thinking"
  22. }
  23. }
  24. // Convert stop sequences
  25. if len(claudeRequest.StopSequences) == 1 {
  26. openAIRequest.Stop = claudeRequest.StopSequences[0]
  27. } else if len(claudeRequest.StopSequences) > 1 {
  28. openAIRequest.Stop = claudeRequest.StopSequences
  29. }
  30. // Convert tools
  31. tools, _ := common.Any2Type[[]dto.Tool](claudeRequest.Tools)
  32. openAITools := make([]dto.ToolCallRequest, 0)
  33. for _, claudeTool := range tools {
  34. openAITool := dto.ToolCallRequest{
  35. Type: "function",
  36. Function: dto.FunctionRequest{
  37. Name: claudeTool.Name,
  38. Description: claudeTool.Description,
  39. Parameters: claudeTool.InputSchema,
  40. },
  41. }
  42. openAITools = append(openAITools, openAITool)
  43. }
  44. openAIRequest.Tools = openAITools
  45. // Convert messages
  46. openAIMessages := make([]dto.Message, 0)
  47. // Add system message if present
  48. if claudeRequest.System != nil {
  49. if claudeRequest.IsStringSystem() && claudeRequest.GetStringSystem() != "" {
  50. openAIMessage := dto.Message{
  51. Role: "system",
  52. }
  53. openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
  54. openAIMessages = append(openAIMessages, openAIMessage)
  55. } else {
  56. systems := claudeRequest.ParseSystem()
  57. if len(systems) > 0 {
  58. systemStr := ""
  59. openAIMessage := dto.Message{
  60. Role: "system",
  61. }
  62. for _, system := range systems {
  63. systemStr += system.Type
  64. }
  65. openAIMessage.SetStringContent(systemStr)
  66. openAIMessages = append(openAIMessages, openAIMessage)
  67. }
  68. }
  69. }
  70. for _, claudeMessage := range claudeRequest.Messages {
  71. openAIMessage := dto.Message{
  72. Role: claudeMessage.Role,
  73. }
  74. //log.Printf("claudeMessage.Content: %v", claudeMessage.Content)
  75. if claudeMessage.IsStringContent() {
  76. openAIMessage.SetStringContent(claudeMessage.GetStringContent())
  77. } else {
  78. content, err := claudeMessage.ParseContent()
  79. if err != nil {
  80. return nil, err
  81. }
  82. contents := content
  83. var toolCalls []dto.ToolCallRequest
  84. mediaMessages := make([]dto.MediaContent, 0, len(contents))
  85. for _, mediaMsg := range contents {
  86. switch mediaMsg.Type {
  87. case "text":
  88. message := dto.MediaContent{
  89. Type: "text",
  90. Text: mediaMsg.GetText(),
  91. }
  92. mediaMessages = append(mediaMessages, message)
  93. case "image":
  94. // Handle image conversion (base64 to URL or keep as is)
  95. imageData := fmt.Sprintf("data:%s;base64,%s", mediaMsg.Source.MediaType, mediaMsg.Source.Data)
  96. //textContent += fmt.Sprintf("[Image: %s]", imageData)
  97. mediaMessage := dto.MediaContent{
  98. Type: "image_url",
  99. ImageUrl: &dto.MessageImageUrl{Url: imageData},
  100. }
  101. mediaMessages = append(mediaMessages, mediaMessage)
  102. case "tool_use":
  103. toolCall := dto.ToolCallRequest{
  104. ID: mediaMsg.Id,
  105. Type: "function",
  106. Function: dto.FunctionRequest{
  107. Name: mediaMsg.Name,
  108. Arguments: toJSONString(mediaMsg.Input),
  109. },
  110. }
  111. toolCalls = append(toolCalls, toolCall)
  112. case "tool_result":
  113. // Add tool result as a separate message
  114. oaiToolMessage := dto.Message{
  115. Role: "tool",
  116. Name: &mediaMsg.Name,
  117. ToolCallId: mediaMsg.ToolUseId,
  118. }
  119. //oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text)
  120. if mediaMsg.IsStringContent() {
  121. oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
  122. } else {
  123. mediaContents := mediaMsg.ParseMediaContent()
  124. encodeJson, _ := common.EncodeJson(mediaContents)
  125. oaiToolMessage.SetStringContent(string(encodeJson))
  126. }
  127. openAIMessages = append(openAIMessages, oaiToolMessage)
  128. }
  129. }
  130. if len(toolCalls) > 0 {
  131. openAIMessage.SetToolCalls(toolCalls)
  132. }
  133. if len(mediaMessages) > 0 && len(toolCalls) == 0 {
  134. openAIMessage.SetMediaContent(mediaMessages)
  135. }
  136. }
  137. if len(openAIMessage.ParseContent()) > 0 || len(openAIMessage.ToolCalls) > 0 {
  138. openAIMessages = append(openAIMessages, openAIMessage)
  139. }
  140. }
  141. openAIRequest.Messages = openAIMessages
  142. return &openAIRequest, nil
  143. }
  144. func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode {
  145. claudeError := dto.ClaudeError{
  146. Type: "new_api_error",
  147. Message: openAIError.Error.Message,
  148. }
  149. return &dto.ClaudeErrorWithStatusCode{
  150. Error: claudeError,
  151. StatusCode: openAIError.StatusCode,
  152. }
  153. }
  154. func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode {
  155. openAIError := dto.OpenAIError{
  156. Message: claudeError.Error.Message,
  157. Type: "new_api_error",
  158. }
  159. return &dto.OpenAIErrorWithStatusCode{
  160. Error: openAIError,
  161. StatusCode: claudeError.StatusCode,
  162. }
  163. }
  164. func generateStopBlock(index int) *dto.ClaudeResponse {
  165. return &dto.ClaudeResponse{
  166. Type: "content_block_stop",
  167. Index: common.GetPointer[int](index),
  168. }
  169. }
  170. func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
  171. var claudeResponses []*dto.ClaudeResponse
  172. if info.SendResponseCount == 1 {
  173. msg := &dto.ClaudeMediaMessage{
  174. Id: openAIResponse.Id,
  175. Model: openAIResponse.Model,
  176. Type: "message",
  177. Role: "assistant",
  178. Usage: &dto.ClaudeUsage{
  179. InputTokens: info.PromptTokens,
  180. OutputTokens: 0,
  181. },
  182. }
  183. msg.SetContent(make([]any, 0))
  184. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  185. Type: "message_start",
  186. Message: msg,
  187. })
  188. claudeResponses = append(claudeResponses)
  189. //claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  190. // Type: "ping",
  191. //})
  192. if openAIResponse.IsToolCall() {
  193. resp := &dto.ClaudeResponse{
  194. Type: "content_block_start",
  195. ContentBlock: &dto.ClaudeMediaMessage{
  196. Id: openAIResponse.GetFirstToolCall().ID,
  197. Type: "tool_use",
  198. Name: openAIResponse.GetFirstToolCall().Function.Name,
  199. },
  200. }
  201. resp.SetIndex(0)
  202. claudeResponses = append(claudeResponses, resp)
  203. } else {
  204. //resp := &dto.ClaudeResponse{
  205. // Type: "content_block_start",
  206. // ContentBlock: &dto.ClaudeMediaMessage{
  207. // Type: "text",
  208. // Text: common.GetPointer[string](""),
  209. // },
  210. //}
  211. //resp.SetIndex(0)
  212. //claudeResponses = append(claudeResponses, resp)
  213. }
  214. return claudeResponses
  215. }
  216. if len(openAIResponse.Choices) == 0 {
  217. // no choices
  218. // TODO: handle this case
  219. return claudeResponses
  220. } else {
  221. chosenChoice := openAIResponse.Choices[0]
  222. if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
  223. // should be done
  224. info.FinishReason = *chosenChoice.FinishReason
  225. return claudeResponses
  226. }
  227. if info.Done {
  228. claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
  229. if info.ClaudeConvertInfo.Usage != nil {
  230. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  231. Type: "message_delta",
  232. Usage: &dto.ClaudeUsage{
  233. InputTokens: info.ClaudeConvertInfo.Usage.PromptTokens,
  234. OutputTokens: info.ClaudeConvertInfo.Usage.CompletionTokens,
  235. },
  236. Delta: &dto.ClaudeMediaMessage{
  237. StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
  238. },
  239. })
  240. }
  241. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  242. Type: "message_stop",
  243. })
  244. } else {
  245. var claudeResponse dto.ClaudeResponse
  246. var isEmpty bool
  247. claudeResponse.Type = "content_block_delta"
  248. if len(chosenChoice.Delta.ToolCalls) > 0 {
  249. if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
  250. claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
  251. info.ClaudeConvertInfo.Index++
  252. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  253. Index: &info.ClaudeConvertInfo.Index,
  254. Type: "content_block_start",
  255. ContentBlock: &dto.ClaudeMediaMessage{
  256. Id: openAIResponse.GetFirstToolCall().ID,
  257. Type: "tool_use",
  258. Name: openAIResponse.GetFirstToolCall().Function.Name,
  259. Input: map[string]interface{}{},
  260. },
  261. })
  262. }
  263. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
  264. // tools delta
  265. claudeResponse.Delta = &dto.ClaudeMediaMessage{
  266. Type: "input_json_delta",
  267. PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
  268. }
  269. } else {
  270. reasoning := chosenChoice.Delta.GetReasoningContent()
  271. textContent := chosenChoice.Delta.GetContentString()
  272. if reasoning != "" || textContent != "" {
  273. if reasoning != "" {
  274. if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
  275. //info.ClaudeConvertInfo.Index++
  276. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  277. Index: &info.ClaudeConvertInfo.Index,
  278. Type: "content_block_start",
  279. ContentBlock: &dto.ClaudeMediaMessage{
  280. Type: "thinking",
  281. Thinking: "",
  282. },
  283. })
  284. }
  285. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
  286. // text delta
  287. claudeResponse.Delta = &dto.ClaudeMediaMessage{
  288. Type: "thinking_delta",
  289. Thinking: reasoning,
  290. }
  291. } else {
  292. if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
  293. if info.LastMessagesType == relaycommon.LastMessageTypeThinking || info.LastMessagesType == relaycommon.LastMessageTypeTools {
  294. claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
  295. info.ClaudeConvertInfo.Index++
  296. }
  297. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  298. Index: &info.ClaudeConvertInfo.Index,
  299. Type: "content_block_start",
  300. ContentBlock: &dto.ClaudeMediaMessage{
  301. Type: "text",
  302. Text: common.GetPointer[string](""),
  303. },
  304. })
  305. }
  306. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
  307. // text delta
  308. claudeResponse.Delta = &dto.ClaudeMediaMessage{
  309. Type: "text_delta",
  310. Text: common.GetPointer[string](textContent),
  311. }
  312. }
  313. } else {
  314. isEmpty = true
  315. }
  316. }
  317. claudeResponse.Index = &info.ClaudeConvertInfo.Index
  318. if !isEmpty {
  319. claudeResponses = append(claudeResponses, &claudeResponse)
  320. }
  321. }
  322. }
  323. return claudeResponses
  324. }
  325. func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.ClaudeResponse {
  326. var stopReason string
  327. contents := make([]dto.ClaudeMediaMessage, 0)
  328. claudeResponse := &dto.ClaudeResponse{
  329. Id: openAIResponse.Id,
  330. Type: "message",
  331. Role: "assistant",
  332. Model: openAIResponse.Model,
  333. }
  334. for _, choice := range openAIResponse.Choices {
  335. stopReason = stopReasonOpenAI2Claude(choice.FinishReason)
  336. claudeContent := dto.ClaudeMediaMessage{}
  337. if choice.FinishReason == "tool_calls" {
  338. claudeContent.Type = "tool_use"
  339. claudeContent.Id = choice.Message.ToolCallId
  340. claudeContent.Name = choice.Message.ParseToolCalls()[0].Function.Name
  341. var mapParams map[string]interface{}
  342. if err := json.Unmarshal([]byte(choice.Message.ParseToolCalls()[0].Function.Arguments), &mapParams); err == nil {
  343. claudeContent.Input = mapParams
  344. } else {
  345. claudeContent.Input = choice.Message.ParseToolCalls()[0].Function.Arguments
  346. }
  347. } else {
  348. claudeContent.Type = "text"
  349. claudeContent.SetText(choice.Message.StringContent())
  350. }
  351. contents = append(contents, claudeContent)
  352. }
  353. claudeResponse.Content = contents
  354. claudeResponse.StopReason = stopReason
  355. claudeResponse.Usage = &dto.ClaudeUsage{
  356. InputTokens: openAIResponse.PromptTokens,
  357. OutputTokens: openAIResponse.CompletionTokens,
  358. }
  359. return claudeResponse
  360. }
  361. func stopReasonOpenAI2Claude(reason string) string {
  362. switch reason {
  363. case "stop":
  364. return "end_turn"
  365. case "stop_sequence":
  366. return "stop_sequence"
  367. case "max_tokens":
  368. return "max_tokens"
  369. case "tool_calls":
  370. return "tool_use"
  371. default:
  372. return reason
  373. }
  374. }
  375. func toJSONString(v interface{}) string {
  376. b, err := json.Marshal(v)
  377. if err != nil {
  378. return "{}"
  379. }
  380. return string(b)
  381. }