text.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. package ali
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "github.com/gin-gonic/gin"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. "one-api/relay/helper"
  11. "one-api/service"
  12. "strings"
  13. )
  14. // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
  15. const EnableSearchModelSuffix = "-internet"
  16. func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
  17. if request.TopP >= 1 {
  18. request.TopP = 0.999
  19. } else if request.TopP <= 0 {
  20. request.TopP = 0.001
  21. }
  22. return &request
  23. }
  24. func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
  25. if request.Model == "" {
  26. request.Model = "text-embedding-v1"
  27. }
  28. return &AliEmbeddingRequest{
  29. Model: request.Model,
  30. Input: struct {
  31. Texts []string `json:"texts"`
  32. }{
  33. Texts: request.ParseInput(),
  34. },
  35. }
  36. }
  37. func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  38. var aliResponse AliEmbeddingResponse
  39. err := json.NewDecoder(resp.Body).Decode(&aliResponse)
  40. if err != nil {
  41. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  42. }
  43. err = resp.Body.Close()
  44. if err != nil {
  45. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  46. }
  47. if aliResponse.Code != "" {
  48. return &dto.OpenAIErrorWithStatusCode{
  49. Error: dto.OpenAIError{
  50. Message: aliResponse.Message,
  51. Type: aliResponse.Code,
  52. Param: aliResponse.RequestId,
  53. Code: aliResponse.Code,
  54. },
  55. StatusCode: resp.StatusCode,
  56. }, nil
  57. }
  58. fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
  59. jsonResponse, err := json.Marshal(fullTextResponse)
  60. if err != nil {
  61. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  62. }
  63. c.Writer.Header().Set("Content-Type", "application/json")
  64. c.Writer.WriteHeader(resp.StatusCode)
  65. _, err = c.Writer.Write(jsonResponse)
  66. return nil, &fullTextResponse.Usage
  67. }
  68. func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
  69. openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
  70. Object: "list",
  71. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
  72. Model: "text-embedding-v1",
  73. Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
  74. }
  75. for _, item := range response.Output.Embeddings {
  76. openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
  77. Object: `embedding`,
  78. Index: item.TextIndex,
  79. Embedding: item.Embedding,
  80. })
  81. }
  82. return &openAIEmbeddingResponse
  83. }
  84. func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
  85. content, _ := json.Marshal(response.Output.Text)
  86. choice := dto.OpenAITextResponseChoice{
  87. Index: 0,
  88. Message: dto.Message{
  89. Role: "assistant",
  90. Content: content,
  91. },
  92. FinishReason: response.Output.FinishReason,
  93. }
  94. fullTextResponse := dto.OpenAITextResponse{
  95. Id: response.RequestId,
  96. Object: "chat.completion",
  97. Created: common.GetTimestamp(),
  98. Choices: []dto.OpenAITextResponseChoice{choice},
  99. Usage: dto.Usage{
  100. PromptTokens: response.Usage.InputTokens,
  101. CompletionTokens: response.Usage.OutputTokens,
  102. TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
  103. },
  104. }
  105. return &fullTextResponse
  106. }
  107. func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
  108. var choice dto.ChatCompletionsStreamResponseChoice
  109. choice.Delta.SetContentString(aliResponse.Output.Text)
  110. if aliResponse.Output.FinishReason != "null" {
  111. finishReason := aliResponse.Output.FinishReason
  112. choice.FinishReason = &finishReason
  113. }
  114. response := dto.ChatCompletionsStreamResponse{
  115. Id: aliResponse.RequestId,
  116. Object: "chat.completion.chunk",
  117. Created: common.GetTimestamp(),
  118. Model: "ernie-bot",
  119. Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
  120. }
  121. return &response
  122. }
  123. func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  124. var usage dto.Usage
  125. scanner := bufio.NewScanner(resp.Body)
  126. scanner.Split(bufio.ScanLines)
  127. dataChan := make(chan string)
  128. stopChan := make(chan bool)
  129. go func() {
  130. for scanner.Scan() {
  131. data := scanner.Text()
  132. if len(data) < 5 { // ignore blank line or wrong format
  133. continue
  134. }
  135. if data[:5] != "data:" {
  136. continue
  137. }
  138. data = data[5:]
  139. dataChan <- data
  140. }
  141. stopChan <- true
  142. }()
  143. helper.SetEventStreamHeaders(c)
  144. lastResponseText := ""
  145. c.Stream(func(w io.Writer) bool {
  146. select {
  147. case data := <-dataChan:
  148. var aliResponse AliResponse
  149. err := json.Unmarshal([]byte(data), &aliResponse)
  150. if err != nil {
  151. common.SysError("error unmarshalling stream response: " + err.Error())
  152. return true
  153. }
  154. if aliResponse.Usage.OutputTokens != 0 {
  155. usage.PromptTokens = aliResponse.Usage.InputTokens
  156. usage.CompletionTokens = aliResponse.Usage.OutputTokens
  157. usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
  158. }
  159. response := streamResponseAli2OpenAI(&aliResponse)
  160. response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
  161. lastResponseText = aliResponse.Output.Text
  162. jsonResponse, err := json.Marshal(response)
  163. if err != nil {
  164. common.SysError("error marshalling stream response: " + err.Error())
  165. return true
  166. }
  167. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  168. return true
  169. case <-stopChan:
  170. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  171. return false
  172. }
  173. })
  174. err := resp.Body.Close()
  175. if err != nil {
  176. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  177. }
  178. return nil, &usage
  179. }
  180. func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  181. var aliResponse AliResponse
  182. responseBody, err := io.ReadAll(resp.Body)
  183. if err != nil {
  184. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  185. }
  186. err = resp.Body.Close()
  187. if err != nil {
  188. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  189. }
  190. err = json.Unmarshal(responseBody, &aliResponse)
  191. if err != nil {
  192. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  193. }
  194. if aliResponse.Code != "" {
  195. return &dto.OpenAIErrorWithStatusCode{
  196. Error: dto.OpenAIError{
  197. Message: aliResponse.Message,
  198. Type: aliResponse.Code,
  199. Param: aliResponse.RequestId,
  200. Code: aliResponse.Code,
  201. },
  202. StatusCode: resp.StatusCode,
  203. }, nil
  204. }
  205. fullTextResponse := responseAli2OpenAI(&aliResponse)
  206. jsonResponse, err := json.Marshal(fullTextResponse)
  207. if err != nil {
  208. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  209. }
  210. c.Writer.Header().Set("Content-Type", "application/json")
  211. c.Writer.WriteHeader(resp.StatusCode)
  212. _, err = c.Writer.Write(jsonResponse)
  213. return nil, &fullTextResponse.Usage
  214. }