distributor.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package middleware
  2. import (
  3. "fmt"
  4. "github.com/gin-gonic/gin"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/model"
  8. "strconv"
  9. )
  10. type ModelRequest struct {
  11. Model string `json:"model"`
  12. }
  13. func Distribute() func(c *gin.Context) {
  14. return func(c *gin.Context) {
  15. var channel *model.Channel
  16. channelId, ok := c.Get("channelId")
  17. if ok {
  18. id, err := strconv.Atoi(channelId.(string))
  19. if err != nil {
  20. c.JSON(http.StatusOK, gin.H{
  21. "error": gin.H{
  22. "message": "无效的渠道 ID",
  23. "type": "one_api_error",
  24. },
  25. })
  26. c.Abort()
  27. return
  28. }
  29. channel, err = model.GetChannelById(id, true)
  30. if err != nil {
  31. c.JSON(200, gin.H{
  32. "error": gin.H{
  33. "message": "无效的渠道 ID",
  34. "type": "one_api_error",
  35. },
  36. })
  37. c.Abort()
  38. return
  39. }
  40. if channel.Status != common.ChannelStatusEnabled {
  41. c.JSON(200, gin.H{
  42. "error": gin.H{
  43. "message": "该渠道已被禁用",
  44. "type": "one_api_error",
  45. },
  46. })
  47. c.Abort()
  48. return
  49. }
  50. } else {
  51. // Select a channel for the user
  52. var modelRequest ModelRequest
  53. err := common.UnmarshalBodyReusable(c, &modelRequest)
  54. if err != nil {
  55. c.JSON(200, gin.H{
  56. "error": gin.H{
  57. "message": "无效的请求",
  58. "type": "one_api_error",
  59. },
  60. })
  61. c.Abort()
  62. return
  63. }
  64. userId := c.GetInt("id")
  65. userGroup, _ := model.GetUserGroup(userId)
  66. channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model)
  67. if err != nil {
  68. c.JSON(200, gin.H{
  69. "error": gin.H{
  70. "message": "无可用渠道",
  71. "type": "one_api_error",
  72. },
  73. })
  74. c.Abort()
  75. return
  76. }
  77. }
  78. c.Set("channel", channel.Type)
  79. c.Set("channel_id", channel.Id)
  80. c.Set("channel_name", channel.Name)
  81. c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
  82. c.Set("base_url", channel.BaseURL)
  83. if channel.Type == common.ChannelTypeAzure {
  84. c.Set("api_version", channel.Other)
  85. }
  86. c.Next()
  87. }
  88. }