decode.go 10 KB


  1. package protocol
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "hash/crc32"
  7. "io"
  8. "io/ioutil"
  9. "reflect"
  10. "sync"
  11. "sync/atomic"
  12. )
  13. type discarder interface {
  14. Discard(int) (int, error)
  15. }
  16. type decoder struct {
  17. reader io.Reader
  18. remain int
  19. buffer [8]byte
  20. err error
  21. table *crc32.Table
  22. crc32 uint32
  23. }
  24. func (d *decoder) Reset(r io.Reader, n int) {
  25. d.reader = r
  26. d.remain = n
  27. d.buffer = [8]byte{}
  28. d.err = nil
  29. d.table = nil
  30. d.crc32 = 0
  31. }
  32. func (d *decoder) Read(b []byte) (int, error) {
  33. if d.err != nil {
  34. return 0, d.err
  35. }
  36. if d.remain == 0 {
  37. return 0, io.EOF
  38. }
  39. if len(b) > d.remain {
  40. b = b[:d.remain]
  41. }
  42. n, err := d.reader.Read(b)
  43. if n > 0 && d.table != nil {
  44. d.crc32 = crc32.Update(d.crc32, d.table, b[:n])
  45. }
  46. d.remain -= n
  47. return n, err
  48. }
  49. func (d *decoder) ReadByte() (byte, error) {
  50. c := d.readByte()
  51. return c, d.err
  52. }
  53. func (d *decoder) done() bool {
  54. return d.remain == 0 || d.err != nil
  55. }
  56. func (d *decoder) setCRC(table *crc32.Table) {
  57. d.table, d.crc32 = table, 0
  58. }
  59. func (d *decoder) decodeBool(v value) {
  60. v.setBool(d.readBool())
  61. }
  62. func (d *decoder) decodeInt8(v value) {
  63. v.setInt8(d.readInt8())
  64. }
  65. func (d *decoder) decodeInt16(v value) {
  66. v.setInt16(d.readInt16())
  67. }
  68. func (d *decoder) decodeInt32(v value) {
  69. v.setInt32(d.readInt32())
  70. }
  71. func (d *decoder) decodeInt64(v value) {
  72. v.setInt64(d.readInt64())
  73. }
  74. func (d *decoder) decodeString(v value) {
  75. v.setString(d.readString())
  76. }
  77. func (d *decoder) decodeCompactString(v value) {
  78. v.setString(d.readCompactString())
  79. }
  80. func (d *decoder) decodeBytes(v value) {
  81. v.setBytes(d.readBytes())
  82. }
  83. func (d *decoder) decodeCompactBytes(v value) {
  84. v.setBytes(d.readCompactBytes())
  85. }
  86. func (d *decoder) decodeArray(v value, elemType reflect.Type, decodeElem decodeFunc) {
  87. if n := d.readInt32(); n < 0 {
  88. v.setArray(array{})
  89. } else {
  90. a := makeArray(elemType, int(n))
  91. for i := 0; i < int(n) && d.remain > 0; i++ {
  92. decodeElem(d, a.index(i))
  93. }
  94. v.setArray(a)
  95. }
  96. }
  97. func (d *decoder) decodeCompactArray(v value, elemType reflect.Type, decodeElem decodeFunc) {
  98. if n := d.readUnsignedVarInt(); n < 1 {
  99. v.setArray(array{})
  100. } else {
  101. a := makeArray(elemType, int(n-1))
  102. for i := 0; i < int(n-1) && d.remain > 0; i++ {
  103. decodeElem(d, a.index(i))
  104. }
  105. v.setArray(a)
  106. }
  107. }
  108. func (d *decoder) discardAll() {
  109. d.discard(d.remain)
  110. }
  111. func (d *decoder) discard(n int) {
  112. if n > d.remain {
  113. n = d.remain
  114. }
  115. var err error
  116. if r, _ := d.reader.(discarder); r != nil {
  117. n, err = r.Discard(n)
  118. d.remain -= n
  119. } else {
  120. _, err = io.Copy(ioutil.Discard, d)
  121. }
  122. d.setError(err)
  123. }
  124. func (d *decoder) read(n int) []byte {
  125. b := make([]byte, n)
  126. n, err := io.ReadFull(d, b)
  127. b = b[:n]
  128. d.setError(err)
  129. return b
  130. }
  131. func (d *decoder) writeTo(w io.Writer, n int) {
  132. limit := d.remain
  133. if n < limit {
  134. d.remain = n
  135. }
  136. c, err := io.Copy(w, d)
  137. if int(c) < n && err == nil {
  138. err = io.ErrUnexpectedEOF
  139. }
  140. d.remain = limit - int(c)
  141. d.setError(err)
  142. }
  143. func (d *decoder) setError(err error) {
  144. if d.err == nil && err != nil {
  145. d.err = err
  146. d.discardAll()
  147. }
  148. }
  149. func (d *decoder) readFull(b []byte) bool {
  150. n, err := io.ReadFull(d, b)
  151. d.setError(err)
  152. return n == len(b)
  153. }
  154. func (d *decoder) readByte() byte {
  155. if d.readFull(d.buffer[:1]) {
  156. return d.buffer[0]
  157. }
  158. return 0
  159. }
  160. func (d *decoder) readBool() bool {
  161. return d.readByte() != 0
  162. }
  163. func (d *decoder) readInt8() int8 {
  164. if d.readFull(d.buffer[:1]) {
  165. return readInt8(d.buffer[:1])
  166. }
  167. return 0
  168. }
  169. func (d *decoder) readInt16() int16 {
  170. if d.readFull(d.buffer[:2]) {
  171. return readInt16(d.buffer[:2])
  172. }
  173. return 0
  174. }
  175. func (d *decoder) readInt32() int32 {
  176. if d.readFull(d.buffer[:4]) {
  177. return readInt32(d.buffer[:4])
  178. }
  179. return 0
  180. }
  181. func (d *decoder) readInt64() int64 {
  182. if d.readFull(d.buffer[:8]) {
  183. return readInt64(d.buffer[:8])
  184. }
  185. return 0
  186. }
  187. func (d *decoder) readString() string {
  188. if n := d.readInt16(); n < 0 {
  189. return ""
  190. } else {
  191. return bytesToString(d.read(int(n)))
  192. }
  193. }
  194. func (d *decoder) readVarString() string {
  195. if n := d.readVarInt(); n < 0 {
  196. return ""
  197. } else {
  198. return bytesToString(d.read(int(n)))
  199. }
  200. }
  201. func (d *decoder) readCompactString() string {
  202. if n := d.readUnsignedVarInt(); n < 1 {
  203. return ""
  204. } else {
  205. return bytesToString(d.read(int(n - 1)))
  206. }
  207. }
  208. func (d *decoder) readBytes() []byte {
  209. if n := d.readInt32(); n < 0 {
  210. return nil
  211. } else {
  212. return d.read(int(n))
  213. }
  214. }
  215. func (d *decoder) readBytesTo(w io.Writer) bool {
  216. if n := d.readInt32(); n < 0 {
  217. return false
  218. } else {
  219. d.writeTo(w, int(n))
  220. return d.err == nil
  221. }
  222. }
  223. func (d *decoder) readVarBytes() []byte {
  224. if n := d.readVarInt(); n < 0 {
  225. return nil
  226. } else {
  227. return d.read(int(n))
  228. }
  229. }
  230. func (d *decoder) readVarBytesTo(w io.Writer) bool {
  231. if n := d.readVarInt(); n < 0 {
  232. return false
  233. } else {
  234. d.writeTo(w, int(n))
  235. return d.err == nil
  236. }
  237. }
  238. func (d *decoder) readCompactBytes() []byte {
  239. if n := d.readUnsignedVarInt(); n < 1 {
  240. return nil
  241. } else {
  242. return d.read(int(n - 1))
  243. }
  244. }
  245. func (d *decoder) readCompactBytesTo(w io.Writer) bool {
  246. if n := d.readUnsignedVarInt(); n < 1 {
  247. return false
  248. } else {
  249. d.writeTo(w, int(n-1))
  250. return d.err == nil
  251. }
  252. }
  253. func (d *decoder) readVarInt() int64 {
  254. n := 11 // varints are at most 11 bytes
  255. if n > d.remain {
  256. n = d.remain
  257. }
  258. x := uint64(0)
  259. s := uint(0)
  260. for n > 0 {
  261. b := d.readByte()
  262. if (b & 0x80) == 0 {
  263. x |= uint64(b) << s
  264. return int64(x>>1) ^ -(int64(x) & 1)
  265. }
  266. x |= uint64(b&0x7f) << s
  267. s += 7
  268. n--
  269. }
  270. d.setError(fmt.Errorf("cannot decode varint from input stream"))
  271. return 0
  272. }
  273. func (d *decoder) readUnsignedVarInt() uint64 {
  274. n := 11 // varints are at most 11 bytes
  275. if n > d.remain {
  276. n = d.remain
  277. }
  278. x := uint64(0)
  279. s := uint(0)
  280. for n > 0 {
  281. b := d.readByte()
  282. if (b & 0x80) == 0 {
  283. x |= uint64(b) << s
  284. return x
  285. }
  286. x |= uint64(b&0x7f) << s
  287. s += 7
  288. n--
  289. }
  290. d.setError(fmt.Errorf("cannot decode unsigned varint from input stream"))
  291. return 0
  292. }
  293. type decodeFunc func(*decoder, value)
  294. var (
  295. _ io.Reader = (*decoder)(nil)
  296. _ io.ByteReader = (*decoder)(nil)
  297. readerFrom = reflect.TypeOf((*io.ReaderFrom)(nil)).Elem()
  298. )
  299. func decodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc {
  300. if reflect.PtrTo(typ).Implements(readerFrom) {
  301. return readerDecodeFuncOf(typ)
  302. }
  303. switch typ.Kind() {
  304. case reflect.Bool:
  305. return (*decoder).decodeBool
  306. case reflect.Int8:
  307. return (*decoder).decodeInt8
  308. case reflect.Int16:
  309. return (*decoder).decodeInt16
  310. case reflect.Int32:
  311. return (*decoder).decodeInt32
  312. case reflect.Int64:
  313. return (*decoder).decodeInt64
  314. case reflect.String:
  315. return stringDecodeFuncOf(flexible, tag)
  316. case reflect.Struct:
  317. return structDecodeFuncOf(typ, version, flexible)
  318. case reflect.Slice:
  319. if typ.Elem().Kind() == reflect.Uint8 { // []byte
  320. return bytesDecodeFuncOf(flexible, tag)
  321. }
  322. return arrayDecodeFuncOf(typ, version, flexible, tag)
  323. default:
  324. panic("unsupported type: " + typ.String())
  325. }
  326. }
  327. func stringDecodeFuncOf(flexible bool, tag structTag) decodeFunc {
  328. if flexible {
  329. // In flexible messages, all strings are compact
  330. return (*decoder).decodeCompactString
  331. }
  332. return (*decoder).decodeString
  333. }
  334. func bytesDecodeFuncOf(flexible bool, tag structTag) decodeFunc {
  335. if flexible {
  336. // In flexible messages, all arrays are compact
  337. return (*decoder).decodeCompactBytes
  338. }
  339. return (*decoder).decodeBytes
  340. }
  341. func structDecodeFuncOf(typ reflect.Type, version int16, flexible bool) decodeFunc {
  342. type field struct {
  343. decode decodeFunc
  344. index index
  345. tagID int
  346. }
  347. var fields []field
  348. taggedFields := map[int]*field{}
  349. forEachStructField(typ, func(typ reflect.Type, index index, tag string) {
  350. forEachStructTag(tag, func(tag structTag) bool {
  351. if tag.MinVersion <= version && version <= tag.MaxVersion {
  352. f := field{
  353. decode: decodeFuncOf(typ, version, flexible, tag),
  354. index: index,
  355. tagID: tag.TagID,
  356. }
  357. if tag.TagID < -1 {
  358. // Normal required field
  359. fields = append(fields, f)
  360. } else {
  361. // Optional tagged field (flexible messages only)
  362. taggedFields[tag.TagID] = &f
  363. }
  364. return false
  365. }
  366. return true
  367. })
  368. })
  369. return func(d *decoder, v value) {
  370. for i := range fields {
  371. f := &fields[i]
  372. f.decode(d, v.fieldByIndex(f.index))
  373. }
  374. if flexible {
  375. // See https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields
  376. // for details of tag buffers in "flexible" messages.
  377. n := int(d.readUnsignedVarInt())
  378. for i := 0; i < n; i++ {
  379. tagID := int(d.readUnsignedVarInt())
  380. size := int(d.readUnsignedVarInt())
  381. f, ok := taggedFields[tagID]
  382. if ok {
  383. f.decode(d, v.fieldByIndex(f.index))
  384. } else {
  385. d.read(size)
  386. }
  387. }
  388. }
  389. }
  390. }
  391. func arrayDecodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc {
  392. elemType := typ.Elem()
  393. elemFunc := decodeFuncOf(elemType, version, flexible, tag)
  394. if flexible {
  395. // In flexible messages, all arrays are compact
  396. return func(d *decoder, v value) { d.decodeCompactArray(v, elemType, elemFunc) }
  397. }
  398. return func(d *decoder, v value) { d.decodeArray(v, elemType, elemFunc) }
  399. }
  400. func readerDecodeFuncOf(typ reflect.Type) decodeFunc {
  401. typ = reflect.PtrTo(typ)
  402. return func(d *decoder, v value) {
  403. if d.err == nil {
  404. _, err := v.iface(typ).(io.ReaderFrom).ReadFrom(d)
  405. if err != nil {
  406. d.setError(err)
  407. }
  408. }
  409. }
  410. }
  411. func readInt8(b []byte) int8 {
  412. return int8(b[0])
  413. }
  414. func readInt16(b []byte) int16 {
  415. return int16(binary.BigEndian.Uint16(b))
  416. }
  417. func readInt32(b []byte) int32 {
  418. return int32(binary.BigEndian.Uint32(b))
  419. }
  420. func readInt64(b []byte) int64 {
  421. return int64(binary.BigEndian.Uint64(b))
  422. }
  423. func Unmarshal(data []byte, version int16, value interface{}) error {
  424. typ := elemTypeOf(value)
  425. cache, _ := unmarshalers.Load().(map[_type]decodeFunc)
  426. decode := cache[typ]
  427. if decode == nil {
  428. decode = decodeFuncOf(reflect.TypeOf(value).Elem(), version, false, structTag{
  429. MinVersion: -1,
  430. MaxVersion: -1,
  431. TagID: -2,
  432. Compact: true,
  433. Nullable: true,
  434. })
  435. newCache := make(map[_type]decodeFunc, len(cache)+1)
  436. newCache[typ] = decode
  437. for typ, fun := range cache {
  438. newCache[typ] = fun
  439. }
  440. unmarshalers.Store(newCache)
  441. }
  442. d, _ := decoders.Get().(*decoder)
  443. if d == nil {
  444. d = &decoder{reader: bytes.NewReader(nil)}
  445. }
  446. d.remain = len(data)
  447. r, _ := d.reader.(*bytes.Reader)
  448. r.Reset(data)
  449. defer func() {
  450. r.Reset(nil)
  451. d.Reset(r, 0)
  452. decoders.Put(d)
  453. }()
  454. decode(d, valueOf(value))
  455. return dontExpectEOF(d.err)
  456. }
  457. var (
  458. decoders sync.Pool // *decoder
  459. unmarshalers atomic.Value // map[_type]decodeFunc
  460. )