relay-utils.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. package controller
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. "github.com/pkoukk/tiktoken-go"
  7. "io"
  8. "net/http"
  9. "one-api/common"
  10. "strconv"
  11. "strings"
  12. "unicode/utf8"
  13. )
  14. var stopFinishReason = "stop"
  15. // tokenEncoderMap won't grow after initialization
  16. var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
  17. var defaultTokenEncoder *tiktoken.Tiktoken
  18. func InitTokenEncoders() {
  19. common.SysLog("initializing token encoders")
  20. gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
  21. if err != nil {
  22. common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
  23. }
  24. defaultTokenEncoder = gpt35TokenEncoder
  25. gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
  26. if err != nil {
  27. common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
  28. }
  29. for model, _ := range common.ModelRatio {
  30. if strings.HasPrefix(model, "gpt-3.5") {
  31. tokenEncoderMap[model] = gpt35TokenEncoder
  32. } else if strings.HasPrefix(model, "gpt-4") {
  33. tokenEncoderMap[model] = gpt4TokenEncoder
  34. } else {
  35. tokenEncoderMap[model] = nil
  36. }
  37. }
  38. common.SysLog("token encoders initialized")
  39. }
  40. func getTokenEncoder(model string) *tiktoken.Tiktoken {
  41. tokenEncoder, ok := tokenEncoderMap[model]
  42. if ok && tokenEncoder != nil {
  43. return tokenEncoder
  44. }
  45. if ok {
  46. tokenEncoder, err := tiktoken.EncodingForModel(model)
  47. if err != nil {
  48. common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
  49. tokenEncoder = defaultTokenEncoder
  50. }
  51. tokenEncoderMap[model] = tokenEncoder
  52. return tokenEncoder
  53. }
  54. return defaultTokenEncoder
  55. }
  56. func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
  57. return len(tokenEncoder.Encode(text, nil, nil))
  58. }
  59. func countTokenMessages(messages []Message, model string) int {
  60. tokenEncoder := getTokenEncoder(model)
  61. // Reference:
  62. // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
  63. // https://github.com/pkoukk/tiktoken-go/issues/6
  64. //
  65. // Every message follows <|start|>{role/name}\n{content}<|end|>\n
  66. var tokensPerMessage int
  67. var tokensPerName int
  68. if model == "gpt-3.5-turbo-0301" {
  69. tokensPerMessage = 4
  70. tokensPerName = -1 // If there's a name, the role is omitted
  71. } else {
  72. tokensPerMessage = 3
  73. tokensPerName = 1
  74. }
  75. tokenNum := 0
  76. for _, message := range messages {
  77. tokenNum += tokensPerMessage
  78. tokenNum += getTokenNum(tokenEncoder, message.Content)
  79. tokenNum += getTokenNum(tokenEncoder, message.Role)
  80. if message.Name != nil {
  81. tokenNum += tokensPerName
  82. tokenNum += getTokenNum(tokenEncoder, *message.Name)
  83. }
  84. }
  85. tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
  86. return tokenNum
  87. }
  88. func countTokenInput(input any, model string) int {
  89. switch input.(type) {
  90. case string:
  91. return countTokenText(input.(string), model)
  92. case []string:
  93. text := ""
  94. for _, s := range input.([]string) {
  95. text += s
  96. }
  97. return countTokenText(text, model)
  98. }
  99. return 0
  100. }
  101. func countAudioToken(text string, model string) int {
  102. if strings.HasPrefix(model, "tts") {
  103. return utf8.RuneCountInString(text)
  104. } else {
  105. return countTokenText(text, model)
  106. }
  107. }
  108. func countTokenText(text string, model string) int {
  109. tokenEncoder := getTokenEncoder(model)
  110. return getTokenNum(tokenEncoder, text)
  111. }
  112. func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
  113. text := err.Error()
  114. // 定义一个正则表达式匹配URL
  115. if strings.Contains(text, "Post") {
  116. text = "请求上游地址失败"
  117. }
  118. //避免暴露内部错误
  119. openAIError := OpenAIError{
  120. Message: text,
  121. Type: "one_api_error",
  122. Code: code,
  123. }
  124. return &OpenAIErrorWithStatusCode{
  125. OpenAIError: openAIError,
  126. StatusCode: statusCode,
  127. }
  128. }
  129. func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
  130. if !common.AutomaticDisableChannelEnabled {
  131. return false
  132. }
  133. if err == nil {
  134. return false
  135. }
  136. if statusCode == http.StatusUnauthorized {
  137. return true
  138. }
  139. if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
  140. return true
  141. }
  142. return false
  143. }
  144. func setEventStreamHeaders(c *gin.Context) {
  145. c.Writer.Header().Set("Content-Type", "text/event-stream")
  146. c.Writer.Header().Set("Cache-Control", "no-cache")
  147. c.Writer.Header().Set("Connection", "keep-alive")
  148. c.Writer.Header().Set("Transfer-Encoding", "chunked")
  149. c.Writer.Header().Set("X-Accel-Buffering", "no")
  150. }
  151. func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
  152. openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
  153. StatusCode: resp.StatusCode,
  154. OpenAIError: OpenAIError{
  155. Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
  156. Type: "upstream_error",
  157. Code: "bad_response_status_code",
  158. Param: strconv.Itoa(resp.StatusCode),
  159. },
  160. }
  161. responseBody, err := io.ReadAll(resp.Body)
  162. if err != nil {
  163. return
  164. }
  165. err = resp.Body.Close()
  166. if err != nil {
  167. return
  168. }
  169. var textResponse TextResponse
  170. err = json.Unmarshal(responseBody, &textResponse)
  171. if err != nil {
  172. return
  173. }
  174. openAIErrorWithStatusCode.OpenAIError = textResponse.Error
  175. return
  176. }
  177. func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
  178. fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
  179. if channelType == common.ChannelTypeOpenAI {
  180. if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
  181. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
  182. }
  183. }
  184. return fullRequestURL
  185. }