dalle.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. package dto
  2. import (
  3. "encoding/json"
  4. "reflect"
  5. )
  6. type ImageRequest struct {
  7. Model string `json:"model"`
  8. Prompt string `json:"prompt" binding:"required"`
  9. N int `json:"n,omitempty"`
  10. Size string `json:"size,omitempty"`
  11. Quality string `json:"quality,omitempty"`
  12. ResponseFormat string `json:"response_format,omitempty"`
  13. Style string `json:"style,omitempty"`
  14. User string `json:"user,omitempty"`
  15. ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
  16. Background string `json:"background,omitempty"`
  17. Moderation string `json:"moderation,omitempty"`
  18. OutputFormat string `json:"output_format,omitempty"`
  19. // 用匿名字段接住额外的字段
  20. Extra map[string]json.RawMessage `json:"-"`
  21. }
  22. func (r *ImageRequest) UnmarshalJSON(data []byte) error {
  23. // 先解析成 map[string]interface{}
  24. var rawMap map[string]json.RawMessage
  25. if err := json.Unmarshal(data, &rawMap); err != nil {
  26. return err
  27. }
  28. // 用 struct tag 获取所有已定义字段名
  29. knownFields := GetJSONFieldNames(reflect.TypeOf(*r))
  30. // 再正常解析已定义字段
  31. type Alias ImageRequest
  32. var known Alias
  33. if err := json.Unmarshal(data, &known); err != nil {
  34. return err
  35. }
  36. *r = ImageRequest(known)
  37. // 提取多余字段
  38. r.Extra = make(map[string]json.RawMessage)
  39. for k, v := range rawMap {
  40. if _, ok := knownFields[k]; !ok {
  41. r.Extra[k] = v
  42. }
  43. }
  44. return nil
  45. }
  46. func (r ImageRequest) MarshalJSON() ([]byte, error) {
  47. // 将已定义字段转为 map
  48. type Alias ImageRequest
  49. alias := Alias(r)
  50. base, err := json.Marshal(alias)
  51. if err != nil {
  52. return nil, err
  53. }
  54. var baseMap map[string]json.RawMessage
  55. if err := json.Unmarshal(base, &baseMap); err != nil {
  56. return nil, err
  57. }
  58. // 合并 ExtraFields
  59. for k, v := range r.Extra {
  60. baseMap[k] = v
  61. }
  62. return json.Marshal(baseMap)
  63. }
  64. type ImageResponse struct {
  65. Data []ImageData `json:"data"`
  66. Created int64 `json:"created"`
  67. }
  68. type ImageData struct {
  69. Url string `json:"url"`
  70. B64Json string `json:"b64_json"`
  71. RevisedPrompt string `json:"revised_prompt"`
  72. }
  73. func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
  74. fields := make(map[string]struct{})
  75. for i := 0; i < t.NumField(); i++ {
  76. field := t.Field(i)
  77. // 跳过匿名字段(例如 ExtraFields)
  78. if field.Anonymous {
  79. continue
  80. }
  81. tag := field.Tag.Get("json")
  82. if tag == "-" || tag == "" {
  83. continue
  84. }
  85. // 取逗号前字段名(排除 omitempty 等)
  86. name := tag
  87. if commaIdx := indexComma(tag); commaIdx != -1 {
  88. name = tag[:commaIdx]
  89. }
  90. fields[name] = struct{}{}
  91. }
  92. return fields
  93. }
  94. func indexComma(s string) int {
  95. for i := 0; i < len(s); i++ {
  96. if s[i] == ',' {
  97. return i
  98. }
  99. }
  100. return -1
  101. }