codec_map.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "errors"
  7. "reflect"
  8. "sort"
  9. "google.golang.org/protobuf/encoding/protowire"
  10. "google.golang.org/protobuf/internal/genid"
  11. pref "google.golang.org/protobuf/reflect/protoreflect"
  12. )
  13. type mapInfo struct {
  14. goType reflect.Type
  15. keyWiretag uint64
  16. valWiretag uint64
  17. keyFuncs valueCoderFuncs
  18. valFuncs valueCoderFuncs
  19. keyZero pref.Value
  20. keyKind pref.Kind
  21. conv *mapConverter
  22. }
  23. func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
  24. // TODO: Consider generating specialized map coders.
  25. keyField := fd.MapKey()
  26. valField := fd.MapValue()
  27. keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
  28. valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
  29. keyFuncs := encoderFuncsForValue(keyField)
  30. valFuncs := encoderFuncsForValue(valField)
  31. conv := newMapConverter(ft, fd)
  32. mapi := &mapInfo{
  33. goType: ft,
  34. keyWiretag: keyWiretag,
  35. valWiretag: valWiretag,
  36. keyFuncs: keyFuncs,
  37. valFuncs: valFuncs,
  38. keyZero: keyField.Default(),
  39. keyKind: keyField.Kind(),
  40. conv: conv,
  41. }
  42. if valField.Kind() == pref.MessageKind {
  43. valueMessage = getMessageInfo(ft.Elem())
  44. }
  45. funcs = pointerCoderFuncs{
  46. size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
  47. return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
  48. },
  49. marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  50. return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
  51. },
  52. unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
  53. mp := p.AsValueOf(ft)
  54. if mp.Elem().IsNil() {
  55. mp.Elem().Set(reflect.MakeMap(mapi.goType))
  56. }
  57. if f.mi == nil {
  58. return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
  59. } else {
  60. return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
  61. }
  62. },
  63. }
  64. switch valField.Kind() {
  65. case pref.MessageKind:
  66. funcs.merge = mergeMapOfMessage
  67. case pref.BytesKind:
  68. funcs.merge = mergeMapOfBytes
  69. default:
  70. funcs.merge = mergeMap
  71. }
  72. if valFuncs.isInit != nil {
  73. funcs.isInit = func(p pointer, f *coderFieldInfo) error {
  74. return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
  75. }
  76. }
  77. return valueMessage, funcs
  78. }
  79. const (
  80. mapKeyTagSize = 1 // field 1, tag size 1.
  81. mapValTagSize = 1 // field 2, tag size 2.
  82. )
  83. func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
  84. if mapv.Len() == 0 {
  85. return 0
  86. }
  87. n := 0
  88. iter := mapRange(mapv)
  89. for iter.Next() {
  90. key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
  91. keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
  92. var valSize int
  93. value := mapi.conv.valConv.PBValueOf(iter.Value())
  94. if f.mi == nil {
  95. valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
  96. } else {
  97. p := pointerOfValue(iter.Value())
  98. valSize += mapValTagSize
  99. valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
  100. }
  101. n += f.tagsize + protowire.SizeBytes(keySize+valSize)
  102. }
  103. return n
  104. }
  105. func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
  106. if wtyp != protowire.BytesType {
  107. return out, errUnknown
  108. }
  109. b, n := protowire.ConsumeBytes(b)
  110. if n < 0 {
  111. return out, protowire.ParseError(n)
  112. }
  113. var (
  114. key = mapi.keyZero
  115. val = mapi.conv.valConv.New()
  116. )
  117. for len(b) > 0 {
  118. num, wtyp, n := protowire.ConsumeTag(b)
  119. if n < 0 {
  120. return out, protowire.ParseError(n)
  121. }
  122. if num > protowire.MaxValidNumber {
  123. return out, errors.New("invalid field number")
  124. }
  125. b = b[n:]
  126. err := errUnknown
  127. switch num {
  128. case genid.MapEntry_Key_field_number:
  129. var v pref.Value
  130. var o unmarshalOutput
  131. v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
  132. if err != nil {
  133. break
  134. }
  135. key = v
  136. n = o.n
  137. case genid.MapEntry_Value_field_number:
  138. var v pref.Value
  139. var o unmarshalOutput
  140. v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
  141. if err != nil {
  142. break
  143. }
  144. val = v
  145. n = o.n
  146. }
  147. if err == errUnknown {
  148. n = protowire.ConsumeFieldValue(num, wtyp, b)
  149. if n < 0 {
  150. return out, protowire.ParseError(n)
  151. }
  152. } else if err != nil {
  153. return out, err
  154. }
  155. b = b[n:]
  156. }
  157. mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
  158. out.n = n
  159. return out, nil
  160. }
  161. func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
  162. if wtyp != protowire.BytesType {
  163. return out, errUnknown
  164. }
  165. b, n := protowire.ConsumeBytes(b)
  166. if n < 0 {
  167. return out, protowire.ParseError(n)
  168. }
  169. var (
  170. key = mapi.keyZero
  171. val = reflect.New(f.mi.GoReflectType.Elem())
  172. )
  173. for len(b) > 0 {
  174. num, wtyp, n := protowire.ConsumeTag(b)
  175. if n < 0 {
  176. return out, protowire.ParseError(n)
  177. }
  178. if num > protowire.MaxValidNumber {
  179. return out, errors.New("invalid field number")
  180. }
  181. b = b[n:]
  182. err := errUnknown
  183. switch num {
  184. case 1:
  185. var v pref.Value
  186. var o unmarshalOutput
  187. v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
  188. if err != nil {
  189. break
  190. }
  191. key = v
  192. n = o.n
  193. case 2:
  194. if wtyp != protowire.BytesType {
  195. break
  196. }
  197. var v []byte
  198. v, n = protowire.ConsumeBytes(b)
  199. if n < 0 {
  200. return out, protowire.ParseError(n)
  201. }
  202. var o unmarshalOutput
  203. o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
  204. if o.initialized {
  205. // Consider this map item initialized so long as we see
  206. // an initialized value.
  207. out.initialized = true
  208. }
  209. }
  210. if err == errUnknown {
  211. n = protowire.ConsumeFieldValue(num, wtyp, b)
  212. if n < 0 {
  213. return out, protowire.ParseError(n)
  214. }
  215. } else if err != nil {
  216. return out, err
  217. }
  218. b = b[n:]
  219. }
  220. mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
  221. out.n = n
  222. return out, nil
  223. }
  224. func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  225. if f.mi == nil {
  226. key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
  227. val := mapi.conv.valConv.PBValueOf(valrv)
  228. size := 0
  229. size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
  230. size += mapi.valFuncs.size(val, mapValTagSize, opts)
  231. b = protowire.AppendVarint(b, uint64(size))
  232. b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
  233. if err != nil {
  234. return nil, err
  235. }
  236. return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
  237. } else {
  238. key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
  239. val := pointerOfValue(valrv)
  240. valSize := f.mi.sizePointer(val, opts)
  241. size := 0
  242. size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
  243. size += mapValTagSize + protowire.SizeBytes(valSize)
  244. b = protowire.AppendVarint(b, uint64(size))
  245. b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
  246. if err != nil {
  247. return nil, err
  248. }
  249. b = protowire.AppendVarint(b, mapi.valWiretag)
  250. b = protowire.AppendVarint(b, uint64(valSize))
  251. return f.mi.marshalAppendPointer(b, val, opts)
  252. }
  253. }
  254. func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  255. if mapv.Len() == 0 {
  256. return b, nil
  257. }
  258. if opts.Deterministic() {
  259. return appendMapDeterministic(b, mapv, mapi, f, opts)
  260. }
  261. iter := mapRange(mapv)
  262. for iter.Next() {
  263. var err error
  264. b = protowire.AppendVarint(b, f.wiretag)
  265. b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
  266. if err != nil {
  267. return b, err
  268. }
  269. }
  270. return b, nil
  271. }
  272. func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  273. keys := mapv.MapKeys()
  274. sort.Slice(keys, func(i, j int) bool {
  275. switch keys[i].Kind() {
  276. case reflect.Bool:
  277. return !keys[i].Bool() && keys[j].Bool()
  278. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  279. return keys[i].Int() < keys[j].Int()
  280. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  281. return keys[i].Uint() < keys[j].Uint()
  282. case reflect.Float32, reflect.Float64:
  283. return keys[i].Float() < keys[j].Float()
  284. case reflect.String:
  285. return keys[i].String() < keys[j].String()
  286. default:
  287. panic("invalid kind: " + keys[i].Kind().String())
  288. }
  289. })
  290. for _, key := range keys {
  291. var err error
  292. b = protowire.AppendVarint(b, f.wiretag)
  293. b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
  294. if err != nil {
  295. return b, err
  296. }
  297. }
  298. return b, nil
  299. }
  300. func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
  301. if mi := f.mi; mi != nil {
  302. mi.init()
  303. if !mi.needsInitCheck {
  304. return nil
  305. }
  306. iter := mapRange(mapv)
  307. for iter.Next() {
  308. val := pointerOfValue(iter.Value())
  309. if err := mi.checkInitializedPointer(val); err != nil {
  310. return err
  311. }
  312. }
  313. } else {
  314. iter := mapRange(mapv)
  315. for iter.Next() {
  316. val := mapi.conv.valConv.PBValueOf(iter.Value())
  317. if err := mapi.valFuncs.isInit(val); err != nil {
  318. return err
  319. }
  320. }
  321. }
  322. return nil
  323. }
  324. func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
  325. dstm := dst.AsValueOf(f.ft).Elem()
  326. srcm := src.AsValueOf(f.ft).Elem()
  327. if srcm.Len() == 0 {
  328. return
  329. }
  330. if dstm.IsNil() {
  331. dstm.Set(reflect.MakeMap(f.ft))
  332. }
  333. iter := mapRange(srcm)
  334. for iter.Next() {
  335. dstm.SetMapIndex(iter.Key(), iter.Value())
  336. }
  337. }
  338. func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
  339. dstm := dst.AsValueOf(f.ft).Elem()
  340. srcm := src.AsValueOf(f.ft).Elem()
  341. if srcm.Len() == 0 {
  342. return
  343. }
  344. if dstm.IsNil() {
  345. dstm.Set(reflect.MakeMap(f.ft))
  346. }
  347. iter := mapRange(srcm)
  348. for iter.Next() {
  349. dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
  350. }
  351. }
  352. func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
  353. dstm := dst.AsValueOf(f.ft).Elem()
  354. srcm := src.AsValueOf(f.ft).Elem()
  355. if srcm.Len() == 0 {
  356. return
  357. }
  358. if dstm.IsNil() {
  359. dstm.Set(reflect.MakeMap(f.ft))
  360. }
  361. iter := mapRange(srcm)
  362. for iter.Next() {
  363. val := reflect.New(f.ft.Elem().Elem())
  364. if f.mi != nil {
  365. f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
  366. } else {
  367. opts.Merge(asMessage(val), asMessage(iter.Value()))
  368. }
  369. dstm.SetMapIndex(iter.Key(), val)
  370. }
  371. }