relay-ali.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. package ali
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "github.com/gin-gonic/gin"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/dto"
  10. "one-api/service"
  11. "strings"
  12. )
  13. // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
  14. const EnableSearchModelSuffix = "-internet"
  15. func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest {
  16. messages := make([]AliMessage, 0, len(request.Messages))
  17. //prompt := ""
  18. for i := 0; i < len(request.Messages); i++ {
  19. message := request.Messages[i]
  20. messages = append(messages, AliMessage{
  21. Content: message.StringContent(),
  22. Role: strings.ToLower(message.Role),
  23. })
  24. }
  25. enableSearch := false
  26. aliModel := request.Model
  27. if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
  28. enableSearch = true
  29. aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
  30. }
  31. return &AliChatRequest{
  32. Model: request.Model,
  33. Input: AliInput{
  34. //Prompt: prompt,
  35. Messages: messages,
  36. },
  37. Parameters: AliParameters{
  38. IncrementalOutput: request.Stream,
  39. Seed: uint64(request.Seed),
  40. EnableSearch: enableSearch,
  41. },
  42. }
  43. }
  44. func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
  45. return &AliEmbeddingRequest{
  46. Model: "text-embedding-v1",
  47. Input: struct {
  48. Texts []string `json:"texts"`
  49. }{
  50. Texts: request.ParseInput(),
  51. },
  52. }
  53. }
  54. func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  55. var aliResponse AliEmbeddingResponse
  56. err := json.NewDecoder(resp.Body).Decode(&aliResponse)
  57. if err != nil {
  58. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  59. }
  60. err = resp.Body.Close()
  61. if err != nil {
  62. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  63. }
  64. if aliResponse.Code != "" {
  65. return &dto.OpenAIErrorWithStatusCode{
  66. Error: dto.OpenAIError{
  67. Message: aliResponse.Message,
  68. Type: aliResponse.Code,
  69. Param: aliResponse.RequestId,
  70. Code: aliResponse.Code,
  71. },
  72. StatusCode: resp.StatusCode,
  73. }, nil
  74. }
  75. fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
  76. jsonResponse, err := json.Marshal(fullTextResponse)
  77. if err != nil {
  78. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  79. }
  80. c.Writer.Header().Set("Content-Type", "application/json")
  81. c.Writer.WriteHeader(resp.StatusCode)
  82. _, err = c.Writer.Write(jsonResponse)
  83. return nil, &fullTextResponse.Usage
  84. }
  85. func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
  86. openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
  87. Object: "list",
  88. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
  89. Model: "text-embedding-v1",
  90. Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
  91. }
  92. for _, item := range response.Output.Embeddings {
  93. openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
  94. Object: `embedding`,
  95. Index: item.TextIndex,
  96. Embedding: item.Embedding,
  97. })
  98. }
  99. return &openAIEmbeddingResponse
  100. }
  101. func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse {
  102. content, _ := json.Marshal(response.Output.Text)
  103. choice := dto.OpenAITextResponseChoice{
  104. Index: 0,
  105. Message: dto.Message{
  106. Role: "assistant",
  107. Content: content,
  108. },
  109. FinishReason: response.Output.FinishReason,
  110. }
  111. fullTextResponse := dto.OpenAITextResponse{
  112. Id: response.RequestId,
  113. Object: "chat.completion",
  114. Created: common.GetTimestamp(),
  115. Choices: []dto.OpenAITextResponseChoice{choice},
  116. Usage: dto.Usage{
  117. PromptTokens: response.Usage.InputTokens,
  118. CompletionTokens: response.Usage.OutputTokens,
  119. TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
  120. },
  121. }
  122. return &fullTextResponse
  123. }
  124. func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse {
  125. var choice dto.ChatCompletionsStreamResponseChoice
  126. choice.Delta.SetContentString(aliResponse.Output.Text)
  127. if aliResponse.Output.FinishReason != "null" {
  128. finishReason := aliResponse.Output.FinishReason
  129. choice.FinishReason = &finishReason
  130. }
  131. response := dto.ChatCompletionsStreamResponse{
  132. Id: aliResponse.RequestId,
  133. Object: "chat.completion.chunk",
  134. Created: common.GetTimestamp(),
  135. Model: "ernie-bot",
  136. Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
  137. }
  138. return &response
  139. }
  140. func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  141. var usage dto.Usage
  142. scanner := bufio.NewScanner(resp.Body)
  143. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  144. if atEOF && len(data) == 0 {
  145. return 0, nil, nil
  146. }
  147. if i := strings.Index(string(data), "\n"); i >= 0 {
  148. return i + 1, data[0:i], nil
  149. }
  150. if atEOF {
  151. return len(data), data, nil
  152. }
  153. return 0, nil, nil
  154. })
  155. dataChan := make(chan string)
  156. stopChan := make(chan bool)
  157. go func() {
  158. for scanner.Scan() {
  159. data := scanner.Text()
  160. if len(data) < 5 { // ignore blank line or wrong format
  161. continue
  162. }
  163. if data[:5] != "data:" {
  164. continue
  165. }
  166. data = data[5:]
  167. dataChan <- data
  168. }
  169. stopChan <- true
  170. }()
  171. service.SetEventStreamHeaders(c)
  172. lastResponseText := ""
  173. c.Stream(func(w io.Writer) bool {
  174. select {
  175. case data := <-dataChan:
  176. var aliResponse AliChatResponse
  177. err := json.Unmarshal([]byte(data), &aliResponse)
  178. if err != nil {
  179. common.SysError("error unmarshalling stream response: " + err.Error())
  180. return true
  181. }
  182. if aliResponse.Usage.OutputTokens != 0 {
  183. usage.PromptTokens = aliResponse.Usage.InputTokens
  184. usage.CompletionTokens = aliResponse.Usage.OutputTokens
  185. usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
  186. }
  187. response := streamResponseAli2OpenAI(&aliResponse)
  188. response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
  189. lastResponseText = aliResponse.Output.Text
  190. jsonResponse, err := json.Marshal(response)
  191. if err != nil {
  192. common.SysError("error marshalling stream response: " + err.Error())
  193. return true
  194. }
  195. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  196. return true
  197. case <-stopChan:
  198. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  199. return false
  200. }
  201. })
  202. err := resp.Body.Close()
  203. if err != nil {
  204. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  205. }
  206. return nil, &usage
  207. }
  208. func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
  209. var aliResponse AliChatResponse
  210. responseBody, err := io.ReadAll(resp.Body)
  211. if err != nil {
  212. return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  213. }
  214. err = resp.Body.Close()
  215. if err != nil {
  216. return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  217. }
  218. err = json.Unmarshal(responseBody, &aliResponse)
  219. if err != nil {
  220. return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  221. }
  222. if aliResponse.Code != "" {
  223. return &dto.OpenAIErrorWithStatusCode{
  224. Error: dto.OpenAIError{
  225. Message: aliResponse.Message,
  226. Type: aliResponse.Code,
  227. Param: aliResponse.RequestId,
  228. Code: aliResponse.Code,
  229. },
  230. StatusCode: resp.StatusCode,
  231. }, nil
  232. }
  233. fullTextResponse := responseAli2OpenAI(&aliResponse)
  234. jsonResponse, err := json.Marshal(fullTextResponse)
  235. if err != nil {
  236. return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  237. }
  238. c.Writer.Header().Set("Content-Type", "application/json")
  239. c.Writer.WriteHeader(resp.StatusCode)
  240. _, err = c.Writer.Write(jsonResponse)
  241. return nil, &fullTextResponse.Usage
  242. }