distributor.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. package middleware
  2. import (
  3. "fmt"
  4. "github.com/gin-gonic/gin"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/model"
  8. "strconv"
  9. )
  10. type ModelRequest struct {
  11. Model string `json:"model"`
  12. }
  13. func Distribute() func(c *gin.Context) {
  14. return func(c *gin.Context) {
  15. var channel *model.Channel
  16. channelId, ok := c.Get("channelId")
  17. if ok {
  18. id, err := strconv.Atoi(channelId.(string))
  19. if err != nil {
  20. c.JSON(http.StatusOK, gin.H{
  21. "error": gin.H{
  22. "message": "无效的渠道 ID",
  23. "type": "one_api_error",
  24. },
  25. })
  26. c.Abort()
  27. return
  28. }
  29. channel, err = model.GetChannelById(id, true)
  30. if err != nil {
  31. c.JSON(200, gin.H{
  32. "error": gin.H{
  33. "message": "无效的渠道 ID",
  34. "type": "one_api_error",
  35. },
  36. })
  37. c.Abort()
  38. return
  39. }
  40. if channel.Status != common.ChannelStatusEnabled {
  41. c.JSON(200, gin.H{
  42. "error": gin.H{
  43. "message": "该渠道已被禁用",
  44. "type": "one_api_error",
  45. },
  46. })
  47. c.Abort()
  48. return
  49. }
  50. } else {
  51. // Select a channel for the user
  52. var modelRequest ModelRequest
  53. err := common.UnmarshalBodyReusable(c, &modelRequest)
  54. if err != nil {
  55. c.JSON(200, gin.H{
  56. "error": gin.H{
  57. "message": "无效的请求",
  58. "type": "one_api_error",
  59. },
  60. })
  61. c.Abort()
  62. return
  63. }
  64. userId := c.GetInt("id")
  65. userGroup, _ := model.GetUserGroup(userId)
  66. channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model)
  67. if err != nil {
  68. c.JSON(200, gin.H{
  69. "error": gin.H{
  70. "message": "无可用渠道",
  71. "type": "one_api_error",
  72. },
  73. })
  74. c.Abort()
  75. return
  76. }
  77. }
  78. c.Set("channel", channel.Type)
  79. c.Set("channel_id", channel.Id)
  80. c.Set("channel_name", channel.Name)
  81. c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
  82. if channel.Type == common.ChannelTypeCustom || channel.Type == common.ChannelTypeAzure {
  83. c.Set("base_url", channel.BaseURL)
  84. if channel.Type == common.ChannelTypeAzure {
  85. c.Set("api_version", channel.Other)
  86. }
  87. }
  88. c.Next()
  89. }
  90. }