text.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. package ali
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/dto"
  9. "one-api/relay/helper"
  10. "one-api/service"
  11. "strings"
  12. "one-api/types"
  13. "github.com/gin-gonic/gin"
  14. )
  15. // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
  16. const EnableSearchModelSuffix = "-internet"
  17. func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
  18. if request.TopP >= 1 {
  19. request.TopP = 0.999
  20. } else if request.TopP <= 0 {
  21. request.TopP = 0.001
  22. }
  23. return &request
  24. }
  25. func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
  26. return &AliEmbeddingRequest{
  27. Model: request.Model,
  28. Input: struct {
  29. Texts []string `json:"texts"`
  30. }{
  31. Texts: request.ParseInput(),
  32. },
  33. }
  34. }
  35. func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
  36. var fullTextResponse dto.FlexibleEmbeddingResponse
  37. err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
  38. if err != nil {
  39. return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
  40. }
  41. service.CloseResponseBodyGracefully(resp)
  42. model := c.GetString("model")
  43. if model == "" {
  44. model = "text-embedding-v4"
  45. }
  46. jsonResponse, err := json.Marshal(fullTextResponse)
  47. if err != nil {
  48. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  49. }
  50. c.Writer.Header().Set("Content-Type", "application/json")
  51. c.Writer.WriteHeader(resp.StatusCode)
  52. c.Writer.Write(jsonResponse)
  53. return nil, &fullTextResponse.Usage
  54. }
  55. func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse, model string) *dto.OpenAIEmbeddingResponse {
  56. openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
  57. Object: "list",
  58. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
  59. Model: model,
  60. Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
  61. }
  62. for _, item := range response.Output.Embeddings {
  63. openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
  64. Object: `embedding`,
  65. Index: item.TextIndex,
  66. Embedding: item.Embedding,
  67. })
  68. }
  69. return &openAIEmbeddingResponse
  70. }
  71. func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
  72. choice := dto.OpenAITextResponseChoice{
  73. Index: 0,
  74. Message: dto.Message{
  75. Role: "assistant",
  76. Content: response.Output.Text,
  77. },
  78. FinishReason: response.Output.FinishReason,
  79. }
  80. fullTextResponse := dto.OpenAITextResponse{
  81. Id: response.RequestId,
  82. Object: "chat.completion",
  83. Created: common.GetTimestamp(),
  84. Choices: []dto.OpenAITextResponseChoice{choice},
  85. Usage: dto.Usage{
  86. PromptTokens: response.Usage.InputTokens,
  87. CompletionTokens: response.Usage.OutputTokens,
  88. TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
  89. },
  90. }
  91. return &fullTextResponse
  92. }
  93. func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
  94. var choice dto.ChatCompletionsStreamResponseChoice
  95. choice.Delta.SetContentString(aliResponse.Output.Text)
  96. if aliResponse.Output.FinishReason != "null" {
  97. finishReason := aliResponse.Output.FinishReason
  98. choice.FinishReason = &finishReason
  99. }
  100. response := dto.ChatCompletionsStreamResponse{
  101. Id: aliResponse.RequestId,
  102. Object: "chat.completion.chunk",
  103. Created: common.GetTimestamp(),
  104. Model: "ernie-bot",
  105. Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
  106. }
  107. return &response
  108. }
  109. func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
  110. var usage dto.Usage
  111. scanner := bufio.NewScanner(resp.Body)
  112. scanner.Split(bufio.ScanLines)
  113. dataChan := make(chan string)
  114. stopChan := make(chan bool)
  115. go func() {
  116. for scanner.Scan() {
  117. data := scanner.Text()
  118. if len(data) < 5 { // ignore blank line or wrong format
  119. continue
  120. }
  121. if data[:5] != "data:" {
  122. continue
  123. }
  124. data = data[5:]
  125. dataChan <- data
  126. }
  127. stopChan <- true
  128. }()
  129. helper.SetEventStreamHeaders(c)
  130. lastResponseText := ""
  131. c.Stream(func(w io.Writer) bool {
  132. select {
  133. case data := <-dataChan:
  134. var aliResponse AliResponse
  135. err := json.Unmarshal([]byte(data), &aliResponse)
  136. if err != nil {
  137. common.SysLog("error unmarshalling stream response: " + err.Error())
  138. return true
  139. }
  140. if aliResponse.Usage.OutputTokens != 0 {
  141. usage.PromptTokens = aliResponse.Usage.InputTokens
  142. usage.CompletionTokens = aliResponse.Usage.OutputTokens
  143. usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
  144. }
  145. response := streamResponseAli2OpenAI(&aliResponse)
  146. response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
  147. lastResponseText = aliResponse.Output.Text
  148. jsonResponse, err := json.Marshal(response)
  149. if err != nil {
  150. common.SysLog("error marshalling stream response: " + err.Error())
  151. return true
  152. }
  153. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  154. return true
  155. case <-stopChan:
  156. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  157. return false
  158. }
  159. })
  160. service.CloseResponseBodyGracefully(resp)
  161. return nil, &usage
  162. }
  163. func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
  164. var aliResponse AliResponse
  165. responseBody, err := io.ReadAll(resp.Body)
  166. if err != nil {
  167. return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
  168. }
  169. service.CloseResponseBodyGracefully(resp)
  170. err = json.Unmarshal(responseBody, &aliResponse)
  171. if err != nil {
  172. return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
  173. }
  174. if aliResponse.Code != "" {
  175. return types.WithOpenAIError(types.OpenAIError{
  176. Message: aliResponse.Message,
  177. Type: "ali_error",
  178. Param: aliResponse.RequestId,
  179. Code: aliResponse.Code,
  180. }, resp.StatusCode), nil
  181. }
  182. fullTextResponse := responseAli2OpenAI(&aliResponse)
  183. jsonResponse, err := common.Marshal(fullTextResponse)
  184. if err != nil {
  185. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  186. }
  187. c.Writer.Header().Set("Content-Type", "application/json")
  188. c.Writer.WriteHeader(resp.StatusCode)
  189. _, err = c.Writer.Write(jsonResponse)
  190. return nil, &fullTextResponse.Usage
  191. }