relay-claude.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672
  1. package claude
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/dto"
  9. relaycommon "one-api/relay/common"
  10. "one-api/relay/helper"
  11. "one-api/service"
  12. "one-api/setting/model_setting"
  13. "strings"
  14. "github.com/gin-gonic/gin"
  15. )
  16. func stopReasonClaude2OpenAI(reason string) string {
  17. switch reason {
  18. case "stop_sequence":
  19. return "stop"
  20. case "end_turn":
  21. return "stop"
  22. case "max_tokens":
  23. return "max_tokens"
  24. case "tool_use":
  25. return "tool_calls"
  26. default:
  27. return reason
  28. }
  29. }
  30. func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest {
  31. claudeRequest := dto.ClaudeRequest{
  32. Model: textRequest.Model,
  33. Prompt: "",
  34. StopSequences: nil,
  35. Temperature: textRequest.Temperature,
  36. TopP: textRequest.TopP,
  37. TopK: textRequest.TopK,
  38. Stream: textRequest.Stream,
  39. }
  40. if claudeRequest.MaxTokensToSample == 0 {
  41. claudeRequest.MaxTokensToSample = 4096
  42. }
  43. prompt := ""
  44. for _, message := range textRequest.Messages {
  45. if message.Role == "user" {
  46. prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
  47. } else if message.Role == "assistant" {
  48. prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
  49. } else if message.Role == "system" {
  50. if prompt == "" {
  51. prompt = message.StringContent()
  52. }
  53. }
  54. }
  55. prompt += "\n\nAssistant:"
  56. claudeRequest.Prompt = prompt
  57. return &claudeRequest
  58. }
  59. func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
  60. claudeTools := make([]dto.Tool, 0, len(textRequest.Tools))
  61. for _, tool := range textRequest.Tools {
  62. if params, ok := tool.Function.Parameters.(map[string]any); ok {
  63. claudeTool := dto.Tool{
  64. Name: tool.Function.Name,
  65. Description: tool.Function.Description,
  66. }
  67. claudeTool.InputSchema = make(map[string]interface{})
  68. if params["type"] != nil {
  69. claudeTool.InputSchema["type"] = params["type"].(string)
  70. }
  71. claudeTool.InputSchema["properties"] = params["properties"]
  72. claudeTool.InputSchema["required"] = params["required"]
  73. for s, a := range params {
  74. if s == "type" || s == "properties" || s == "required" {
  75. continue
  76. }
  77. claudeTool.InputSchema[s] = a
  78. }
  79. claudeTools = append(claudeTools, claudeTool)
  80. }
  81. }
  82. claudeRequest := dto.ClaudeRequest{
  83. Model: textRequest.Model,
  84. MaxTokens: textRequest.MaxTokens,
  85. StopSequences: nil,
  86. Temperature: textRequest.Temperature,
  87. TopP: textRequest.TopP,
  88. TopK: textRequest.TopK,
  89. Stream: textRequest.Stream,
  90. Tools: claudeTools,
  91. }
  92. if claudeRequest.MaxTokens == 0 {
  93. claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
  94. }
  95. if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
  96. strings.HasSuffix(textRequest.Model, "-thinking") {
  97. // 因为BudgetTokens 必须大于1024
  98. if claudeRequest.MaxTokens < 1280 {
  99. claudeRequest.MaxTokens = 1280
  100. }
  101. // BudgetTokens 为 max_tokens 的 80%
  102. claudeRequest.Thinking = &dto.Thinking{
  103. Type: "enabled",
  104. BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
  105. }
  106. // TODO: 临时处理
  107. // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
  108. claudeRequest.TopP = 0
  109. claudeRequest.Temperature = common.GetPointer[float64](1.0)
  110. claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
  111. }
  112. if textRequest.Stop != nil {
  113. // stop maybe string/array string, convert to array string
  114. switch textRequest.Stop.(type) {
  115. case string:
  116. claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
  117. case []interface{}:
  118. stopSequences := make([]string, 0)
  119. for _, stop := range textRequest.Stop.([]interface{}) {
  120. stopSequences = append(stopSequences, stop.(string))
  121. }
  122. claudeRequest.StopSequences = stopSequences
  123. }
  124. }
  125. formatMessages := make([]dto.Message, 0)
  126. lastMessage := dto.Message{
  127. Role: "tool",
  128. }
  129. for i, message := range textRequest.Messages {
  130. if message.Role == "" {
  131. textRequest.Messages[i].Role = "user"
  132. }
  133. fmtMessage := dto.Message{
  134. Role: message.Role,
  135. Content: message.Content,
  136. }
  137. if message.Role == "tool" {
  138. fmtMessage.ToolCallId = message.ToolCallId
  139. }
  140. if message.Role == "assistant" && message.ToolCalls != nil {
  141. fmtMessage.ToolCalls = message.ToolCalls
  142. }
  143. if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
  144. if lastMessage.IsStringContent() && message.IsStringContent() {
  145. content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
  146. fmtMessage.Content = content
  147. // delete last message
  148. formatMessages = formatMessages[:len(formatMessages)-1]
  149. }
  150. }
  151. if fmtMessage.Content == nil {
  152. content, _ := json.Marshal("...")
  153. fmtMessage.Content = content
  154. }
  155. formatMessages = append(formatMessages, fmtMessage)
  156. lastMessage = fmtMessage
  157. }
  158. claudeMessages := make([]dto.ClaudeMessage, 0)
  159. isFirstMessage := true
  160. for _, message := range formatMessages {
  161. if message.Role == "system" {
  162. if message.IsStringContent() {
  163. claudeRequest.System = message.StringContent()
  164. } else {
  165. contents := message.ParseContent()
  166. content := ""
  167. for _, ctx := range contents {
  168. if ctx.Type == "text" {
  169. content += ctx.Text
  170. }
  171. }
  172. claudeRequest.System = content
  173. }
  174. } else {
  175. if isFirstMessage {
  176. isFirstMessage = false
  177. if message.Role != "user" {
  178. // fix: first message is assistant, add user message
  179. claudeMessage := dto.ClaudeMessage{
  180. Role: "user",
  181. Content: []dto.ClaudeMediaMessage{
  182. {
  183. Type: "text",
  184. Text: common.GetPointer[string]("..."),
  185. },
  186. },
  187. }
  188. claudeMessages = append(claudeMessages, claudeMessage)
  189. }
  190. }
  191. claudeMessage := dto.ClaudeMessage{
  192. Role: message.Role,
  193. }
  194. if message.Role == "tool" {
  195. if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
  196. lastMessage := claudeMessages[len(claudeMessages)-1]
  197. if content, ok := lastMessage.Content.(string); ok {
  198. lastMessage.Content = []dto.ClaudeMediaMessage{
  199. {
  200. Type: "text",
  201. Text: common.GetPointer[string](content),
  202. },
  203. }
  204. }
  205. lastMessage.Content = append(lastMessage.Content.([]dto.ClaudeMediaMessage), dto.ClaudeMediaMessage{
  206. Type: "tool_result",
  207. ToolUseId: message.ToolCallId,
  208. Content: message.Content,
  209. })
  210. claudeMessages[len(claudeMessages)-1] = lastMessage
  211. continue
  212. } else {
  213. claudeMessage.Role = "user"
  214. claudeMessage.Content = []dto.ClaudeMediaMessage{
  215. {
  216. Type: "tool_result",
  217. ToolUseId: message.ToolCallId,
  218. Content: message.Content,
  219. },
  220. }
  221. }
  222. } else if message.IsStringContent() && message.ToolCalls == nil {
  223. claudeMessage.Content = message.StringContent()
  224. } else {
  225. claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
  226. for _, mediaMessage := range message.ParseContent() {
  227. claudeMediaMessage := dto.ClaudeMediaMessage{
  228. Type: mediaMessage.Type,
  229. }
  230. if mediaMessage.Type == "text" {
  231. claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text)
  232. } else {
  233. imageUrl := mediaMessage.GetImageMedia()
  234. claudeMediaMessage.Type = "image"
  235. claudeMediaMessage.Source = &dto.ClaudeMessageSource{}
  236. // 判断是否是url
  237. if strings.HasPrefix(imageUrl.Url, "http") {
  238. claudeMediaMessage.Source.Type = "url"
  239. claudeMediaMessage.Source.Url = imageUrl.Url
  240. } else {
  241. _, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
  242. if err != nil {
  243. return nil, err
  244. }
  245. claudeMediaMessage.Source.Type = "base64"
  246. claudeMediaMessage.Source.MediaType = "image/" + format
  247. claudeMediaMessage.Source.Data = base64String
  248. }
  249. }
  250. claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
  251. }
  252. if message.ToolCalls != nil {
  253. for _, toolCall := range message.ParseToolCalls() {
  254. inputObj := make(map[string]any)
  255. if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
  256. common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
  257. continue
  258. }
  259. claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
  260. Type: "tool_use",
  261. Id: toolCall.ID,
  262. Name: toolCall.Function.Name,
  263. Input: inputObj,
  264. })
  265. }
  266. }
  267. claudeMessage.Content = claudeMediaMessages
  268. }
  269. claudeMessages = append(claudeMessages, claudeMessage)
  270. }
  271. }
  272. claudeRequest.Prompt = ""
  273. claudeRequest.Messages = claudeMessages
  274. return &claudeRequest, nil
  275. }
  276. func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
  277. var response dto.ChatCompletionsStreamResponse
  278. response.Object = "chat.completion.chunk"
  279. response.Model = claudeResponse.Model
  280. response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
  281. tools := make([]dto.ToolCallResponse, 0)
  282. fcIdx := 0
  283. if claudeResponse.Index != nil {
  284. fcIdx = *claudeResponse.Index - 1
  285. if fcIdx < 0 {
  286. fcIdx = 0
  287. }
  288. }
  289. var choice dto.ChatCompletionsStreamResponseChoice
  290. if reqMode == RequestModeCompletion {
  291. choice.Delta.SetContentString(claudeResponse.Completion)
  292. finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
  293. if finishReason != "null" {
  294. choice.FinishReason = &finishReason
  295. }
  296. } else {
  297. if claudeResponse.Type == "message_start" {
  298. response.Id = claudeResponse.Message.Id
  299. response.Model = claudeResponse.Message.Model
  300. //claudeUsage = &claudeResponse.Message.Usage
  301. choice.Delta.SetContentString("")
  302. choice.Delta.Role = "assistant"
  303. } else if claudeResponse.Type == "content_block_start" {
  304. if claudeResponse.ContentBlock != nil {
  305. //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
  306. if claudeResponse.ContentBlock.Type == "tool_use" {
  307. tools = append(tools, dto.ToolCallResponse{
  308. Index: common.GetPointer(fcIdx),
  309. ID: claudeResponse.ContentBlock.Id,
  310. Type: "function",
  311. Function: dto.FunctionResponse{
  312. Name: claudeResponse.ContentBlock.Name,
  313. Arguments: "",
  314. },
  315. })
  316. }
  317. } else {
  318. return nil
  319. }
  320. } else if claudeResponse.Type == "content_block_delta" {
  321. if claudeResponse.Delta != nil {
  322. choice.Delta.Content = claudeResponse.Delta.Text
  323. switch claudeResponse.Delta.Type {
  324. case "input_json_delta":
  325. tools = append(tools, dto.ToolCallResponse{
  326. Type: "function",
  327. Index: common.GetPointer(fcIdx),
  328. Function: dto.FunctionResponse{
  329. Arguments: *claudeResponse.Delta.PartialJson,
  330. },
  331. })
  332. case "signature_delta":
  333. // 加密的不处理
  334. signatureContent := "\n"
  335. choice.Delta.ReasoningContent = &signatureContent
  336. case "thinking_delta":
  337. thinkingContent := claudeResponse.Delta.Thinking
  338. choice.Delta.ReasoningContent = &thinkingContent
  339. }
  340. }
  341. } else if claudeResponse.Type == "message_delta" {
  342. finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
  343. if finishReason != "null" {
  344. choice.FinishReason = &finishReason
  345. }
  346. //claudeUsage = &claudeResponse.Usage
  347. } else if claudeResponse.Type == "message_stop" {
  348. return nil
  349. } else {
  350. return nil
  351. }
  352. }
  353. if len(tools) > 0 {
  354. choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
  355. choice.Delta.ToolCalls = tools
  356. }
  357. response.Choices = append(response.Choices, choice)
  358. return &response
  359. }
  360. func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
  361. choices := make([]dto.OpenAITextResponseChoice, 0)
  362. fullTextResponse := dto.OpenAITextResponse{
  363. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  364. Object: "chat.completion",
  365. Created: common.GetTimestamp(),
  366. }
  367. var responseText string
  368. var responseThinking string
  369. if len(claudeResponse.Content) > 0 {
  370. responseText = claudeResponse.Content[0].GetText()
  371. responseThinking = claudeResponse.Content[0].Thinking
  372. }
  373. tools := make([]dto.ToolCallResponse, 0)
  374. thinkingContent := ""
  375. if reqMode == RequestModeCompletion {
  376. content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
  377. choice := dto.OpenAITextResponseChoice{
  378. Index: 0,
  379. Message: dto.Message{
  380. Role: "assistant",
  381. Content: content,
  382. Name: nil,
  383. },
  384. FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
  385. }
  386. choices = append(choices, choice)
  387. } else {
  388. fullTextResponse.Id = claudeResponse.Id
  389. for _, message := range claudeResponse.Content {
  390. switch message.Type {
  391. case "tool_use":
  392. args, _ := json.Marshal(message.Input)
  393. tools = append(tools, dto.ToolCallResponse{
  394. ID: message.Id,
  395. Type: "function", // compatible with other OpenAI derivative applications
  396. Function: dto.FunctionResponse{
  397. Name: message.Name,
  398. Arguments: string(args),
  399. },
  400. })
  401. case "thinking":
  402. // 加密的不管, 只输出明文的推理过程
  403. thinkingContent = message.Thinking
  404. case "text":
  405. responseText = message.GetText()
  406. }
  407. }
  408. }
  409. choice := dto.OpenAITextResponseChoice{
  410. Index: 0,
  411. Message: dto.Message{
  412. Role: "assistant",
  413. },
  414. FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
  415. }
  416. choice.SetStringContent(responseText)
  417. if len(responseThinking) > 0 {
  418. choice.ReasoningContent = responseThinking
  419. }
  420. if len(tools) > 0 {
  421. choice.Message.SetToolCalls(tools)
  422. }
  423. choice.Message.ReasoningContent = thinkingContent
  424. fullTextResponse.Model = claudeResponse.Model
  425. choices = append(choices, choice)
  426. fullTextResponse.Choices = choices
  427. return &fullTextResponse
  428. }
  429. type ClaudeResponseInfo struct {
  430. ResponseId string
  431. Created int64
  432. Model string
  433. ResponseText strings.Builder
  434. Usage *dto.Usage
  435. }
  436. func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
  437. if requestMode == RequestModeCompletion {
  438. claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
  439. } else {
  440. if claudeResponse.Type == "message_start" {
  441. // message_start, 获取usage
  442. claudeInfo.ResponseId = claudeResponse.Message.Id
  443. claudeInfo.Model = claudeResponse.Message.Model
  444. claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
  445. } else if claudeResponse.Type == "content_block_delta" {
  446. if claudeResponse.Delta.Text != nil {
  447. claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
  448. }
  449. } else if claudeResponse.Type == "message_delta" {
  450. claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
  451. if claudeResponse.Usage.InputTokens > 0 {
  452. claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
  453. }
  454. claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
  455. } else if claudeResponse.Type == "content_block_start" {
  456. } else {
  457. return false
  458. }
  459. }
  460. if oaiResponse != nil {
  461. oaiResponse.Id = claudeInfo.ResponseId
  462. oaiResponse.Created = claudeInfo.Created
  463. oaiResponse.Model = claudeInfo.Model
  464. }
  465. return true
  466. }
  467. func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
  468. var claudeResponse dto.ClaudeResponse
  469. err := common.DecodeJsonStr(data, &claudeResponse)
  470. if err != nil {
  471. common.SysError("error unmarshalling stream response: " + err.Error())
  472. return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
  473. }
  474. if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
  475. return &dto.OpenAIErrorWithStatusCode{
  476. Error: dto.OpenAIError{
  477. Code: "stream_response_error",
  478. Type: claudeResponse.Error.Type,
  479. Message: claudeResponse.Error.Message,
  480. },
  481. StatusCode: http.StatusInternalServerError,
  482. }
  483. }
  484. if info.RelayFormat == relaycommon.RelayFormatClaude {
  485. if requestMode == RequestModeCompletion {
  486. claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
  487. } else {
  488. if claudeResponse.Type == "message_start" {
  489. // message_start, 获取usage
  490. info.UpstreamModelName = claudeResponse.Message.Model
  491. claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
  492. claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
  493. claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
  494. claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
  495. } else if claudeResponse.Type == "content_block_delta" {
  496. claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
  497. } else if claudeResponse.Type == "message_delta" {
  498. if claudeResponse.Usage.InputTokens > 0 {
  499. // 不叠加,只取最新的
  500. claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
  501. }
  502. claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
  503. claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
  504. }
  505. }
  506. helper.ClaudeChunkData(c, claudeResponse, data)
  507. } else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
  508. response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
  509. if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
  510. return nil
  511. }
  512. err = helper.ObjectData(c, response)
  513. if err != nil {
  514. common.LogError(c, "send_stream_response_failed: "+err.Error())
  515. }
  516. }
  517. return nil
  518. }
  519. func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
  520. if info.RelayFormat == relaycommon.RelayFormatClaude {
  521. if requestMode == RequestModeCompletion {
  522. claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
  523. } else {
  524. // 说明流模式建立失败,可能为官方出错
  525. if claudeInfo.Usage.PromptTokens == 0 {
  526. //usage.PromptTokens = info.PromptTokens
  527. }
  528. if claudeInfo.Usage.CompletionTokens == 0 {
  529. claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
  530. }
  531. }
  532. } else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
  533. if requestMode == RequestModeCompletion {
  534. claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
  535. } else {
  536. if claudeInfo.Usage.PromptTokens == 0 {
  537. //上游出错
  538. }
  539. if claudeInfo.Usage.CompletionTokens == 0 {
  540. claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
  541. }
  542. }
  543. if info.ShouldIncludeUsage {
  544. response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
  545. err := helper.ObjectData(c, response)
  546. if err != nil {
  547. common.SysError("send final response failed: " + err.Error())
  548. }
  549. }
  550. helper.Done(c)
  551. }
  552. }
  553. func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  554. claudeInfo := &ClaudeResponseInfo{
  555. ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  556. Created: common.GetTimestamp(),
  557. Model: info.UpstreamModelName,
  558. ResponseText: strings.Builder{},
  559. Usage: &dto.Usage{},
  560. }
  561. var err *dto.OpenAIErrorWithStatusCode
  562. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  563. err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
  564. if err != nil {
  565. return false
  566. }
  567. return true
  568. })
  569. if err != nil {
  570. return err, nil
  571. }
  572. HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
  573. return nil, claudeInfo.Usage
  574. }
  575. func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
  576. var claudeResponse dto.ClaudeResponse
  577. err := common.DecodeJson(data, &claudeResponse)
  578. if err != nil {
  579. return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
  580. }
  581. if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
  582. return &dto.OpenAIErrorWithStatusCode{
  583. Error: dto.OpenAIError{
  584. Message: claudeResponse.Error.Message,
  585. Type: claudeResponse.Error.Type,
  586. Code: claudeResponse.Error.Type,
  587. },
  588. StatusCode: http.StatusInternalServerError,
  589. }
  590. }
  591. if requestMode == RequestModeCompletion {
  592. completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
  593. if err != nil {
  594. return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
  595. }
  596. claudeInfo.Usage.PromptTokens = info.PromptTokens
  597. claudeInfo.Usage.CompletionTokens = completionTokens
  598. claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
  599. } else {
  600. claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
  601. claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
  602. claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
  603. claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
  604. claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
  605. }
  606. var responseData []byte
  607. switch info.RelayFormat {
  608. case relaycommon.RelayFormatOpenAI:
  609. openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
  610. openaiResponse.Usage = *claudeInfo.Usage
  611. responseData, err = json.Marshal(openaiResponse)
  612. if err != nil {
  613. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
  614. }
  615. case relaycommon.RelayFormatClaude:
  616. responseData = data
  617. }
  618. c.Writer.Header().Set("Content-Type", "application/json")
  619. c.Writer.WriteHeader(http.StatusOK)
  620. _, err = c.Writer.Write(responseData)
  621. return nil
  622. }
  623. func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  624. claudeInfo := &ClaudeResponseInfo{
  625. ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  626. Created: common.GetTimestamp(),
  627. Model: info.UpstreamModelName,
  628. ResponseText: strings.Builder{},
  629. Usage: &dto.Usage{},
  630. }
  631. responseBody, err := io.ReadAll(resp.Body)
  632. if err != nil {
  633. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  634. }
  635. resp.Body.Close()
  636. if common.DebugEnabled {
  637. println("responseBody: ", string(responseBody))
  638. }
  639. handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
  640. if handleErr != nil {
  641. return handleErr, nil
  642. }
  643. return nil, claudeInfo.Usage
  644. }