relay-text.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. package controller
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "fmt"
  7. "github.com/gin-gonic/gin"
  8. "io"
  9. "net/http"
  10. "one-api/common"
  11. "one-api/model"
  12. "strings"
  13. )
  14. func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
  15. channelType := c.GetInt("channel")
  16. tokenId := c.GetInt("token_id")
  17. consumeQuota := c.GetBool("consume_quota")
  18. group := c.GetString("group")
  19. var textRequest GeneralOpenAIRequest
  20. if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
  21. err := common.UnmarshalBodyReusable(c, &textRequest)
  22. if err != nil {
  23. return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
  24. }
  25. }
  26. if relayMode == RelayModeModeration && textRequest.Model == "" {
  27. textRequest.Model = "text-moderation-latest"
  28. }
  29. baseURL := common.ChannelBaseURLs[channelType]
  30. requestURL := c.Request.URL.String()
  31. if channelType == common.ChannelTypeCustom {
  32. baseURL = c.GetString("base_url")
  33. } else if channelType == common.ChannelTypeOpenAI {
  34. if c.GetString("base_url") != "" {
  35. baseURL = c.GetString("base_url")
  36. }
  37. }
  38. fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
  39. if channelType == common.ChannelTypeAzure {
  40. // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
  41. query := c.Request.URL.Query()
  42. apiVersion := query.Get("api-version")
  43. if apiVersion == "" {
  44. apiVersion = c.GetString("api_version")
  45. }
  46. requestURL := strings.Split(requestURL, "?")[0]
  47. requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
  48. baseURL = c.GetString("base_url")
  49. task := strings.TrimPrefix(requestURL, "/v1/")
  50. model_ := textRequest.Model
  51. model_ = strings.Replace(model_, ".", "", -1)
  52. // https://github.com/songquanpeng/one-api/issues/67
  53. model_ = strings.TrimSuffix(model_, "-0301")
  54. model_ = strings.TrimSuffix(model_, "-0314")
  55. model_ = strings.TrimSuffix(model_, "-0613")
  56. fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
  57. } else if channelType == common.ChannelTypePaLM {
  58. err := relayPaLM(textRequest, c)
  59. return err
  60. }
  61. var promptTokens int
  62. switch relayMode {
  63. case RelayModeChatCompletions:
  64. promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
  65. case RelayModeCompletions:
  66. promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
  67. case RelayModeModeration:
  68. promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
  69. }
  70. preConsumedTokens := common.PreConsumedQuota
  71. if textRequest.MaxTokens != 0 {
  72. preConsumedTokens = promptTokens + textRequest.MaxTokens
  73. }
  74. modelRatio := common.GetModelRatio(textRequest.Model)
  75. groupRatio := common.GetGroupRatio(group)
  76. ratio := modelRatio * groupRatio
  77. preConsumedQuota := int(float64(preConsumedTokens) * ratio)
  78. if consumeQuota {
  79. err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
  80. if err != nil {
  81. return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
  82. }
  83. }
  84. req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
  85. if err != nil {
  86. return errorWrapper(err, "new_request_failed", http.StatusOK)
  87. }
  88. if channelType == common.ChannelTypeAzure {
  89. key := c.Request.Header.Get("Authorization")
  90. key = strings.TrimPrefix(key, "Bearer ")
  91. req.Header.Set("api-key", key)
  92. } else {
  93. req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
  94. }
  95. req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
  96. req.Header.Set("Accept", c.Request.Header.Get("Accept"))
  97. req.Header.Set("Connection", c.Request.Header.Get("Connection"))
  98. client := &http.Client{}
  99. resp, err := client.Do(req)
  100. if err != nil {
  101. return errorWrapper(err, "do_request_failed", http.StatusOK)
  102. }
  103. err = req.Body.Close()
  104. if err != nil {
  105. return errorWrapper(err, "close_request_body_failed", http.StatusOK)
  106. }
  107. err = c.Request.Body.Close()
  108. if err != nil {
  109. return errorWrapper(err, "close_request_body_failed", http.StatusOK)
  110. }
  111. var textResponse TextResponse
  112. isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
  113. var streamResponseText string
  114. defer func() {
  115. if consumeQuota {
  116. quota := 0
  117. completionRatio := 1.34 // default for gpt-3
  118. if strings.HasPrefix(textRequest.Model, "gpt-4") {
  119. completionRatio = 2
  120. }
  121. if isStream {
  122. responseTokens := countTokenText(streamResponseText, textRequest.Model)
  123. quota = promptTokens + int(float64(responseTokens)*completionRatio)
  124. } else {
  125. quota = textResponse.Usage.PromptTokens + int(float64(textResponse.Usage.CompletionTokens)*completionRatio)
  126. }
  127. quota = int(float64(quota) * ratio)
  128. if ratio != 0 && quota <= 0 {
  129. quota = 1
  130. }
  131. quotaDelta := quota - preConsumedQuota
  132. err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
  133. if err != nil {
  134. common.SysError("Error consuming token remain quota: " + err.Error())
  135. }
  136. tokenName := c.GetString("token_name")
  137. userId := c.GetInt("id")
  138. model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("通过令牌「%s」使用模型 %s 消耗 %s(模型倍率 %.2f,分组倍率 %.2f)", tokenName, textRequest.Model, common.LogQuota(quota), modelRatio, groupRatio))
  139. model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
  140. channelId := c.GetInt("channel_id")
  141. model.UpdateChannelUsedQuota(channelId, quota)
  142. }
  143. }()
  144. if isStream {
  145. scanner := bufio.NewScanner(resp.Body)
  146. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  147. if atEOF && len(data) == 0 {
  148. return 0, nil, nil
  149. }
  150. if i := strings.Index(string(data), "\n\n"); i >= 0 {
  151. return i + 2, data[0:i], nil
  152. }
  153. if atEOF {
  154. return len(data), data, nil
  155. }
  156. return 0, nil, nil
  157. })
  158. dataChan := make(chan string)
  159. stopChan := make(chan bool)
  160. go func() {
  161. for scanner.Scan() {
  162. data := scanner.Text()
  163. if len(data) < 6 { // must be something wrong!
  164. common.SysError("Invalid stream response: " + data)
  165. continue
  166. }
  167. dataChan <- data
  168. data = data[6:]
  169. if !strings.HasPrefix(data, "[DONE]") {
  170. switch relayMode {
  171. case RelayModeChatCompletions:
  172. var streamResponse ChatCompletionsStreamResponse
  173. err = json.Unmarshal([]byte(data), &streamResponse)
  174. if err != nil {
  175. common.SysError("Error unmarshalling stream response: " + err.Error())
  176. return
  177. }
  178. for _, choice := range streamResponse.Choices {
  179. streamResponseText += choice.Delta.Content
  180. }
  181. case RelayModeCompletions:
  182. var streamResponse CompletionsStreamResponse
  183. err = json.Unmarshal([]byte(data), &streamResponse)
  184. if err != nil {
  185. common.SysError("Error unmarshalling stream response: " + err.Error())
  186. return
  187. }
  188. for _, choice := range streamResponse.Choices {
  189. streamResponseText += choice.Text
  190. }
  191. }
  192. }
  193. }
  194. stopChan <- true
  195. }()
  196. c.Writer.Header().Set("Content-Type", "text/event-stream")
  197. c.Writer.Header().Set("Cache-Control", "no-cache")
  198. c.Writer.Header().Set("Connection", "keep-alive")
  199. c.Writer.Header().Set("Transfer-Encoding", "chunked")
  200. c.Writer.Header().Set("X-Accel-Buffering", "no")
  201. c.Stream(func(w io.Writer) bool {
  202. select {
  203. case data := <-dataChan:
  204. if strings.HasPrefix(data, "data: [DONE]") {
  205. data = data[:12]
  206. }
  207. c.Render(-1, common.CustomEvent{Data: data})
  208. return true
  209. case <-stopChan:
  210. return false
  211. }
  212. })
  213. err = resp.Body.Close()
  214. if err != nil {
  215. return errorWrapper(err, "close_response_body_failed", http.StatusOK)
  216. }
  217. return nil
  218. } else {
  219. if consumeQuota {
  220. responseBody, err := io.ReadAll(resp.Body)
  221. if err != nil {
  222. return errorWrapper(err, "read_response_body_failed", http.StatusOK)
  223. }
  224. err = resp.Body.Close()
  225. if err != nil {
  226. return errorWrapper(err, "close_response_body_failed", http.StatusOK)
  227. }
  228. err = json.Unmarshal(responseBody, &textResponse)
  229. if err != nil {
  230. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK)
  231. }
  232. if textResponse.Error.Type != "" {
  233. return &OpenAIErrorWithStatusCode{
  234. OpenAIError: textResponse.Error,
  235. StatusCode: resp.StatusCode,
  236. }
  237. }
  238. // Reset response body
  239. resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
  240. }
  241. // We shouldn't set the header before we parse the response body, because the parse part may fail.
  242. // And then we will have to send an error response, but in this case, the header has already been set.
  243. // So the client will be confused by the response.
  244. // For example, Postman will report error, and we cannot check the response at all.
  245. for k, v := range resp.Header {
  246. c.Writer.Header().Set(k, v[0])
  247. }
  248. c.Writer.WriteHeader(resp.StatusCode)
  249. _, err = io.Copy(c.Writer, resp.Body)
  250. if err != nil {
  251. return errorWrapper(err, "copy_response_body_failed", http.StatusOK)
  252. }
  253. err = resp.Body.Close()
  254. if err != nil {
  255. return errorWrapper(err, "close_response_body_failed", http.StatusOK)
  256. }
  257. return nil
  258. }
  259. }