distributor.go 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. package middleware
  2. import (
  3. "fmt"
  4. "github.com/gin-gonic/gin"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/model"
  8. "strconv"
  9. "strings"
  10. )
  11. type ModelRequest struct {
  12. Model string `json:"model"`
  13. }
  14. func Distribute() func(c *gin.Context) {
  15. return func(c *gin.Context) {
  16. var channel *model.Channel
  17. channelId, ok := c.Get("channelId")
  18. if ok {
  19. id, err := strconv.Atoi(channelId.(string))
  20. if err != nil {
  21. c.JSON(http.StatusOK, gin.H{
  22. "error": gin.H{
  23. "message": "无效的渠道 ID",
  24. "type": "one_api_error",
  25. },
  26. })
  27. c.Abort()
  28. return
  29. }
  30. channel, err = model.GetChannelById(id, true)
  31. if err != nil {
  32. c.JSON(200, gin.H{
  33. "error": gin.H{
  34. "message": "无效的渠道 ID",
  35. "type": "one_api_error",
  36. },
  37. })
  38. c.Abort()
  39. return
  40. }
  41. if channel.Status != common.ChannelStatusEnabled {
  42. c.JSON(200, gin.H{
  43. "error": gin.H{
  44. "message": "该渠道已被禁用",
  45. "type": "one_api_error",
  46. },
  47. })
  48. c.Abort()
  49. return
  50. }
  51. } else {
  52. // Select a channel for the user
  53. var modelRequest ModelRequest
  54. err := common.UnmarshalBodyReusable(c, &modelRequest)
  55. if err != nil {
  56. c.JSON(200, gin.H{
  57. "error": gin.H{
  58. "message": "无效的请求",
  59. "type": "one_api_error",
  60. },
  61. })
  62. c.Abort()
  63. return
  64. }
  65. if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
  66. if modelRequest.Model == "" {
  67. modelRequest.Model = "text-moderation-stable"
  68. }
  69. }
  70. userId := c.GetInt("id")
  71. userGroup, _ := model.GetUserGroup(userId)
  72. channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model)
  73. if err != nil {
  74. c.JSON(200, gin.H{
  75. "error": gin.H{
  76. "message": "无可用渠道",
  77. "type": "one_api_error",
  78. },
  79. })
  80. c.Abort()
  81. return
  82. }
  83. }
  84. c.Set("channel", channel.Type)
  85. c.Set("channel_id", channel.Id)
  86. c.Set("channel_name", channel.Name)
  87. c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
  88. c.Set("base_url", channel.BaseURL)
  89. if channel.Type == common.ChannelTypeAzure {
  90. c.Set("api_version", channel.Other)
  91. }
  92. c.Next()
  93. }
  94. }