relay-claude.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687
  1. package claude
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/dto"
  9. "one-api/relay/channel/openrouter"
  10. relaycommon "one-api/relay/common"
  11. "one-api/relay/helper"
  12. "one-api/service"
  13. "one-api/setting/model_setting"
  14. "strings"
  15. "github.com/gin-gonic/gin"
  16. )
  17. func stopReasonClaude2OpenAI(reason string) string {
  18. switch reason {
  19. case "stop_sequence":
  20. return "stop"
  21. case "end_turn":
  22. return "stop"
  23. case "max_tokens":
  24. return "max_tokens"
  25. case "tool_use":
  26. return "tool_calls"
  27. default:
  28. return reason
  29. }
  30. }
  31. func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest {
  32. claudeRequest := dto.ClaudeRequest{
  33. Model: textRequest.Model,
  34. Prompt: "",
  35. StopSequences: nil,
  36. Temperature: textRequest.Temperature,
  37. TopP: textRequest.TopP,
  38. TopK: textRequest.TopK,
  39. Stream: textRequest.Stream,
  40. }
  41. if claudeRequest.MaxTokensToSample == 0 {
  42. claudeRequest.MaxTokensToSample = 4096
  43. }
  44. prompt := ""
  45. for _, message := range textRequest.Messages {
  46. if message.Role == "user" {
  47. prompt += fmt.Sprintf("\n\nHuman: %s", message.StringContent())
  48. } else if message.Role == "assistant" {
  49. prompt += fmt.Sprintf("\n\nAssistant: %s", message.StringContent())
  50. } else if message.Role == "system" {
  51. if prompt == "" {
  52. prompt = message.StringContent()
  53. }
  54. }
  55. }
  56. prompt += "\n\nAssistant:"
  57. claudeRequest.Prompt = prompt
  58. return &claudeRequest
  59. }
  60. func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
  61. claudeTools := make([]dto.Tool, 0, len(textRequest.Tools))
  62. for _, tool := range textRequest.Tools {
  63. if params, ok := tool.Function.Parameters.(map[string]any); ok {
  64. claudeTool := dto.Tool{
  65. Name: tool.Function.Name,
  66. Description: tool.Function.Description,
  67. }
  68. claudeTool.InputSchema = make(map[string]interface{})
  69. if params["type"] != nil {
  70. claudeTool.InputSchema["type"] = params["type"].(string)
  71. }
  72. claudeTool.InputSchema["properties"] = params["properties"]
  73. claudeTool.InputSchema["required"] = params["required"]
  74. for s, a := range params {
  75. if s == "type" || s == "properties" || s == "required" {
  76. continue
  77. }
  78. claudeTool.InputSchema[s] = a
  79. }
  80. claudeTools = append(claudeTools, claudeTool)
  81. }
  82. }
  83. claudeRequest := dto.ClaudeRequest{
  84. Model: textRequest.Model,
  85. MaxTokens: textRequest.MaxTokens,
  86. StopSequences: nil,
  87. Temperature: textRequest.Temperature,
  88. TopP: textRequest.TopP,
  89. TopK: textRequest.TopK,
  90. Stream: textRequest.Stream,
  91. Tools: claudeTools,
  92. }
  93. if claudeRequest.MaxTokens == 0 {
  94. claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
  95. }
  96. if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
  97. strings.HasSuffix(textRequest.Model, "-thinking") {
  98. // 因为BudgetTokens 必须大于1024
  99. if claudeRequest.MaxTokens < 1280 {
  100. claudeRequest.MaxTokens = 1280
  101. }
  102. // BudgetTokens 为 max_tokens 的 80%
  103. claudeRequest.Thinking = &dto.Thinking{
  104. Type: "enabled",
  105. BudgetTokens: common.GetPointer[int](int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
  106. }
  107. // TODO: 临时处理
  108. // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
  109. claudeRequest.TopP = 0
  110. claudeRequest.Temperature = common.GetPointer[float64](1.0)
  111. claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
  112. }
  113. if textRequest.Reasoning != nil {
  114. var reasoning openrouter.RequestReasoning
  115. if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil {
  116. return nil, err
  117. }
  118. budgetTokens := reasoning.MaxTokens
  119. if budgetTokens > 0 {
  120. claudeRequest.Thinking = &dto.Thinking{
  121. Type: "enabled",
  122. BudgetTokens: &budgetTokens,
  123. }
  124. }
  125. }
  126. if textRequest.Stop != nil {
  127. // stop maybe string/array string, convert to array string
  128. switch textRequest.Stop.(type) {
  129. case string:
  130. claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
  131. case []interface{}:
  132. stopSequences := make([]string, 0)
  133. for _, stop := range textRequest.Stop.([]interface{}) {
  134. stopSequences = append(stopSequences, stop.(string))
  135. }
  136. claudeRequest.StopSequences = stopSequences
  137. }
  138. }
  139. formatMessages := make([]dto.Message, 0)
  140. lastMessage := dto.Message{
  141. Role: "tool",
  142. }
  143. for i, message := range textRequest.Messages {
  144. if message.Role == "" {
  145. textRequest.Messages[i].Role = "user"
  146. }
  147. fmtMessage := dto.Message{
  148. Role: message.Role,
  149. Content: message.Content,
  150. }
  151. if message.Role == "tool" {
  152. fmtMessage.ToolCallId = message.ToolCallId
  153. }
  154. if message.Role == "assistant" && message.ToolCalls != nil {
  155. fmtMessage.ToolCalls = message.ToolCalls
  156. }
  157. if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
  158. if lastMessage.IsStringContent() && message.IsStringContent() {
  159. fmtMessage.SetStringContent(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
  160. // delete last message
  161. formatMessages = formatMessages[:len(formatMessages)-1]
  162. }
  163. }
  164. if fmtMessage.Content == nil {
  165. fmtMessage.SetStringContent("...")
  166. }
  167. formatMessages = append(formatMessages, fmtMessage)
  168. lastMessage = fmtMessage
  169. }
  170. claudeMessages := make([]dto.ClaudeMessage, 0)
  171. isFirstMessage := true
  172. for _, message := range formatMessages {
  173. if message.Role == "system" {
  174. if message.IsStringContent() {
  175. claudeRequest.System = message.StringContent()
  176. } else {
  177. contents := message.ParseContent()
  178. content := ""
  179. for _, ctx := range contents {
  180. if ctx.Type == "text" {
  181. content += ctx.Text
  182. }
  183. }
  184. claudeRequest.System = content
  185. }
  186. } else {
  187. if isFirstMessage {
  188. isFirstMessage = false
  189. if message.Role != "user" {
  190. // fix: first message is assistant, add user message
  191. claudeMessage := dto.ClaudeMessage{
  192. Role: "user",
  193. Content: []dto.ClaudeMediaMessage{
  194. {
  195. Type: "text",
  196. Text: common.GetPointer[string]("..."),
  197. },
  198. },
  199. }
  200. claudeMessages = append(claudeMessages, claudeMessage)
  201. }
  202. }
  203. claudeMessage := dto.ClaudeMessage{
  204. Role: message.Role,
  205. }
  206. if message.Role == "tool" {
  207. if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
  208. lastMessage := claudeMessages[len(claudeMessages)-1]
  209. if content, ok := lastMessage.Content.(string); ok {
  210. lastMessage.Content = []dto.ClaudeMediaMessage{
  211. {
  212. Type: "text",
  213. Text: common.GetPointer[string](content),
  214. },
  215. }
  216. }
  217. lastMessage.Content = append(lastMessage.Content.([]dto.ClaudeMediaMessage), dto.ClaudeMediaMessage{
  218. Type: "tool_result",
  219. ToolUseId: message.ToolCallId,
  220. Content: message.Content,
  221. })
  222. claudeMessages[len(claudeMessages)-1] = lastMessage
  223. continue
  224. } else {
  225. claudeMessage.Role = "user"
  226. claudeMessage.Content = []dto.ClaudeMediaMessage{
  227. {
  228. Type: "tool_result",
  229. ToolUseId: message.ToolCallId,
  230. Content: message.Content,
  231. },
  232. }
  233. }
  234. } else if message.IsStringContent() && message.ToolCalls == nil {
  235. claudeMessage.Content = message.StringContent()
  236. } else {
  237. claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
  238. for _, mediaMessage := range message.ParseContent() {
  239. claudeMediaMessage := dto.ClaudeMediaMessage{
  240. Type: mediaMessage.Type,
  241. }
  242. if mediaMessage.Type == "text" {
  243. claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text)
  244. } else {
  245. imageUrl := mediaMessage.GetImageMedia()
  246. claudeMediaMessage.Type = "image"
  247. claudeMediaMessage.Source = &dto.ClaudeMessageSource{
  248. Type: "base64",
  249. }
  250. // 判断是否是url
  251. if strings.HasPrefix(imageUrl.Url, "http") {
  252. // 是url,获取图片的类型和base64编码的数据
  253. fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
  254. if err != nil {
  255. return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
  256. }
  257. claudeMediaMessage.Source.MediaType = fileData.MimeType
  258. claudeMediaMessage.Source.Data = fileData.Base64Data
  259. } else {
  260. _, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
  261. if err != nil {
  262. return nil, err
  263. }
  264. claudeMediaMessage.Source.MediaType = "image/" + format
  265. claudeMediaMessage.Source.Data = base64String
  266. }
  267. }
  268. claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
  269. }
  270. if message.ToolCalls != nil {
  271. for _, toolCall := range message.ParseToolCalls() {
  272. inputObj := make(map[string]any)
  273. if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
  274. common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
  275. continue
  276. }
  277. claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
  278. Type: "tool_use",
  279. Id: toolCall.ID,
  280. Name: toolCall.Function.Name,
  281. Input: inputObj,
  282. })
  283. }
  284. }
  285. claudeMessage.Content = claudeMediaMessages
  286. }
  287. claudeMessages = append(claudeMessages, claudeMessage)
  288. }
  289. }
  290. claudeRequest.Prompt = ""
  291. claudeRequest.Messages = claudeMessages
  292. return &claudeRequest, nil
  293. }
  294. func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
  295. var response dto.ChatCompletionsStreamResponse
  296. response.Object = "chat.completion.chunk"
  297. response.Model = claudeResponse.Model
  298. response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
  299. tools := make([]dto.ToolCallResponse, 0)
  300. fcIdx := 0
  301. if claudeResponse.Index != nil {
  302. fcIdx = *claudeResponse.Index - 1
  303. if fcIdx < 0 {
  304. fcIdx = 0
  305. }
  306. }
  307. var choice dto.ChatCompletionsStreamResponseChoice
  308. if reqMode == RequestModeCompletion {
  309. choice.Delta.SetContentString(claudeResponse.Completion)
  310. finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
  311. if finishReason != "null" {
  312. choice.FinishReason = &finishReason
  313. }
  314. } else {
  315. if claudeResponse.Type == "message_start" {
  316. response.Id = claudeResponse.Message.Id
  317. response.Model = claudeResponse.Message.Model
  318. //claudeUsage = &claudeResponse.Message.Usage
  319. choice.Delta.SetContentString("")
  320. choice.Delta.Role = "assistant"
  321. } else if claudeResponse.Type == "content_block_start" {
  322. if claudeResponse.ContentBlock != nil {
  323. //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
  324. if claudeResponse.ContentBlock.Type == "tool_use" {
  325. tools = append(tools, dto.ToolCallResponse{
  326. Index: common.GetPointer(fcIdx),
  327. ID: claudeResponse.ContentBlock.Id,
  328. Type: "function",
  329. Function: dto.FunctionResponse{
  330. Name: claudeResponse.ContentBlock.Name,
  331. Arguments: "",
  332. },
  333. })
  334. }
  335. } else {
  336. return nil
  337. }
  338. } else if claudeResponse.Type == "content_block_delta" {
  339. if claudeResponse.Delta != nil {
  340. choice.Delta.Content = claudeResponse.Delta.Text
  341. switch claudeResponse.Delta.Type {
  342. case "input_json_delta":
  343. tools = append(tools, dto.ToolCallResponse{
  344. Type: "function",
  345. Index: common.GetPointer(fcIdx),
  346. Function: dto.FunctionResponse{
  347. Arguments: *claudeResponse.Delta.PartialJson,
  348. },
  349. })
  350. case "signature_delta":
  351. // 加密的不处理
  352. signatureContent := "\n"
  353. choice.Delta.ReasoningContent = &signatureContent
  354. case "thinking_delta":
  355. thinkingContent := claudeResponse.Delta.Thinking
  356. choice.Delta.ReasoningContent = &thinkingContent
  357. }
  358. }
  359. } else if claudeResponse.Type == "message_delta" {
  360. finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
  361. if finishReason != "null" {
  362. choice.FinishReason = &finishReason
  363. }
  364. //claudeUsage = &claudeResponse.Usage
  365. } else if claudeResponse.Type == "message_stop" {
  366. return nil
  367. } else {
  368. return nil
  369. }
  370. }
  371. if len(tools) > 0 {
  372. choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
  373. choice.Delta.ToolCalls = tools
  374. }
  375. response.Choices = append(response.Choices, choice)
  376. return &response
  377. }
  378. func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
  379. choices := make([]dto.OpenAITextResponseChoice, 0)
  380. fullTextResponse := dto.OpenAITextResponse{
  381. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  382. Object: "chat.completion",
  383. Created: common.GetTimestamp(),
  384. }
  385. var responseText string
  386. var responseThinking string
  387. if len(claudeResponse.Content) > 0 {
  388. responseText = claudeResponse.Content[0].GetText()
  389. responseThinking = claudeResponse.Content[0].Thinking
  390. }
  391. tools := make([]dto.ToolCallResponse, 0)
  392. thinkingContent := ""
  393. if reqMode == RequestModeCompletion {
  394. choice := dto.OpenAITextResponseChoice{
  395. Index: 0,
  396. Message: dto.Message{
  397. Role: "assistant",
  398. Content: strings.TrimPrefix(claudeResponse.Completion, " "),
  399. Name: nil,
  400. },
  401. FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
  402. }
  403. choices = append(choices, choice)
  404. } else {
  405. fullTextResponse.Id = claudeResponse.Id
  406. for _, message := range claudeResponse.Content {
  407. switch message.Type {
  408. case "tool_use":
  409. args, _ := json.Marshal(message.Input)
  410. tools = append(tools, dto.ToolCallResponse{
  411. ID: message.Id,
  412. Type: "function", // compatible with other OpenAI derivative applications
  413. Function: dto.FunctionResponse{
  414. Name: message.Name,
  415. Arguments: string(args),
  416. },
  417. })
  418. case "thinking":
  419. // 加密的不管, 只输出明文的推理过程
  420. thinkingContent = message.Thinking
  421. case "text":
  422. responseText = message.GetText()
  423. }
  424. }
  425. }
  426. choice := dto.OpenAITextResponseChoice{
  427. Index: 0,
  428. Message: dto.Message{
  429. Role: "assistant",
  430. },
  431. FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
  432. }
  433. choice.SetStringContent(responseText)
  434. if len(responseThinking) > 0 {
  435. choice.ReasoningContent = responseThinking
  436. }
  437. if len(tools) > 0 {
  438. choice.Message.SetToolCalls(tools)
  439. }
  440. choice.Message.ReasoningContent = thinkingContent
  441. fullTextResponse.Model = claudeResponse.Model
  442. choices = append(choices, choice)
  443. fullTextResponse.Choices = choices
  444. return &fullTextResponse
  445. }
  446. type ClaudeResponseInfo struct {
  447. ResponseId string
  448. Created int64
  449. Model string
  450. ResponseText strings.Builder
  451. Usage *dto.Usage
  452. Done bool
  453. }
  454. func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
  455. if requestMode == RequestModeCompletion {
  456. claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
  457. } else {
  458. if claudeResponse.Type == "message_start" {
  459. claudeInfo.ResponseId = claudeResponse.Message.Id
  460. claudeInfo.Model = claudeResponse.Message.Model
  461. // message_start, 获取usage
  462. claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
  463. claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
  464. claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
  465. claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
  466. } else if claudeResponse.Type == "content_block_delta" {
  467. if claudeResponse.Delta.Text != nil {
  468. claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
  469. }
  470. if claudeResponse.Delta.Thinking != "" {
  471. claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
  472. }
  473. } else if claudeResponse.Type == "message_delta" {
  474. // 最终的usage获取
  475. if claudeResponse.Usage.InputTokens > 0 {
  476. // 不叠加,只取最新的
  477. claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
  478. }
  479. claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
  480. claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
  481. // 判断是否完整
  482. claudeInfo.Done = true
  483. } else if claudeResponse.Type == "content_block_start" {
  484. } else {
  485. return false
  486. }
  487. }
  488. if oaiResponse != nil {
  489. oaiResponse.Id = claudeInfo.ResponseId
  490. oaiResponse.Created = claudeInfo.Created
  491. oaiResponse.Model = claudeInfo.Model
  492. }
  493. return true
  494. }
  495. func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
  496. var claudeResponse dto.ClaudeResponse
  497. err := common.UnmarshalJsonStr(data, &claudeResponse)
  498. if err != nil {
  499. common.SysError("error unmarshalling stream response: " + err.Error())
  500. return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
  501. }
  502. if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
  503. return &dto.OpenAIErrorWithStatusCode{
  504. Error: dto.OpenAIError{
  505. Code: "stream_response_error",
  506. Type: claudeResponse.Error.Type,
  507. Message: claudeResponse.Error.Message,
  508. },
  509. StatusCode: http.StatusInternalServerError,
  510. }
  511. }
  512. if info.RelayFormat == relaycommon.RelayFormatClaude {
  513. FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
  514. if requestMode == RequestModeCompletion {
  515. } else {
  516. if claudeResponse.Type == "message_start" {
  517. // message_start, 获取usage
  518. info.UpstreamModelName = claudeResponse.Message.Model
  519. } else if claudeResponse.Type == "content_block_delta" {
  520. } else if claudeResponse.Type == "message_delta" {
  521. }
  522. }
  523. helper.ClaudeChunkData(c, claudeResponse, data)
  524. } else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
  525. response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
  526. if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
  527. return nil
  528. }
  529. err = helper.ObjectData(c, response)
  530. if err != nil {
  531. common.LogError(c, "send_stream_response_failed: "+err.Error())
  532. }
  533. }
  534. return nil
  535. }
  536. func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
  537. if requestMode == RequestModeCompletion {
  538. claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
  539. } else {
  540. if claudeInfo.Usage.PromptTokens == 0 {
  541. //上游出错
  542. }
  543. if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
  544. if common.DebugEnabled {
  545. common.SysError("claude response usage is not complete, maybe upstream error")
  546. }
  547. claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
  548. }
  549. }
  550. if info.RelayFormat == relaycommon.RelayFormatClaude {
  551. //
  552. } else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
  553. if info.ShouldIncludeUsage {
  554. response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
  555. err := helper.ObjectData(c, response)
  556. if err != nil {
  557. common.SysError("send final response failed: " + err.Error())
  558. }
  559. }
  560. helper.Done(c)
  561. }
  562. }
  563. func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  564. claudeInfo := &ClaudeResponseInfo{
  565. ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  566. Created: common.GetTimestamp(),
  567. Model: info.UpstreamModelName,
  568. ResponseText: strings.Builder{},
  569. Usage: &dto.Usage{},
  570. }
  571. var err *dto.OpenAIErrorWithStatusCode
  572. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  573. err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
  574. if err != nil {
  575. return false
  576. }
  577. return true
  578. })
  579. if err != nil {
  580. return err, nil
  581. }
  582. HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
  583. return nil, claudeInfo.Usage
  584. }
  585. func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
  586. var claudeResponse dto.ClaudeResponse
  587. err := common.UnmarshalJson(data, &claudeResponse)
  588. if err != nil {
  589. return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
  590. }
  591. if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
  592. return &dto.OpenAIErrorWithStatusCode{
  593. Error: dto.OpenAIError{
  594. Message: claudeResponse.Error.Message,
  595. Type: claudeResponse.Error.Type,
  596. Code: claudeResponse.Error.Type,
  597. },
  598. StatusCode: http.StatusInternalServerError,
  599. }
  600. }
  601. if requestMode == RequestModeCompletion {
  602. completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
  603. claudeInfo.Usage.PromptTokens = info.PromptTokens
  604. claudeInfo.Usage.CompletionTokens = completionTokens
  605. claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
  606. } else {
  607. claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
  608. claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
  609. claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
  610. claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
  611. claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
  612. }
  613. var responseData []byte
  614. switch info.RelayFormat {
  615. case relaycommon.RelayFormatOpenAI:
  616. openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
  617. openaiResponse.Usage = *claudeInfo.Usage
  618. responseData, err = json.Marshal(openaiResponse)
  619. if err != nil {
  620. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
  621. }
  622. case relaycommon.RelayFormatClaude:
  623. responseData = data
  624. }
  625. common.IOCopyBytesGracefully(c, nil, responseData)
  626. return nil
  627. }
  628. func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  629. defer common.CloseResponseBodyGracefully(resp)
  630. claudeInfo := &ClaudeResponseInfo{
  631. ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  632. Created: common.GetTimestamp(),
  633. Model: info.UpstreamModelName,
  634. ResponseText: strings.Builder{},
  635. Usage: &dto.Usage{},
  636. }
  637. responseBody, err := io.ReadAll(resp.Body)
  638. if err != nil {
  639. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  640. }
  641. if common.DebugEnabled {
  642. println("responseBody: ", string(responseBody))
  643. }
  644. handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
  645. if handleErr != nil {
  646. return handleErr, nil
  647. }
  648. return nil, claudeInfo.Usage
  649. }