convert.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. package service
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "one-api/common"
  6. "one-api/dto"
  7. relaycommon "one-api/relay/common"
  8. "strings"
  9. )
  10. func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
  11. openAIRequest := dto.GeneralOpenAIRequest{
  12. Model: claudeRequest.Model,
  13. MaxTokens: claudeRequest.MaxTokens,
  14. Temperature: claudeRequest.Temperature,
  15. TopP: claudeRequest.TopP,
  16. Stream: claudeRequest.Stream,
  17. }
  18. if claudeRequest.Thinking != nil {
  19. if strings.HasSuffix(info.OriginModelName, "-thinking") &&
  20. !strings.HasSuffix(claudeRequest.Model, "-thinking") {
  21. openAIRequest.Model = openAIRequest.Model + "-thinking"
  22. }
  23. }
  24. // Convert stop sequences
  25. if len(claudeRequest.StopSequences) == 1 {
  26. openAIRequest.Stop = claudeRequest.StopSequences[0]
  27. } else if len(claudeRequest.StopSequences) > 1 {
  28. openAIRequest.Stop = claudeRequest.StopSequences
  29. }
  30. // Convert tools
  31. tools, _ := common.Any2Type[[]dto.Tool](claudeRequest.Tools)
  32. openAITools := make([]dto.ToolCallRequest, 0)
  33. for _, claudeTool := range tools {
  34. openAITool := dto.ToolCallRequest{
  35. Type: "function",
  36. Function: dto.FunctionRequest{
  37. Name: claudeTool.Name,
  38. Description: claudeTool.Description,
  39. Parameters: claudeTool.InputSchema,
  40. },
  41. }
  42. openAITools = append(openAITools, openAITool)
  43. }
  44. openAIRequest.Tools = openAITools
  45. // Convert messages
  46. openAIMessages := make([]dto.Message, 0)
  47. // Add system message if present
  48. if claudeRequest.System != nil {
  49. if claudeRequest.IsStringSystem() && claudeRequest.GetStringSystem() != "" {
  50. openAIMessage := dto.Message{
  51. Role: "system",
  52. }
  53. openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
  54. openAIMessages = append(openAIMessages, openAIMessage)
  55. } else {
  56. systems := claudeRequest.ParseSystem()
  57. if len(systems) > 0 {
  58. systemStr := ""
  59. openAIMessage := dto.Message{
  60. Role: "system",
  61. }
  62. for _, system := range systems {
  63. if system.Text != nil {
  64. systemStr += *system.Text
  65. }
  66. }
  67. openAIMessage.SetStringContent(systemStr)
  68. openAIMessages = append(openAIMessages, openAIMessage)
  69. }
  70. }
  71. }
  72. for _, claudeMessage := range claudeRequest.Messages {
  73. openAIMessage := dto.Message{
  74. Role: claudeMessage.Role,
  75. }
  76. //log.Printf("claudeMessage.Content: %v", claudeMessage.Content)
  77. if claudeMessage.IsStringContent() {
  78. openAIMessage.SetStringContent(claudeMessage.GetStringContent())
  79. } else {
  80. content, err := claudeMessage.ParseContent()
  81. if err != nil {
  82. return nil, err
  83. }
  84. contents := content
  85. var toolCalls []dto.ToolCallRequest
  86. mediaMessages := make([]dto.MediaContent, 0, len(contents))
  87. for _, mediaMsg := range contents {
  88. switch mediaMsg.Type {
  89. case "text":
  90. message := dto.MediaContent{
  91. Type: "text",
  92. Text: mediaMsg.GetText(),
  93. }
  94. mediaMessages = append(mediaMessages, message)
  95. case "image":
  96. // Handle image conversion (base64 to URL or keep as is)
  97. imageData := fmt.Sprintf("data:%s;base64,%s", mediaMsg.Source.MediaType, mediaMsg.Source.Data)
  98. //textContent += fmt.Sprintf("[Image: %s]", imageData)
  99. mediaMessage := dto.MediaContent{
  100. Type: "image_url",
  101. ImageUrl: &dto.MessageImageUrl{Url: imageData},
  102. }
  103. mediaMessages = append(mediaMessages, mediaMessage)
  104. case "tool_use":
  105. toolCall := dto.ToolCallRequest{
  106. ID: mediaMsg.Id,
  107. Type: "function",
  108. Function: dto.FunctionRequest{
  109. Name: mediaMsg.Name,
  110. Arguments: toJSONString(mediaMsg.Input),
  111. },
  112. }
  113. toolCalls = append(toolCalls, toolCall)
  114. case "tool_result":
  115. // Add tool result as a separate message
  116. oaiToolMessage := dto.Message{
  117. Role: "tool",
  118. Name: &mediaMsg.Name,
  119. ToolCallId: mediaMsg.ToolUseId,
  120. }
  121. //oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text)
  122. if mediaMsg.IsStringContent() {
  123. oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
  124. } else {
  125. mediaContents := mediaMsg.ParseMediaContent()
  126. encodeJson, _ := common.EncodeJson(mediaContents)
  127. oaiToolMessage.SetStringContent(string(encodeJson))
  128. }
  129. openAIMessages = append(openAIMessages, oaiToolMessage)
  130. }
  131. }
  132. if len(toolCalls) > 0 {
  133. openAIMessage.SetToolCalls(toolCalls)
  134. }
  135. if len(mediaMessages) > 0 && len(toolCalls) == 0 {
  136. openAIMessage.SetMediaContent(mediaMessages)
  137. }
  138. }
  139. if len(openAIMessage.ParseContent()) > 0 || len(openAIMessage.ToolCalls) > 0 {
  140. openAIMessages = append(openAIMessages, openAIMessage)
  141. }
  142. }
  143. openAIRequest.Messages = openAIMessages
  144. return &openAIRequest, nil
  145. }
  146. func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode {
  147. claudeError := dto.ClaudeError{
  148. Type: "new_api_error",
  149. Message: openAIError.Error.Message,
  150. }
  151. return &dto.ClaudeErrorWithStatusCode{
  152. Error: claudeError,
  153. StatusCode: openAIError.StatusCode,
  154. }
  155. }
  156. func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode {
  157. openAIError := dto.OpenAIError{
  158. Message: claudeError.Error.Message,
  159. Type: "new_api_error",
  160. }
  161. return &dto.OpenAIErrorWithStatusCode{
  162. Error: openAIError,
  163. StatusCode: claudeError.StatusCode,
  164. }
  165. }
  166. func generateStopBlock(index int) *dto.ClaudeResponse {
  167. return &dto.ClaudeResponse{
  168. Type: "content_block_stop",
  169. Index: common.GetPointer[int](index),
  170. }
  171. }
  172. func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
  173. var claudeResponses []*dto.ClaudeResponse
  174. if info.SendResponseCount == 1 {
  175. msg := &dto.ClaudeMediaMessage{
  176. Id: openAIResponse.Id,
  177. Model: openAIResponse.Model,
  178. Type: "message",
  179. Role: "assistant",
  180. Usage: &dto.ClaudeUsage{
  181. InputTokens: info.PromptTokens,
  182. OutputTokens: 0,
  183. },
  184. }
  185. msg.SetContent(make([]any, 0))
  186. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  187. Type: "message_start",
  188. Message: msg,
  189. })
  190. claudeResponses = append(claudeResponses)
  191. //claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  192. // Type: "ping",
  193. //})
  194. if openAIResponse.IsToolCall() {
  195. resp := &dto.ClaudeResponse{
  196. Type: "content_block_start",
  197. ContentBlock: &dto.ClaudeMediaMessage{
  198. Id: openAIResponse.GetFirstToolCall().ID,
  199. Type: "tool_use",
  200. Name: openAIResponse.GetFirstToolCall().Function.Name,
  201. },
  202. }
  203. resp.SetIndex(0)
  204. claudeResponses = append(claudeResponses, resp)
  205. } else {
  206. //resp := &dto.ClaudeResponse{
  207. // Type: "content_block_start",
  208. // ContentBlock: &dto.ClaudeMediaMessage{
  209. // Type: "text",
  210. // Text: common.GetPointer[string](""),
  211. // },
  212. //}
  213. //resp.SetIndex(0)
  214. //claudeResponses = append(claudeResponses, resp)
  215. }
  216. return claudeResponses
  217. }
  218. if len(openAIResponse.Choices) == 0 {
  219. // no choices
  220. // TODO: handle this case
  221. return claudeResponses
  222. } else {
  223. chosenChoice := openAIResponse.Choices[0]
  224. if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
  225. // should be done
  226. info.FinishReason = *chosenChoice.FinishReason
  227. return claudeResponses
  228. }
  229. if info.Done {
  230. claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
  231. if info.ClaudeConvertInfo.Usage != nil {
  232. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  233. Type: "message_delta",
  234. Usage: &dto.ClaudeUsage{
  235. InputTokens: info.ClaudeConvertInfo.Usage.PromptTokens,
  236. OutputTokens: info.ClaudeConvertInfo.Usage.CompletionTokens,
  237. },
  238. Delta: &dto.ClaudeMediaMessage{
  239. StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
  240. },
  241. })
  242. }
  243. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  244. Type: "message_stop",
  245. })
  246. } else {
  247. var claudeResponse dto.ClaudeResponse
  248. var isEmpty bool
  249. claudeResponse.Type = "content_block_delta"
  250. if len(chosenChoice.Delta.ToolCalls) > 0 {
  251. if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
  252. claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
  253. info.ClaudeConvertInfo.Index++
  254. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  255. Index: &info.ClaudeConvertInfo.Index,
  256. Type: "content_block_start",
  257. ContentBlock: &dto.ClaudeMediaMessage{
  258. Id: openAIResponse.GetFirstToolCall().ID,
  259. Type: "tool_use",
  260. Name: openAIResponse.GetFirstToolCall().Function.Name,
  261. Input: map[string]interface{}{},
  262. },
  263. })
  264. }
  265. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
  266. // tools delta
  267. claudeResponse.Delta = &dto.ClaudeMediaMessage{
  268. Type: "input_json_delta",
  269. PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
  270. }
  271. } else {
  272. reasoning := chosenChoice.Delta.GetReasoningContent()
  273. textContent := chosenChoice.Delta.GetContentString()
  274. if reasoning != "" || textContent != "" {
  275. if reasoning != "" {
  276. if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
  277. //info.ClaudeConvertInfo.Index++
  278. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  279. Index: &info.ClaudeConvertInfo.Index,
  280. Type: "content_block_start",
  281. ContentBlock: &dto.ClaudeMediaMessage{
  282. Type: "thinking",
  283. Thinking: "",
  284. },
  285. })
  286. }
  287. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
  288. // text delta
  289. claudeResponse.Delta = &dto.ClaudeMediaMessage{
  290. Type: "thinking_delta",
  291. Thinking: reasoning,
  292. }
  293. } else {
  294. if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
  295. if info.LastMessagesType == relaycommon.LastMessageTypeThinking || info.LastMessagesType == relaycommon.LastMessageTypeTools {
  296. claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
  297. info.ClaudeConvertInfo.Index++
  298. }
  299. claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
  300. Index: &info.ClaudeConvertInfo.Index,
  301. Type: "content_block_start",
  302. ContentBlock: &dto.ClaudeMediaMessage{
  303. Type: "text",
  304. Text: common.GetPointer[string](""),
  305. },
  306. })
  307. }
  308. info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
  309. // text delta
  310. claudeResponse.Delta = &dto.ClaudeMediaMessage{
  311. Type: "text_delta",
  312. Text: common.GetPointer[string](textContent),
  313. }
  314. }
  315. } else {
  316. isEmpty = true
  317. }
  318. }
  319. claudeResponse.Index = &info.ClaudeConvertInfo.Index
  320. if !isEmpty {
  321. claudeResponses = append(claudeResponses, &claudeResponse)
  322. }
  323. }
  324. }
  325. return claudeResponses
  326. }
  327. func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.ClaudeResponse {
  328. var stopReason string
  329. contents := make([]dto.ClaudeMediaMessage, 0)
  330. claudeResponse := &dto.ClaudeResponse{
  331. Id: openAIResponse.Id,
  332. Type: "message",
  333. Role: "assistant",
  334. Model: openAIResponse.Model,
  335. }
  336. for _, choice := range openAIResponse.Choices {
  337. stopReason = stopReasonOpenAI2Claude(choice.FinishReason)
  338. claudeContent := dto.ClaudeMediaMessage{}
  339. if choice.FinishReason == "tool_calls" {
  340. claudeContent.Type = "tool_use"
  341. claudeContent.Id = choice.Message.ToolCallId
  342. claudeContent.Name = choice.Message.ParseToolCalls()[0].Function.Name
  343. var mapParams map[string]interface{}
  344. if err := json.Unmarshal([]byte(choice.Message.ParseToolCalls()[0].Function.Arguments), &mapParams); err == nil {
  345. claudeContent.Input = mapParams
  346. } else {
  347. claudeContent.Input = choice.Message.ParseToolCalls()[0].Function.Arguments
  348. }
  349. } else {
  350. claudeContent.Type = "text"
  351. claudeContent.SetText(choice.Message.StringContent())
  352. }
  353. contents = append(contents, claudeContent)
  354. }
  355. claudeResponse.Content = contents
  356. claudeResponse.StopReason = stopReason
  357. claudeResponse.Usage = &dto.ClaudeUsage{
  358. InputTokens: openAIResponse.PromptTokens,
  359. OutputTokens: openAIResponse.CompletionTokens,
  360. }
  361. return claudeResponse
  362. }
  363. func stopReasonOpenAI2Claude(reason string) string {
  364. switch reason {
  365. case "stop":
  366. return "end_turn"
  367. case "stop_sequence":
  368. return "stop_sequence"
  369. case "max_tokens":
  370. return "max_tokens"
  371. case "tool_calls":
  372. return "tool_use"
  373. default:
  374. return reason
  375. }
  376. }
  377. func toJSONString(v interface{}) string {
  378. b, err := json.Marshal(v)
  379. if err != nil {
  380. return "{}"
  381. }
  382. return string(b)
  383. }