playground.go 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. package controller
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/constant"
  8. "one-api/dto"
  9. "one-api/middleware"
  10. "one-api/model"
  11. "one-api/service"
  12. "one-api/setting"
  13. "time"
  14. "github.com/gin-gonic/gin"
  15. )
  16. func Playground(c *gin.Context) {
  17. var openaiErr *dto.OpenAIErrorWithStatusCode
  18. defer func() {
  19. if openaiErr != nil {
  20. c.JSON(openaiErr.StatusCode, gin.H{
  21. "error": openaiErr.Error,
  22. })
  23. }
  24. }()
  25. useAccessToken := c.GetBool("use_access_token")
  26. if useAccessToken {
  27. openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
  28. return
  29. }
  30. playgroundRequest := &dto.PlayGroundRequest{}
  31. err := common.UnmarshalBodyReusable(c, playgroundRequest)
  32. if err != nil {
  33. openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
  34. return
  35. }
  36. if playgroundRequest.Model == "" {
  37. openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
  38. return
  39. }
  40. c.Set("original_model", playgroundRequest.Model)
  41. group := playgroundRequest.Group
  42. userGroup := c.GetString("group")
  43. if group == "" {
  44. group = userGroup
  45. } else {
  46. if !setting.GroupInUserUsableGroups(group) && group != userGroup {
  47. openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
  48. return
  49. }
  50. c.Set("group", group)
  51. }
  52. userId := c.GetInt("id")
  53. //c.Set("token_name", "playground-"+group)
  54. tempToken := &model.Token{
  55. UserId: userId,
  56. Name: fmt.Sprintf("playground-%s", group),
  57. Group: group,
  58. }
  59. _ = middleware.SetupContextForToken(c, tempToken)
  60. _, err = getChannel(c, group, playgroundRequest.Model, 0)
  61. if err != nil {
  62. openaiErr = service.OpenAIErrorWrapperLocal(err, "get_playground_channel_failed", http.StatusInternalServerError)
  63. return
  64. }
  65. //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
  66. common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
  67. // Write user context to ensure acceptUnsetRatio is available
  68. userCache, err := model.GetUserCache(userId)
  69. if err != nil {
  70. openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
  71. return
  72. }
  73. userCache.WriteContext(c)
  74. Relay(c)
  75. }