relay-gemini.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. package gemini
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/constant"
  9. "one-api/dto"
  10. relaycommon "one-api/relay/common"
  11. "one-api/relay/helper"
  12. "one-api/service"
  13. "one-api/setting/model_setting"
  14. "strings"
  15. "unicode/utf8"
  16. "github.com/gin-gonic/gin"
  17. )
  18. // Setting safety to the lowest possible values since Gemini is already powerless enough
  19. func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) {
  20. geminiRequest := GeminiChatRequest{
  21. Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
  22. //SafetySettings: []GeminiChatSafetySettings{},
  23. GenerationConfig: GeminiChatGenerationConfig{
  24. Temperature: textRequest.Temperature,
  25. TopP: textRequest.TopP,
  26. MaxOutputTokens: textRequest.MaxTokens,
  27. Seed: int64(textRequest.Seed),
  28. },
  29. }
  30. safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
  31. for _, category := range SafetySettingList {
  32. safetySettings = append(safetySettings, GeminiChatSafetySettings{
  33. Category: category,
  34. Threshold: model_setting.GetGeminiSafetySetting(category),
  35. })
  36. }
  37. geminiRequest.SafetySettings = safetySettings
  38. // openaiContent.FuncToToolCalls()
  39. if textRequest.Tools != nil {
  40. functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
  41. googleSearch := false
  42. codeExecution := false
  43. for _, tool := range textRequest.Tools {
  44. if tool.Function.Name == "googleSearch" {
  45. googleSearch = true
  46. continue
  47. }
  48. if tool.Function.Name == "codeExecution" {
  49. codeExecution = true
  50. continue
  51. }
  52. if tool.Function.Parameters != nil {
  53. params, ok := tool.Function.Parameters.(map[string]interface{})
  54. if ok {
  55. if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
  56. if len(props) == 0 {
  57. tool.Function.Parameters = nil
  58. }
  59. }
  60. }
  61. }
  62. functions = append(functions, tool.Function)
  63. }
  64. if codeExecution {
  65. geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
  66. CodeExecution: make(map[string]string),
  67. })
  68. }
  69. if googleSearch {
  70. geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
  71. GoogleSearch: make(map[string]string),
  72. })
  73. }
  74. if len(functions) > 0 {
  75. geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
  76. FunctionDeclarations: functions,
  77. })
  78. }
  79. // common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
  80. // json_data, _ := json.Marshal(geminiRequest.Tools)
  81. // common.SysLog("tools_json: " + string(json_data))
  82. } else if textRequest.Functions != nil {
  83. geminiRequest.Tools = []GeminiChatTool{
  84. {
  85. FunctionDeclarations: textRequest.Functions,
  86. },
  87. }
  88. }
  89. if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
  90. geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
  91. if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
  92. cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
  93. geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
  94. }
  95. }
  96. tool_call_ids := make(map[string]string)
  97. var system_content []string
  98. //shouldAddDummyModelMessage := false
  99. for _, message := range textRequest.Messages {
  100. if message.Role == "system" {
  101. system_content = append(system_content, message.StringContent())
  102. continue
  103. } else if message.Role == "tool" || message.Role == "function" {
  104. if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
  105. geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
  106. Role: "user",
  107. })
  108. }
  109. var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
  110. name := ""
  111. if message.Name != nil {
  112. name = *message.Name
  113. } else if val, exists := tool_call_ids[message.ToolCallId]; exists {
  114. name = val
  115. }
  116. content := common.StrToMap(message.StringContent())
  117. functionResp := &FunctionResponse{
  118. Name: name,
  119. Response: GeminiFunctionResponseContent{
  120. Name: name,
  121. Content: content,
  122. },
  123. }
  124. if content == nil {
  125. functionResp.Response.Content = message.StringContent()
  126. }
  127. *parts = append(*parts, GeminiPart{
  128. FunctionResponse: functionResp,
  129. })
  130. continue
  131. }
  132. var parts []GeminiPart
  133. content := GeminiChatContent{
  134. Role: message.Role,
  135. }
  136. // isToolCall := false
  137. if message.ToolCalls != nil {
  138. // message.Role = "model"
  139. // isToolCall = true
  140. for _, call := range message.ParseToolCalls() {
  141. args := map[string]interface{}{}
  142. if call.Function.Arguments != "" {
  143. if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
  144. return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
  145. }
  146. }
  147. toolCall := GeminiPart{
  148. FunctionCall: &FunctionCall{
  149. FunctionName: call.Function.Name,
  150. Arguments: args,
  151. },
  152. }
  153. parts = append(parts, toolCall)
  154. tool_call_ids[call.ID] = call.Function.Name
  155. }
  156. }
  157. openaiContent := message.ParseContent()
  158. imageNum := 0
  159. for _, part := range openaiContent {
  160. if part.Type == dto.ContentTypeText {
  161. if part.Text == "" {
  162. continue
  163. }
  164. parts = append(parts, GeminiPart{
  165. Text: part.Text,
  166. })
  167. } else if part.Type == dto.ContentTypeImageURL {
  168. imageNum += 1
  169. if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
  170. return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
  171. }
  172. // 判断是否是url
  173. if strings.HasPrefix(part.GetImageMedia().Url, "http") {
  174. // 是url,获取图片的类型和base64编码的数据
  175. fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
  176. if err != nil {
  177. return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
  178. }
  179. parts = append(parts, GeminiPart{
  180. InlineData: &GeminiInlineData{
  181. MimeType: fileData.MimeType,
  182. Data: fileData.Base64Data,
  183. },
  184. })
  185. } else {
  186. format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url)
  187. if err != nil {
  188. return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
  189. }
  190. parts = append(parts, GeminiPart{
  191. InlineData: &GeminiInlineData{
  192. MimeType: format,
  193. Data: base64String,
  194. },
  195. })
  196. }
  197. }
  198. }
  199. content.Parts = parts
  200. // there's no assistant role in gemini and API shall vomit if Role is not user or model
  201. if content.Role == "assistant" {
  202. content.Role = "model"
  203. }
  204. geminiRequest.Contents = append(geminiRequest.Contents, content)
  205. }
  206. if len(system_content) > 0 {
  207. geminiRequest.SystemInstructions = &GeminiChatContent{
  208. Parts: []GeminiPart{
  209. {
  210. Text: strings.Join(system_content, "\n"),
  211. },
  212. },
  213. }
  214. }
  215. return &geminiRequest, nil
  216. }
  217. func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
  218. if depth >= 5 {
  219. return schema
  220. }
  221. v, ok := schema.(map[string]interface{})
  222. if !ok || len(v) == 0 {
  223. return schema
  224. }
  225. // 删除所有的title字段
  226. delete(v, "title")
  227. // 如果type不为object和array,则直接返回
  228. if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
  229. return schema
  230. }
  231. switch v["type"] {
  232. case "object":
  233. delete(v, "additionalProperties")
  234. // 处理 properties
  235. if properties, ok := v["properties"].(map[string]interface{}); ok {
  236. for key, value := range properties {
  237. properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
  238. }
  239. }
  240. for _, field := range []string{"allOf", "anyOf", "oneOf"} {
  241. if nested, ok := v[field].([]interface{}); ok {
  242. for i, item := range nested {
  243. nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
  244. }
  245. }
  246. }
  247. case "array":
  248. if items, ok := v["items"].(map[string]interface{}); ok {
  249. v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
  250. }
  251. }
  252. return v
  253. }
  254. func unescapeString(s string) (string, error) {
  255. var result []rune
  256. escaped := false
  257. i := 0
  258. for i < len(s) {
  259. r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
  260. if r == utf8.RuneError {
  261. return "", fmt.Errorf("invalid UTF-8 encoding")
  262. }
  263. if escaped {
  264. // 如果是转义符后的字符,检查其类型
  265. switch r {
  266. case '"':
  267. result = append(result, '"')
  268. case '\\':
  269. result = append(result, '\\')
  270. case '/':
  271. result = append(result, '/')
  272. case 'b':
  273. result = append(result, '\b')
  274. case 'f':
  275. result = append(result, '\f')
  276. case 'n':
  277. result = append(result, '\n')
  278. case 'r':
  279. result = append(result, '\r')
  280. case 't':
  281. result = append(result, '\t')
  282. case '\'':
  283. result = append(result, '\'')
  284. default:
  285. // 如果遇到一个非法的转义字符,直接按原样输出
  286. result = append(result, '\\', r)
  287. }
  288. escaped = false
  289. } else {
  290. if r == '\\' {
  291. escaped = true // 记录反斜杠作为转义符
  292. } else {
  293. result = append(result, r)
  294. }
  295. }
  296. i += size // 移动到下一个字符
  297. }
  298. return string(result), nil
  299. }
  300. func unescapeMapOrSlice(data interface{}) interface{} {
  301. switch v := data.(type) {
  302. case map[string]interface{}:
  303. for k, val := range v {
  304. v[k] = unescapeMapOrSlice(val)
  305. }
  306. case []interface{}:
  307. for i, val := range v {
  308. v[i] = unescapeMapOrSlice(val)
  309. }
  310. case string:
  311. if unescaped, err := unescapeString(v); err != nil {
  312. return v
  313. } else {
  314. return unescaped
  315. }
  316. }
  317. return data
  318. }
  319. func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
  320. var argsBytes []byte
  321. var err error
  322. if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
  323. argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
  324. } else {
  325. argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
  326. }
  327. if err != nil {
  328. return nil
  329. }
  330. return &dto.ToolCallResponse{
  331. ID: fmt.Sprintf("call_%s", common.GetUUID()),
  332. Type: "function",
  333. Function: dto.FunctionResponse{
  334. Arguments: string(argsBytes),
  335. Name: item.FunctionCall.FunctionName,
  336. },
  337. }
  338. }
  339. func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
  340. fullTextResponse := dto.OpenAITextResponse{
  341. Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
  342. Object: "chat.completion",
  343. Created: common.GetTimestamp(),
  344. Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
  345. }
  346. content, _ := json.Marshal("")
  347. isToolCall := false
  348. for _, candidate := range response.Candidates {
  349. choice := dto.OpenAITextResponseChoice{
  350. Index: int(candidate.Index),
  351. Message: dto.Message{
  352. Role: "assistant",
  353. Content: content,
  354. },
  355. FinishReason: constant.FinishReasonStop,
  356. }
  357. if len(candidate.Content.Parts) > 0 {
  358. var texts []string
  359. var toolCalls []dto.ToolCallResponse
  360. for _, part := range candidate.Content.Parts {
  361. if part.FunctionCall != nil {
  362. choice.FinishReason = constant.FinishReasonToolCalls
  363. if call := getResponseToolCall(&part); call != nil {
  364. toolCalls = append(toolCalls, *call)
  365. }
  366. } else {
  367. if part.ExecutableCode != nil {
  368. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
  369. } else if part.CodeExecutionResult != nil {
  370. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
  371. } else {
  372. // 过滤掉空行
  373. if part.Text != "\n" {
  374. texts = append(texts, part.Text)
  375. }
  376. }
  377. }
  378. }
  379. if len(toolCalls) > 0 {
  380. choice.Message.SetToolCalls(toolCalls)
  381. isToolCall = true
  382. }
  383. choice.Message.SetStringContent(strings.Join(texts, "\n"))
  384. }
  385. if candidate.FinishReason != nil {
  386. switch *candidate.FinishReason {
  387. case "STOP":
  388. choice.FinishReason = constant.FinishReasonStop
  389. case "MAX_TOKENS":
  390. choice.FinishReason = constant.FinishReasonLength
  391. default:
  392. choice.FinishReason = constant.FinishReasonContentFilter
  393. }
  394. }
  395. if isToolCall {
  396. choice.FinishReason = constant.FinishReasonToolCalls
  397. }
  398. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  399. }
  400. return &fullTextResponse
  401. }
  402. func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
  403. choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
  404. isStop := false
  405. for _, candidate := range geminiResponse.Candidates {
  406. if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
  407. isStop = true
  408. candidate.FinishReason = nil
  409. }
  410. choice := dto.ChatCompletionsStreamResponseChoice{
  411. Index: int(candidate.Index),
  412. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  413. Role: "assistant",
  414. },
  415. }
  416. var texts []string
  417. isTools := false
  418. if candidate.FinishReason != nil {
  419. // p := GeminiConvertFinishReason(*candidate.FinishReason)
  420. switch *candidate.FinishReason {
  421. case "STOP":
  422. choice.FinishReason = &constant.FinishReasonStop
  423. case "MAX_TOKENS":
  424. choice.FinishReason = &constant.FinishReasonLength
  425. default:
  426. choice.FinishReason = &constant.FinishReasonContentFilter
  427. }
  428. }
  429. for _, part := range candidate.Content.Parts {
  430. if part.FunctionCall != nil {
  431. isTools = true
  432. if call := getResponseToolCall(&part); call != nil {
  433. call.SetIndex(len(choice.Delta.ToolCalls))
  434. choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
  435. }
  436. } else {
  437. if part.ExecutableCode != nil {
  438. texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
  439. } else if part.CodeExecutionResult != nil {
  440. texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
  441. } else {
  442. if part.Text != "\n" {
  443. texts = append(texts, part.Text)
  444. }
  445. }
  446. }
  447. }
  448. choice.Delta.SetContentString(strings.Join(texts, "\n"))
  449. if isTools {
  450. choice.FinishReason = &constant.FinishReasonToolCalls
  451. }
  452. choices = append(choices, choice)
  453. }
  454. var response dto.ChatCompletionsStreamResponse
  455. response.Object = "chat.completion.chunk"
  456. response.Choices = choices
  457. return &response, isStop
  458. }
  459. func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  460. // responseText := ""
  461. id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
  462. createAt := common.GetTimestamp()
  463. var usage = &dto.Usage{}
  464. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  465. var geminiResponse GeminiChatResponse
  466. err := json.Unmarshal([]byte(data), &geminiResponse)
  467. if err != nil {
  468. common.LogError(c, "error unmarshalling stream response: "+err.Error())
  469. return false
  470. }
  471. response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
  472. response.Id = id
  473. response.Created = createAt
  474. response.Model = info.UpstreamModelName
  475. // responseText += response.Choices[0].Delta.GetContentString()
  476. if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
  477. usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
  478. usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
  479. }
  480. err = helper.ObjectData(c, response)
  481. if err != nil {
  482. common.LogError(c, err.Error())
  483. }
  484. if isStop {
  485. response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
  486. helper.ObjectData(c, response)
  487. }
  488. return true
  489. })
  490. var response *dto.ChatCompletionsStreamResponse
  491. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  492. usage.PromptTokensDetails.TextTokens = usage.PromptTokens
  493. usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
  494. if info.ShouldIncludeUsage {
  495. response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
  496. err := helper.ObjectData(c, response)
  497. if err != nil {
  498. common.SysError("send final response failed: " + err.Error())
  499. }
  500. }
  501. helper.Done(c)
  502. //resp.Body.Close()
  503. return nil, usage
  504. }
  505. func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  506. responseBody, err := io.ReadAll(resp.Body)
  507. if err != nil {
  508. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  509. }
  510. err = resp.Body.Close()
  511. if err != nil {
  512. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  513. }
  514. var geminiResponse GeminiChatResponse
  515. err = json.Unmarshal(responseBody, &geminiResponse)
  516. if err != nil {
  517. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  518. }
  519. if len(geminiResponse.Candidates) == 0 {
  520. return &dto.OpenAIErrorWithStatusCode{
  521. Error: dto.OpenAIError{
  522. Message: "No candidates returned",
  523. Type: "server_error",
  524. Param: "",
  525. Code: 500,
  526. },
  527. StatusCode: resp.StatusCode,
  528. }, nil
  529. }
  530. fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
  531. fullTextResponse.Model = info.UpstreamModelName
  532. usage := dto.Usage{
  533. PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
  534. CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
  535. TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
  536. }
  537. fullTextResponse.Usage = usage
  538. jsonResponse, err := json.Marshal(fullTextResponse)
  539. if err != nil {
  540. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  541. }
  542. c.Writer.Header().Set("Content-Type", "application/json")
  543. c.Writer.WriteHeader(resp.StatusCode)
  544. _, err = c.Writer.Write(jsonResponse)
  545. return nil, &usage
  546. }
  547. func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
  548. responseBody, readErr := io.ReadAll(resp.Body)
  549. if readErr != nil {
  550. return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
  551. }
  552. _ = resp.Body.Close()
  553. var geminiResponse GeminiEmbeddingResponse
  554. if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
  555. return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
  556. }
  557. // convert to openai format response
  558. openAIResponse := dto.OpenAIEmbeddingResponse{
  559. Object: "list",
  560. Data: []dto.OpenAIEmbeddingResponseItem{
  561. {
  562. Object: "embedding",
  563. Embedding: geminiResponse.Embedding.Values,
  564. Index: 0,
  565. },
  566. },
  567. Model: info.UpstreamModelName,
  568. }
  569. // calculate usage
  570. // https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
  571. // Google has not yet clarified how embedding models will be billed
  572. // refer to openai billing method to use input tokens billing
  573. // https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
  574. usage = &dto.Usage{
  575. PromptTokens: info.PromptTokens,
  576. CompletionTokens: 0,
  577. TotalTokens: info.PromptTokens,
  578. }
  579. openAIResponse.Usage = *usage.(*dto.Usage)
  580. jsonResponse, jsonErr := json.Marshal(openAIResponse)
  581. if jsonErr != nil {
  582. return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
  583. }
  584. c.Writer.Header().Set("Content-Type", "application/json")
  585. c.Writer.WriteHeader(resp.StatusCode)
  586. _, _ = c.Writer.Write(jsonResponse)
  587. return usage, nil
  588. }