distributor.go 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. package middleware
  2. import (
  3. "fmt"
  4. "github.com/gin-gonic/gin"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/model"
  8. "strconv"
  9. )
  10. func Distribute() func(c *gin.Context) {
  11. return func(c *gin.Context) {
  12. var channel *model.Channel
  13. channelId, ok := c.Get("channelId")
  14. if ok {
  15. id, err := strconv.Atoi(channelId.(string))
  16. if err != nil {
  17. c.JSON(http.StatusOK, gin.H{
  18. "error": gin.H{
  19. "message": "无效的渠道 ID",
  20. "type": "one_api_error",
  21. },
  22. })
  23. c.Abort()
  24. return
  25. }
  26. channel, err = model.GetChannelById(id, true)
  27. if err != nil {
  28. c.JSON(200, gin.H{
  29. "error": gin.H{
  30. "message": "无效的渠道 ID",
  31. "type": "one_api_error",
  32. },
  33. })
  34. c.Abort()
  35. return
  36. }
  37. if channel.Status != common.ChannelStatusEnabled {
  38. c.JSON(200, gin.H{
  39. "error": gin.H{
  40. "message": "该渠道已被禁用",
  41. "type": "one_api_error",
  42. },
  43. })
  44. c.Abort()
  45. return
  46. }
  47. } else {
  48. // Select a channel for the user
  49. var err error
  50. channel, err = model.GetRandomChannel()
  51. if err != nil {
  52. c.JSON(200, gin.H{
  53. "error": gin.H{
  54. "message": "无可用渠道",
  55. "type": "one_api_error",
  56. },
  57. })
  58. c.Abort()
  59. return
  60. }
  61. }
  62. c.Set("channel", channel.Type)
  63. c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
  64. if channel.Type == common.ChannelTypeCustom || channel.Type == common.ChannelTypeAzure {
  65. c.Set("base_url", channel.BaseURL)
  66. }
  67. c.Next()
  68. }
  69. }