channel.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package model
  2. import (
  3. "gorm.io/gorm"
  4. "one-api/common"
  5. )
  6. type Channel struct {
  7. Id int `json:"id"`
  8. Type int `json:"type" gorm:"default:0"`
  9. Key string `json:"key" gorm:"not null"`
  10. OpenAIOrganization *string `json:"openai_organization"`
  11. Status int `json:"status" gorm:"default:1"`
  12. Name string `json:"name" gorm:"index"`
  13. Weight *uint `json:"weight" gorm:"default:0"`
  14. CreatedTime int64 `json:"created_time" gorm:"bigint"`
  15. TestTime int64 `json:"test_time" gorm:"bigint"`
  16. ResponseTime int `json:"response_time"` // in milliseconds
  17. BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
  18. Other string `json:"other"`
  19. Balance float64 `json:"balance"` // in USD
  20. BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
  21. Models string `json:"models"`
  22. Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
  23. UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
  24. ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
  25. Priority *int64 `json:"priority" gorm:"bigint;default:0"`
  26. AutoBan *int `json:"auto_ban" gorm:"default:1"`
  27. }
  28. func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
  29. var channels []*Channel
  30. var err error
  31. order := "priority desc"
  32. if idSort {
  33. order = "id desc"
  34. }
  35. if selectAll {
  36. err = DB.Order(order).Find(&channels).Error
  37. } else {
  38. err = DB.Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
  39. }
  40. return channels, err
  41. }
  42. func SearchChannels(keyword string, group string, model string) ([]*Channel, error) {
  43. var channels []*Channel
  44. keyCol := "`key`"
  45. groupCol := "`group`"
  46. modelsCol := "`models`"
  47. // 如果是 PostgreSQL,使用双引号
  48. if common.UsingPostgreSQL {
  49. keyCol = `"key"`
  50. groupCol = `"group"`
  51. modelsCol = `"models"`
  52. }
  53. // 构造基础查询
  54. baseQuery := DB.Model(&Channel{}).Omit(keyCol)
  55. // 构造WHERE子句
  56. var whereClause string
  57. var args []interface{}
  58. if group != "" {
  59. whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " LIKE ? AND " + modelsCol + " LIKE ?"
  60. args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+group+"%", "%"+model+"%")
  61. } else {
  62. whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
  63. args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")
  64. }
  65. // 执行查询
  66. err := baseQuery.Where(whereClause, args...).Find(&channels).Error
  67. if err != nil {
  68. return nil, err
  69. }
  70. return channels, nil
  71. }
  72. func GetChannelById(id int, selectAll bool) (*Channel, error) {
  73. channel := Channel{Id: id}
  74. var err error = nil
  75. if selectAll {
  76. err = DB.First(&channel, "id = ?", id).Error
  77. } else {
  78. err = DB.Omit("key").First(&channel, "id = ?", id).Error
  79. }
  80. return &channel, err
  81. }
  82. func BatchInsertChannels(channels []Channel) error {
  83. var err error
  84. err = DB.Create(&channels).Error
  85. if err != nil {
  86. return err
  87. }
  88. for _, channel_ := range channels {
  89. err = channel_.AddAbilities()
  90. if err != nil {
  91. return err
  92. }
  93. }
  94. return nil
  95. }
  96. func BatchDeleteChannels(ids []int) error {
  97. //使用事务 删除channel表和channel_ability表
  98. tx := DB.Begin()
  99. err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error
  100. if err != nil {
  101. // 回滚事务
  102. tx.Rollback()
  103. return err
  104. }
  105. err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error
  106. if err != nil {
  107. // 回滚事务
  108. tx.Rollback()
  109. return err
  110. }
  111. // 提交事务
  112. tx.Commit()
  113. return err
  114. }
  115. func (channel *Channel) GetPriority() int64 {
  116. if channel.Priority == nil {
  117. return 0
  118. }
  119. return *channel.Priority
  120. }
  121. func (channel *Channel) GetWeight() int {
  122. if channel.Weight == nil {
  123. return 0
  124. }
  125. return int(*channel.Weight)
  126. }
  127. func (channel *Channel) GetBaseURL() string {
  128. if channel.BaseURL == nil {
  129. return ""
  130. }
  131. return *channel.BaseURL
  132. }
  133. func (channel *Channel) GetModelMapping() string {
  134. if channel.ModelMapping == nil {
  135. return ""
  136. }
  137. return *channel.ModelMapping
  138. }
  139. func (channel *Channel) Insert() error {
  140. var err error
  141. err = DB.Create(channel).Error
  142. if err != nil {
  143. return err
  144. }
  145. err = channel.AddAbilities()
  146. return err
  147. }
  148. func (channel *Channel) Update() error {
  149. var err error
  150. err = DB.Model(channel).Updates(channel).Error
  151. if err != nil {
  152. return err
  153. }
  154. DB.Model(channel).First(channel, "id = ?", channel.Id)
  155. err = channel.UpdateAbilities()
  156. return err
  157. }
  158. func (channel *Channel) UpdateResponseTime(responseTime int64) {
  159. err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
  160. TestTime: common.GetTimestamp(),
  161. ResponseTime: int(responseTime),
  162. }).Error
  163. if err != nil {
  164. common.SysError("failed to update response time: " + err.Error())
  165. }
  166. }
  167. func (channel *Channel) UpdateBalance(balance float64) {
  168. err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
  169. BalanceUpdatedTime: common.GetTimestamp(),
  170. Balance: balance,
  171. }).Error
  172. if err != nil {
  173. common.SysError("failed to update balance: " + err.Error())
  174. }
  175. }
  176. func (channel *Channel) Delete() error {
  177. var err error
  178. err = DB.Delete(channel).Error
  179. if err != nil {
  180. return err
  181. }
  182. err = channel.DeleteAbilities()
  183. return err
  184. }
  185. func UpdateChannelStatusById(id int, status int) {
  186. err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
  187. if err != nil {
  188. common.SysError("failed to update ability status: " + err.Error())
  189. }
  190. err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
  191. if err != nil {
  192. common.SysError("failed to update channel status: " + err.Error())
  193. }
  194. }
  195. func UpdateChannelUsedQuota(id int, quota int) {
  196. if common.BatchUpdateEnabled {
  197. addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
  198. return
  199. }
  200. updateChannelUsedQuota(id, quota)
  201. }
  202. func updateChannelUsedQuota(id int, quota int) {
  203. err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
  204. if err != nil {
  205. common.SysError("failed to update channel used quota: " + err.Error())
  206. }
  207. }
  208. func DeleteChannelByStatus(status int64) (int64, error) {
  209. result := DB.Where("status = ?", status).Delete(&Channel{})
  210. return result.RowsAffected, result.Error
  211. }
  212. func DeleteDisabledChannel() (int64, error) {
  213. result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
  214. return result.RowsAffected, result.Error
  215. }