openai_image.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. package dto
  2. import (
  3. "encoding/json"
  4. "reflect"
  5. "strings"
  6. "github.com/QuantumNous/new-api/common"
  7. "github.com/QuantumNous/new-api/types"
  8. "github.com/gin-gonic/gin"
  9. )
  10. type ImageRequest struct {
  11. Model string `json:"model"`
  12. Prompt string `json:"prompt" binding:"required"`
  13. N *uint `json:"n,omitempty"`
  14. Size string `json:"size,omitempty"`
  15. Quality string `json:"quality,omitempty"`
  16. ResponseFormat string `json:"response_format,omitempty"`
  17. Style json.RawMessage `json:"style,omitempty"`
  18. User json.RawMessage `json:"user,omitempty"`
  19. ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
  20. Background json.RawMessage `json:"background,omitempty"`
  21. Moderation json.RawMessage `json:"moderation,omitempty"`
  22. OutputFormat json.RawMessage `json:"output_format,omitempty"`
  23. OutputCompression json.RawMessage `json:"output_compression,omitempty"`
  24. PartialImages json.RawMessage `json:"partial_images,omitempty"`
  25. // Stream bool `json:"stream,omitempty"`
  26. Images json.RawMessage `json:"images,omitempty"`
  27. Mask json.RawMessage `json:"mask,omitempty"`
  28. InputFidelity json.RawMessage `json:"input_fidelity,omitempty"`
  29. Watermark *bool `json:"watermark,omitempty"`
  30. // zhipu 4v
  31. WatermarkEnabled json.RawMessage `json:"watermark_enabled,omitempty"`
  32. UserId json.RawMessage `json:"user_id,omitempty"`
  33. Image json.RawMessage `json:"image,omitempty"`
  34. // 用匿名参数接收额外参数
  35. Extra map[string]json.RawMessage `json:"-"`
  36. }
  37. func (i *ImageRequest) UnmarshalJSON(data []byte) error {
  38. // 先解析成 map[string]interface{}
  39. var rawMap map[string]json.RawMessage
  40. if err := common.Unmarshal(data, &rawMap); err != nil {
  41. return err
  42. }
  43. // 用 struct tag 获取所有已定义字段名
  44. knownFields := GetJSONFieldNames(reflect.TypeOf(*i))
  45. // 再正常解析已定义字段
  46. type Alias ImageRequest
  47. var known Alias
  48. if err := common.Unmarshal(data, &known); err != nil {
  49. return err
  50. }
  51. *i = ImageRequest(known)
  52. // 提取多余字段
  53. i.Extra = make(map[string]json.RawMessage)
  54. for k, v := range rawMap {
  55. if _, ok := knownFields[k]; !ok {
  56. i.Extra[k] = v
  57. }
  58. }
  59. return nil
  60. }
  61. // 序列化时需要重新把字段平铺
  62. func (r ImageRequest) MarshalJSON() ([]byte, error) {
  63. // 将已定义字段转为 map
  64. type Alias ImageRequest
  65. alias := Alias(r)
  66. base, err := common.Marshal(alias)
  67. if err != nil {
  68. return nil, err
  69. }
  70. var baseMap map[string]json.RawMessage
  71. if err := common.Unmarshal(base, &baseMap); err != nil {
  72. return nil, err
  73. }
  74. // 不能合并ExtraFields!!!!!!!!
  75. // 合并 ExtraFields
  76. //for k, v := range r.Extra {
  77. // if _, exists := baseMap[k]; !exists {
  78. // baseMap[k] = v
  79. // }
  80. //}
  81. return common.Marshal(baseMap)
  82. }
  83. func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
  84. fields := make(map[string]struct{})
  85. for i := 0; i < t.NumField(); i++ {
  86. field := t.Field(i)
  87. // 跳过匿名字段(例如 ExtraFields)
  88. if field.Anonymous {
  89. continue
  90. }
  91. tag := field.Tag.Get("json")
  92. if tag == "-" || tag == "" {
  93. continue
  94. }
  95. // 取逗号前字段名(排除 omitempty 等)
  96. name := tag
  97. if commaIdx := indexComma(tag); commaIdx != -1 {
  98. name = tag[:commaIdx]
  99. }
  100. fields[name] = struct{}{}
  101. }
  102. return fields
  103. }
  104. func indexComma(s string) int {
  105. for i := 0; i < len(s); i++ {
  106. if s[i] == ',' {
  107. return i
  108. }
  109. }
  110. return -1
  111. }
  112. func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
  113. var sizeRatio = 1.0
  114. var qualityRatio = 1.0
  115. if strings.HasPrefix(i.Model, "dall-e") {
  116. // Size
  117. if i.Size == "256x256" {
  118. sizeRatio = 0.4
  119. } else if i.Size == "512x512" {
  120. sizeRatio = 0.45
  121. } else if i.Size == "1024x1024" {
  122. sizeRatio = 1
  123. } else if i.Size == "1024x1792" || i.Size == "1792x1024" {
  124. sizeRatio = 2
  125. }
  126. if i.Model == "dall-e-3" && i.Quality == "hd" {
  127. qualityRatio = 2.0
  128. if i.Size == "1024x1792" || i.Size == "1792x1024" {
  129. qualityRatio = 1.5
  130. }
  131. }
  132. }
  133. // n is NOT included here; it is handled via OtherRatio("n") in
  134. // image_handler.go (default) or channel adaptors (actual count).
  135. // Including n here caused double-counting for channels that also
  136. // set OtherRatio("n") (e.g. Ali/Bailian).
  137. return &types.TokenCountMeta{
  138. CombineText: i.Prompt,
  139. MaxTokens: 1584,
  140. ImagePriceRatio: sizeRatio * qualityRatio,
  141. }
  142. }
  143. func (i *ImageRequest) IsStream(c *gin.Context) bool {
  144. return false
  145. }
  146. func (i *ImageRequest) SetModelName(modelName string) {
  147. if modelName != "" {
  148. i.Model = modelName
  149. }
  150. }
  151. type ImageResponse struct {
  152. Data []ImageData `json:"data"`
  153. Created int64 `json:"created"`
  154. Metadata json.RawMessage `json:"metadata,omitempty"`
  155. }
  156. type ImageData struct {
  157. Url string `json:"url"`
  158. B64Json string `json:"b64_json"`
  159. RevisedPrompt string `json:"revised_prompt"`
  160. }