relay-claude.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671
  1. package claude
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/dto"
  9. relaycommon "one-api/relay/common"
  10. "one-api/relay/helper"
  11. "one-api/service"
  12. "one-api/setting/model_setting"
  13. "strings"
  14. "github.com/gin-gonic/gin"
  15. )
  16. func stopReasonClaude2OpenAI(reason string) string {
  17. switch reason {
  18. case "stop_sequence":
  19. return "stop"
  20. case "end_turn":
  21. return "stop"
  22. case "max_tokens":
  23. return "max_tokens"
  24. case "tool_use":
  25. return "tool_calls"
  26. default:
  27. return reason
  28. }
  29. }
  30. func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest {
  31. claudeRequest := dto.ClaudeRequest{
  32. Model: textRequest.Model,
  33. Prompt: "",
  34. StopSequences: nil,
  35. Temperature: textRequest.Temperature,
  36. TopP: textRequest.TopP,
  37. TopK: textRequest.TopK,
  38. Stream: textRequest.Stream,
  39. }
  40. if claudeRequest.MaxTokensToSample == 0 {
  41. claudeRequest.MaxTokensToSample = 4096
  42. }
  43. prompt := ""
  44. for _, message := range textRequest.Messages {
  45. if message.Role == "user" {
  46. prompt += fmt.Sprintf("\n\nHuman: %s", message.StringContent())
  47. } else if message.Role == "assistant" {
  48. prompt += fmt.Sprintf("\n\nAssistant: %s", message.StringContent())
  49. } else if message.Role == "system" {
  50. if prompt == "" {
  51. prompt = message.StringContent()
  52. }
  53. }
  54. }
  55. prompt += "\n\nAssistant:"
  56. claudeRequest.Prompt = prompt
  57. return &claudeRequest
  58. }
  59. func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
  60. claudeTools := make([]dto.Tool, 0, len(textRequest.Tools))
  61. for _, tool := range textRequest.Tools {
  62. if params, ok := tool.Function.Parameters.(map[string]any); ok {
  63. claudeTool := dto.Tool{
  64. Name: tool.Function.Name,
  65. Description: tool.Function.Description,
  66. }
  67. claudeTool.InputSchema = make(map[string]interface{})
  68. if params["type"] != nil {
  69. claudeTool.InputSchema["type"] = params["type"].(string)
  70. }
  71. claudeTool.InputSchema["properties"] = params["properties"]
  72. claudeTool.InputSchema["required"] = params["required"]
  73. for s, a := range params {
  74. if s == "type" || s == "properties" || s == "required" {
  75. continue
  76. }
  77. claudeTool.InputSchema[s] = a
  78. }
  79. claudeTools = append(claudeTools, claudeTool)
  80. }
  81. }
  82. claudeRequest := dto.ClaudeRequest{
  83. Model: textRequest.Model,
  84. MaxTokens: textRequest.MaxTokens,
  85. StopSequences: nil,
  86. Temperature: textRequest.Temperature,
  87. TopP: textRequest.TopP,
  88. TopK: textRequest.TopK,
  89. Stream: textRequest.Stream,
  90. Tools: claudeTools,
  91. }
  92. if claudeRequest.MaxTokens == 0 {
  93. claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
  94. }
  95. if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
  96. strings.HasSuffix(textRequest.Model, "-thinking") {
  97. // 因为BudgetTokens 必须大于1024
  98. if claudeRequest.MaxTokens < 1280 {
  99. claudeRequest.MaxTokens = 1280
  100. }
  101. // BudgetTokens 为 max_tokens 的 80%
  102. claudeRequest.Thinking = &dto.Thinking{
  103. Type: "enabled",
  104. BudgetTokens: common.GetPointer[int](int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
  105. }
  106. // TODO: 临时处理
  107. // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
  108. claudeRequest.TopP = 0
  109. claudeRequest.Temperature = common.GetPointer[float64](1.0)
  110. claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
  111. }
  112. if textRequest.Stop != nil {
  113. // stop maybe string/array string, convert to array string
  114. switch textRequest.Stop.(type) {
  115. case string:
  116. claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
  117. case []interface{}:
  118. stopSequences := make([]string, 0)
  119. for _, stop := range textRequest.Stop.([]interface{}) {
  120. stopSequences = append(stopSequences, stop.(string))
  121. }
  122. claudeRequest.StopSequences = stopSequences
  123. }
  124. }
  125. formatMessages := make([]dto.Message, 0)
  126. lastMessage := dto.Message{
  127. Role: "tool",
  128. }
  129. for i, message := range textRequest.Messages {
  130. if message.Role == "" {
  131. textRequest.Messages[i].Role = "user"
  132. }
  133. fmtMessage := dto.Message{
  134. Role: message.Role,
  135. Content: message.Content,
  136. }
  137. if message.Role == "tool" {
  138. fmtMessage.ToolCallId = message.ToolCallId
  139. }
  140. if message.Role == "assistant" && message.ToolCalls != nil {
  141. fmtMessage.ToolCalls = message.ToolCalls
  142. }
  143. if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
  144. if lastMessage.IsStringContent() && message.IsStringContent() {
  145. fmtMessage.SetStringContent(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
  146. // delete last message
  147. formatMessages = formatMessages[:len(formatMessages)-1]
  148. }
  149. }
  150. if fmtMessage.Content == nil {
  151. fmtMessage.SetStringContent("...")
  152. }
  153. formatMessages = append(formatMessages, fmtMessage)
  154. lastMessage = fmtMessage
  155. }
  156. claudeMessages := make([]dto.ClaudeMessage, 0)
  157. isFirstMessage := true
  158. for _, message := range formatMessages {
  159. if message.Role == "system" {
  160. if message.IsStringContent() {
  161. claudeRequest.System = message.StringContent()
  162. } else {
  163. contents := message.ParseContent()
  164. content := ""
  165. for _, ctx := range contents {
  166. if ctx.Type == "text" {
  167. content += ctx.Text
  168. }
  169. }
  170. claudeRequest.System = content
  171. }
  172. } else {
  173. if isFirstMessage {
  174. isFirstMessage = false
  175. if message.Role != "user" {
  176. // fix: first message is assistant, add user message
  177. claudeMessage := dto.ClaudeMessage{
  178. Role: "user",
  179. Content: []dto.ClaudeMediaMessage{
  180. {
  181. Type: "text",
  182. Text: common.GetPointer[string]("..."),
  183. },
  184. },
  185. }
  186. claudeMessages = append(claudeMessages, claudeMessage)
  187. }
  188. }
  189. claudeMessage := dto.ClaudeMessage{
  190. Role: message.Role,
  191. }
  192. if message.Role == "tool" {
  193. if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
  194. lastMessage := claudeMessages[len(claudeMessages)-1]
  195. if content, ok := lastMessage.Content.(string); ok {
  196. lastMessage.Content = []dto.ClaudeMediaMessage{
  197. {
  198. Type: "text",
  199. Text: common.GetPointer[string](content),
  200. },
  201. }
  202. }
  203. lastMessage.Content = append(lastMessage.Content.([]dto.ClaudeMediaMessage), dto.ClaudeMediaMessage{
  204. Type: "tool_result",
  205. ToolUseId: message.ToolCallId,
  206. Content: message.Content,
  207. })
  208. claudeMessages[len(claudeMessages)-1] = lastMessage
  209. continue
  210. } else {
  211. claudeMessage.Role = "user"
  212. claudeMessage.Content = []dto.ClaudeMediaMessage{
  213. {
  214. Type: "tool_result",
  215. ToolUseId: message.ToolCallId,
  216. Content: message.Content,
  217. },
  218. }
  219. }
  220. } else if message.IsStringContent() && message.ToolCalls == nil {
  221. claudeMessage.Content = message.StringContent()
  222. } else {
  223. claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
  224. for _, mediaMessage := range message.ParseContent() {
  225. claudeMediaMessage := dto.ClaudeMediaMessage{
  226. Type: mediaMessage.Type,
  227. }
  228. if mediaMessage.Type == "text" {
  229. claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text)
  230. } else {
  231. imageUrl := mediaMessage.GetImageMedia()
  232. claudeMediaMessage.Type = "image"
  233. claudeMediaMessage.Source = &dto.ClaudeMessageSource{
  234. Type: "base64",
  235. }
  236. // 判断是否是url
  237. if strings.HasPrefix(imageUrl.Url, "http") {
  238. // 是url,获取图片的类型和base64编码的数据
  239. fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
  240. if err != nil {
  241. return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
  242. }
  243. claudeMediaMessage.Source.MediaType = fileData.MimeType
  244. claudeMediaMessage.Source.Data = fileData.Base64Data
  245. } else {
  246. _, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
  247. if err != nil {
  248. return nil, err
  249. }
  250. claudeMediaMessage.Source.MediaType = "image/" + format
  251. claudeMediaMessage.Source.Data = base64String
  252. }
  253. }
  254. claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
  255. }
  256. if message.ToolCalls != nil {
  257. for _, toolCall := range message.ParseToolCalls() {
  258. inputObj := make(map[string]any)
  259. if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
  260. common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
  261. continue
  262. }
  263. claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
  264. Type: "tool_use",
  265. Id: toolCall.ID,
  266. Name: toolCall.Function.Name,
  267. Input: inputObj,
  268. })
  269. }
  270. }
  271. claudeMessage.Content = claudeMediaMessages
  272. }
  273. claudeMessages = append(claudeMessages, claudeMessage)
  274. }
  275. }
  276. claudeRequest.Prompt = ""
  277. claudeRequest.Messages = claudeMessages
  278. return &claudeRequest, nil
  279. }
  280. func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
  281. var response dto.ChatCompletionsStreamResponse
  282. response.Object = "chat.completion.chunk"
  283. response.Model = claudeResponse.Model
  284. response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
  285. tools := make([]dto.ToolCallResponse, 0)
  286. fcIdx := 0
  287. if claudeResponse.Index != nil {
  288. fcIdx = *claudeResponse.Index - 1
  289. if fcIdx < 0 {
  290. fcIdx = 0
  291. }
  292. }
  293. var choice dto.ChatCompletionsStreamResponseChoice
  294. if reqMode == RequestModeCompletion {
  295. choice.Delta.SetContentString(claudeResponse.Completion)
  296. finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
  297. if finishReason != "null" {
  298. choice.FinishReason = &finishReason
  299. }
  300. } else {
  301. if claudeResponse.Type == "message_start" {
  302. response.Id = claudeResponse.Message.Id
  303. response.Model = claudeResponse.Message.Model
  304. //claudeUsage = &claudeResponse.Message.Usage
  305. choice.Delta.SetContentString("")
  306. choice.Delta.Role = "assistant"
  307. } else if claudeResponse.Type == "content_block_start" {
  308. if claudeResponse.ContentBlock != nil {
  309. //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
  310. if claudeResponse.ContentBlock.Type == "tool_use" {
  311. tools = append(tools, dto.ToolCallResponse{
  312. Index: common.GetPointer(fcIdx),
  313. ID: claudeResponse.ContentBlock.Id,
  314. Type: "function",
  315. Function: dto.FunctionResponse{
  316. Name: claudeResponse.ContentBlock.Name,
  317. Arguments: "",
  318. },
  319. })
  320. }
  321. } else {
  322. return nil
  323. }
  324. } else if claudeResponse.Type == "content_block_delta" {
  325. if claudeResponse.Delta != nil {
  326. choice.Delta.Content = claudeResponse.Delta.Text
  327. switch claudeResponse.Delta.Type {
  328. case "input_json_delta":
  329. tools = append(tools, dto.ToolCallResponse{
  330. Type: "function",
  331. Index: common.GetPointer(fcIdx),
  332. Function: dto.FunctionResponse{
  333. Arguments: *claudeResponse.Delta.PartialJson,
  334. },
  335. })
  336. case "signature_delta":
  337. // 加密的不处理
  338. signatureContent := "\n"
  339. choice.Delta.ReasoningContent = &signatureContent
  340. case "thinking_delta":
  341. thinkingContent := claudeResponse.Delta.Thinking
  342. choice.Delta.ReasoningContent = &thinkingContent
  343. }
  344. }
  345. } else if claudeResponse.Type == "message_delta" {
  346. finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
  347. if finishReason != "null" {
  348. choice.FinishReason = &finishReason
  349. }
  350. //claudeUsage = &claudeResponse.Usage
  351. } else if claudeResponse.Type == "message_stop" {
  352. return nil
  353. } else {
  354. return nil
  355. }
  356. }
  357. if len(tools) > 0 {
  358. choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
  359. choice.Delta.ToolCalls = tools
  360. }
  361. response.Choices = append(response.Choices, choice)
  362. return &response
  363. }
  364. func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
  365. choices := make([]dto.OpenAITextResponseChoice, 0)
  366. fullTextResponse := dto.OpenAITextResponse{
  367. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  368. Object: "chat.completion",
  369. Created: common.GetTimestamp(),
  370. }
  371. var responseText string
  372. var responseThinking string
  373. if len(claudeResponse.Content) > 0 {
  374. responseText = claudeResponse.Content[0].GetText()
  375. responseThinking = claudeResponse.Content[0].Thinking
  376. }
  377. tools := make([]dto.ToolCallResponse, 0)
  378. thinkingContent := ""
  379. if reqMode == RequestModeCompletion {
  380. choice := dto.OpenAITextResponseChoice{
  381. Index: 0,
  382. Message: dto.Message{
  383. Role: "assistant",
  384. Content: strings.TrimPrefix(claudeResponse.Completion, " "),
  385. Name: nil,
  386. },
  387. FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
  388. }
  389. choices = append(choices, choice)
  390. } else {
  391. fullTextResponse.Id = claudeResponse.Id
  392. for _, message := range claudeResponse.Content {
  393. switch message.Type {
  394. case "tool_use":
  395. args, _ := json.Marshal(message.Input)
  396. tools = append(tools, dto.ToolCallResponse{
  397. ID: message.Id,
  398. Type: "function", // compatible with other OpenAI derivative applications
  399. Function: dto.FunctionResponse{
  400. Name: message.Name,
  401. Arguments: string(args),
  402. },
  403. })
  404. case "thinking":
  405. // 加密的不管, 只输出明文的推理过程
  406. thinkingContent = message.Thinking
  407. case "text":
  408. responseText = message.GetText()
  409. }
  410. }
  411. }
  412. choice := dto.OpenAITextResponseChoice{
  413. Index: 0,
  414. Message: dto.Message{
  415. Role: "assistant",
  416. },
  417. FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
  418. }
  419. choice.SetStringContent(responseText)
  420. if len(responseThinking) > 0 {
  421. choice.ReasoningContent = responseThinking
  422. }
  423. if len(tools) > 0 {
  424. choice.Message.SetToolCalls(tools)
  425. }
  426. choice.Message.ReasoningContent = thinkingContent
  427. fullTextResponse.Model = claudeResponse.Model
  428. choices = append(choices, choice)
  429. fullTextResponse.Choices = choices
  430. return &fullTextResponse
  431. }
  432. type ClaudeResponseInfo struct {
  433. ResponseId string
  434. Created int64
  435. Model string
  436. ResponseText strings.Builder
  437. Usage *dto.Usage
  438. Done bool
  439. }
  440. func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
  441. if requestMode == RequestModeCompletion {
  442. claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
  443. } else {
  444. if claudeResponse.Type == "message_start" {
  445. claudeInfo.ResponseId = claudeResponse.Message.Id
  446. claudeInfo.Model = claudeResponse.Message.Model
  447. // message_start, 获取usage
  448. claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
  449. claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
  450. claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
  451. claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
  452. } else if claudeResponse.Type == "content_block_delta" {
  453. if claudeResponse.Delta.Text != nil {
  454. claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
  455. }
  456. if claudeResponse.Delta.Thinking != "" {
  457. claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
  458. }
  459. } else if claudeResponse.Type == "message_delta" {
  460. // 最终的usage获取
  461. if claudeResponse.Usage.InputTokens > 0 {
  462. // 不叠加,只取最新的
  463. claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
  464. }
  465. claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
  466. claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
  467. // 判断是否完整
  468. claudeInfo.Done = true
  469. } else if claudeResponse.Type == "content_block_start" {
  470. } else {
  471. return false
  472. }
  473. }
  474. if oaiResponse != nil {
  475. oaiResponse.Id = claudeInfo.ResponseId
  476. oaiResponse.Created = claudeInfo.Created
  477. oaiResponse.Model = claudeInfo.Model
  478. }
  479. return true
  480. }
  481. func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
  482. var claudeResponse dto.ClaudeResponse
  483. err := common.DecodeJsonStr(data, &claudeResponse)
  484. if err != nil {
  485. common.SysError("error unmarshalling stream response: " + err.Error())
  486. return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
  487. }
  488. if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
  489. return &dto.OpenAIErrorWithStatusCode{
  490. Error: dto.OpenAIError{
  491. Code: "stream_response_error",
  492. Type: claudeResponse.Error.Type,
  493. Message: claudeResponse.Error.Message,
  494. },
  495. StatusCode: http.StatusInternalServerError,
  496. }
  497. }
  498. if info.RelayFormat == relaycommon.RelayFormatClaude {
  499. FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
  500. if requestMode == RequestModeCompletion {
  501. } else {
  502. if claudeResponse.Type == "message_start" {
  503. // message_start, 获取usage
  504. info.UpstreamModelName = claudeResponse.Message.Model
  505. } else if claudeResponse.Type == "content_block_delta" {
  506. } else if claudeResponse.Type == "message_delta" {
  507. }
  508. }
  509. helper.ClaudeChunkData(c, claudeResponse, data)
  510. } else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
  511. response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
  512. if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
  513. return nil
  514. }
  515. err = helper.ObjectData(c, response)
  516. if err != nil {
  517. common.LogError(c, "send_stream_response_failed: "+err.Error())
  518. }
  519. }
  520. return nil
  521. }
  522. func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
  523. if requestMode == RequestModeCompletion {
  524. claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
  525. } else {
  526. if claudeInfo.Usage.PromptTokens == 0 {
  527. //上游出错
  528. }
  529. if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
  530. if common.DebugEnabled {
  531. common.SysError("claude response usage is not complete, maybe upstream error")
  532. }
  533. claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
  534. }
  535. }
  536. if info.RelayFormat == relaycommon.RelayFormatClaude {
  537. //
  538. } else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
  539. if info.ShouldIncludeUsage {
  540. response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
  541. err := helper.ObjectData(c, response)
  542. if err != nil {
  543. common.SysError("send final response failed: " + err.Error())
  544. }
  545. }
  546. helper.Done(c)
  547. }
  548. }
  549. func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  550. claudeInfo := &ClaudeResponseInfo{
  551. ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  552. Created: common.GetTimestamp(),
  553. Model: info.UpstreamModelName,
  554. ResponseText: strings.Builder{},
  555. Usage: &dto.Usage{},
  556. }
  557. var err *dto.OpenAIErrorWithStatusCode
  558. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  559. err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
  560. if err != nil {
  561. return false
  562. }
  563. return true
  564. })
  565. if err != nil {
  566. return err, nil
  567. }
  568. HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
  569. return nil, claudeInfo.Usage
  570. }
  571. func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
  572. var claudeResponse dto.ClaudeResponse
  573. err := common.DecodeJson(data, &claudeResponse)
  574. if err != nil {
  575. return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
  576. }
  577. if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
  578. return &dto.OpenAIErrorWithStatusCode{
  579. Error: dto.OpenAIError{
  580. Message: claudeResponse.Error.Message,
  581. Type: claudeResponse.Error.Type,
  582. Code: claudeResponse.Error.Type,
  583. },
  584. StatusCode: http.StatusInternalServerError,
  585. }
  586. }
  587. if requestMode == RequestModeCompletion {
  588. completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
  589. claudeInfo.Usage.PromptTokens = info.PromptTokens
  590. claudeInfo.Usage.CompletionTokens = completionTokens
  591. claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
  592. } else {
  593. claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
  594. claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
  595. claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
  596. claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
  597. claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
  598. }
  599. var responseData []byte
  600. switch info.RelayFormat {
  601. case relaycommon.RelayFormatOpenAI:
  602. openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
  603. openaiResponse.Usage = *claudeInfo.Usage
  604. responseData, err = json.Marshal(openaiResponse)
  605. if err != nil {
  606. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
  607. }
  608. case relaycommon.RelayFormatClaude:
  609. responseData = data
  610. }
  611. c.Writer.Header().Set("Content-Type", "application/json")
  612. c.Writer.WriteHeader(http.StatusOK)
  613. _, err = c.Writer.Write(responseData)
  614. return nil
  615. }
  616. func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  617. claudeInfo := &ClaudeResponseInfo{
  618. ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  619. Created: common.GetTimestamp(),
  620. Model: info.UpstreamModelName,
  621. ResponseText: strings.Builder{},
  622. Usage: &dto.Usage{},
  623. }
  624. responseBody, err := io.ReadAll(resp.Body)
  625. if err != nil {
  626. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  627. }
  628. resp.Body.Close()
  629. if common.DebugEnabled {
  630. println("responseBody: ", string(responseBody))
  631. }
  632. handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
  633. if handleErr != nil {
  634. return handleErr, nil
  635. }
  636. return nil, claudeInfo.Usage
  637. }