package protocol import ( "bytes" "encoding/binary" "fmt" "hash/crc32" "io" "io/ioutil" "reflect" "sync" "sync/atomic" ) type discarder interface { Discard(int) (int, error) } type decoder struct { reader io.Reader remain int buffer [8]byte err error table *crc32.Table crc32 uint32 } func (d *decoder) Reset(r io.Reader, n int) { d.reader = r d.remain = n d.buffer = [8]byte{} d.err = nil d.table = nil d.crc32 = 0 } func (d *decoder) Read(b []byte) (int, error) { if d.err != nil { return 0, d.err } if d.remain == 0 { return 0, io.EOF } if len(b) > d.remain { b = b[:d.remain] } n, err := d.reader.Read(b) if n > 0 && d.table != nil { d.crc32 = crc32.Update(d.crc32, d.table, b[:n]) } d.remain -= n return n, err } func (d *decoder) ReadByte() (byte, error) { c := d.readByte() return c, d.err } func (d *decoder) done() bool { return d.remain == 0 || d.err != nil } func (d *decoder) setCRC(table *crc32.Table) { d.table, d.crc32 = table, 0 } func (d *decoder) decodeBool(v value) { v.setBool(d.readBool()) } func (d *decoder) decodeInt8(v value) { v.setInt8(d.readInt8()) } func (d *decoder) decodeInt16(v value) { v.setInt16(d.readInt16()) } func (d *decoder) decodeInt32(v value) { v.setInt32(d.readInt32()) } func (d *decoder) decodeInt64(v value) { v.setInt64(d.readInt64()) } func (d *decoder) decodeString(v value) { v.setString(d.readString()) } func (d *decoder) decodeCompactString(v value) { v.setString(d.readCompactString()) } func (d *decoder) decodeBytes(v value) { v.setBytes(d.readBytes()) } func (d *decoder) decodeCompactBytes(v value) { v.setBytes(d.readCompactBytes()) } func (d *decoder) decodeArray(v value, elemType reflect.Type, decodeElem decodeFunc) { if n := d.readInt32(); n < 0 { v.setArray(array{}) } else { a := makeArray(elemType, int(n)) for i := 0; i < int(n) && d.remain > 0; i++ { decodeElem(d, a.index(i)) } v.setArray(a) } } func (d *decoder) decodeCompactArray(v value, elemType reflect.Type, decodeElem decodeFunc) { if n := d.readUnsignedVarInt(); n < 1 { v.setArray(array{}) } else { a := makeArray(elemType, int(n-1)) for i := 0; i < int(n-1) && d.remain > 0; i++ { decodeElem(d, a.index(i)) } v.setArray(a) } } func (d *decoder) discardAll() { d.discard(d.remain) } func (d *decoder) discard(n int) { if n > d.remain { n = d.remain } var err error if r, _ := d.reader.(discarder); r != nil { n, err = r.Discard(n) d.remain -= n } else { _, err = io.Copy(ioutil.Discard, d) } d.setError(err) } func (d *decoder) read(n int) []byte { b := make([]byte, n) n, err := io.ReadFull(d, b) b = b[:n] d.setError(err) return b } func (d *decoder) writeTo(w io.Writer, n int) { limit := d.remain if n < limit { d.remain = n } c, err := io.Copy(w, d) if int(c) < n && err == nil { err = io.ErrUnexpectedEOF } d.remain = limit - int(c) d.setError(err) } func (d *decoder) setError(err error) { if d.err == nil && err != nil { d.err = err d.discardAll() } } func (d *decoder) readFull(b []byte) bool { n, err := io.ReadFull(d, b) d.setError(err) return n == len(b) } func (d *decoder) readByte() byte { if d.readFull(d.buffer[:1]) { return d.buffer[0] } return 0 } func (d *decoder) readBool() bool { return d.readByte() != 0 } func (d *decoder) readInt8() int8 { if d.readFull(d.buffer[:1]) { return readInt8(d.buffer[:1]) } return 0 } func (d *decoder) readInt16() int16 { if d.readFull(d.buffer[:2]) { return readInt16(d.buffer[:2]) } return 0 } func (d *decoder) readInt32() int32 { if d.readFull(d.buffer[:4]) { return readInt32(d.buffer[:4]) } return 0 } func (d *decoder) readInt64() int64 { if d.readFull(d.buffer[:8]) { return readInt64(d.buffer[:8]) } return 0 } func (d *decoder) readString() string { if n := d.readInt16(); n < 0 { return "" } else { return bytesToString(d.read(int(n))) } } func (d *decoder) readVarString() string { if n := d.readVarInt(); n < 0 { return "" } else { return bytesToString(d.read(int(n))) } } func (d *decoder) readCompactString() string { if n := d.readUnsignedVarInt(); n < 1 { return "" } else { return bytesToString(d.read(int(n - 1))) } } func (d *decoder) readBytes() []byte { if n := d.readInt32(); n < 0 { return nil } else { return d.read(int(n)) } } func (d *decoder) readBytesTo(w io.Writer) bool { if n := d.readInt32(); n < 0 { return false } else { d.writeTo(w, int(n)) return d.err == nil } } func (d *decoder) readVarBytes() []byte { if n := d.readVarInt(); n < 0 { return nil } else { return d.read(int(n)) } } func (d *decoder) readVarBytesTo(w io.Writer) bool { if n := d.readVarInt(); n < 0 { return false } else { d.writeTo(w, int(n)) return d.err == nil } } func (d *decoder) readCompactBytes() []byte { if n := d.readUnsignedVarInt(); n < 1 { return nil } else { return d.read(int(n - 1)) } } func (d *decoder) readCompactBytesTo(w io.Writer) bool { if n := d.readUnsignedVarInt(); n < 1 { return false } else { d.writeTo(w, int(n-1)) return d.err == nil } } func (d *decoder) readVarInt() int64 { n := 11 // varints are at most 11 bytes if n > d.remain { n = d.remain } x := uint64(0) s := uint(0) for n > 0 { b := d.readByte() if (b & 0x80) == 0 { x |= uint64(b) << s return int64(x>>1) ^ -(int64(x) & 1) } x |= uint64(b&0x7f) << s s += 7 n-- } d.setError(fmt.Errorf("cannot decode varint from input stream")) return 0 } func (d *decoder) readUnsignedVarInt() uint64 { n := 11 // varints are at most 11 bytes if n > d.remain { n = d.remain } x := uint64(0) s := uint(0) for n > 0 { b := d.readByte() if (b & 0x80) == 0 { x |= uint64(b) << s return x } x |= uint64(b&0x7f) << s s += 7 n-- } d.setError(fmt.Errorf("cannot decode unsigned varint from input stream")) return 0 } type decodeFunc func(*decoder, value) var ( _ io.Reader = (*decoder)(nil) _ io.ByteReader = (*decoder)(nil) readerFrom = reflect.TypeOf((*io.ReaderFrom)(nil)).Elem() ) func decodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc { if reflect.PtrTo(typ).Implements(readerFrom) { return readerDecodeFuncOf(typ) } switch typ.Kind() { case reflect.Bool: return (*decoder).decodeBool case reflect.Int8: return (*decoder).decodeInt8 case reflect.Int16: return (*decoder).decodeInt16 case reflect.Int32: return (*decoder).decodeInt32 case reflect.Int64: return (*decoder).decodeInt64 case reflect.String: return stringDecodeFuncOf(flexible, tag) case reflect.Struct: return structDecodeFuncOf(typ, version, flexible) case reflect.Slice: if typ.Elem().Kind() == reflect.Uint8 { // []byte return bytesDecodeFuncOf(flexible, tag) } return arrayDecodeFuncOf(typ, version, flexible, tag) default: panic("unsupported type: " + typ.String()) } } func stringDecodeFuncOf(flexible bool, tag structTag) decodeFunc { if flexible { // In flexible messages, all strings are compact return (*decoder).decodeCompactString } return (*decoder).decodeString } func bytesDecodeFuncOf(flexible bool, tag structTag) decodeFunc { if flexible { // In flexible messages, all arrays are compact return (*decoder).decodeCompactBytes } return (*decoder).decodeBytes } func structDecodeFuncOf(typ reflect.Type, version int16, flexible bool) decodeFunc { type field struct { decode decodeFunc index index tagID int } var fields []field taggedFields := map[int]*field{} forEachStructField(typ, func(typ reflect.Type, index index, tag string) { forEachStructTag(tag, func(tag structTag) bool { if tag.MinVersion <= version && version <= tag.MaxVersion { f := field{ decode: decodeFuncOf(typ, version, flexible, tag), index: index, tagID: tag.TagID, } if tag.TagID < -1 { // Normal required field fields = append(fields, f) } else { // Optional tagged field (flexible messages only) taggedFields[tag.TagID] = &f } return false } return true }) }) return func(d *decoder, v value) { for i := range fields { f := &fields[i] f.decode(d, v.fieldByIndex(f.index)) } if flexible { // See https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields // for details of tag buffers in "flexible" messages. n := int(d.readUnsignedVarInt()) for i := 0; i < n; i++ { tagID := int(d.readUnsignedVarInt()) size := int(d.readUnsignedVarInt()) f, ok := taggedFields[tagID] if ok { f.decode(d, v.fieldByIndex(f.index)) } else { d.read(size) } } } } } func arrayDecodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc { elemType := typ.Elem() elemFunc := decodeFuncOf(elemType, version, flexible, tag) if flexible { // In flexible messages, all arrays are compact return func(d *decoder, v value) { d.decodeCompactArray(v, elemType, elemFunc) } } return func(d *decoder, v value) { d.decodeArray(v, elemType, elemFunc) } } func readerDecodeFuncOf(typ reflect.Type) decodeFunc { typ = reflect.PtrTo(typ) return func(d *decoder, v value) { if d.err == nil { _, err := v.iface(typ).(io.ReaderFrom).ReadFrom(d) if err != nil { d.setError(err) } } } } func readInt8(b []byte) int8 { return int8(b[0]) } func readInt16(b []byte) int16 { return int16(binary.BigEndian.Uint16(b)) } func readInt32(b []byte) int32 { return int32(binary.BigEndian.Uint32(b)) } func readInt64(b []byte) int64 { return int64(binary.BigEndian.Uint64(b)) } func Unmarshal(data []byte, version int16, value interface{}) error { typ := elemTypeOf(value) cache, _ := unmarshalers.Load().(map[_type]decodeFunc) decode := cache[typ] if decode == nil { decode = decodeFuncOf(reflect.TypeOf(value).Elem(), version, false, structTag{ MinVersion: -1, MaxVersion: -1, TagID: -2, Compact: true, Nullable: true, }) newCache := make(map[_type]decodeFunc, len(cache)+1) newCache[typ] = decode for typ, fun := range cache { newCache[typ] = fun } unmarshalers.Store(newCache) } d, _ := decoders.Get().(*decoder) if d == nil { d = &decoder{reader: bytes.NewReader(nil)} } d.remain = len(data) r, _ := d.reader.(*bytes.Reader) r.Reset(data) defer func() { r.Reset(nil) d.Reset(r, 0) decoders.Put(d) }() decode(d, valueOf(value)) return dontExpectEOF(d.err) } var ( decoders sync.Pool // *decoder unmarshalers atomic.Value // map[_type]decodeFunc )