relay-utils.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. package controller
  2. import (
  3. "fmt"
  4. "github.com/pkoukk/tiktoken-go"
  5. "one-api/common"
  6. "strings"
  7. )
  8. var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
  9. func getTokenEncoder(model string) *tiktoken.Tiktoken {
  10. if tokenEncoder, ok := tokenEncoderMap[model]; ok {
  11. return tokenEncoder
  12. }
  13. tokenEncoder, err := tiktoken.EncodingForModel(model)
  14. if err != nil {
  15. common.FatalLog(fmt.Sprintf("failed to get token encoder for model %s: %s", model, err.Error()))
  16. }
  17. tokenEncoderMap[model] = tokenEncoder
  18. return tokenEncoder
  19. }
  20. func countTokenMessages(messages []Message, model string) int {
  21. tokenEncoder := getTokenEncoder(model)
  22. // Reference:
  23. // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
  24. // https://github.com/pkoukk/tiktoken-go/issues/6
  25. //
  26. // Every message follows <|start|>{role/name}\n{content}<|end|>\n
  27. var tokensPerMessage int
  28. var tokensPerName int
  29. if strings.HasPrefix(model, "gpt-3.5") {
  30. tokensPerMessage = 4
  31. tokensPerName = -1 // If there's a name, the role is omitted
  32. } else if strings.HasPrefix(model, "gpt-4") {
  33. tokensPerMessage = 3
  34. tokensPerName = 1
  35. } else {
  36. tokensPerMessage = 3
  37. tokensPerName = 1
  38. }
  39. tokenNum := 0
  40. for _, message := range messages {
  41. tokenNum += tokensPerMessage
  42. tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil))
  43. tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil))
  44. if message.Name != nil {
  45. tokenNum += tokensPerName
  46. tokenNum += len(tokenEncoder.Encode(*message.Name, nil, nil))
  47. }
  48. }
  49. tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
  50. return tokenNum
  51. }
  52. func countTokenText(text string, model string) int {
  53. tokenEncoder := getTokenEncoder(model)
  54. token := tokenEncoder.Encode(text, nil, nil)
  55. return len(token)
  56. }