text.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. package ali
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "github.com/gin-gonic/gin"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. "one-api/service"
  11. "strings"
  12. )
  13. // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
  14. const EnableSearchModelSuffix = "-internet"
  15. func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
  16. if request.TopP >= 1 {
  17. request.TopP = 0.999
  18. } else if request.TopP <= 0 {
  19. request.TopP = 0.001
  20. }
  21. return &request
  22. }
  23. func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
  24. return &AliEmbeddingRequest{
  25. Model: "text-embedding-v1",
  26. Input: struct {
  27. Texts []string `json:"texts"`
  28. }{
  29. Texts: request.ParseInput(),
  30. },
  31. }
  32. }
  33. func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  34. var aliResponse AliEmbeddingResponse
  35. err := json.NewDecoder(resp.Body).Decode(&aliResponse)
  36. if err != nil {
  37. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  38. }
  39. err = resp.Body.Close()
  40. if err != nil {
  41. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  42. }
  43. if aliResponse.Code != "" {
  44. return &dto.OpenAIErrorWithStatusCode{
  45. Error: dto.OpenAIError{
  46. Message: aliResponse.Message,
  47. Type: aliResponse.Code,
  48. Param: aliResponse.RequestId,
  49. Code: aliResponse.Code,
  50. },
  51. StatusCode: resp.StatusCode,
  52. }, nil
  53. }
  54. fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
  55. jsonResponse, err := json.Marshal(fullTextResponse)
  56. if err != nil {
  57. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  58. }
  59. c.Writer.Header().Set("Content-Type", "application/json")
  60. c.Writer.WriteHeader(resp.StatusCode)
  61. _, err = c.Writer.Write(jsonResponse)
  62. return nil, &fullTextResponse.Usage
  63. }
  64. func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
  65. openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
  66. Object: "list",
  67. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
  68. Model: "text-embedding-v1",
  69. Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
  70. }
  71. for _, item := range response.Output.Embeddings {
  72. openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
  73. Object: `embedding`,
  74. Index: item.TextIndex,
  75. Embedding: item.Embedding,
  76. })
  77. }
  78. return &openAIEmbeddingResponse
  79. }
  80. func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
  81. content, _ := json.Marshal(response.Output.Text)
  82. choice := dto.OpenAITextResponseChoice{
  83. Index: 0,
  84. Message: dto.Message{
  85. Role: "assistant",
  86. Content: content,
  87. },
  88. FinishReason: response.Output.FinishReason,
  89. }
  90. fullTextResponse := dto.OpenAITextResponse{
  91. Id: response.RequestId,
  92. Object: "chat.completion",
  93. Created: common.GetTimestamp(),
  94. Choices: []dto.OpenAITextResponseChoice{choice},
  95. Usage: dto.Usage{
  96. PromptTokens: response.Usage.InputTokens,
  97. CompletionTokens: response.Usage.OutputTokens,
  98. TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
  99. },
  100. }
  101. return &fullTextResponse
  102. }
  103. func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
  104. var choice dto.ChatCompletionsStreamResponseChoice
  105. choice.Delta.SetContentString(aliResponse.Output.Text)
  106. if aliResponse.Output.FinishReason != "null" {
  107. finishReason := aliResponse.Output.FinishReason
  108. choice.FinishReason = &finishReason
  109. }
  110. response := dto.ChatCompletionsStreamResponse{
  111. Id: aliResponse.RequestId,
  112. Object: "chat.completion.chunk",
  113. Created: common.GetTimestamp(),
  114. Model: "ernie-bot",
  115. Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
  116. }
  117. return &response
  118. }
  119. func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  120. var usage dto.Usage
  121. scanner := bufio.NewScanner(resp.Body)
  122. scanner.Split(bufio.ScanLines)
  123. dataChan := make(chan string)
  124. stopChan := make(chan bool)
  125. go func() {
  126. for scanner.Scan() {
  127. data := scanner.Text()
  128. if len(data) < 5 { // ignore blank line or wrong format
  129. continue
  130. }
  131. if data[:5] != "data:" {
  132. continue
  133. }
  134. data = data[5:]
  135. dataChan <- data
  136. }
  137. stopChan <- true
  138. }()
  139. service.SetEventStreamHeaders(c)
  140. lastResponseText := ""
  141. c.Stream(func(w io.Writer) bool {
  142. select {
  143. case data := <-dataChan:
  144. var aliResponse AliResponse
  145. err := json.Unmarshal([]byte(data), &aliResponse)
  146. if err != nil {
  147. common.SysError("error unmarshalling stream response: " + err.Error())
  148. return true
  149. }
  150. if aliResponse.Usage.OutputTokens != 0 {
  151. usage.PromptTokens = aliResponse.Usage.InputTokens
  152. usage.CompletionTokens = aliResponse.Usage.OutputTokens
  153. usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
  154. }
  155. response := streamResponseAli2OpenAI(&aliResponse)
  156. response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
  157. lastResponseText = aliResponse.Output.Text
  158. jsonResponse, err := json.Marshal(response)
  159. if err != nil {
  160. common.SysError("error marshalling stream response: " + err.Error())
  161. return true
  162. }
  163. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  164. return true
  165. case <-stopChan:
  166. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  167. return false
  168. }
  169. })
  170. err := resp.Body.Close()
  171. if err != nil {
  172. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  173. }
  174. return nil, &usage
  175. }
  176. func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  177. var aliResponse AliResponse
  178. responseBody, err := io.ReadAll(resp.Body)
  179. if err != nil {
  180. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  181. }
  182. err = resp.Body.Close()
  183. if err != nil {
  184. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  185. }
  186. err = json.Unmarshal(responseBody, &aliResponse)
  187. if err != nil {
  188. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  189. }
  190. if aliResponse.Code != "" {
  191. return &dto.OpenAIErrorWithStatusCode{
  192. Error: dto.OpenAIError{
  193. Message: aliResponse.Message,
  194. Type: aliResponse.Code,
  195. Param: aliResponse.RequestId,
  196. Code: aliResponse.Code,
  197. },
  198. StatusCode: resp.StatusCode,
  199. }, nil
  200. }
  201. fullTextResponse := responseAli2OpenAI(&aliResponse)
  202. jsonResponse, err := json.Marshal(fullTextResponse)
  203. if err != nil {
  204. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  205. }
  206. c.Writer.Header().Set("Content-Type", "application/json")
  207. c.Writer.WriteHeader(resp.StatusCode)
  208. _, err = c.Writer.Write(jsonResponse)
  209. return nil, &fullTextResponse.Usage
  210. }