decode.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "math/bits"
  7. "google.golang.org/protobuf/encoding/protowire"
  8. "google.golang.org/protobuf/internal/errors"
  9. "google.golang.org/protobuf/internal/flags"
  10. "google.golang.org/protobuf/proto"
  11. "google.golang.org/protobuf/reflect/protoreflect"
  12. preg "google.golang.org/protobuf/reflect/protoregistry"
  13. "google.golang.org/protobuf/runtime/protoiface"
  14. piface "google.golang.org/protobuf/runtime/protoiface"
  15. )
  16. type unmarshalOptions struct {
  17. flags protoiface.UnmarshalInputFlags
  18. resolver interface {
  19. FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
  20. FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
  21. }
  22. }
  23. func (o unmarshalOptions) Options() proto.UnmarshalOptions {
  24. return proto.UnmarshalOptions{
  25. Merge: true,
  26. AllowPartial: true,
  27. DiscardUnknown: o.DiscardUnknown(),
  28. Resolver: o.resolver,
  29. }
  30. }
  31. func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&piface.UnmarshalDiscardUnknown != 0 }
  32. func (o unmarshalOptions) IsDefault() bool {
  33. return o.flags == 0 && o.resolver == preg.GlobalTypes
  34. }
  35. var lazyUnmarshalOptions = unmarshalOptions{
  36. resolver: preg.GlobalTypes,
  37. }
  38. type unmarshalOutput struct {
  39. n int // number of bytes consumed
  40. initialized bool
  41. }
  42. // unmarshal is protoreflect.Methods.Unmarshal.
  43. func (mi *MessageInfo) unmarshal(in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
  44. var p pointer
  45. if ms, ok := in.Message.(*messageState); ok {
  46. p = ms.pointer()
  47. } else {
  48. p = in.Message.(*messageReflectWrapper).pointer()
  49. }
  50. out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
  51. flags: in.Flags,
  52. resolver: in.Resolver,
  53. })
  54. var flags piface.UnmarshalOutputFlags
  55. if out.initialized {
  56. flags |= piface.UnmarshalInitialized
  57. }
  58. return piface.UnmarshalOutput{
  59. Flags: flags,
  60. }, err
  61. }
  62. // errUnknown is returned during unmarshaling to indicate a parse error that
  63. // should result in a field being placed in the unknown fields section (for example,
  64. // when the wire type doesn't match) as opposed to the entire unmarshal operation
  65. // failing (for example, when a field extends past the available input).
  66. //
  67. // This is a sentinel error which should never be visible to the user.
  68. var errUnknown = errors.New("unknown")
  69. func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
  70. mi.init()
  71. if flags.ProtoLegacy && mi.isMessageSet {
  72. return unmarshalMessageSet(mi, b, p, opts)
  73. }
  74. initialized := true
  75. var requiredMask uint64
  76. var exts *map[int32]ExtensionField
  77. start := len(b)
  78. for len(b) > 0 {
  79. // Parse the tag (field number and wire type).
  80. var tag uint64
  81. if b[0] < 0x80 {
  82. tag = uint64(b[0])
  83. b = b[1:]
  84. } else if len(b) >= 2 && b[1] < 128 {
  85. tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
  86. b = b[2:]
  87. } else {
  88. var n int
  89. tag, n = protowire.ConsumeVarint(b)
  90. if n < 0 {
  91. return out, protowire.ParseError(n)
  92. }
  93. b = b[n:]
  94. }
  95. var num protowire.Number
  96. if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
  97. return out, errors.New("invalid field number")
  98. } else {
  99. num = protowire.Number(n)
  100. }
  101. wtyp := protowire.Type(tag & 7)
  102. if wtyp == protowire.EndGroupType {
  103. if num != groupTag {
  104. return out, errors.New("mismatching end group marker")
  105. }
  106. groupTag = 0
  107. break
  108. }
  109. var f *coderFieldInfo
  110. if int(num) < len(mi.denseCoderFields) {
  111. f = mi.denseCoderFields[num]
  112. } else {
  113. f = mi.coderFields[num]
  114. }
  115. var n int
  116. err := errUnknown
  117. switch {
  118. case f != nil:
  119. if f.funcs.unmarshal == nil {
  120. break
  121. }
  122. var o unmarshalOutput
  123. o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
  124. n = o.n
  125. if err != nil {
  126. break
  127. }
  128. requiredMask |= f.validation.requiredBit
  129. if f.funcs.isInit != nil && !o.initialized {
  130. initialized = false
  131. }
  132. default:
  133. // Possible extension.
  134. if exts == nil && mi.extensionOffset.IsValid() {
  135. exts = p.Apply(mi.extensionOffset).Extensions()
  136. if *exts == nil {
  137. *exts = make(map[int32]ExtensionField)
  138. }
  139. }
  140. if exts == nil {
  141. break
  142. }
  143. var o unmarshalOutput
  144. o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
  145. if err != nil {
  146. break
  147. }
  148. n = o.n
  149. if !o.initialized {
  150. initialized = false
  151. }
  152. }
  153. if err != nil {
  154. if err != errUnknown {
  155. return out, err
  156. }
  157. n = protowire.ConsumeFieldValue(num, wtyp, b)
  158. if n < 0 {
  159. return out, protowire.ParseError(n)
  160. }
  161. if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
  162. u := p.Apply(mi.unknownOffset).Bytes()
  163. *u = protowire.AppendTag(*u, num, wtyp)
  164. *u = append(*u, b[:n]...)
  165. }
  166. }
  167. b = b[n:]
  168. }
  169. if groupTag != 0 {
  170. return out, errors.New("missing end group marker")
  171. }
  172. if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
  173. initialized = false
  174. }
  175. if initialized {
  176. out.initialized = true
  177. }
  178. out.n = start - len(b)
  179. return out, nil
  180. }
  181. func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
  182. x := exts[int32(num)]
  183. xt := x.Type()
  184. if xt == nil {
  185. var err error
  186. xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
  187. if err != nil {
  188. if err == preg.NotFound {
  189. return out, errUnknown
  190. }
  191. return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
  192. }
  193. }
  194. xi := getExtensionFieldInfo(xt)
  195. if xi.funcs.unmarshal == nil {
  196. return out, errUnknown
  197. }
  198. if flags.LazyUnmarshalExtensions {
  199. if opts.IsDefault() && x.canLazy(xt) {
  200. out, valid := skipExtension(b, xi, num, wtyp, opts)
  201. switch valid {
  202. case ValidationValid:
  203. if out.initialized {
  204. x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
  205. exts[int32(num)] = x
  206. return out, nil
  207. }
  208. case ValidationInvalid:
  209. return out, errors.New("invalid wire format")
  210. case ValidationUnknown:
  211. }
  212. }
  213. }
  214. ival := x.Value()
  215. if !ival.IsValid() && xi.unmarshalNeedsValue {
  216. // Create a new message, list, or map value to fill in.
  217. // For enums, create a prototype value to let the unmarshal func know the
  218. // concrete type.
  219. ival = xt.New()
  220. }
  221. v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
  222. if err != nil {
  223. return out, err
  224. }
  225. if xi.funcs.isInit == nil {
  226. out.initialized = true
  227. }
  228. x.Set(xt, v)
  229. exts[int32(num)] = x
  230. return out, nil
  231. }
  232. func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
  233. if xi.validation.mi == nil {
  234. return out, ValidationUnknown
  235. }
  236. xi.validation.mi.init()
  237. switch xi.validation.typ {
  238. case validationTypeMessage:
  239. if wtyp != protowire.BytesType {
  240. return out, ValidationUnknown
  241. }
  242. v, n := protowire.ConsumeBytes(b)
  243. if n < 0 {
  244. return out, ValidationUnknown
  245. }
  246. out, st := xi.validation.mi.validate(v, 0, opts)
  247. out.n = n
  248. return out, st
  249. case validationTypeGroup:
  250. if wtyp != protowire.StartGroupType {
  251. return out, ValidationUnknown
  252. }
  253. out, st := xi.validation.mi.validate(b, num, opts)
  254. return out, st
  255. default:
  256. return out, ValidationUnknown
  257. }
  258. }