dialer.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. package kafka
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "io"
  6. "net"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "github.com/segmentio/kafka-go/sasl"
  11. )
  12. // The Dialer type mirrors the net.Dialer API but is designed to open kafka
  13. // connections instead of raw network connections.
  14. type Dialer struct {
  15. // Unique identifier for client connections established by this Dialer.
  16. ClientID string
  17. // Optionally specifies the function that the dialer uses to establish
  18. // network connections. If nil, net.(*Dialer).DialContext is used instead.
  19. //
  20. // When DialFunc is set, LocalAddr, DualStack, FallbackDelay, and KeepAlive
  21. // are ignored.
  22. DialFunc func(ctx context.Context, network string, address string) (net.Conn, error)
  23. // Timeout is the maximum amount of time a dial will wait for a connect to
  24. // complete. If Deadline is also set, it may fail earlier.
  25. //
  26. // The default is no timeout.
  27. //
  28. // When dialing a name with multiple IP addresses, the timeout may be
  29. // divided between them.
  30. //
  31. // With or without a timeout, the operating system may impose its own
  32. // earlier timeout. For instance, TCP timeouts are often around 3 minutes.
  33. Timeout time.Duration
  34. // Deadline is the absolute point in time after which dials will fail.
  35. // If Timeout is set, it may fail earlier.
  36. // Zero means no deadline, or dependent on the operating system as with the
  37. // Timeout option.
  38. Deadline time.Time
  39. // LocalAddr is the local address to use when dialing an address.
  40. // The address must be of a compatible type for the network being dialed.
  41. // If nil, a local address is automatically chosen.
  42. LocalAddr net.Addr
  43. // DualStack enables RFC 6555-compliant "Happy Eyeballs" dialing when the
  44. // network is "tcp" and the destination is a host name with both IPv4 and
  45. // IPv6 addresses. This allows a client to tolerate networks where one
  46. // address family is silently broken.
  47. DualStack bool
  48. // FallbackDelay specifies the length of time to wait before spawning a
  49. // fallback connection, when DualStack is enabled.
  50. // If zero, a default delay of 300ms is used.
  51. FallbackDelay time.Duration
  52. // KeepAlive specifies the keep-alive period for an active network
  53. // connection.
  54. // If zero, keep-alives are not enabled. Network protocols that do not
  55. // support keep-alives ignore this field.
  56. KeepAlive time.Duration
  57. // Resolver optionally gives a hook to convert the broker address into an
  58. // alternate host or IP address which is useful for custom service discovery.
  59. // If a custom resolver returns any possible hosts, the first one will be
  60. // used and the original discarded. If a port number is included with the
  61. // resolved host, it will only be used if a port number was not previously
  62. // specified. If no port is specified or resolved, the default of 9092 will be
  63. // used.
  64. Resolver Resolver
  65. // TLS enables Dialer to open secure connections. If nil, standard net.Conn
  66. // will be used.
  67. TLS *tls.Config
  68. // SASLMechanism configures the Dialer to use SASL authentication. If nil,
  69. // no authentication will be performed.
  70. SASLMechanism sasl.Mechanism
  71. // The transactional id to use for transactional delivery. Idempotent
  72. // deliver should be enabled if transactional id is configured.
  73. // For more details look at transactional.id description here: http://kafka.apache.org/documentation.html#producerconfigs
  74. // Empty string means that the connection will be non-transactional.
  75. TransactionalID string
  76. }
  77. // Dial connects to the address on the named network.
  78. func (d *Dialer) Dial(network string, address string) (*Conn, error) {
  79. return d.DialContext(context.Background(), network, address)
  80. }
  81. // DialContext connects to the address on the named network using the provided
  82. // context.
  83. //
  84. // The provided Context must be non-nil. If the context expires before the
  85. // connection is complete, an error is returned. Once successfully connected,
  86. // any expiration of the context will not affect the connection.
  87. //
  88. // When using TCP, and the host in the address parameter resolves to multiple
  89. // network addresses, any dial timeout (from d.Timeout or ctx) is spread over
  90. // each consecutive dial, such that each is given an appropriate fraction of the
  91. // time to connect. For example, if a host has 4 IP addresses and the timeout is
  92. // 1 minute, the connect to each single address will be given 15 seconds to
  93. // complete before trying the next one.
  94. func (d *Dialer) DialContext(ctx context.Context, network string, address string) (*Conn, error) {
  95. return d.connect(
  96. ctx,
  97. network,
  98. address,
  99. ConnConfig{
  100. ClientID: d.ClientID,
  101. TransactionalID: d.TransactionalID,
  102. },
  103. )
  104. }
  105. // DialLeader opens a connection to the leader of the partition for a given
  106. // topic.
  107. //
  108. // The address given to the DialContext method may not be the one that the
  109. // connection will end up being established to, because the dialer will lookup
  110. // the partition leader for the topic and return a connection to that server.
  111. // The original address is only used as a mechanism to discover the
  112. // configuration of the kafka cluster that we're connecting to.
  113. func (d *Dialer) DialLeader(ctx context.Context, network string, address string, topic string, partition int) (*Conn, error) {
  114. p, err := d.LookupPartition(ctx, network, address, topic, partition)
  115. if err != nil {
  116. return nil, err
  117. }
  118. return d.DialPartition(ctx, network, address, p)
  119. }
  120. // DialPartition opens a connection to the leader of the partition specified by partition
  121. // descriptor. It's strongly advised to use descriptor of the partition that comes out of
  122. // functions LookupPartition or LookupPartitions.
  123. func (d *Dialer) DialPartition(ctx context.Context, network string, address string, partition Partition) (*Conn, error) {
  124. return d.connect(ctx, network, net.JoinHostPort(partition.Leader.Host, strconv.Itoa(partition.Leader.Port)), ConnConfig{
  125. ClientID: d.ClientID,
  126. Topic: partition.Topic,
  127. Partition: partition.ID,
  128. TransactionalID: d.TransactionalID,
  129. })
  130. }
  131. // LookupLeader searches for the kafka broker that is the leader of the
  132. // partition for a given topic, returning a Broker value representing it.
  133. func (d *Dialer) LookupLeader(ctx context.Context, network string, address string, topic string, partition int) (Broker, error) {
  134. p, err := d.LookupPartition(ctx, network, address, topic, partition)
  135. return p.Leader, err
  136. }
  137. // LookupPartition searches for the description of specified partition id.
  138. func (d *Dialer) LookupPartition(ctx context.Context, network string, address string, topic string, partition int) (Partition, error) {
  139. c, err := d.DialContext(ctx, network, address)
  140. if err != nil {
  141. return Partition{}, err
  142. }
  143. defer c.Close()
  144. brkch := make(chan Partition, 1)
  145. errch := make(chan error, 1)
  146. go func() {
  147. for attempt := 0; true; attempt++ {
  148. if attempt != 0 {
  149. if !sleep(ctx, backoff(attempt, 100*time.Millisecond, 10*time.Second)) {
  150. errch <- ctx.Err()
  151. return
  152. }
  153. }
  154. partitions, err := c.ReadPartitions(topic)
  155. if err != nil {
  156. if isTemporary(err) {
  157. continue
  158. }
  159. errch <- err
  160. return
  161. }
  162. for _, p := range partitions {
  163. if p.ID == partition {
  164. brkch <- p
  165. return
  166. }
  167. }
  168. }
  169. errch <- UnknownTopicOrPartition
  170. }()
  171. var prt Partition
  172. select {
  173. case prt = <-brkch:
  174. case err = <-errch:
  175. case <-ctx.Done():
  176. err = ctx.Err()
  177. }
  178. return prt, err
  179. }
  180. // LookupPartitions returns the list of partitions that exist for the given topic.
  181. func (d *Dialer) LookupPartitions(ctx context.Context, network string, address string, topic string) ([]Partition, error) {
  182. conn, err := d.DialContext(ctx, network, address)
  183. if err != nil {
  184. return nil, err
  185. }
  186. defer conn.Close()
  187. prtch := make(chan []Partition, 1)
  188. errch := make(chan error, 1)
  189. go func() {
  190. if prt, err := conn.ReadPartitions(topic); err != nil {
  191. errch <- err
  192. } else {
  193. prtch <- prt
  194. }
  195. }()
  196. var prt []Partition
  197. select {
  198. case prt = <-prtch:
  199. case err = <-errch:
  200. case <-ctx.Done():
  201. err = ctx.Err()
  202. }
  203. return prt, err
  204. }
  205. // connectTLS returns a tls.Conn that has already completed the Handshake
  206. func (d *Dialer) connectTLS(ctx context.Context, conn net.Conn, config *tls.Config) (tlsConn *tls.Conn, err error) {
  207. tlsConn = tls.Client(conn, config)
  208. errch := make(chan error)
  209. go func() {
  210. defer close(errch)
  211. errch <- tlsConn.Handshake()
  212. }()
  213. select {
  214. case <-ctx.Done():
  215. conn.Close()
  216. tlsConn.Close()
  217. <-errch // ignore possible error from Handshake
  218. err = ctx.Err()
  219. case err = <-errch:
  220. }
  221. return
  222. }
  223. // connect opens a socket connection to the broker, wraps it to create a
  224. // kafka connection, and performs SASL authentication if configured to do so.
  225. func (d *Dialer) connect(ctx context.Context, network, address string, connCfg ConnConfig) (*Conn, error) {
  226. if d.Timeout != 0 {
  227. var cancel context.CancelFunc
  228. ctx, cancel = context.WithTimeout(ctx, d.Timeout)
  229. defer cancel()
  230. }
  231. if !d.Deadline.IsZero() {
  232. var cancel context.CancelFunc
  233. ctx, cancel = context.WithDeadline(ctx, d.Deadline)
  234. defer cancel()
  235. }
  236. c, err := d.dialContext(ctx, network, address)
  237. if err != nil {
  238. return nil, err
  239. }
  240. conn := NewConnWith(c, connCfg)
  241. if d.SASLMechanism != nil {
  242. if err := d.authenticateSASL(ctx, conn); err != nil {
  243. _ = conn.Close()
  244. return nil, err
  245. }
  246. }
  247. return conn, nil
  248. }
  249. // authenticateSASL performs all of the required requests to authenticate this
  250. // connection. If any step fails, this function returns with an error. A nil
  251. // error indicates successful authentication.
  252. //
  253. // In case of error, this function *does not* close the connection. That is the
  254. // responsibility of the caller.
  255. func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error {
  256. if err := conn.saslHandshake(d.SASLMechanism.Name()); err != nil {
  257. return err
  258. }
  259. sess, state, err := d.SASLMechanism.Start(ctx)
  260. if err != nil {
  261. return err
  262. }
  263. for completed := false; !completed; {
  264. challenge, err := conn.saslAuthenticate(state)
  265. switch err {
  266. case nil:
  267. case io.EOF:
  268. // the broker may communicate a failed exchange by closing the
  269. // connection (esp. in the case where we're passing opaque sasl
  270. // data over the wire since there's no protocol info).
  271. return SASLAuthenticationFailed
  272. default:
  273. return err
  274. }
  275. completed, state, err = sess.Next(ctx, challenge)
  276. if err != nil {
  277. return err
  278. }
  279. }
  280. return nil
  281. }
  282. func (d *Dialer) dialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
  283. address, err := lookupHost(ctx, addr, d.Resolver)
  284. if err != nil {
  285. return nil, err
  286. }
  287. dial := d.DialFunc
  288. if dial == nil {
  289. dial = (&net.Dialer{
  290. LocalAddr: d.LocalAddr,
  291. DualStack: d.DualStack,
  292. FallbackDelay: d.FallbackDelay,
  293. KeepAlive: d.KeepAlive,
  294. }).DialContext
  295. }
  296. conn, err := dial(ctx, network, address)
  297. if err != nil {
  298. return nil, err
  299. }
  300. if d.TLS != nil {
  301. c := d.TLS
  302. // If no ServerName is set, infer the ServerName
  303. // from the hostname we're connecting to.
  304. if c.ServerName == "" {
  305. c = d.TLS.Clone()
  306. // Copied from tls.go in the standard library.
  307. colonPos := strings.LastIndex(address, ":")
  308. if colonPos == -1 {
  309. colonPos = len(address)
  310. }
  311. hostname := address[:colonPos]
  312. c.ServerName = hostname
  313. }
  314. return d.connectTLS(ctx, conn, c)
  315. }
  316. return conn, nil
  317. }
  318. // DefaultDialer is the default dialer used when none is specified.
  319. var DefaultDialer = &Dialer{
  320. Timeout: 10 * time.Second,
  321. DualStack: true,
  322. }
  323. // Dial is a convenience wrapper for DefaultDialer.Dial.
  324. func Dial(network string, address string) (*Conn, error) {
  325. return DefaultDialer.Dial(network, address)
  326. }
  327. // DialContext is a convenience wrapper for DefaultDialer.DialContext.
  328. func DialContext(ctx context.Context, network string, address string) (*Conn, error) {
  329. return DefaultDialer.DialContext(ctx, network, address)
  330. }
  331. // DialLeader is a convenience wrapper for DefaultDialer.DialLeader.
  332. func DialLeader(ctx context.Context, network string, address string, topic string, partition int) (*Conn, error) {
  333. return DefaultDialer.DialLeader(ctx, network, address, topic, partition)
  334. }
  335. // DialPartition is a convenience wrapper for DefaultDialer.DialPartition.
  336. func DialPartition(ctx context.Context, network string, address string, partition Partition) (*Conn, error) {
  337. return DefaultDialer.DialPartition(ctx, network, address, partition)
  338. }
  339. // LookupPartition is a convenience wrapper for DefaultDialer.LookupPartition.
  340. func LookupPartition(ctx context.Context, network string, address string, topic string, partition int) (Partition, error) {
  341. return DefaultDialer.LookupPartition(ctx, network, address, topic, partition)
  342. }
  343. // LookupPartitions is a convenience wrapper for DefaultDialer.LookupPartitions.
  344. func LookupPartitions(ctx context.Context, network string, address string, topic string) ([]Partition, error) {
  345. return DefaultDialer.LookupPartitions(ctx, network, address, topic)
  346. }
  347. func sleep(ctx context.Context, duration time.Duration) bool {
  348. if duration == 0 {
  349. select {
  350. default:
  351. return true
  352. case <-ctx.Done():
  353. return false
  354. }
  355. }
  356. timer := time.NewTimer(duration)
  357. defer timer.Stop()
  358. select {
  359. case <-timer.C:
  360. return true
  361. case <-ctx.Done():
  362. return false
  363. }
  364. }
  365. func backoff(attempt int, min time.Duration, max time.Duration) time.Duration {
  366. d := time.Duration(attempt*attempt) * min
  367. if d > max {
  368. d = max
  369. }
  370. return d
  371. }
  372. func splitHostPort(s string) (host string, port string) {
  373. host, port, _ = net.SplitHostPort(s)
  374. if len(host) == 0 && len(port) == 0 {
  375. host = s
  376. }
  377. return
  378. }
  379. func lookupHost(ctx context.Context, address string, resolver Resolver) (string, error) {
  380. host, port := splitHostPort(address)
  381. if resolver != nil {
  382. resolved, err := resolver.LookupHost(ctx, host)
  383. if err != nil {
  384. return "", err
  385. }
  386. // if the resolver doesn't return anything, we'll fall back on the provided
  387. // address instead
  388. if len(resolved) > 0 {
  389. resolvedHost, resolvedPort := splitHostPort(resolved[0])
  390. // we'll always prefer the resolved host
  391. host = resolvedHost
  392. // in the case of port though, the provided address takes priority, and we
  393. // only use the resolved address to set the port when not specified
  394. if port == "" {
  395. port = resolvedPort
  396. }
  397. }
  398. }
  399. if port == "" {
  400. port = "9092"
  401. }
  402. return net.JoinHostPort(host, port), nil
  403. }