relay_utils.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package common
  2. import (
  3. "fmt"
  4. "net/http"
  5. "one-api/common"
  6. "one-api/constant"
  7. "one-api/dto"
  8. "strings"
  9. "github.com/gin-gonic/gin"
  10. )
  11. type HasPrompt interface {
  12. GetPrompt() string
  13. }
  14. type HasImage interface {
  15. HasImage() bool
  16. }
  17. func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
  18. fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
  19. if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
  20. switch channelType {
  21. case constant.ChannelTypeOpenAI:
  22. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
  23. case constant.ChannelTypeAzure:
  24. fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
  25. }
  26. }
  27. return fullRequestURL
  28. }
  29. func GetAPIVersion(c *gin.Context) string {
  30. query := c.Request.URL.Query()
  31. apiVersion := query.Get("api-version")
  32. if apiVersion == "" {
  33. apiVersion = c.GetString("api_version")
  34. }
  35. return apiVersion
  36. }
  37. func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
  38. return &dto.TaskError{
  39. Code: code,
  40. Message: err.Error(),
  41. StatusCode: statusCode,
  42. LocalError: localError,
  43. Error: err,
  44. }
  45. }
  46. func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
  47. info.Action = action
  48. c.Set("task_request", requestObj)
  49. }
  50. func validatePrompt(prompt string) *dto.TaskError {
  51. if strings.TrimSpace(prompt) == "" {
  52. return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
  53. }
  54. return nil
  55. }
  56. func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
  57. var req TaskSubmitReq
  58. if err := common.UnmarshalBodyReusable(c, &req); err != nil {
  59. return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
  60. }
  61. if taskErr := validatePrompt(req.Prompt); taskErr != nil {
  62. return taskErr
  63. }
  64. if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
  65. // 兼容单图上传
  66. req.Images = []string{req.Image}
  67. }
  68. if req.HasImage() {
  69. action = constant.TaskActionGenerate
  70. if info.ChannelType == constant.ChannelTypeVidu {
  71. // vidu 增加 首尾帧生视频和参考图生视频
  72. if len(req.Images) == 2 {
  73. action = constant.TaskActionFirstTailGenerate
  74. } else if len(req.Images) > 2 {
  75. action = constant.TaskActionReferenceGenerate
  76. }
  77. }
  78. }
  79. storeTaskRequest(c, info, action, req)
  80. return nil
  81. }