|
- package protocol
- import (
- "bytes"
- "encoding/binary"
- "fmt"
- "hash/crc32"
- "io"
- "io/ioutil"
- "reflect"
- "sync"
- "sync/atomic"
- )
- type discarder interface {
- Discard(int) (int, error)
- }
- type decoder struct {
- reader io.Reader
- remain int
- buffer [8]byte
- err error
- table *crc32.Table
- crc32 uint32
- }
- func (d *decoder) Reset(r io.Reader, n int) {
- d.reader = r
- d.remain = n
- d.buffer = [8]byte{}
- d.err = nil
- d.table = nil
- d.crc32 = 0
- }
- func (d *decoder) Read(b []byte) (int, error) {
- if d.err != nil {
- return 0, d.err
- }
- if d.remain == 0 {
- return 0, io.EOF
- }
- if len(b) > d.remain {
- b = b[:d.remain]
- }
- n, err := d.reader.Read(b)
- if n > 0 && d.table != nil {
- d.crc32 = crc32.Update(d.crc32, d.table, b[:n])
- }
- d.remain -= n
- return n, err
- }
- func (d *decoder) ReadByte() (byte, error) {
- c := d.readByte()
- return c, d.err
- }
- func (d *decoder) done() bool {
- return d.remain == 0 || d.err != nil
- }
- func (d *decoder) setCRC(table *crc32.Table) {
- d.table, d.crc32 = table, 0
- }
- func (d *decoder) decodeBool(v value) {
- v.setBool(d.readBool())
- }
- func (d *decoder) decodeInt8(v value) {
- v.setInt8(d.readInt8())
- }
- func (d *decoder) decodeInt16(v value) {
- v.setInt16(d.readInt16())
- }
- func (d *decoder) decodeInt32(v value) {
- v.setInt32(d.readInt32())
- }
- func (d *decoder) decodeInt64(v value) {
- v.setInt64(d.readInt64())
- }
- func (d *decoder) decodeString(v value) {
- v.setString(d.readString())
- }
- func (d *decoder) decodeCompactString(v value) {
- v.setString(d.readCompactString())
- }
- func (d *decoder) decodeBytes(v value) {
- v.setBytes(d.readBytes())
- }
- func (d *decoder) decodeCompactBytes(v value) {
- v.setBytes(d.readCompactBytes())
- }
- func (d *decoder) decodeArray(v value, elemType reflect.Type, decodeElem decodeFunc) {
- if n := d.readInt32(); n < 0 {
- v.setArray(array{})
- } else {
- a := makeArray(elemType, int(n))
- for i := 0; i < int(n) && d.remain > 0; i++ {
- decodeElem(d, a.index(i))
- }
- v.setArray(a)
- }
- }
- func (d *decoder) decodeCompactArray(v value, elemType reflect.Type, decodeElem decodeFunc) {
- if n := d.readUnsignedVarInt(); n < 1 {
- v.setArray(array{})
- } else {
- a := makeArray(elemType, int(n-1))
- for i := 0; i < int(n-1) && d.remain > 0; i++ {
- decodeElem(d, a.index(i))
- }
- v.setArray(a)
- }
- }
- func (d *decoder) discardAll() {
- d.discard(d.remain)
- }
- func (d *decoder) discard(n int) {
- if n > d.remain {
- n = d.remain
- }
- var err error
- if r, _ := d.reader.(discarder); r != nil {
- n, err = r.Discard(n)
- d.remain -= n
- } else {
- _, err = io.Copy(ioutil.Discard, d)
- }
- d.setError(err)
- }
- func (d *decoder) read(n int) []byte {
- b := make([]byte, n)
- n, err := io.ReadFull(d, b)
- b = b[:n]
- d.setError(err)
- return b
- }
- func (d *decoder) writeTo(w io.Writer, n int) {
- limit := d.remain
- if n < limit {
- d.remain = n
- }
- c, err := io.Copy(w, d)
- if int(c) < n && err == nil {
- err = io.ErrUnexpectedEOF
- }
- d.remain = limit - int(c)
- d.setError(err)
- }
- func (d *decoder) setError(err error) {
- if d.err == nil && err != nil {
- d.err = err
- d.discardAll()
- }
- }
- func (d *decoder) readFull(b []byte) bool {
- n, err := io.ReadFull(d, b)
- d.setError(err)
- return n == len(b)
- }
- func (d *decoder) readByte() byte {
- if d.readFull(d.buffer[:1]) {
- return d.buffer[0]
- }
- return 0
- }
- func (d *decoder) readBool() bool {
- return d.readByte() != 0
- }
- func (d *decoder) readInt8() int8 {
- if d.readFull(d.buffer[:1]) {
- return readInt8(d.buffer[:1])
- }
- return 0
- }
- func (d *decoder) readInt16() int16 {
- if d.readFull(d.buffer[:2]) {
- return readInt16(d.buffer[:2])
- }
- return 0
- }
- func (d *decoder) readInt32() int32 {
- if d.readFull(d.buffer[:4]) {
- return readInt32(d.buffer[:4])
- }
- return 0
- }
- func (d *decoder) readInt64() int64 {
- if d.readFull(d.buffer[:8]) {
- return readInt64(d.buffer[:8])
- }
- return 0
- }
- func (d *decoder) readString() string {
- if n := d.readInt16(); n < 0 {
- return ""
- } else {
- return bytesToString(d.read(int(n)))
- }
- }
- func (d *decoder) readVarString() string {
- if n := d.readVarInt(); n < 0 {
- return ""
- } else {
- return bytesToString(d.read(int(n)))
- }
- }
- func (d *decoder) readCompactString() string {
- if n := d.readUnsignedVarInt(); n < 1 {
- return ""
- } else {
- return bytesToString(d.read(int(n - 1)))
- }
- }
- func (d *decoder) readBytes() []byte {
- if n := d.readInt32(); n < 0 {
- return nil
- } else {
- return d.read(int(n))
- }
- }
- func (d *decoder) readBytesTo(w io.Writer) bool {
- if n := d.readInt32(); n < 0 {
- return false
- } else {
- d.writeTo(w, int(n))
- return d.err == nil
- }
- }
- func (d *decoder) readVarBytes() []byte {
- if n := d.readVarInt(); n < 0 {
- return nil
- } else {
- return d.read(int(n))
- }
- }
- func (d *decoder) readVarBytesTo(w io.Writer) bool {
- if n := d.readVarInt(); n < 0 {
- return false
- } else {
- d.writeTo(w, int(n))
- return d.err == nil
- }
- }
- func (d *decoder) readCompactBytes() []byte {
- if n := d.readUnsignedVarInt(); n < 1 {
- return nil
- } else {
- return d.read(int(n - 1))
- }
- }
- func (d *decoder) readCompactBytesTo(w io.Writer) bool {
- if n := d.readUnsignedVarInt(); n < 1 {
- return false
- } else {
- d.writeTo(w, int(n-1))
- return d.err == nil
- }
- }
- func (d *decoder) readVarInt() int64 {
- n := 11 // varints are at most 11 bytes
- if n > d.remain {
- n = d.remain
- }
- x := uint64(0)
- s := uint(0)
- for n > 0 {
- b := d.readByte()
- if (b & 0x80) == 0 {
- x |= uint64(b) << s
- return int64(x>>1) ^ -(int64(x) & 1)
- }
- x |= uint64(b&0x7f) << s
- s += 7
- n--
- }
- d.setError(fmt.Errorf("cannot decode varint from input stream"))
- return 0
- }
- func (d *decoder) readUnsignedVarInt() uint64 {
- n := 11 // varints are at most 11 bytes
- if n > d.remain {
- n = d.remain
- }
- x := uint64(0)
- s := uint(0)
- for n > 0 {
- b := d.readByte()
- if (b & 0x80) == 0 {
- x |= uint64(b) << s
- return x
- }
- x |= uint64(b&0x7f) << s
- s += 7
- n--
- }
- d.setError(fmt.Errorf("cannot decode unsigned varint from input stream"))
- return 0
- }
- type decodeFunc func(*decoder, value)
- var (
- _ io.Reader = (*decoder)(nil)
- _ io.ByteReader = (*decoder)(nil)
- readerFrom = reflect.TypeOf((*io.ReaderFrom)(nil)).Elem()
- )
- func decodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc {
- if reflect.PtrTo(typ).Implements(readerFrom) {
- return readerDecodeFuncOf(typ)
- }
- switch typ.Kind() {
- case reflect.Bool:
- return (*decoder).decodeBool
- case reflect.Int8:
- return (*decoder).decodeInt8
- case reflect.Int16:
- return (*decoder).decodeInt16
- case reflect.Int32:
- return (*decoder).decodeInt32
- case reflect.Int64:
- return (*decoder).decodeInt64
- case reflect.String:
- return stringDecodeFuncOf(flexible, tag)
- case reflect.Struct:
- return structDecodeFuncOf(typ, version, flexible)
- case reflect.Slice:
- if typ.Elem().Kind() == reflect.Uint8 { // []byte
- return bytesDecodeFuncOf(flexible, tag)
- }
- return arrayDecodeFuncOf(typ, version, flexible, tag)
- default:
- panic("unsupported type: " + typ.String())
- }
- }
- func stringDecodeFuncOf(flexible bool, tag structTag) decodeFunc {
- if flexible {
- // In flexible messages, all strings are compact
- return (*decoder).decodeCompactString
- }
- return (*decoder).decodeString
- }
- func bytesDecodeFuncOf(flexible bool, tag structTag) decodeFunc {
- if flexible {
- // In flexible messages, all arrays are compact
- return (*decoder).decodeCompactBytes
- }
- return (*decoder).decodeBytes
- }
- func structDecodeFuncOf(typ reflect.Type, version int16, flexible bool) decodeFunc {
- type field struct {
- decode decodeFunc
- index index
- tagID int
- }
- var fields []field
- taggedFields := map[int]*field{}
- forEachStructField(typ, func(typ reflect.Type, index index, tag string) {
- forEachStructTag(tag, func(tag structTag) bool {
- if tag.MinVersion <= version && version <= tag.MaxVersion {
- f := field{
- decode: decodeFuncOf(typ, version, flexible, tag),
- index: index,
- tagID: tag.TagID,
- }
- if tag.TagID < -1 {
- // Normal required field
- fields = append(fields, f)
- } else {
- // Optional tagged field (flexible messages only)
- taggedFields[tag.TagID] = &f
- }
- return false
- }
- return true
- })
- })
- return func(d *decoder, v value) {
- for i := range fields {
- f := &fields[i]
- f.decode(d, v.fieldByIndex(f.index))
- }
- if flexible {
- // See https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields
- // for details of tag buffers in "flexible" messages.
- n := int(d.readUnsignedVarInt())
- for i := 0; i < n; i++ {
- tagID := int(d.readUnsignedVarInt())
- size := int(d.readUnsignedVarInt())
- f, ok := taggedFields[tagID]
- if ok {
- f.decode(d, v.fieldByIndex(f.index))
- } else {
- d.read(size)
- }
- }
- }
- }
- }
- func arrayDecodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc {
- elemType := typ.Elem()
- elemFunc := decodeFuncOf(elemType, version, flexible, tag)
- if flexible {
- // In flexible messages, all arrays are compact
- return func(d *decoder, v value) { d.decodeCompactArray(v, elemType, elemFunc) }
- }
- return func(d *decoder, v value) { d.decodeArray(v, elemType, elemFunc) }
- }
- func readerDecodeFuncOf(typ reflect.Type) decodeFunc {
- typ = reflect.PtrTo(typ)
- return func(d *decoder, v value) {
- if d.err == nil {
- _, err := v.iface(typ).(io.ReaderFrom).ReadFrom(d)
- if err != nil {
- d.setError(err)
- }
- }
- }
- }
- func readInt8(b []byte) int8 {
- return int8(b[0])
- }
- func readInt16(b []byte) int16 {
- return int16(binary.BigEndian.Uint16(b))
- }
- func readInt32(b []byte) int32 {
- return int32(binary.BigEndian.Uint32(b))
- }
- func readInt64(b []byte) int64 {
- return int64(binary.BigEndian.Uint64(b))
- }
- func Unmarshal(data []byte, version int16, value interface{}) error {
- typ := elemTypeOf(value)
- cache, _ := unmarshalers.Load().(map[_type]decodeFunc)
- decode := cache[typ]
- if decode == nil {
- decode = decodeFuncOf(reflect.TypeOf(value).Elem(), version, false, structTag{
- MinVersion: -1,
- MaxVersion: -1,
- TagID: -2,
- Compact: true,
- Nullable: true,
- })
- newCache := make(map[_type]decodeFunc, len(cache)+1)
- newCache[typ] = decode
- for typ, fun := range cache {
- newCache[typ] = fun
- }
- unmarshalers.Store(newCache)
- }
- d, _ := decoders.Get().(*decoder)
- if d == nil {
- d = &decoder{reader: bytes.NewReader(nil)}
- }
- d.remain = len(data)
- r, _ := d.reader.(*bytes.Reader)
- r.Reset(data)
- defer func() {
- r.Reset(nil)
- d.Reset(r, 0)
- decoders.Put(d)
- }()
- decode(d, valueOf(value))
- return dontExpectEOF(d.err)
- }
- var (
- decoders sync.Pool // *decoder
- unmarshalers atomic.Value // map[_type]decodeFunc
- )
|