midjourney.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package service
  2. import (
  3. "one-api/constant"
  4. "one-api/dto"
  5. relayconstant "one-api/relay/constant"
  6. "strconv"
  7. "strings"
  8. )
  9. func CoverActionToModelName(mjAction string) string {
  10. modelName := "mj_" + strings.ToLower(mjAction)
  11. return modelName
  12. }
  13. func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
  14. action := ""
  15. if relayMode == relayconstant.RelayModeMidjourneyAction {
  16. // plus request
  17. err := CoverPlusActionToNormalAction(midjRequest)
  18. if err != nil {
  19. return "", err, false
  20. }
  21. action = midjRequest.Action
  22. } else {
  23. switch relayMode {
  24. case relayconstant.RelayModeMidjourneyImagine:
  25. action = constant.MjActionImagine
  26. case relayconstant.RelayModeMidjourneyDescribe:
  27. action = constant.MjActionDescribe
  28. case relayconstant.RelayModeMidjourneyBlend:
  29. action = constant.MjActionBlend
  30. case relayconstant.RelayModeMidjourneyShorten:
  31. action = constant.MjActionShorten
  32. case relayconstant.RelayModeMidjourneyChange:
  33. action = midjRequest.Action
  34. case relayconstant.RelayModeMidjourneyModal:
  35. action = constant.MjActionInPaint
  36. case relayconstant.RelayModeMidjourneySimpleChange:
  37. params := ConvertSimpleChangeParams(midjRequest.Content)
  38. if params == nil {
  39. return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
  40. }
  41. action = params.Action
  42. case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
  43. return "", nil, true
  44. default:
  45. return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action"), false
  46. }
  47. }
  48. modelName := CoverActionToModelName(action)
  49. return modelName, nil, true
  50. }
  51. func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
  52. // "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
  53. customId := midjRequest.CustomId
  54. if customId == "" {
  55. return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
  56. }
  57. splits := strings.Split(customId, "::")
  58. var action string
  59. if splits[1] == "JOB" {
  60. action = splits[2]
  61. } else {
  62. action = splits[1]
  63. }
  64. if action == "" {
  65. return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
  66. }
  67. if strings.Contains(action, "upsample") {
  68. index, err := strconv.Atoi(splits[3])
  69. if err != nil {
  70. return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
  71. }
  72. midjRequest.Index = index
  73. midjRequest.Action = constant.MjActionUpscale
  74. } else if strings.Contains(action, "variation") {
  75. midjRequest.Index = 1
  76. if action == "variation" {
  77. index, err := strconv.Atoi(splits[3])
  78. if err != nil {
  79. return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
  80. }
  81. midjRequest.Index = index
  82. midjRequest.Action = constant.MjActionVariation
  83. } else if action == "low_variation" {
  84. midjRequest.Action = constant.MjActionLowVariation
  85. } else if action == "high_variation" {
  86. midjRequest.Action = constant.MjActionHighVariation
  87. }
  88. } else if strings.Contains(action, "pan") {
  89. midjRequest.Action = constant.MjActionPan
  90. midjRequest.Index = 1
  91. } else if action == "Outpaint" || action == "CustomZoom" {
  92. midjRequest.Action = constant.MjActionZoom
  93. midjRequest.Index = 1
  94. } else if action == "Inpaint" {
  95. midjRequest.Action = constant.MjActionInPaintPre
  96. midjRequest.Index = 1
  97. } else {
  98. return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
  99. }
  100. return nil
  101. }
  102. func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
  103. split := strings.Split(content, " ")
  104. if len(split) != 2 {
  105. return nil
  106. }
  107. action := strings.ToLower(split[1])
  108. changeParams := &dto.MidjourneyRequest{}
  109. changeParams.TaskId = split[0]
  110. if action[0] == 'u' {
  111. changeParams.Action = "UPSCALE"
  112. } else if action[0] == 'v' {
  113. changeParams.Action = "VARIATION"
  114. } else if action == "r" {
  115. changeParams.Action = "REROLL"
  116. return changeParams
  117. } else {
  118. return nil
  119. }
  120. index, err := strconv.Atoi(action[1:2])
  121. if err != nil || index < 1 || index > 4 {
  122. return nil
  123. }
  124. changeParams.Index = index
  125. return changeParams
  126. }