package protocol

import (
	"bytes"
	"encoding/binary"
	"fmt"
	"hash/crc32"
	"io"
	"reflect"
	"sync"
	"sync/atomic"
)

type encoder struct {
	writer io.Writer
	err    error
	table  *crc32.Table
	crc32  uint32
	buffer [32]byte
}

type encoderChecksum struct {
	reader  io.Reader
	encoder *encoder
}

func (e *encoderChecksum) Read(b []byte) (int, error) {
	n, err := e.reader.Read(b)
	if n > 0 {
		e.encoder.update(b[:n])
	}
	return n, err
}

func (e *encoder) Reset(w io.Writer) {
	e.writer = w
	e.err = nil
	e.table = nil
	e.crc32 = 0
	e.buffer = [32]byte{}
}

func (e *encoder) ReadFrom(r io.Reader) (int64, error) {
	if e.table != nil {
		r = &encoderChecksum{
			reader:  r,
			encoder: e,
		}
	}
	return io.Copy(e.writer, r)
}

func (e *encoder) Write(b []byte) (int, error) {
	if e.err != nil {
		return 0, e.err
	}
	n, err := e.writer.Write(b)
	if n > 0 {
		e.update(b[:n])
	}
	if err != nil {
		e.err = err
	}
	return n, err
}

func (e *encoder) WriteByte(b byte) error {
	e.buffer[0] = b
	_, err := e.Write(e.buffer[:1])
	return err
}

func (e *encoder) WriteString(s string) (int, error) {
	// This implementation is an optimization to avoid the heap allocation that
	// would occur when converting the string to a []byte to call crc32.Update.
	//
	// Strings are rarely long in the kafka protocol, so the use of a 32 byte
	// buffer is a good comprise between keeping the encoder value small and
	// limiting the number of calls to Write.
	//
	// We introduced this optimization because memory profiles on the benchmarks
	// showed that most heap allocations were caused by this code path.
	n := 0

	for len(s) != 0 {
		c := copy(e.buffer[:], s)
		w, err := e.Write(e.buffer[:c])
		n += w
		if err != nil {
			return n, err
		}
		s = s[c:]
	}

	return n, nil
}

func (e *encoder) setCRC(table *crc32.Table) {
	e.table, e.crc32 = table, 0
}

func (e *encoder) update(b []byte) {
	if e.table != nil {
		e.crc32 = crc32.Update(e.crc32, e.table, b)
	}
}

func (e *encoder) encodeBool(v value) {
	b := int8(0)
	if v.bool() {
		b = 1
	}
	e.writeInt8(b)
}

func (e *encoder) encodeInt8(v value) {
	e.writeInt8(v.int8())
}

func (e *encoder) encodeInt16(v value) {
	e.writeInt16(v.int16())
}

func (e *encoder) encodeInt32(v value) {
	e.writeInt32(v.int32())
}

func (e *encoder) encodeInt64(v value) {
	e.writeInt64(v.int64())
}

func (e *encoder) encodeString(v value) {
	e.writeString(v.string())
}

func (e *encoder) encodeVarString(v value) {
	e.writeVarString(v.string())
}

func (e *encoder) encodeCompactString(v value) {
	e.writeCompactString(v.string())
}

func (e *encoder) encodeNullString(v value) {
	e.writeNullString(v.string())
}

func (e *encoder) encodeVarNullString(v value) {
	e.writeVarNullString(v.string())
}

func (e *encoder) encodeCompactNullString(v value) {
	e.writeCompactNullString(v.string())
}

func (e *encoder) encodeBytes(v value) {
	e.writeBytes(v.bytes())
}

func (e *encoder) encodeVarBytes(v value) {
	e.writeVarBytes(v.bytes())
}

func (e *encoder) encodeCompactBytes(v value) {
	e.writeCompactBytes(v.bytes())
}

func (e *encoder) encodeNullBytes(v value) {
	e.writeNullBytes(v.bytes())
}

func (e *encoder) encodeVarNullBytes(v value) {
	e.writeVarNullBytes(v.bytes())
}

func (e *encoder) encodeCompactNullBytes(v value) {
	e.writeCompactNullBytes(v.bytes())
}

func (e *encoder) encodeArray(v value, elemType reflect.Type, encodeElem encodeFunc) {
	a := v.array(elemType)
	n := a.length()
	e.writeInt32(int32(n))

	for i := 0; i < n; i++ {
		encodeElem(e, a.index(i))
	}
}

func (e *encoder) encodeCompactArray(v value, elemType reflect.Type, encodeElem encodeFunc) {
	a := v.array(elemType)
	n := a.length()
	e.writeUnsignedVarInt(uint64(n + 1))

	for i := 0; i < n; i++ {
		encodeElem(e, a.index(i))
	}
}

func (e *encoder) encodeNullArray(v value, elemType reflect.Type, encodeElem encodeFunc) {
	a := v.array(elemType)
	if a.isNil() {
		e.writeInt32(-1)
		return
	}

	n := a.length()
	e.writeInt32(int32(n))

	for i := 0; i < n; i++ {
		encodeElem(e, a.index(i))
	}
}

func (e *encoder) encodeCompactNullArray(v value, elemType reflect.Type, encodeElem encodeFunc) {
	a := v.array(elemType)
	if a.isNil() {
		e.writeUnsignedVarInt(0)
		return
	}

	n := a.length()
	e.writeUnsignedVarInt(uint64(n + 1))
	for i := 0; i < n; i++ {
		encodeElem(e, a.index(i))
	}
}

func (e *encoder) writeInt8(i int8) {
	writeInt8(e.buffer[:1], i)
	e.Write(e.buffer[:1])
}

func (e *encoder) writeInt16(i int16) {
	writeInt16(e.buffer[:2], i)
	e.Write(e.buffer[:2])
}

func (e *encoder) writeInt32(i int32) {
	writeInt32(e.buffer[:4], i)
	e.Write(e.buffer[:4])
}

func (e *encoder) writeInt64(i int64) {
	writeInt64(e.buffer[:8], i)
	e.Write(e.buffer[:8])
}

func (e *encoder) writeString(s string) {
	e.writeInt16(int16(len(s)))
	e.WriteString(s)
}

func (e *encoder) writeVarString(s string) {
	e.writeVarInt(int64(len(s)))
	e.WriteString(s)
}

func (e *encoder) writeCompactString(s string) {
	e.writeUnsignedVarInt(uint64(len(s)) + 1)
	e.WriteString(s)
}

func (e *encoder) writeNullString(s string) {
	if s == "" {
		e.writeInt16(-1)
	} else {
		e.writeInt16(int16(len(s)))
		e.WriteString(s)
	}
}

func (e *encoder) writeVarNullString(s string) {
	if s == "" {
		e.writeVarInt(-1)
	} else {
		e.writeVarInt(int64(len(s)))
		e.WriteString(s)
	}
}

func (e *encoder) writeCompactNullString(s string) {
	if s == "" {
		e.writeUnsignedVarInt(0)
	} else {
		e.writeUnsignedVarInt(uint64(len(s)) + 1)
		e.WriteString(s)
	}
}

func (e *encoder) writeBytes(b []byte) {
	e.writeInt32(int32(len(b)))
	e.Write(b)
}

func (e *encoder) writeVarBytes(b []byte) {
	e.writeVarInt(int64(len(b)))
	e.Write(b)
}

func (e *encoder) writeCompactBytes(b []byte) {
	e.writeUnsignedVarInt(uint64(len(b)) + 1)
	e.Write(b)
}

func (e *encoder) writeNullBytes(b []byte) {
	if b == nil {
		e.writeInt32(-1)
	} else {
		e.writeInt32(int32(len(b)))
		e.Write(b)
	}
}

func (e *encoder) writeVarNullBytes(b []byte) {
	if b == nil {
		e.writeVarInt(-1)
	} else {
		e.writeVarInt(int64(len(b)))
		e.Write(b)
	}
}

func (e *encoder) writeCompactNullBytes(b []byte) {
	if b == nil {
		e.writeUnsignedVarInt(0)
	} else {
		e.writeUnsignedVarInt(uint64(len(b)) + 1)
		e.Write(b)
	}
}

func (e *encoder) writeBytesFrom(b Bytes) error {
	size := int64(b.Len())
	e.writeInt32(int32(size))
	n, err := io.Copy(e, b)
	if err == nil && n != size {
		err = fmt.Errorf("size of bytes does not match the number of bytes that were written (size=%d, written=%d): %w", size, n, io.ErrUnexpectedEOF)
	}
	return err
}

func (e *encoder) writeNullBytesFrom(b Bytes) error {
	if b == nil {
		e.writeInt32(-1)
		return nil
	} else {
		size := int64(b.Len())
		e.writeInt32(int32(size))
		n, err := io.Copy(e, b)
		if err == nil && n != size {
			err = fmt.Errorf("size of nullable bytes does not match the number of bytes that were written (size=%d, written=%d): %w", size, n, io.ErrUnexpectedEOF)
		}
		return err
	}
}

func (e *encoder) writeVarNullBytesFrom(b Bytes) error {
	if b == nil {
		e.writeVarInt(-1)
		return nil
	} else {
		size := int64(b.Len())
		e.writeVarInt(size)
		n, err := io.Copy(e, b)
		if err == nil && n != size {
			err = fmt.Errorf("size of nullable bytes does not match the number of bytes that were written (size=%d, written=%d): %w", size, n, io.ErrUnexpectedEOF)
		}
		return err
	}
}

func (e *encoder) writeCompactNullBytesFrom(b Bytes) error {
	if b == nil {
		e.writeUnsignedVarInt(0)
		return nil
	} else {
		size := int64(b.Len())
		e.writeUnsignedVarInt(uint64(size + 1))
		n, err := io.Copy(e, b)
		if err == nil && n != size {
			err = fmt.Errorf("size of compact nullable bytes does not match the number of bytes that were written (size=%d, written=%d): %w", size, n, io.ErrUnexpectedEOF)
		}
		return err
	}
}

func (e *encoder) writeVarInt(i int64) {
	e.writeUnsignedVarInt(uint64((i << 1) ^ (i >> 63)))
}

func (e *encoder) writeUnsignedVarInt(i uint64) {
	b := e.buffer[:]
	n := 0

	for i >= 0x80 && n < len(b) {
		b[n] = byte(i) | 0x80
		i >>= 7
		n++
	}

	if n < len(b) {
		b[n] = byte(i)
		n++
	}

	e.Write(b[:n])
}

type encodeFunc func(*encoder, value)

var (
	_ io.ReaderFrom   = (*encoder)(nil)
	_ io.Writer       = (*encoder)(nil)
	_ io.ByteWriter   = (*encoder)(nil)
	_ io.StringWriter = (*encoder)(nil)

	writerTo = reflect.TypeOf((*io.WriterTo)(nil)).Elem()
)

func encodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) encodeFunc {
	if reflect.PtrTo(typ).Implements(writerTo) {
		return writerEncodeFuncOf(typ)
	}
	switch typ.Kind() {
	case reflect.Bool:
		return (*encoder).encodeBool
	case reflect.Int8:
		return (*encoder).encodeInt8
	case reflect.Int16:
		return (*encoder).encodeInt16
	case reflect.Int32:
		return (*encoder).encodeInt32
	case reflect.Int64:
		return (*encoder).encodeInt64
	case reflect.String:
		return stringEncodeFuncOf(flexible, tag)
	case reflect.Struct:
		return structEncodeFuncOf(typ, version, flexible)
	case reflect.Slice:
		if typ.Elem().Kind() == reflect.Uint8 { // []byte
			return bytesEncodeFuncOf(flexible, tag)
		}
		return arrayEncodeFuncOf(typ, version, flexible, tag)
	default:
		panic("unsupported type: " + typ.String())
	}
}

func stringEncodeFuncOf(flexible bool, tag structTag) encodeFunc {
	switch {
	case flexible && tag.Nullable:
		// In flexible messages, all strings are compact
		return (*encoder).encodeCompactNullString
	case flexible:
		// In flexible messages, all strings are compact
		return (*encoder).encodeCompactString
	case tag.Nullable:
		return (*encoder).encodeNullString
	default:
		return (*encoder).encodeString
	}
}

func bytesEncodeFuncOf(flexible bool, tag structTag) encodeFunc {
	switch {
	case flexible && tag.Nullable:
		// In flexible messages, all arrays are compact
		return (*encoder).encodeCompactNullBytes
	case flexible:
		// In flexible messages, all arrays are compact
		return (*encoder).encodeCompactBytes
	case tag.Nullable:
		return (*encoder).encodeNullBytes
	default:
		return (*encoder).encodeBytes
	}
}

func structEncodeFuncOf(typ reflect.Type, version int16, flexible bool) encodeFunc {
	type field struct {
		encode encodeFunc
		index  index
		tagID  int
	}

	var fields []field
	var taggedFields []field

	forEachStructField(typ, func(typ reflect.Type, index index, tag string) {
		if typ.Size() != 0 { // skip struct{}
			forEachStructTag(tag, func(tag structTag) bool {
				if tag.MinVersion <= version && version <= tag.MaxVersion {
					f := field{
						encode: encodeFuncOf(typ, version, flexible, tag),
						index:  index,
						tagID:  tag.TagID,
					}

					if tag.TagID < -1 {
						// Normal required field
						fields = append(fields, f)
					} else {
						// Optional tagged field (flexible messages only)
						taggedFields = append(taggedFields, f)
					}
					return false
				}
				return true
			})
		}
	})

	return func(e *encoder, v value) {
		for i := range fields {
			f := &fields[i]
			f.encode(e, v.fieldByIndex(f.index))
		}

		if flexible {
			// See https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields
			// for details of tag buffers in "flexible" messages.
			e.writeUnsignedVarInt(uint64(len(taggedFields)))

			for i := range taggedFields {
				f := &taggedFields[i]
				e.writeUnsignedVarInt(uint64(f.tagID))

				buf := &bytes.Buffer{}
				se := &encoder{writer: buf}
				f.encode(se, v.fieldByIndex(f.index))
				e.writeUnsignedVarInt(uint64(buf.Len()))
				e.Write(buf.Bytes())
			}
		}
	}
}

func arrayEncodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) encodeFunc {
	elemType := typ.Elem()
	elemFunc := encodeFuncOf(elemType, version, flexible, tag)
	switch {
	case flexible && tag.Nullable:
		// In flexible messages, all arrays are compact
		return func(e *encoder, v value) { e.encodeCompactNullArray(v, elemType, elemFunc) }
	case flexible:
		// In flexible messages, all arrays are compact
		return func(e *encoder, v value) { e.encodeCompactArray(v, elemType, elemFunc) }
	case tag.Nullable:
		return func(e *encoder, v value) { e.encodeNullArray(v, elemType, elemFunc) }
	default:
		return func(e *encoder, v value) { e.encodeArray(v, elemType, elemFunc) }
	}
}

func writerEncodeFuncOf(typ reflect.Type) encodeFunc {
	typ = reflect.PtrTo(typ)
	return func(e *encoder, v value) {
		// Optimization to write directly into the buffer when the encoder
		// does no need to compute a crc32 checksum.
		w := io.Writer(e)
		if e.table == nil {
			w = e.writer
		}
		_, err := v.iface(typ).(io.WriterTo).WriteTo(w)
		if err != nil {
			e.err = err
		}
	}
}

func writeInt8(b []byte, i int8) {
	b[0] = byte(i)
}

func writeInt16(b []byte, i int16) {
	binary.BigEndian.PutUint16(b, uint16(i))
}

func writeInt32(b []byte, i int32) {
	binary.BigEndian.PutUint32(b, uint32(i))
}

func writeInt64(b []byte, i int64) {
	binary.BigEndian.PutUint64(b, uint64(i))
}

func Marshal(version int16, value interface{}) ([]byte, error) {
	typ := typeOf(value)
	cache, _ := marshalers.Load().(map[_type]encodeFunc)
	encode := cache[typ]

	if encode == nil {
		encode = encodeFuncOf(reflect.TypeOf(value), version, false, structTag{
			MinVersion: -1,
			MaxVersion: -1,
			TagID:      -2,
			Compact:    true,
			Nullable:   true,
		})

		newCache := make(map[_type]encodeFunc, len(cache)+1)
		newCache[typ] = encode

		for typ, fun := range cache {
			newCache[typ] = fun
		}

		marshalers.Store(newCache)
	}

	e, _ := encoders.Get().(*encoder)
	if e == nil {
		e = &encoder{writer: new(bytes.Buffer)}
	}

	b, _ := e.writer.(*bytes.Buffer)
	defer func() {
		b.Reset()
		e.Reset(b)
		encoders.Put(e)
	}()

	encode(e, nonAddressableValueOf(value))

	if e.err != nil {
		return nil, e.err
	}

	buf := b.Bytes()
	out := make([]byte, len(buf))
	copy(out, buf)
	return out, nil
}

var (
	encoders   sync.Pool    // *encoder
	marshalers atomic.Value // map[_type]encodeFunc
)