distributor.go 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. package middleware
  2. import (
  3. "fmt"
  4. "github.com/gin-gonic/gin"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/model"
  8. "strconv"
  9. )
  10. func Distribute() func(c *gin.Context) {
  11. return func(c *gin.Context) {
  12. var channel *model.Channel
  13. channelId, ok := c.Get("channelId")
  14. if ok {
  15. id, err := strconv.Atoi(channelId.(string))
  16. if err != nil {
  17. c.JSON(http.StatusOK, gin.H{
  18. "error": gin.H{
  19. "message": "无效的渠道 ID",
  20. "type": "one_api_error",
  21. },
  22. })
  23. c.Abort()
  24. return
  25. }
  26. channel, err = model.GetChannelById(id, true)
  27. if err != nil {
  28. c.JSON(200, gin.H{
  29. "error": gin.H{
  30. "message": "无效的渠道 ID",
  31. "type": "one_api_error",
  32. },
  33. })
  34. c.Abort()
  35. return
  36. }
  37. if channel.Status != common.ChannelStatusEnabled {
  38. c.JSON(200, gin.H{
  39. "error": gin.H{
  40. "message": "该渠道已被禁用",
  41. "type": "one_api_error",
  42. },
  43. })
  44. c.Abort()
  45. return
  46. }
  47. } else {
  48. // Select a channel for the user
  49. var err error
  50. channel, err = model.GetRandomChannel()
  51. if err != nil {
  52. c.JSON(200, gin.H{
  53. "error": gin.H{
  54. "message": "无可用渠道",
  55. "type": "one_api_error",
  56. },
  57. })
  58. c.Abort()
  59. return
  60. }
  61. }
  62. c.Set("channel", channel.Type)
  63. c.Set("channel_id", channel.Id)
  64. c.Set("channel_name", channel.Name)
  65. c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
  66. if channel.Type == common.ChannelTypeCustom || channel.Type == common.ChannelTypeAzure {
  67. c.Set("base_url", channel.BaseURL)
  68. if channel.Type == common.ChannelTypeAzure {
  69. c.Set("api_version", channel.Other)
  70. }
  71. }
  72. c.Next()
  73. }
  74. }