convert.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. package service
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "one-api/common"
  6. "one-api/dto"
  7. relaycommon "one-api/relay/common"
  8. "strings"
  9. )
  10. func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
  11. openAIRequest := dto.GeneralOpenAIRequest{
  12. Model: claudeRequest.Model,
  13. MaxTokens: claudeRequest.MaxTokens,
  14. Temperature: claudeRequest.Temperature,
  15. TopP: claudeRequest.TopP,
  16. Stream: claudeRequest.Stream,
  17. }
  18. if claudeRequest.Thinking != nil {
  19. if strings.HasSuffix(info.OriginModelName, "-thinking") &&
  20. !strings.HasSuffix(claudeRequest.Model, "-thinking") {
  21. openAIRequest.Model = openAIRequest.Model + "-thinking"
  22. }
  23. }
  24. // Convert stop sequences
  25. if len(claudeRequest.StopSequences) == 1 {
  26. openAIRequest.Stop = claudeRequest.StopSequences[0]
  27. } else if len(claudeRequest.StopSequences) > 1 {
  28. openAIRequest.Stop = claudeRequest.StopSequences
  29. }
  30. // Convert tools
  31. tools, _ := common.Any2Type[[]dto.Tool](claudeRequest.Tools)
  32. openAITools := make([]dto.ToolCallRequest, 0)
  33. for _, claudeTool := range tools {
  34. openAITool := dto.ToolCallRequest{
  35. Type: "function",
  36. Function: dto.FunctionRequest{
  37. Name: claudeTool.Name,
  38. Description: claudeTool.Description,
  39. Parameters: claudeTool.InputSchema,
  40. },
  41. }
  42. openAITools = append(openAITools, openAITool)
  43. }
  44. openAIRequest.Tools = openAITools
  45. // Convert messages
  46. openAIMessages := make([]dto.Message, 0)
  47. // Add system message if present
  48. if claudeRequest.System != nil {
  49. if claudeRequest.IsStringSystem() && claudeRequest.GetStringSystem() != "" {
  50. openAIMessage := dto.Message{
  51. Role: "system",
  52. }
  53. openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
  54. openAIMessages = append(openAIMessages, openAIMessage)
  55. } else {
  56. systems := claudeRequest.ParseSystem()
  57. if len(systems) > 0 {
  58. systemStr := ""
  59. openAIMessage := dto.Message{
  60. Role: "system",
  61. }
  62. for _, system := range systems {
  63. if system.Text != nil {
  64. systemStr += *system.Text
  65. }
  66. }
  67. openAIMessage.SetStringContent(systemStr)
  68. openAIMessages = append(openAIMessages, openAIMessage)
  69. }
  70. }
  71. }
  72. for _, claudeMessage := range claudeRequest.Messages {
  73. openAIMessage := dto.Message{
  74. Role: claudeMessage.Role,
  75. }
  76. //log.Printf("claudeMessage.Content: %v", claudeMessage.Content)
  77. if claudeMessage.IsStringContent() {
  78. openAIMessage.SetStringContent(claudeMessage.GetStringContent())
  79. } else {
  80. content, err := claudeMessage.ParseContent()
  81. if err != nil {
  82. return nil, err
  83. }
  84. contents := content
  85. var toolCalls []dto.ToolCallRequest
  86. mediaMessages := make([]dto.MediaContent, 0, len(contents))
  87. for _, mediaMsg := range contents {
  88. switch mediaMsg.Type {
  89. case "text":
  90. message := dto.MediaContent{
  91. Type: "text",
  92. Text: mediaMsg.GetText(),
  93. }
  94. mediaMessages = append(mediaMessages, message)
  95. case "image":
  96. // Handle image conversion (base64 to URL or keep as is)
  97. imageData := fmt.Sprintf("data:%s;base64,%s", mediaMsg.Source.MediaType, mediaMsg.Source.Data)
  98. //textContent += fmt.Sprintf("[Image: %s]", imageData)
  99. mediaMessage := dto.MediaContent{
  100. Type: "image_url",
  101. ImageUrl: &dto.MessageImageUrl{Url: imageData},
  102. }
  103. mediaMessages = append(mediaMessages, mediaMessage)
  104. case "tool_use":
  105. toolCall := dto.ToolCallRequest{
  106. ID: mediaMsg.Id,
  107. Type: "function",
  108. Function: dto.FunctionRequest{
  109. Name: mediaMsg.Name,
  110. Arguments: toJSONString(mediaMsg.Input),
  111. },
  112. }
  113. toolCalls = append(toolCalls, toolCall)
  114. case "tool_result":
  115. // Add tool result as a separate message
  116. oaiToolMessage := dto.Message{
  117. Role: "tool",
  118. Name: &mediaMsg.Name,
  119. ToolCallId: mediaMsg.ToolUseId,
  120. }
  121. //oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text)
  122. if mediaMsg.IsStringContent() {
  123. oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
  124. } else {
  125. mediaContents := mediaMsg.ParseMediaContent()
  126. encodeJson, _ := common.EncodeJson(mediaContents)
  127. oaiToolMessage.SetStringContent(string(encodeJson))
  128. }
  129. openAIMessages = append(openAIMessages, oaiToolMessage)
  130. }
  131. }
  132. if len(toolCalls) > 0 {
  133. openAIMessage.SetToolCalls(toolCalls)
  134. }
  135. if len(mediaMessages) > 0 && len(toolCalls) == 0 {
  136. openAIMessage.SetMediaContent(mediaMessages)
  137. }
  138. }
  139. if len(openAIMessage.ParseContent()) > 0 || len(openAIMessage.ToolCalls) > 0 {
  140. openAIMessages = append(openAIMessages, openAIMessage)
  141. }
  142. }
  143. openAIRequest.Messages = openAIMessages
  144. return &openAIRequest, nil
  145. }
  146. func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode {
  147. claudeError := dto.ClaudeError{
  148. Type: "new_api_error",
  149. Message: openAIError.Error.Message,
  150. }
  151. return &dto.ClaudeErrorWithStatusCode{
  152. Error: claudeError,
  153. StatusCode: openAIError.StatusCode,
  154. }
  155. }
  156. func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode {
  157. openAIError := dto.OpenAIError{
  158. Message: claudeError.Error.Message,
  159. Type: "new_api_error",
  160. }
  161. return &dto.OpenAIErrorWithStatusCode{
  162. Error: openAIError,
  163. StatusCode: claudeError.StatusCode,
  164. }
  165. }
  166. func generateStopBlock(index int) *dto.ClaudeResponse {
  167. return &dto.ClaudeResponse{
  168. Type: "content_block_stop",
  169. Index: common.GetPointer[int](index),
  170. }
  171. }
  172. func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
  173. var claudeResponses []*dto.ClaudeResponse
  174. if info.SendResponseCount == 1 {
  175. msg := &dto.ClaudeMediaMessage{
  176. Id: openAIResponse.Id,
  177. Model: openAIResponse.Model,
  178. Type: "message",
  179. Role: "assistant",
  180. Usage: &dto.ClaudeUsage{
  181. InputTokens: info.PromptTokens,
  182. OutputTokens: 0,
  183. },
  184. }
  185. msg.SetContent(make([]any, 0))
  186. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  187. Type: "message_start",
  188. Message: msg,
  189. })
  190. claudeResponses = append(claudeResponses)
  191. //claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  192. // Type: "ping",
  193. //})
  194. if openAIResponse.IsToolCall() {
  195. resp := &dto.ClaudeResponse{
  196. Type: "content_block_start",
  197. ContentBlock: &dto.ClaudeMediaMessage{
  198. Id: openAIResponse.GetFirstToolCall().ID,
  199. Type: "tool_use",
  200. Name: openAIResponse.GetFirstToolCall().Function.Name,
  201. },
  202. }
  203. resp.SetIndex(0)
  204. claudeResponses = append(claudeResponses, resp)
  205. } else {
  206. //resp := &dto.ClaudeResponse{
  207. // Type: "content_block_start",
  208. // ContentBlock: &dto.ClaudeMediaMessage{
  209. // Type: "text",
  210. // Text: common.GetPointer[string](""),
  211. // },
  212. //}
  213. //resp.SetIndex(0)
  214. //claudeResponses = append(claudeResponses, resp)
  215. }
  216. return claudeResponses
  217. }
  218. if len(openAIResponse.Choices) == 0 {
  219. // no choices
  220. // TODO: handle this case
  221. return claudeResponses
  222. } else {
  223. chosenChoice := openAIResponse.Choices[0]
  224. if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
  225. // should be done
  226. info.FinishReason = *chosenChoice.FinishReason
  227. return claudeResponses
  228. }
  229. if info.Done {
  230. claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
  231. oaiUsage := info.ClaudeConvertInfo.Usage
  232. if oaiUsage != nil {
  233. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  234. Type: "message_delta",
  235. Usage: &dto.ClaudeUsage{
  236. InputTokens: oaiUsage.PromptTokens,
  237. OutputTokens: oaiUsage.CompletionTokens,
  238. CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
  239. CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
  240. },
  241. Delta: &dto.ClaudeMediaMessage{
  242. StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
  243. },
  244. })
  245. }
  246. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  247. Type: "message_stop",
  248. })
  249. } else {
  250. var claudeResponse dto.ClaudeResponse
  251. var isEmpty bool
  252. claudeResponse.Type = "content_block_delta"
  253. if len(chosenChoice.Delta.ToolCalls) > 0 {
  254. if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
  255. claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
  256. info.ClaudeConvertInfo.Index++
  257. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  258. Index: &info.ClaudeConvertInfo.Index,
  259. Type: "content_block_start",
  260. ContentBlock: &dto.ClaudeMediaMessage{
  261. Id: openAIResponse.GetFirstToolCall().ID,
  262. Type: "tool_use",
  263. Name: openAIResponse.GetFirstToolCall().Function.Name,
  264. Input: map[string]interface{}{},
  265. },
  266. })
  267. }
  268. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
  269. // tools delta
  270. claudeResponse.Delta = &dto.ClaudeMediaMessage{
  271. Type: "input_json_delta",
  272. PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
  273. }
  274. } else {
  275. reasoning := chosenChoice.Delta.GetReasoningContent()
  276. textContent := chosenChoice.Delta.GetContentString()
  277. if reasoning != "" || textContent != "" {
  278. if reasoning != "" {
  279. if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
  280. //info.ClaudeConvertInfo.Index++
  281. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  282. Index: &info.ClaudeConvertInfo.Index,
  283. Type: "content_block_start",
  284. ContentBlock: &dto.ClaudeMediaMessage{
  285. Type: "thinking",
  286. Thinking: "",
  287. },
  288. })
  289. }
  290. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
  291. // text delta
  292. claudeResponse.Delta = &dto.ClaudeMediaMessage{
  293. Type: "thinking_delta",
  294. Thinking: reasoning,
  295. }
  296. } else {
  297. if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
  298. if info.LastMessagesType == relaycommon.LastMessageTypeThinking || info.LastMessagesType == relaycommon.LastMessageTypeTools {
  299. claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
  300. info.ClaudeConvertInfo.Index++
  301. }
  302. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  303. Index: &info.ClaudeConvertInfo.Index,
  304. Type: "content_block_start",
  305. ContentBlock: &dto.ClaudeMediaMessage{
  306. Type: "text",
  307. Text: common.GetPointer[string](""),
  308. },
  309. })
  310. }
  311. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
  312. // text delta
  313. claudeResponse.Delta = &dto.ClaudeMediaMessage{
  314. Type: "text_delta",
  315. Text: common.GetPointer[string](textContent),
  316. }
  317. }
  318. } else {
  319. isEmpty = true
  320. }
  321. }
  322. claudeResponse.Index = &info.ClaudeConvertInfo.Index
  323. if !isEmpty {
  324. claudeResponses = append(claudeResponses, &claudeResponse)
  325. }
  326. }
  327. }
  328. return claudeResponses
  329. }
  330. func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.ClaudeResponse {
  331. var stopReason string
  332. contents := make([]dto.ClaudeMediaMessage, 0)
  333. claudeResponse := &dto.ClaudeResponse{
  334. Id: openAIResponse.Id,
  335. Type: "message",
  336. Role: "assistant",
  337. Model: openAIResponse.Model,
  338. }
  339. for _, choice := range openAIResponse.Choices {
  340. stopReason = stopReasonOpenAI2Claude(choice.FinishReason)
  341. claudeContent := dto.ClaudeMediaMessage{}
  342. if choice.FinishReason == "tool_calls" {
  343. claudeContent.Type = "tool_use"
  344. claudeContent.Id = choice.Message.ToolCallId
  345. claudeContent.Name = choice.Message.ParseToolCalls()[0].Function.Name
  346. var mapParams map[string]interface{}
  347. if err := json.Unmarshal([]byte(choice.Message.ParseToolCalls()[0].Function.Arguments), &mapParams); err == nil {
  348. claudeContent.Input = mapParams
  349. } else {
  350. claudeContent.Input = choice.Message.ParseToolCalls()[0].Function.Arguments
  351. }
  352. } else {
  353. claudeContent.Type = "text"
  354. claudeContent.SetText(choice.Message.StringContent())
  355. }
  356. contents = append(contents, claudeContent)
  357. }
  358. claudeResponse.Content = contents
  359. claudeResponse.StopReason = stopReason
  360. claudeResponse.Usage = &dto.ClaudeUsage{
  361. InputTokens: openAIResponse.PromptTokens,
  362. OutputTokens: openAIResponse.CompletionTokens,
  363. }
  364. return claudeResponse
  365. }
  366. func stopReasonOpenAI2Claude(reason string) string {
  367. switch reason {
  368. case "stop":
  369. return "end_turn"
  370. case "stop_sequence":
  371. return "stop_sequence"
  372. case "max_tokens":
  373. return "max_tokens"
  374. case "tool_calls":
  375. return "tool_use"
  376. default:
  377. return reason
  378. }
  379. }
  380. func toJSONString(v interface{}) string {
  381. b, err := json.Marshal(v)
  382. if err != nil {
  383. return "{}"
  384. }
  385. return string(b)
  386. }