relay-utils.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. package controller
  2. import (
  3. "fmt"
  4. "github.com/pkoukk/tiktoken-go"
  5. "one-api/common"
  6. )
  7. var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
  8. func getTokenEncoder(model string) *tiktoken.Tiktoken {
  9. if tokenEncoder, ok := tokenEncoderMap[model]; ok {
  10. return tokenEncoder
  11. }
  12. tokenEncoder, err := tiktoken.EncodingForModel(model)
  13. if err != nil {
  14. common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
  15. tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
  16. if err != nil {
  17. common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
  18. }
  19. }
  20. tokenEncoderMap[model] = tokenEncoder
  21. return tokenEncoder
  22. }
  23. func countTokenMessages(messages []Message, model string) int {
  24. tokenEncoder := getTokenEncoder(model)
  25. // Reference:
  26. // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
  27. // https://github.com/pkoukk/tiktoken-go/issues/6
  28. //
  29. // Every message follows <|start|>{role/name}\n{content}<|end|>\n
  30. var tokensPerMessage int
  31. var tokensPerName int
  32. if model == "gpt-3.5-turbo-0301" {
  33. tokensPerMessage = 4
  34. tokensPerName = -1 // If there's a name, the role is omitted
  35. } else {
  36. tokensPerMessage = 3
  37. tokensPerName = 1
  38. }
  39. tokenNum := 0
  40. for _, message := range messages {
  41. tokenNum += tokensPerMessage
  42. tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil))
  43. tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil))
  44. if message.Name != nil {
  45. tokenNum += tokensPerName
  46. tokenNum += len(tokenEncoder.Encode(*message.Name, nil, nil))
  47. }
  48. }
  49. tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
  50. return tokenNum
  51. }
  52. func countTokenInput(input any, model string) int {
  53. switch input.(type) {
  54. case string:
  55. return countTokenText(input.(string), model)
  56. case []string:
  57. text := ""
  58. for _, s := range input.([]string) {
  59. text += s
  60. }
  61. return countTokenText(text, model)
  62. }
  63. return 0
  64. }
  65. func countTokenText(text string, model string) int {
  66. tokenEncoder := getTokenEncoder(model)
  67. token := tokenEncoder.Encode(text, nil, nil)
  68. return len(token)
  69. }
  70. func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
  71. openAIError := OpenAIError{
  72. Message: err.Error(),
  73. Type: "one_api_error",
  74. Code: code,
  75. }
  76. return &OpenAIErrorWithStatusCode{
  77. OpenAIError: openAIError,
  78. StatusCode: statusCode,
  79. }
  80. }