text.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. package ali
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/dto"
  9. "one-api/relay/helper"
  10. "one-api/service"
  11. "strings"
  12. "github.com/gin-gonic/gin"
  13. )
  14. // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
  15. const EnableSearchModelSuffix = "-internet"
  16. func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
  17. if request.TopP >= 1 {
  18. request.TopP = 0.999
  19. } else if request.TopP <= 0 {
  20. request.TopP = 0.001
  21. }
  22. return &request
  23. }
  24. func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
  25. return &AliEmbeddingRequest{
  26. Model: request.Model,
  27. Input: struct {
  28. Texts []string `json:"texts"`
  29. }{
  30. Texts: request.ParseInput(),
  31. },
  32. }
  33. }
  34. func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  35. var aliResponse AliEmbeddingResponse
  36. err := json.NewDecoder(resp.Body).Decode(&aliResponse)
  37. if err != nil {
  38. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  39. }
  40. err = resp.Body.Close()
  41. if err != nil {
  42. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  43. }
  44. if aliResponse.Code != "" {
  45. return &dto.OpenAIErrorWithStatusCode{
  46. Error: dto.OpenAIError{
  47. Message: aliResponse.Message,
  48. Type: aliResponse.Code,
  49. Param: aliResponse.RequestId,
  50. Code: aliResponse.Code,
  51. },
  52. StatusCode: resp.StatusCode,
  53. }, nil
  54. }
  55. model := c.GetString("model")
  56. if model == "" {
  57. model = "text-embedding-v4"
  58. }
  59. fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse, model)
  60. jsonResponse, err := json.Marshal(fullTextResponse)
  61. if err != nil {
  62. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  63. }
  64. c.Writer.Header().Set("Content-Type", "application/json")
  65. c.Writer.WriteHeader(resp.StatusCode)
  66. _, err = c.Writer.Write(jsonResponse)
  67. return nil, &fullTextResponse.Usage
  68. }
  69. func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse, model string) *dto.OpenAIEmbeddingResponse {
  70. openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
  71. Object: "list",
  72. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
  73. Model: model,
  74. Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
  75. }
  76. for _, item := range response.Output.Embeddings {
  77. openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
  78. Object: `embedding`,
  79. Index: item.TextIndex,
  80. Embedding: item.Embedding,
  81. })
  82. }
  83. return &openAIEmbeddingResponse
  84. }
  85. func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
  86. content, _ := json.Marshal(response.Output.Text)
  87. choice := dto.OpenAITextResponseChoice{
  88. Index: 0,
  89. Message: dto.Message{
  90. Role: "assistant",
  91. Content: content,
  92. },
  93. FinishReason: response.Output.FinishReason,
  94. }
  95. fullTextResponse := dto.OpenAITextResponse{
  96. Id: response.RequestId,
  97. Object: "chat.completion",
  98. Created: common.GetTimestamp(),
  99. Choices: []dto.OpenAITextResponseChoice{choice},
  100. Usage: dto.Usage{
  101. PromptTokens: response.Usage.InputTokens,
  102. CompletionTokens: response.Usage.OutputTokens,
  103. TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
  104. },
  105. }
  106. return &fullTextResponse
  107. }
  108. func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
  109. var choice dto.ChatCompletionsStreamResponseChoice
  110. choice.Delta.SetContentString(aliResponse.Output.Text)
  111. if aliResponse.Output.FinishReason != "null" {
  112. finishReason := aliResponse.Output.FinishReason
  113. choice.FinishReason = &finishReason
  114. }
  115. response := dto.ChatCompletionsStreamResponse{
  116. Id: aliResponse.RequestId,
  117. Object: "chat.completion.chunk",
  118. Created: common.GetTimestamp(),
  119. Model: "ernie-bot",
  120. Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
  121. }
  122. return &response
  123. }
  124. func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  125. var usage dto.Usage
  126. scanner := bufio.NewScanner(resp.Body)
  127. scanner.Split(bufio.ScanLines)
  128. dataChan := make(chan string)
  129. stopChan := make(chan bool)
  130. go func() {
  131. for scanner.Scan() {
  132. data := scanner.Text()
  133. if len(data) < 5 { // ignore blank line or wrong format
  134. continue
  135. }
  136. if data[:5] != "data:" {
  137. continue
  138. }
  139. data = data[5:]
  140. dataChan <- data
  141. }
  142. stopChan <- true
  143. }()
  144. helper.SetEventStreamHeaders(c)
  145. lastResponseText := ""
  146. c.Stream(func(w io.Writer) bool {
  147. select {
  148. case data := <-dataChan:
  149. var aliResponse AliResponse
  150. err := json.Unmarshal([]byte(data), &aliResponse)
  151. if err != nil {
  152. common.SysError("error unmarshalling stream response: " + err.Error())
  153. return true
  154. }
  155. if aliResponse.Usage.OutputTokens != 0 {
  156. usage.PromptTokens = aliResponse.Usage.InputTokens
  157. usage.CompletionTokens = aliResponse.Usage.OutputTokens
  158. usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
  159. }
  160. response := streamResponseAli2OpenAI(&aliResponse)
  161. response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
  162. lastResponseText = aliResponse.Output.Text
  163. jsonResponse, err := json.Marshal(response)
  164. if err != nil {
  165. common.SysError("error marshalling stream response: " + err.Error())
  166. return true
  167. }
  168. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  169. return true
  170. case <-stopChan:
  171. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  172. return false
  173. }
  174. })
  175. err := resp.Body.Close()
  176. if err != nil {
  177. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  178. }
  179. return nil, &usage
  180. }
  181. func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  182. var aliResponse AliResponse
  183. responseBody, err := io.ReadAll(resp.Body)
  184. if err != nil {
  185. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  186. }
  187. err = resp.Body.Close()
  188. if err != nil {
  189. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  190. }
  191. err = json.Unmarshal(responseBody, &aliResponse)
  192. if err != nil {
  193. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  194. }
  195. if aliResponse.Code != "" {
  196. return &dto.OpenAIErrorWithStatusCode{
  197. Error: dto.OpenAIError{
  198. Message: aliResponse.Message,
  199. Type: aliResponse.Code,
  200. Param: aliResponse.RequestId,
  201. Code: aliResponse.Code,
  202. },
  203. StatusCode: resp.StatusCode,
  204. }, nil
  205. }
  206. fullTextResponse := responseAli2OpenAI(&aliResponse)
  207. jsonResponse, err := json.Marshal(fullTextResponse)
  208. if err != nil {
  209. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  210. }
  211. c.Writer.Header().Set("Content-Type", "application/json")
  212. c.Writer.WriteHeader(resp.StatusCode)
  213. _, err = c.Writer.Write(jsonResponse)
  214. return nil, &fullTextResponse.Usage
  215. }