relay_info.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. package common
  2. import (
  3. "github.com/gin-gonic/gin"
  4. "github.com/gorilla/websocket"
  5. "one-api/common"
  6. "one-api/relay/constant"
  7. "strings"
  8. "time"
  9. )
  10. type RelayInfo struct {
  11. ChannelType int
  12. ChannelId int
  13. TokenId int
  14. UserId int
  15. Group string
  16. TokenUnlimited bool
  17. StartTime time.Time
  18. FirstResponseTime time.Time
  19. setFirstResponse bool
  20. ApiType int
  21. IsStream bool
  22. IsPlayground bool
  23. RelayMode int
  24. UpstreamModelName string
  25. OriginModelName string
  26. RequestURLPath string
  27. ApiVersion string
  28. PromptTokens int
  29. ApiKey string
  30. Organization string
  31. BaseUrl string
  32. SupportStreamOptions bool
  33. ShouldIncludeUsage bool
  34. ClientWs *websocket.Conn
  35. TargetWs *websocket.Conn
  36. }
  37. func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
  38. info := GenRelayInfo(c)
  39. info.ClientWs = ws
  40. return info
  41. }
  42. func GenRelayInfo(c *gin.Context) *RelayInfo {
  43. channelType := c.GetInt("channel_type")
  44. channelId := c.GetInt("channel_id")
  45. tokenId := c.GetInt("token_id")
  46. userId := c.GetInt("id")
  47. group := c.GetString("group")
  48. tokenUnlimited := c.GetBool("token_unlimited_quota")
  49. startTime := time.Now()
  50. // firstResponseTime = time.Now() - 1 second
  51. apiType, _ := constant.ChannelType2APIType(channelType)
  52. info := &RelayInfo{
  53. RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
  54. BaseUrl: c.GetString("base_url"),
  55. RequestURLPath: c.Request.URL.String(),
  56. ChannelType: channelType,
  57. ChannelId: channelId,
  58. TokenId: tokenId,
  59. UserId: userId,
  60. Group: group,
  61. TokenUnlimited: tokenUnlimited,
  62. StartTime: startTime,
  63. FirstResponseTime: startTime.Add(-time.Second),
  64. OriginModelName: c.GetString("original_model"),
  65. UpstreamModelName: c.GetString("original_model"),
  66. ApiType: apiType,
  67. ApiVersion: c.GetString("api_version"),
  68. ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
  69. Organization: c.GetString("channel_organization"),
  70. }
  71. if strings.HasPrefix(c.Request.URL.Path, "/pg") {
  72. info.IsPlayground = true
  73. info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg")
  74. info.RequestURLPath = "/v1" + info.RequestURLPath
  75. }
  76. if info.BaseUrl == "" {
  77. info.BaseUrl = common.ChannelBaseURLs[channelType]
  78. }
  79. if info.ChannelType == common.ChannelTypeAzure {
  80. info.ApiVersion = GetAPIVersion(c)
  81. }
  82. if info.ChannelType == common.ChannelTypeVertexAi {
  83. info.ApiVersion = c.GetString("region")
  84. }
  85. if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
  86. info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
  87. info.ChannelType == common.ChannelCloudflare {
  88. info.SupportStreamOptions = true
  89. }
  90. return info
  91. }
  92. func (info *RelayInfo) SetPromptTokens(promptTokens int) {
  93. info.PromptTokens = promptTokens
  94. }
  95. func (info *RelayInfo) SetIsStream(isStream bool) {
  96. info.IsStream = isStream
  97. }
  98. func (info *RelayInfo) SetFirstResponseTime() {
  99. if !info.setFirstResponse {
  100. info.FirstResponseTime = time.Now()
  101. info.setFirstResponse = true
  102. }
  103. }
  104. type TaskRelayInfo struct {
  105. ChannelType int
  106. ChannelId int
  107. TokenId int
  108. UserId int
  109. Group string
  110. StartTime time.Time
  111. ApiType int
  112. RelayMode int
  113. UpstreamModelName string
  114. RequestURLPath string
  115. ApiKey string
  116. BaseUrl string
  117. Action string
  118. OriginTaskID string
  119. ConsumeQuota bool
  120. }
  121. func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
  122. channelType := c.GetInt("channel_type")
  123. channelId := c.GetInt("channel_id")
  124. tokenId := c.GetInt("token_id")
  125. userId := c.GetInt("id")
  126. group := c.GetString("group")
  127. startTime := time.Now()
  128. apiType, _ := constant.ChannelType2APIType(channelType)
  129. info := &TaskRelayInfo{
  130. RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
  131. BaseUrl: c.GetString("base_url"),
  132. RequestURLPath: c.Request.URL.String(),
  133. ChannelType: channelType,
  134. ChannelId: channelId,
  135. TokenId: tokenId,
  136. UserId: userId,
  137. Group: group,
  138. StartTime: startTime,
  139. ApiType: apiType,
  140. ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
  141. }
  142. if info.BaseUrl == "" {
  143. info.BaseUrl = common.ChannelBaseURLs[channelType]
  144. }
  145. return info
  146. }
  147. func (info *TaskRelayInfo) ToRelayInfo() *RelayInfo {
  148. return &RelayInfo{
  149. ChannelType: info.ChannelType,
  150. ChannelId: info.ChannelId,
  151. TokenId: info.TokenId,
  152. UserId: info.UserId,
  153. Group: info.Group,
  154. StartTime: info.StartTime,
  155. ApiType: info.ApiType,
  156. RelayMode: info.RelayMode,
  157. UpstreamModelName: info.UpstreamModelName,
  158. RequestURLPath: info.RequestURLPath,
  159. ApiKey: info.ApiKey,
  160. BaseUrl: info.BaseUrl,
  161. }
  162. }