relay-claude.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. package claude
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. relaycommon "one-api/relay/common"
  11. "one-api/service"
  12. "strings"
  13. "github.com/gin-gonic/gin"
  14. )
  15. func stopReasonClaude2OpenAI(reason string) string {
  16. switch reason {
  17. case "stop_sequence":
  18. return "stop"
  19. case "end_turn":
  20. return "stop"
  21. case "max_tokens":
  22. return "max_tokens"
  23. default:
  24. return reason
  25. }
  26. }
  27. func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
  28. claudeRequest := ClaudeRequest{
  29. Model: textRequest.Model,
  30. Prompt: "",
  31. StopSequences: nil,
  32. Temperature: textRequest.Temperature,
  33. TopP: textRequest.TopP,
  34. TopK: textRequest.TopK,
  35. Stream: textRequest.Stream,
  36. }
  37. if claudeRequest.MaxTokensToSample == 0 {
  38. claudeRequest.MaxTokensToSample = 4096
  39. }
  40. prompt := ""
  41. for _, message := range textRequest.Messages {
  42. if message.Role == "user" {
  43. prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
  44. } else if message.Role == "assistant" {
  45. prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
  46. } else if message.Role == "system" {
  47. if prompt == "" {
  48. prompt = message.StringContent()
  49. }
  50. }
  51. }
  52. prompt += "\n\nAssistant:"
  53. claudeRequest.Prompt = prompt
  54. return &claudeRequest
  55. }
  56. func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
  57. claudeTools := make([]Tool, 0, len(textRequest.Tools))
  58. for _, tool := range textRequest.Tools {
  59. if params, ok := tool.Function.Parameters.(map[string]any); ok {
  60. claudeTool := Tool{
  61. Name: tool.Function.Name,
  62. Description: tool.Function.Description,
  63. }
  64. claudeTool.InputSchema = make(map[string]interface{})
  65. claudeTool.InputSchema["type"] = params["type"].(string)
  66. claudeTool.InputSchema["properties"] = params["properties"]
  67. claudeTool.InputSchema["required"] = params["required"]
  68. for s, a := range params {
  69. if s == "type" || s == "properties" || s == "required" {
  70. continue
  71. }
  72. claudeTool.InputSchema[s] = a
  73. }
  74. claudeTools = append(claudeTools, claudeTool)
  75. }
  76. }
  77. claudeRequest := ClaudeRequest{
  78. Model: textRequest.Model,
  79. MaxTokens: textRequest.MaxTokens,
  80. StopSequences: nil,
  81. Temperature: textRequest.Temperature,
  82. TopP: textRequest.TopP,
  83. TopK: textRequest.TopK,
  84. Stream: textRequest.Stream,
  85. Tools: claudeTools,
  86. }
  87. if claudeRequest.MaxTokens == 0 {
  88. claudeRequest.MaxTokens = 4096
  89. }
  90. if textRequest.Stop != nil {
  91. // stop maybe string/array string, convert to array string
  92. switch textRequest.Stop.(type) {
  93. case string:
  94. claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
  95. case []interface{}:
  96. stopSequences := make([]string, 0)
  97. for _, stop := range textRequest.Stop.([]interface{}) {
  98. stopSequences = append(stopSequences, stop.(string))
  99. }
  100. claudeRequest.StopSequences = stopSequences
  101. }
  102. }
  103. formatMessages := make([]dto.Message, 0)
  104. lastMessage := dto.Message{
  105. Role: "tool",
  106. }
  107. for i, message := range textRequest.Messages {
  108. if message.Role == "" {
  109. textRequest.Messages[i].Role = "user"
  110. }
  111. fmtMessage := dto.Message{
  112. Role: message.Role,
  113. Content: message.Content,
  114. }
  115. if message.Role == "tool" {
  116. fmtMessage.ToolCallId = message.ToolCallId
  117. }
  118. if message.Role == "assistant" && message.ToolCalls != nil {
  119. fmtMessage.ToolCalls = message.ToolCalls
  120. }
  121. if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
  122. if lastMessage.IsStringContent() && message.IsStringContent() {
  123. content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
  124. fmtMessage.Content = content
  125. // delete last message
  126. formatMessages = formatMessages[:len(formatMessages)-1]
  127. }
  128. }
  129. if fmtMessage.Content == nil {
  130. content, _ := json.Marshal("...")
  131. fmtMessage.Content = content
  132. }
  133. formatMessages = append(formatMessages, fmtMessage)
  134. lastMessage = fmtMessage
  135. }
  136. claudeMessages := make([]ClaudeMessage, 0)
  137. isFirstMessage := true
  138. for _, message := range formatMessages {
  139. if message.Role == "system" {
  140. if message.IsStringContent() {
  141. claudeRequest.System = message.StringContent()
  142. } else {
  143. contents := message.ParseContent()
  144. content := ""
  145. for _, ctx := range contents {
  146. if ctx.Type == "text" {
  147. content += ctx.Text
  148. }
  149. }
  150. claudeRequest.System = content
  151. }
  152. } else {
  153. if isFirstMessage {
  154. isFirstMessage = false
  155. if message.Role != "user" {
  156. // fix: first message is assistant, add user message
  157. claudeMessage := ClaudeMessage{
  158. Role: "user",
  159. Content: []ClaudeMediaMessage{
  160. {
  161. Type: "text",
  162. Text: "...",
  163. },
  164. },
  165. }
  166. claudeMessages = append(claudeMessages, claudeMessage)
  167. }
  168. }
  169. claudeMessage := ClaudeMessage{
  170. Role: message.Role,
  171. }
  172. if message.Role == "tool" {
  173. if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
  174. lastMessage := claudeMessages[len(claudeMessages)-1]
  175. if content, ok := lastMessage.Content.(string); ok {
  176. lastMessage.Content = []ClaudeMediaMessage{
  177. {
  178. Type: "text",
  179. Text: content,
  180. },
  181. }
  182. }
  183. lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
  184. Type: "tool_result",
  185. ToolUseId: message.ToolCallId,
  186. Content: message.StringContent(),
  187. })
  188. claudeMessages[len(claudeMessages)-1] = lastMessage
  189. continue
  190. } else {
  191. claudeMessage.Role = "user"
  192. claudeMessage.Content = []ClaudeMediaMessage{
  193. {
  194. Type: "tool_result",
  195. ToolUseId: message.ToolCallId,
  196. Content: message.StringContent(),
  197. },
  198. }
  199. }
  200. } else if message.IsStringContent() && message.ToolCalls == nil {
  201. claudeMessage.Content = message.StringContent()
  202. } else {
  203. claudeMediaMessages := make([]ClaudeMediaMessage, 0)
  204. for _, mediaMessage := range message.ParseContent() {
  205. claudeMediaMessage := ClaudeMediaMessage{
  206. Type: mediaMessage.Type,
  207. }
  208. if mediaMessage.Type == "text" {
  209. claudeMediaMessage.Text = mediaMessage.Text
  210. } else {
  211. imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
  212. claudeMediaMessage.Type = "image"
  213. claudeMediaMessage.Source = &ClaudeMessageSource{
  214. Type: "base64",
  215. }
  216. // 判断是否是url
  217. if strings.HasPrefix(imageUrl.Url, "http") {
  218. // 是url,获取图片的类型和base64编码的数据
  219. mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url)
  220. claudeMediaMessage.Source.MediaType = mimeType
  221. claudeMediaMessage.Source.Data = data
  222. } else {
  223. _, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
  224. if err != nil {
  225. return nil, err
  226. }
  227. claudeMediaMessage.Source.MediaType = "image/" + format
  228. claudeMediaMessage.Source.Data = base64String
  229. }
  230. }
  231. claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
  232. }
  233. if message.ToolCalls != nil {
  234. for _, tc := range message.ToolCalls.([]interface{}) {
  235. toolCallJSON, _ := json.Marshal(tc)
  236. var toolCall dto.ToolCall
  237. err := json.Unmarshal(toolCallJSON, &toolCall)
  238. if err != nil {
  239. common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
  240. continue
  241. }
  242. inputObj := make(map[string]any)
  243. if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
  244. common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
  245. continue
  246. }
  247. claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
  248. Type: "tool_use",
  249. Id: toolCall.ID,
  250. Name: toolCall.Function.Name,
  251. Input: inputObj,
  252. })
  253. }
  254. }
  255. claudeMessage.Content = claudeMediaMessages
  256. }
  257. claudeMessages = append(claudeMessages, claudeMessage)
  258. }
  259. }
  260. claudeRequest.Prompt = ""
  261. claudeRequest.Messages = claudeMessages
  262. return &claudeRequest, nil
  263. }
  264. func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
  265. var response dto.ChatCompletionsStreamResponse
  266. var claudeUsage *ClaudeUsage
  267. response.Object = "chat.completion.chunk"
  268. response.Model = claudeResponse.Model
  269. response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
  270. tools := make([]dto.ToolCall, 0)
  271. var choice dto.ChatCompletionsStreamResponseChoice
  272. if reqMode == RequestModeCompletion {
  273. choice.Delta.SetContentString(claudeResponse.Completion)
  274. finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
  275. if finishReason != "null" {
  276. choice.FinishReason = &finishReason
  277. }
  278. } else {
  279. if claudeResponse.Type == "message_start" {
  280. response.Id = claudeResponse.Message.Id
  281. response.Model = claudeResponse.Message.Model
  282. claudeUsage = &claudeResponse.Message.Usage
  283. choice.Delta.SetContentString("")
  284. choice.Delta.Role = "assistant"
  285. } else if claudeResponse.Type == "content_block_start" {
  286. if claudeResponse.ContentBlock != nil {
  287. //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
  288. if claudeResponse.ContentBlock.Type == "tool_use" {
  289. tools = append(tools, dto.ToolCall{
  290. ID: claudeResponse.ContentBlock.Id,
  291. Type: "function",
  292. Function: dto.FunctionCall{
  293. Name: claudeResponse.ContentBlock.Name,
  294. Arguments: "",
  295. },
  296. })
  297. }
  298. } else {
  299. return nil, nil
  300. }
  301. } else if claudeResponse.Type == "content_block_delta" {
  302. if claudeResponse.Delta != nil {
  303. choice.Index = claudeResponse.Index
  304. choice.Delta.SetContentString(claudeResponse.Delta.Text)
  305. if claudeResponse.Delta.Type == "input_json_delta" {
  306. tools = append(tools, dto.ToolCall{
  307. Function: dto.FunctionCall{
  308. Arguments: claudeResponse.Delta.PartialJson,
  309. },
  310. })
  311. }
  312. }
  313. } else if claudeResponse.Type == "message_delta" {
  314. finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
  315. if finishReason != "null" {
  316. choice.FinishReason = &finishReason
  317. }
  318. claudeUsage = &claudeResponse.Usage
  319. } else if claudeResponse.Type == "message_stop" {
  320. return nil, nil
  321. } else {
  322. return nil, nil
  323. }
  324. }
  325. if claudeUsage == nil {
  326. claudeUsage = &ClaudeUsage{}
  327. }
  328. if len(tools) > 0 {
  329. choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
  330. choice.Delta.ToolCalls = tools
  331. }
  332. response.Choices = append(response.Choices, choice)
  333. return &response, claudeUsage
  334. }
  335. func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
  336. choices := make([]dto.OpenAITextResponseChoice, 0)
  337. fullTextResponse := dto.OpenAITextResponse{
  338. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  339. Object: "chat.completion",
  340. Created: common.GetTimestamp(),
  341. }
  342. var responseText string
  343. if len(claudeResponse.Content) > 0 {
  344. responseText = claudeResponse.Content[0].Text
  345. }
  346. tools := make([]dto.ToolCall, 0)
  347. if reqMode == RequestModeCompletion {
  348. content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
  349. choice := dto.OpenAITextResponseChoice{
  350. Index: 0,
  351. Message: dto.Message{
  352. Role: "assistant",
  353. Content: content,
  354. Name: nil,
  355. },
  356. FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
  357. }
  358. choices = append(choices, choice)
  359. } else {
  360. fullTextResponse.Id = claudeResponse.Id
  361. for _, message := range claudeResponse.Content {
  362. if message.Type == "tool_use" {
  363. args, _ := json.Marshal(message.Input)
  364. tools = append(tools, dto.ToolCall{
  365. ID: message.Id,
  366. Type: "function", // compatible with other OpenAI derivative applications
  367. Function: dto.FunctionCall{
  368. Name: message.Name,
  369. Arguments: string(args),
  370. },
  371. })
  372. }
  373. }
  374. }
  375. choice := dto.OpenAITextResponseChoice{
  376. Index: 0,
  377. Message: dto.Message{
  378. Role: "assistant",
  379. },
  380. FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
  381. }
  382. choice.SetStringContent(responseText)
  383. if len(tools) > 0 {
  384. choice.Message.ToolCalls = tools
  385. }
  386. fullTextResponse.Model = claudeResponse.Model
  387. choices = append(choices, choice)
  388. fullTextResponse.Choices = choices
  389. return &fullTextResponse
  390. }
  391. func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  392. responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
  393. var usage *dto.Usage
  394. usage = &dto.Usage{}
  395. responseText := ""
  396. createdTime := common.GetTimestamp()
  397. scanner := bufio.NewScanner(resp.Body)
  398. scanner.Split(bufio.ScanLines)
  399. service.SetEventStreamHeaders(c)
  400. for scanner.Scan() {
  401. data := scanner.Text()
  402. info.SetFirstResponseTime()
  403. if len(data) < 6 || !strings.HasPrefix(data, "data:") {
  404. continue
  405. }
  406. data = strings.TrimPrefix(data, "data:")
  407. data = strings.TrimSpace(data)
  408. var claudeResponse ClaudeResponse
  409. err := json.Unmarshal([]byte(data), &claudeResponse)
  410. if err != nil {
  411. common.SysError("error unmarshalling stream response: " + err.Error())
  412. continue
  413. }
  414. response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
  415. if response == nil {
  416. continue
  417. }
  418. if requestMode == RequestModeCompletion {
  419. responseText += claudeResponse.Completion
  420. responseId = response.Id
  421. } else {
  422. if claudeResponse.Type == "message_start" {
  423. // message_start, 获取usage
  424. responseId = claudeResponse.Message.Id
  425. info.UpstreamModelName = claudeResponse.Message.Model
  426. usage.PromptTokens = claudeUsage.InputTokens
  427. } else if claudeResponse.Type == "content_block_delta" {
  428. responseText += claudeResponse.Delta.Text
  429. } else if claudeResponse.Type == "message_delta" {
  430. usage.CompletionTokens = claudeUsage.OutputTokens
  431. usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
  432. } else if claudeResponse.Type == "content_block_start" {
  433. } else {
  434. continue
  435. }
  436. }
  437. //response.Id = responseId
  438. response.Id = responseId
  439. response.Created = createdTime
  440. response.Model = info.UpstreamModelName
  441. err = service.ObjectData(c, response)
  442. if err != nil {
  443. common.LogError(c, "send_stream_response_failed: "+err.Error())
  444. }
  445. }
  446. if requestMode == RequestModeCompletion {
  447. usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
  448. } else {
  449. if usage.PromptTokens == 0 {
  450. usage.PromptTokens = info.PromptTokens
  451. }
  452. if usage.CompletionTokens == 0 {
  453. usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
  454. }
  455. }
  456. if info.ShouldIncludeUsage {
  457. response := service.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
  458. err := service.ObjectData(c, response)
  459. if err != nil {
  460. common.SysError("send final response failed: " + err.Error())
  461. }
  462. }
  463. service.Done(c)
  464. resp.Body.Close()
  465. return nil, usage
  466. }
  467. func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  468. responseBody, err := io.ReadAll(resp.Body)
  469. if err != nil {
  470. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  471. }
  472. err = resp.Body.Close()
  473. if err != nil {
  474. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  475. }
  476. var claudeResponse ClaudeResponse
  477. err = json.Unmarshal(responseBody, &claudeResponse)
  478. if err != nil {
  479. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  480. }
  481. if claudeResponse.Error.Type != "" {
  482. return &dto.OpenAIErrorWithStatusCode{
  483. Error: dto.OpenAIError{
  484. Message: claudeResponse.Error.Message,
  485. Type: claudeResponse.Error.Type,
  486. Param: "",
  487. Code: claudeResponse.Error.Type,
  488. },
  489. StatusCode: resp.StatusCode,
  490. }, nil
  491. }
  492. fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
  493. completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
  494. if err != nil {
  495. return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
  496. }
  497. usage := dto.Usage{}
  498. if requestMode == RequestModeCompletion {
  499. usage.PromptTokens = info.PromptTokens
  500. usage.CompletionTokens = completionTokens
  501. usage.TotalTokens = info.PromptTokens + completionTokens
  502. } else {
  503. usage.PromptTokens = claudeResponse.Usage.InputTokens
  504. usage.CompletionTokens = claudeResponse.Usage.OutputTokens
  505. usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
  506. }
  507. fullTextResponse.Usage = usage
  508. jsonResponse, err := json.Marshal(fullTextResponse)
  509. if err != nil {
  510. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  511. }
  512. c.Writer.Header().Set("Content-Type", "application/json")
  513. c.Writer.WriteHeader(resp.StatusCode)
  514. _, err = c.Writer.Write(jsonResponse)
  515. return nil, &usage
  516. }