conn.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package protocol
  2. import (
  3. "bufio"
  4. "fmt"
  5. "net"
  6. "sync/atomic"
  7. "time"
  8. )
  9. type Conn struct {
  10. buffer *bufio.Reader
  11. conn net.Conn
  12. clientID string
  13. idgen int32
  14. versions atomic.Value // map[ApiKey]int16
  15. }
  16. func NewConn(conn net.Conn, clientID string) *Conn {
  17. return &Conn{
  18. buffer: bufio.NewReader(conn),
  19. conn: conn,
  20. clientID: clientID,
  21. }
  22. }
  23. func (c *Conn) String() string {
  24. return fmt.Sprintf("kafka://%s@%s->%s", c.clientID, c.LocalAddr(), c.RemoteAddr())
  25. }
  26. func (c *Conn) Close() error {
  27. return c.conn.Close()
  28. }
  29. func (c *Conn) Discard(n int) (int, error) {
  30. return c.buffer.Discard(n)
  31. }
  32. func (c *Conn) Peek(n int) ([]byte, error) {
  33. return c.buffer.Peek(n)
  34. }
  35. func (c *Conn) Read(b []byte) (int, error) {
  36. return c.buffer.Read(b)
  37. }
  38. func (c *Conn) Write(b []byte) (int, error) {
  39. return c.conn.Write(b)
  40. }
  41. func (c *Conn) LocalAddr() net.Addr {
  42. return c.conn.LocalAddr()
  43. }
  44. func (c *Conn) RemoteAddr() net.Addr {
  45. return c.conn.RemoteAddr()
  46. }
  47. func (c *Conn) SetDeadline(t time.Time) error {
  48. return c.conn.SetDeadline(t)
  49. }
  50. func (c *Conn) SetReadDeadline(t time.Time) error {
  51. return c.conn.SetReadDeadline(t)
  52. }
  53. func (c *Conn) SetWriteDeadline(t time.Time) error {
  54. return c.conn.SetWriteDeadline(t)
  55. }
  56. func (c *Conn) SetVersions(versions map[ApiKey]int16) {
  57. connVersions := make(map[ApiKey]int16, len(versions))
  58. for k, v := range versions {
  59. connVersions[k] = v
  60. }
  61. c.versions.Store(connVersions)
  62. }
  63. func (c *Conn) RoundTrip(msg Message) (Message, error) {
  64. correlationID := atomic.AddInt32(&c.idgen, +1)
  65. versions, _ := c.versions.Load().(map[ApiKey]int16)
  66. apiVersion := versions[msg.ApiKey()]
  67. if p, _ := msg.(PreparedMessage); p != nil {
  68. p.Prepare(apiVersion)
  69. }
  70. return RoundTrip(c, apiVersion, correlationID, c.clientID, msg)
  71. }
  72. var (
  73. _ net.Conn = (*Conn)(nil)
  74. _ bufferedReader = (*Conn)(nil)
  75. )