adaptor.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. package sora
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "mime/multipart"
  7. "net/http"
  8. "strconv"
  9. "strings"
  10. "github.com/QuantumNous/new-api/common"
  11. "github.com/QuantumNous/new-api/constant"
  12. "github.com/QuantumNous/new-api/dto"
  13. "github.com/QuantumNous/new-api/model"
  14. "github.com/QuantumNous/new-api/relay/channel"
  15. taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
  16. relaycommon "github.com/QuantumNous/new-api/relay/common"
  17. "github.com/QuantumNous/new-api/service"
  18. "github.com/gin-gonic/gin"
  19. "github.com/pkg/errors"
  20. "github.com/tidwall/sjson"
  21. )
  22. // ============================
  23. // Request / Response structures
  24. // ============================
  25. type ContentItem struct {
  26. Type string `json:"type"` // "text" or "image_url"
  27. Text string `json:"text,omitempty"` // for text type
  28. ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type
  29. }
  30. type ImageURL struct {
  31. URL string `json:"url"`
  32. }
  33. type responseTask struct {
  34. ID string `json:"id"`
  35. TaskID string `json:"task_id,omitempty"` //兼容旧接口
  36. Object string `json:"object"`
  37. Model string `json:"model"`
  38. Status string `json:"status"`
  39. Progress int `json:"progress"`
  40. CreatedAt int64 `json:"created_at"`
  41. CompletedAt int64 `json:"completed_at,omitempty"`
  42. ExpiresAt int64 `json:"expires_at,omitempty"`
  43. Seconds string `json:"seconds,omitempty"`
  44. Size string `json:"size,omitempty"`
  45. RemixedFromVideoID string `json:"remixed_from_video_id,omitempty"`
  46. Error *struct {
  47. Message string `json:"message"`
  48. Code string `json:"code"`
  49. } `json:"error,omitempty"`
  50. }
  51. // ============================
  52. // Adaptor implementation
  53. // ============================
  54. type TaskAdaptor struct {
  55. taskcommon.BaseBilling
  56. ChannelType int
  57. apiKey string
  58. baseURL string
  59. }
  60. func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
  61. a.ChannelType = info.ChannelType
  62. a.baseURL = info.ChannelBaseUrl
  63. a.apiKey = info.ApiKey
  64. }
  65. func validateRemixRequest(c *gin.Context) *dto.TaskError {
  66. var req relaycommon.TaskSubmitReq
  67. if err := common.UnmarshalBodyReusable(c, &req); err != nil {
  68. return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
  69. }
  70. if strings.TrimSpace(req.Prompt) == "" {
  71. return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
  72. }
  73. // 存储原始请求到 context,与 ValidateMultipartDirect 路径保持一致
  74. c.Set("task_request", req)
  75. return nil
  76. }
  77. func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
  78. if info.Action == constant.TaskActionRemix {
  79. return validateRemixRequest(c)
  80. }
  81. return relaycommon.ValidateMultipartDirect(c, info)
  82. }
  83. // EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。
  84. func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
  85. // remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置
  86. if info.Action == constant.TaskActionRemix {
  87. return nil
  88. }
  89. req, err := relaycommon.GetTaskRequest(c)
  90. if err != nil {
  91. return nil
  92. }
  93. seconds, _ := strconv.Atoi(req.Seconds)
  94. if seconds == 0 {
  95. seconds = req.Duration
  96. }
  97. if seconds <= 0 {
  98. seconds = 4
  99. }
  100. size := req.Size
  101. if size == "" {
  102. size = "720x1280"
  103. }
  104. ratios := map[string]float64{
  105. "seconds": float64(seconds),
  106. "size": 1,
  107. }
  108. if size == "1792x1024" || size == "1024x1792" {
  109. ratios["size"] = 1.666667
  110. }
  111. return ratios
  112. }
  113. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
  114. if info.Action == constant.TaskActionRemix {
  115. return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
  116. }
  117. return fmt.Sprintf("%s/v1/videos", a.baseURL), nil
  118. }
  119. // BuildRequestHeader sets required headers.
  120. func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
  121. req.Header.Set("Authorization", "Bearer "+a.apiKey)
  122. req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
  123. return nil
  124. }
  125. func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
  126. storage, err := common.GetBodyStorage(c)
  127. if err != nil {
  128. return nil, errors.Wrap(err, "get_request_body_failed")
  129. }
  130. cachedBody, err := storage.Bytes()
  131. if err != nil {
  132. return nil, errors.Wrap(err, "read_body_bytes_failed")
  133. }
  134. contentType := c.GetHeader("Content-Type")
  135. if strings.HasPrefix(contentType, "application/json") {
  136. var bodyMap map[string]interface{}
  137. if err := common.Unmarshal(cachedBody, &bodyMap); err == nil {
  138. bodyMap["model"] = info.UpstreamModelName
  139. if newBody, err := common.Marshal(bodyMap); err == nil {
  140. return bytes.NewReader(newBody), nil
  141. }
  142. }
  143. return bytes.NewReader(cachedBody), nil
  144. }
  145. if strings.Contains(contentType, "multipart/form-data") {
  146. formData, err := common.ParseMultipartFormReusable(c)
  147. if err != nil {
  148. return bytes.NewReader(cachedBody), nil
  149. }
  150. var buf bytes.Buffer
  151. writer := multipart.NewWriter(&buf)
  152. writer.WriteField("model", info.UpstreamModelName)
  153. for key, values := range formData.Value {
  154. if key == "model" {
  155. continue
  156. }
  157. for _, v := range values {
  158. writer.WriteField(key, v)
  159. }
  160. }
  161. for fieldName, fileHeaders := range formData.File {
  162. for _, fh := range fileHeaders {
  163. f, err := fh.Open()
  164. if err != nil {
  165. continue
  166. }
  167. part, err := writer.CreateFormFile(fieldName, fh.Filename)
  168. if err != nil {
  169. f.Close()
  170. continue
  171. }
  172. io.Copy(part, f)
  173. f.Close()
  174. }
  175. }
  176. writer.Close()
  177. c.Request.Header.Set("Content-Type", writer.FormDataContentType())
  178. return &buf, nil
  179. }
  180. return common.ReaderOnly(storage), nil
  181. }
  182. // DoRequest delegates to common helper.
  183. func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
  184. return channel.DoTaskApiRequest(a, c, info, requestBody)
  185. }
  186. // DoResponse handles upstream response, returns taskID etc.
  187. func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
  188. responseBody, err := io.ReadAll(resp.Body)
  189. if err != nil {
  190. taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
  191. return
  192. }
  193. _ = resp.Body.Close()
  194. // Parse Sora response
  195. var dResp responseTask
  196. if err := common.Unmarshal(responseBody, &dResp); err != nil {
  197. taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
  198. return
  199. }
  200. upstreamID := dResp.ID
  201. if upstreamID == "" {
  202. upstreamID = dResp.TaskID
  203. }
  204. if upstreamID == "" {
  205. taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
  206. return
  207. }
  208. // 使用公开 task_xxxx ID 返回给客户端
  209. dResp.ID = info.PublicTaskID
  210. dResp.TaskID = info.PublicTaskID
  211. c.JSON(http.StatusOK, dResp)
  212. return upstreamID, responseBody, nil
  213. }
  214. // FetchTask fetch task status
  215. func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
  216. taskID, ok := body["task_id"].(string)
  217. if !ok {
  218. return nil, fmt.Errorf("invalid task_id")
  219. }
  220. uri := fmt.Sprintf("%s/v1/videos/%s", baseUrl, taskID)
  221. req, err := http.NewRequest(http.MethodGet, uri, nil)
  222. if err != nil {
  223. return nil, err
  224. }
  225. req.Header.Set("Authorization", "Bearer "+key)
  226. client, err := service.GetHttpClientWithProxy(proxy)
  227. if err != nil {
  228. return nil, fmt.Errorf("new proxy http client failed: %w", err)
  229. }
  230. return client.Do(req)
  231. }
  232. func (a *TaskAdaptor) GetModelList() []string {
  233. return ModelList
  234. }
  235. func (a *TaskAdaptor) GetChannelName() string {
  236. return ChannelName
  237. }
  238. func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
  239. resTask := responseTask{}
  240. if err := common.Unmarshal(respBody, &resTask); err != nil {
  241. return nil, errors.Wrap(err, "unmarshal task result failed")
  242. }
  243. taskResult := relaycommon.TaskInfo{
  244. Code: 0,
  245. }
  246. switch resTask.Status {
  247. case "queued", "pending":
  248. taskResult.Status = model.TaskStatusQueued
  249. case "processing", "in_progress":
  250. taskResult.Status = model.TaskStatusInProgress
  251. case "completed":
  252. taskResult.Status = model.TaskStatusSuccess
  253. // Url intentionally left empty — the caller constructs the proxy URL using the public task ID
  254. case "failed", "cancelled":
  255. taskResult.Status = model.TaskStatusFailure
  256. if resTask.Error != nil {
  257. taskResult.Reason = resTask.Error.Message
  258. } else {
  259. taskResult.Reason = "task failed"
  260. }
  261. default:
  262. }
  263. if resTask.Progress > 0 && resTask.Progress < 100 {
  264. taskResult.Progress = fmt.Sprintf("%d%%", resTask.Progress)
  265. }
  266. return &taskResult, nil
  267. }
  268. func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
  269. data := task.Data
  270. var err error
  271. if data, err = sjson.SetBytes(data, "id", task.TaskID); err != nil {
  272. return nil, errors.Wrap(err, "set id failed")
  273. }
  274. return data, nil
  275. }