conn.go 34 KB


  1. // Package zk is a native Go client library for the ZooKeeper orchestration service.
  2. package zk
  3. /*
  4. TODO:
  5. * make sure a ping response comes back in a reasonable time
  6. Possible watcher events:
  7. * Event{Type: EventNotWatching, State: StateDisconnected, Path: path, Err: err}
  8. */
  9. import (
  10. "crypto/rand"
  11. "encoding/binary"
  12. "errors"
  13. "fmt"
  14. "io"
  15. "net"
  16. "strconv"
  17. "strings"
  18. "sync"
  19. "sync/atomic"
  20. "time"
  21. )
  22. // ErrNoServer indicates that an operation cannot be completed
  23. // because attempts to connect to all servers in the list failed.
  24. var ErrNoServer = errors.New("zk: could not connect to a server")
  25. // ErrInvalidPath indicates that an operation was being attempted on
  26. // an invalid path. (e.g. empty path)
  27. var ErrInvalidPath = errors.New("zk: invalid path")
  28. // DefaultLogger uses the stdlib log package for logging.
  29. var DefaultLogger Logger = defaultLogger{}
  30. const (
  31. bufferSize = 1536 * 1024
  32. eventChanSize = 6
  33. sendChanSize = 16
  34. protectedPrefix = "_c_"
  35. )
  36. type watchType int
  37. const (
  38. watchTypeData = iota
  39. watchTypeExist
  40. watchTypeChild
  41. )
  42. type watchPathType struct {
  43. path string
  44. wType watchType
  45. }
  46. type Dialer func(network, address string, timeout time.Duration) (net.Conn, error)
  47. // Logger is an interface that can be implemented to provide custom log output.
  48. type Logger interface {
  49. Printf(string, ...interface{})
  50. }
  51. type authCreds struct {
  52. scheme string
  53. auth []byte
  54. }
  55. type Conn struct {
  56. lastZxid int64
  57. sessionID int64
  58. state State // must be 32-bit aligned
  59. xid uint32
  60. sessionTimeoutMs int32 // session timeout in milliseconds
  61. passwd []byte
  62. dialer Dialer
  63. hostProvider HostProvider
  64. serverMu sync.Mutex // protects server
  65. server string // remember the address/port of the current server
  66. conn net.Conn
  67. eventChan chan Event
  68. eventCallback EventCallback // may be nil
  69. shouldQuit chan struct{}
  70. pingInterval time.Duration
  71. recvTimeout time.Duration
  72. connectTimeout time.Duration
  73. maxBufferSize int
  74. creds []authCreds
  75. credsMu sync.Mutex // protects server
  76. sendChan chan *request
  77. requests map[int32]*request // Xid -> pending request
  78. requestsLock sync.Mutex
  79. watchers map[watchPathType][]chan Event
  80. watchersLock sync.Mutex
  81. closeChan chan struct{} // channel to tell send loop stop
  82. // Debug (used by unit tests)
  83. reconnectLatch chan struct{}
  84. setWatchLimit int
  85. setWatchCallback func([]*setWatchesRequest)
  86. // Debug (for recurring re-auth hang)
  87. debugCloseRecvLoop bool
  88. debugReauthDone chan struct{}
  89. logger Logger
  90. logInfo bool // true if information messages are logged; false if only errors are logged
  91. buf []byte
  92. }
  93. // connOption represents a connection option.
  94. type connOption func(c *Conn)
  95. type request struct {
  96. xid int32
  97. opcode int32
  98. pkt interface{}
  99. recvStruct interface{}
  100. recvChan chan response
  101. // Because sending and receiving happen in separate go routines, there's
  102. // a possible race condition when creating watches from outside the read
  103. // loop. We must ensure that a watcher gets added to the list synchronously
  104. // with the response from the server on any request that creates a watch.
  105. // In order to not hard code the watch logic for each opcode in the recv
  106. // loop the caller can use recvFunc to insert some synchronously code
  107. // after a response.
  108. recvFunc func(*request, *responseHeader, error)
  109. }
  110. type response struct {
  111. zxid int64
  112. err error
  113. }
  114. type Event struct {
  115. Type EventType
  116. State State
  117. Path string // For non-session events, the path of the watched node.
  118. Err error
  119. Server string // For connection events
  120. }
  121. // HostProvider is used to represent a set of hosts a ZooKeeper client should connect to.
  122. // It is an analog of the Java equivalent:
  123. // http://svn.apache.org/viewvc/zookeeper/trunk/src/java/main/org/apache/zookeeper/client/HostProvider.java?view=markup
  124. type HostProvider interface {
  125. // Init is called first, with the servers specified in the connection string.
  126. Init(servers []string) error
  127. // Len returns the number of servers.
  128. Len() int
  129. // Next returns the next server to connect to. retryStart will be true if we've looped through
  130. // all known servers without Connected() being called.
  131. Next() (server string, retryStart bool)
  132. // Notify the HostProvider of a successful connection.
  133. Connected()
  134. }
  135. // ConnectWithDialer establishes a new connection to a pool of zookeeper servers
  136. // using a custom Dialer. See Connect for further information about session timeout.
  137. // This method is deprecated and provided for compatibility: use the WithDialer option instead.
  138. func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) {
  139. return Connect(servers, sessionTimeout, WithDialer(dialer))
  140. }
  141. // Connect establishes a new connection to a pool of zookeeper
  142. // servers. The provided session timeout sets the amount of time for which
  143. // a session is considered valid after losing connection to a server. Within
  144. // the session timeout it's possible to reestablish a connection to a different
  145. // server and keep the same session. This is means any ephemeral nodes and
  146. // watches are maintained.
  147. func Connect(servers []string, sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) {
  148. if len(servers) == 0 {
  149. return nil, nil, errors.New("zk: server list must not be empty")
  150. }
  151. srvs := make([]string, len(servers))
  152. for i, addr := range servers {
  153. if strings.Contains(addr, ":") {
  154. srvs[i] = addr
  155. } else {
  156. srvs[i] = addr + ":" + strconv.Itoa(DefaultPort)
  157. }
  158. }
  159. // Randomize the order of the servers to avoid creating hotspots
  160. stringShuffle(srvs)
  161. ec := make(chan Event, eventChanSize)
  162. conn := &Conn{
  163. dialer: net.DialTimeout,
  164. hostProvider: &DNSHostProvider{},
  165. conn: nil,
  166. state: StateDisconnected,
  167. eventChan: ec,
  168. shouldQuit: make(chan struct{}),
  169. connectTimeout: 1 * time.Second,
  170. sendChan: make(chan *request, sendChanSize),
  171. requests: make(map[int32]*request),
  172. watchers: make(map[watchPathType][]chan Event),
  173. passwd: emptyPassword,
  174. logger: DefaultLogger,
  175. logInfo: true, // default is true for backwards compatability
  176. buf: make([]byte, bufferSize),
  177. }
  178. // Set provided options.
  179. for _, option := range options {
  180. option(conn)
  181. }
  182. if err := conn.hostProvider.Init(srvs); err != nil {
  183. return nil, nil, err
  184. }
  185. conn.setTimeouts(int32(sessionTimeout / time.Millisecond))
  186. go func() {
  187. conn.loop()
  188. conn.flushRequests(ErrClosing)
  189. conn.invalidateWatches(ErrClosing)
  190. close(conn.eventChan)
  191. }()
  192. return conn, ec, nil
  193. }
  194. // WithDialer returns a connection option specifying a non-default Dialer.
  195. func WithDialer(dialer Dialer) connOption {
  196. return func(c *Conn) {
  197. c.dialer = dialer
  198. }
  199. }
  200. // WithHostProvider returns a connection option specifying a non-default HostProvider.
  201. func WithHostProvider(hostProvider HostProvider) connOption {
  202. return func(c *Conn) {
  203. c.hostProvider = hostProvider
  204. }
  205. }
  206. // WithLogger returns a connection option specifying a non-default Logger
  207. func WithLogger(logger Logger) connOption {
  208. return func(c *Conn) {
  209. c.logger = logger
  210. }
  211. }
  212. // WithLogInfo returns a connection option specifying whether or not information messages
  213. // shoud be logged.
  214. func WithLogInfo(logInfo bool) connOption {
  215. return func(c *Conn) {
  216. c.logInfo = logInfo
  217. }
  218. }
  219. // EventCallback is a function that is called when an Event occurs.
  220. type EventCallback func(Event)
  221. // WithEventCallback returns a connection option that specifies an event
  222. // callback.
  223. // The callback must not block - doing so would delay the ZK go routines.
  224. func WithEventCallback(cb EventCallback) connOption {
  225. return func(c *Conn) {
  226. c.eventCallback = cb
  227. }
  228. }
  229. // WithMaxBufferSize sets the maximum buffer size used to read and decode
  230. // packets received from the Zookeeper server. The standard Zookeeper client for
  231. // Java defaults to a limit of 1mb. For backwards compatibility, this Go client
  232. // defaults to unbounded unless overridden via this option. A value that is zero
  233. // or negative indicates that no limit is enforced.
  234. //
  235. // This is meant to prevent resource exhaustion in the face of potentially
  236. // malicious data in ZK. It should generally match the server setting (which
  237. // also defaults ot 1mb) so that clients and servers agree on the limits for
  238. // things like the size of data in an individual znode and the total size of a
  239. // transaction.
  240. //
  241. // For production systems, this should be set to a reasonable value (ideally
  242. // that matches the server configuration). For ops tooling, it is handy to use a
  243. // much larger limit, in order to do things like clean-up problematic state in
  244. // the ZK tree. For example, if a single znode has a huge number of children, it
  245. // is possible for the response to a "list children" operation to exceed this
  246. // buffer size and cause errors in clients. The only way to subsequently clean
  247. // up the tree (by removing superfluous children) is to use a client configured
  248. // with a larger buffer size that can successfully query for all of the child
  249. // names and then remove them. (Note there are other tools that can list all of
  250. // the child names without an increased buffer size in the client, but they work
  251. // by inspecting the servers' transaction logs to enumerate children instead of
  252. // sending an online request to a server.
  253. func WithMaxBufferSize(maxBufferSize int) connOption {
  254. return func(c *Conn) {
  255. c.maxBufferSize = maxBufferSize
  256. }
  257. }
  258. // WithMaxConnBufferSize sets maximum buffer size used to send and encode
  259. // packets to Zookeeper server. The standard Zookeepeer client for java defaults
  260. // to a limit of 1mb. This option should be used for non-standard server setup
  261. // where znode is bigger than default 1mb.
  262. func WithMaxConnBufferSize(maxBufferSize int) connOption {
  263. return func(c *Conn) {
  264. c.buf = make([]byte, maxBufferSize)
  265. }
  266. }
  267. func (c *Conn) Close() {
  268. close(c.shouldQuit)
  269. select {
  270. case <-c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil):
  271. case <-time.After(time.Second):
  272. }
  273. }
  274. // State returns the current state of the connection.
  275. func (c *Conn) State() State {
  276. return State(atomic.LoadInt32((*int32)(&c.state)))
  277. }
  278. // SessionID returns the current session id of the connection.
  279. func (c *Conn) SessionID() int64 {
  280. return atomic.LoadInt64(&c.sessionID)
  281. }
  282. // SetLogger sets the logger to be used for printing errors.
  283. // Logger is an interface provided by this package.
  284. func (c *Conn) SetLogger(l Logger) {
  285. c.logger = l
  286. }
  287. func (c *Conn) setTimeouts(sessionTimeoutMs int32) {
  288. c.sessionTimeoutMs = sessionTimeoutMs
  289. sessionTimeout := time.Duration(sessionTimeoutMs) * time.Millisecond
  290. c.recvTimeout = sessionTimeout * 2 / 3
  291. c.pingInterval = c.recvTimeout / 2
  292. }
  293. func (c *Conn) setState(state State) {
  294. atomic.StoreInt32((*int32)(&c.state), int32(state))
  295. c.sendEvent(Event{Type: EventSession, State: state, Server: c.Server()})
  296. }
  297. func (c *Conn) sendEvent(evt Event) {
  298. if c.eventCallback != nil {
  299. c.eventCallback(evt)
  300. }
  301. select {
  302. case c.eventChan <- evt:
  303. default:
  304. // panic("zk: event channel full - it must be monitored and never allowed to be full")
  305. }
  306. }
  307. func (c *Conn) connect() error {
  308. var retryStart bool
  309. for {
  310. c.serverMu.Lock()
  311. c.server, retryStart = c.hostProvider.Next()
  312. c.serverMu.Unlock()
  313. c.setState(StateConnecting)
  314. if retryStart {
  315. c.flushUnsentRequests(ErrNoServer)
  316. select {
  317. case <-time.After(time.Second):
  318. // pass
  319. case <-c.shouldQuit:
  320. c.setState(StateDisconnected)
  321. c.flushUnsentRequests(ErrClosing)
  322. return ErrClosing
  323. }
  324. }
  325. zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout)
  326. if err == nil {
  327. c.conn = zkConn
  328. c.setState(StateConnected)
  329. if c.logInfo {
  330. c.logger.Printf("Connected to %s", c.Server())
  331. }
  332. return nil
  333. }
  334. c.logger.Printf("Failed to connect to %s: %+v", c.Server(), err)
  335. }
  336. }
  337. func (c *Conn) resendZkAuth(reauthReadyChan chan struct{}) {
  338. shouldCancel := func() bool {
  339. select {
  340. case <-c.shouldQuit:
  341. return true
  342. case <-c.closeChan:
  343. return true
  344. default:
  345. return false
  346. }
  347. }
  348. c.credsMu.Lock()
  349. defer c.credsMu.Unlock()
  350. defer close(reauthReadyChan)
  351. if c.logInfo {
  352. c.logger.Printf("re-submitting `%d` credentials after reconnect", len(c.creds))
  353. }
  354. for _, cred := range c.creds {
  355. if shouldCancel() {
  356. return
  357. }
  358. resChan, err := c.sendRequest(
  359. opSetAuth,
  360. &setAuthRequest{Type: 0,
  361. Scheme: cred.scheme,
  362. Auth: cred.auth,
  363. },
  364. &setAuthResponse{},
  365. nil)
  366. if err != nil {
  367. c.logger.Printf("call to sendRequest failed during credential resubmit: %s", err)
  368. // FIXME(prozlach): lets ignore errors for now
  369. continue
  370. }
  371. var res response
  372. select {
  373. case res = <-resChan:
  374. case <-c.closeChan:
  375. c.logger.Printf("recv closed, cancel re-submitting credentials")
  376. return
  377. case <-c.shouldQuit:
  378. c.logger.Printf("should quit, cancel re-submitting credentials")
  379. return
  380. }
  381. if res.err != nil {
  382. c.logger.Printf("credential re-submit failed: %s", res.err)
  383. // FIXME(prozlach): lets ignore errors for now
  384. continue
  385. }
  386. }
  387. }
  388. func (c *Conn) sendRequest(
  389. opcode int32,
  390. req interface{},
  391. res interface{},
  392. recvFunc func(*request, *responseHeader, error),
  393. ) (
  394. <-chan response,
  395. error,
  396. ) {
  397. rq := &request{
  398. xid: c.nextXid(),
  399. opcode: opcode,
  400. pkt: req,
  401. recvStruct: res,
  402. recvChan: make(chan response, 1),
  403. recvFunc: recvFunc,
  404. }
  405. if err := c.sendData(rq); err != nil {
  406. return nil, err
  407. }
  408. return rq.recvChan, nil
  409. }
  410. func (c *Conn) loop() {
  411. for {
  412. if err := c.connect(); err != nil {
  413. // c.Close() was called
  414. return
  415. }
  416. err := c.authenticate()
  417. switch {
  418. case err == ErrSessionExpired:
  419. c.logger.Printf("authentication failed: %s", err)
  420. c.invalidateWatches(err)
  421. case err != nil && c.conn != nil:
  422. c.logger.Printf("authentication failed: %s", err)
  423. c.conn.Close()
  424. case err == nil:
  425. if c.logInfo {
  426. c.logger.Printf("authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs)
  427. }
  428. c.hostProvider.Connected() // mark success
  429. c.closeChan = make(chan struct{}) // channel to tell send loop stop
  430. reauthChan := make(chan struct{}) // channel to tell send loop that authdata has been resubmitted
  431. var wg sync.WaitGroup
  432. wg.Add(1)
  433. go func() {
  434. <-reauthChan
  435. if c.debugCloseRecvLoop {
  436. close(c.debugReauthDone)
  437. }
  438. err := c.sendLoop()
  439. if err != nil || c.logInfo {
  440. c.logger.Printf("send loop terminated: err=%v", err)
  441. }
  442. c.conn.Close() // causes recv loop to EOF/exit
  443. wg.Done()
  444. }()
  445. wg.Add(1)
  446. go func() {
  447. var err error
  448. if c.debugCloseRecvLoop {
  449. err = errors.New("DEBUG: close recv loop")
  450. } else {
  451. err = c.recvLoop(c.conn)
  452. }
  453. if err != io.EOF || c.logInfo {
  454. c.logger.Printf("recv loop terminated: err=%v", err)
  455. }
  456. if err == nil {
  457. panic("zk: recvLoop should never return nil error")
  458. }
  459. close(c.closeChan) // tell send loop to exit
  460. wg.Done()
  461. }()
  462. c.resendZkAuth(reauthChan)
  463. c.sendSetWatches()
  464. wg.Wait()
  465. }
  466. c.setState(StateDisconnected)
  467. select {
  468. case <-c.shouldQuit:
  469. c.flushRequests(ErrClosing)
  470. return
  471. default:
  472. }
  473. if err != ErrSessionExpired {
  474. err = ErrConnectionClosed
  475. }
  476. c.flushRequests(err)
  477. if c.reconnectLatch != nil {
  478. select {
  479. case <-c.shouldQuit:
  480. return
  481. case <-c.reconnectLatch:
  482. }
  483. }
  484. }
  485. }
  486. func (c *Conn) flushUnsentRequests(err error) {
  487. for {
  488. select {
  489. default:
  490. return
  491. case req := <-c.sendChan:
  492. req.recvChan <- response{-1, err}
  493. }
  494. }
  495. }
  496. // Send error to all pending requests and clear request map
  497. func (c *Conn) flushRequests(err error) {
  498. c.requestsLock.Lock()
  499. for _, req := range c.requests {
  500. req.recvChan <- response{-1, err}
  501. }
  502. c.requests = make(map[int32]*request)
  503. c.requestsLock.Unlock()
  504. }
  505. // Send error to all watchers and clear watchers map
  506. func (c *Conn) invalidateWatches(err error) {
  507. c.watchersLock.Lock()
  508. defer c.watchersLock.Unlock()
  509. if len(c.watchers) >= 0 {
  510. for pathType, watchers := range c.watchers {
  511. ev := Event{Type: EventNotWatching, State: StateDisconnected, Path: pathType.path, Err: err}
  512. for _, ch := range watchers {
  513. ch <- ev
  514. close(ch)
  515. }
  516. }
  517. c.watchers = make(map[watchPathType][]chan Event)
  518. }
  519. }
  520. func (c *Conn) sendSetWatches() {
  521. c.watchersLock.Lock()
  522. defer c.watchersLock.Unlock()
  523. if len(c.watchers) == 0 {
  524. return
  525. }
  526. // NB: A ZK server, by default, rejects packets >1mb. So, if we have too
  527. // many watches to reset, we need to break this up into multiple packets
  528. // to avoid hitting that limit. Mirroring the Java client behavior: we are
  529. // conservative in that we limit requests to 128kb (since server limit is
  530. // is actually configurable and could conceivably be configured smaller
  531. // than default of 1mb).
  532. limit := 128 * 1024
  533. if c.setWatchLimit > 0 {
  534. limit = c.setWatchLimit
  535. }
  536. var reqs []*setWatchesRequest
  537. var req *setWatchesRequest
  538. var sizeSoFar int
  539. n := 0
  540. for pathType, watchers := range c.watchers {
  541. if len(watchers) == 0 {
  542. continue
  543. }
  544. addlLen := 4 + len(pathType.path)
  545. if req == nil || sizeSoFar+addlLen > limit {
  546. if req != nil {
  547. // add to set of requests that we'll send
  548. reqs = append(reqs, req)
  549. }
  550. sizeSoFar = 28 // fixed overhead of a set-watches packet
  551. req = &setWatchesRequest{
  552. RelativeZxid: c.lastZxid,
  553. DataWatches: make([]string, 0),
  554. ExistWatches: make([]string, 0),
  555. ChildWatches: make([]string, 0),
  556. }
  557. }
  558. sizeSoFar += addlLen
  559. switch pathType.wType {
  560. case watchTypeData:
  561. req.DataWatches = append(req.DataWatches, pathType.path)
  562. case watchTypeExist:
  563. req.ExistWatches = append(req.ExistWatches, pathType.path)
  564. case watchTypeChild:
  565. req.ChildWatches = append(req.ChildWatches, pathType.path)
  566. }
  567. n++
  568. }
  569. if n == 0 {
  570. return
  571. }
  572. if req != nil { // don't forget any trailing packet we were building
  573. reqs = append(reqs, req)
  574. }
  575. if c.setWatchCallback != nil {
  576. c.setWatchCallback(reqs)
  577. }
  578. go func() {
  579. res := &setWatchesResponse{}
  580. // TODO: Pipeline these so queue all of them up before waiting on any
  581. // response. That will require some investigation to make sure there
  582. // aren't failure modes where a blocking write to the channel of requests
  583. // could hang indefinitely and cause this goroutine to leak...
  584. for _, req := range reqs {
  585. _, err := c.request(opSetWatches, req, res, nil)
  586. if err != nil {
  587. c.logger.Printf("Failed to set previous watches: %s", err.Error())
  588. break
  589. }
  590. }
  591. }()
  592. }
  593. func (c *Conn) authenticate() error {
  594. buf := make([]byte, 256)
  595. // Encode and send a connect request.
  596. n, err := encodePacket(buf[4:], &connectRequest{
  597. ProtocolVersion: protocolVersion,
  598. LastZxidSeen: c.lastZxid,
  599. TimeOut: c.sessionTimeoutMs,
  600. SessionID: c.SessionID(),
  601. Passwd: c.passwd,
  602. })
  603. if err != nil {
  604. return err
  605. }
  606. binary.BigEndian.PutUint32(buf[:4], uint32(n))
  607. if err := c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout * 10)); err != nil {
  608. return err
  609. }
  610. _, err = c.conn.Write(buf[:n+4])
  611. if err != nil {
  612. return err
  613. }
  614. if err := c.conn.SetWriteDeadline(time.Time{}); err != nil {
  615. return err
  616. }
  617. // Receive and decode a connect response.
  618. if err := c.conn.SetReadDeadline(time.Now().Add(c.recvTimeout * 10)); err != nil {
  619. return err
  620. }
  621. _, err = io.ReadFull(c.conn, buf[:4])
  622. if err != nil {
  623. return err
  624. }
  625. if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
  626. return err
  627. }
  628. blen := int(binary.BigEndian.Uint32(buf[:4]))
  629. if cap(buf) < blen {
  630. buf = make([]byte, blen)
  631. }
  632. _, err = io.ReadFull(c.conn, buf[:blen])
  633. if err != nil {
  634. return err
  635. }
  636. r := connectResponse{}
  637. _, err = decodePacket(buf[:blen], &r)
  638. if err != nil {
  639. return err
  640. }
  641. if r.SessionID == 0 {
  642. atomic.StoreInt64(&c.sessionID, int64(0))
  643. c.passwd = emptyPassword
  644. c.lastZxid = 0
  645. c.setState(StateExpired)
  646. return ErrSessionExpired
  647. }
  648. atomic.StoreInt64(&c.sessionID, r.SessionID)
  649. c.setTimeouts(r.TimeOut)
  650. c.passwd = r.Passwd
  651. c.setState(StateHasSession)
  652. return nil
  653. }
  654. func (c *Conn) sendData(req *request) error {
  655. header := &requestHeader{req.xid, req.opcode}
  656. n, err := encodePacket(c.buf[4:], header)
  657. if err != nil {
  658. req.recvChan <- response{-1, err}
  659. return nil
  660. }
  661. n2, err := encodePacket(c.buf[4+n:], req.pkt)
  662. if err != nil {
  663. req.recvChan <- response{-1, err}
  664. return nil
  665. }
  666. n += n2
  667. binary.BigEndian.PutUint32(c.buf[:4], uint32(n))
  668. c.requestsLock.Lock()
  669. select {
  670. case <-c.closeChan:
  671. req.recvChan <- response{-1, ErrConnectionClosed}
  672. c.requestsLock.Unlock()
  673. return ErrConnectionClosed
  674. default:
  675. }
  676. c.requests[req.xid] = req
  677. c.requestsLock.Unlock()
  678. if err := c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)); err != nil {
  679. return err
  680. }
  681. _, err = c.conn.Write(c.buf[:n+4])
  682. if err != nil {
  683. req.recvChan <- response{-1, err}
  684. c.conn.Close()
  685. return err
  686. }
  687. if err := c.conn.SetWriteDeadline(time.Time{}); err != nil {
  688. return err
  689. }
  690. return nil
  691. }
  692. func (c *Conn) sendLoop() error {
  693. pingTicker := time.NewTicker(c.pingInterval)
  694. defer pingTicker.Stop()
  695. for {
  696. select {
  697. case req := <-c.sendChan:
  698. if err := c.sendData(req); err != nil {
  699. return err
  700. }
  701. case <-pingTicker.C:
  702. n, err := encodePacket(c.buf[4:], &requestHeader{Xid: -2, Opcode: opPing})
  703. if err != nil {
  704. panic("zk: opPing should never fail to serialize")
  705. }
  706. binary.BigEndian.PutUint32(c.buf[:4], uint32(n))
  707. if err := c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)); err != nil {
  708. return err
  709. }
  710. _, err = c.conn.Write(c.buf[:n+4])
  711. if err != nil {
  712. c.conn.Close()
  713. return err
  714. }
  715. if err := c.conn.SetWriteDeadline(time.Time{}); err != nil {
  716. return err
  717. }
  718. case <-c.closeChan:
  719. return nil
  720. }
  721. }
  722. }
  723. func (c *Conn) recvLoop(conn net.Conn) error {
  724. sz := bufferSize
  725. if c.maxBufferSize > 0 && sz > c.maxBufferSize {
  726. sz = c.maxBufferSize
  727. }
  728. buf := make([]byte, sz)
  729. for {
  730. // package length
  731. if err := conn.SetReadDeadline(time.Now().Add(c.recvTimeout)); err != nil {
  732. c.logger.Printf("failed to set connection deadline: %v", err)
  733. }
  734. _, err := io.ReadFull(conn, buf[:4])
  735. if err != nil {
  736. return fmt.Errorf("failed to read from connection: %v", err)
  737. }
  738. blen := int(binary.BigEndian.Uint32(buf[:4]))
  739. if cap(buf) < blen {
  740. if c.maxBufferSize > 0 && blen > c.maxBufferSize {
  741. return fmt.Errorf("received packet from server with length %d, which exceeds max buffer size %d", blen, c.maxBufferSize)
  742. }
  743. buf = make([]byte, blen)
  744. }
  745. _, err = io.ReadFull(conn, buf[:blen])
  746. if err != nil {
  747. return err
  748. }
  749. if err := conn.SetReadDeadline(time.Time{}); err != nil {
  750. return err
  751. }
  752. res := responseHeader{}
  753. _, err = decodePacket(buf[:16], &res)
  754. if err != nil {
  755. return err
  756. }
  757. if res.Xid == -1 {
  758. res := &watcherEvent{}
  759. _, err := decodePacket(buf[16:blen], res)
  760. if err != nil {
  761. return err
  762. }
  763. ev := Event{
  764. Type: res.Type,
  765. State: res.State,
  766. Path: res.Path,
  767. Err: nil,
  768. }
  769. c.sendEvent(ev)
  770. wTypes := make([]watchType, 0, 2)
  771. switch res.Type {
  772. case EventNodeCreated:
  773. wTypes = append(wTypes, watchTypeExist)
  774. case EventNodeDeleted, EventNodeDataChanged:
  775. wTypes = append(wTypes, watchTypeExist, watchTypeData, watchTypeChild)
  776. case EventNodeChildrenChanged:
  777. wTypes = append(wTypes, watchTypeChild)
  778. }
  779. c.watchersLock.Lock()
  780. for _, t := range wTypes {
  781. wpt := watchPathType{res.Path, t}
  782. if watchers, ok := c.watchers[wpt]; ok {
  783. for _, ch := range watchers {
  784. ch <- ev
  785. close(ch)
  786. }
  787. delete(c.watchers, wpt)
  788. }
  789. }
  790. c.watchersLock.Unlock()
  791. } else if res.Xid == -2 {
  792. // Ping response. Ignore.
  793. } else if res.Xid < 0 {
  794. c.logger.Printf("Xid < 0 (%d) but not ping or watcher event", res.Xid)
  795. } else {
  796. if res.Zxid > 0 {
  797. c.lastZxid = res.Zxid
  798. }
  799. c.requestsLock.Lock()
  800. req, ok := c.requests[res.Xid]
  801. if ok {
  802. delete(c.requests, res.Xid)
  803. }
  804. c.requestsLock.Unlock()
  805. if !ok {
  806. c.logger.Printf("Response for unknown request with xid %d", res.Xid)
  807. } else {
  808. if res.Err != 0 {
  809. err = res.Err.toError()
  810. } else {
  811. _, err = decodePacket(buf[16:blen], req.recvStruct)
  812. }
  813. if req.recvFunc != nil {
  814. req.recvFunc(req, &res, err)
  815. }
  816. req.recvChan <- response{res.Zxid, err}
  817. if req.opcode == opClose {
  818. return io.EOF
  819. }
  820. }
  821. }
  822. }
  823. }
  824. func (c *Conn) nextXid() int32 {
  825. return int32(atomic.AddUint32(&c.xid, 1) & 0x7fffffff)
  826. }
  827. func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event {
  828. c.watchersLock.Lock()
  829. defer c.watchersLock.Unlock()
  830. ch := make(chan Event, 1)
  831. wpt := watchPathType{path, watchType}
  832. c.watchers[wpt] = append(c.watchers[wpt], ch)
  833. return ch
  834. }
  835. func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response {
  836. rq := &request{
  837. xid: c.nextXid(),
  838. opcode: opcode,
  839. pkt: req,
  840. recvStruct: res,
  841. recvChan: make(chan response, 1),
  842. recvFunc: recvFunc,
  843. }
  844. c.sendChan <- rq
  845. return rq.recvChan
  846. }
  847. func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) {
  848. r := <-c.queueRequest(opcode, req, res, recvFunc)
  849. return r.zxid, r.err
  850. }
  851. func (c *Conn) AddAuth(scheme string, auth []byte) error {
  852. _, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil)
  853. if err != nil {
  854. return err
  855. }
  856. // Remember authdata so that it can be re-submitted on reconnect
  857. //
  858. // FIXME(prozlach): For now we treat "userfoo:passbar" and "userfoo:passbar2"
  859. // as two different entries, which will be re-submitted on reconnet. Some
  860. // research is needed on how ZK treats these cases and
  861. // then maybe switch to something like "map[username] = password" to allow
  862. // only single password for given user with users being unique.
  863. obj := authCreds{
  864. scheme: scheme,
  865. auth: auth,
  866. }
  867. c.credsMu.Lock()
  868. c.creds = append(c.creds, obj)
  869. c.credsMu.Unlock()
  870. return nil
  871. }
  872. func (c *Conn) Children(path string) ([]string, *Stat, error) {
  873. if err := validatePath(path, false); err != nil {
  874. return nil, nil, err
  875. }
  876. res := &getChildren2Response{}
  877. _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: false}, res, nil)
  878. return res.Children, &res.Stat, err
  879. }
  880. func (c *Conn) ChildrenW(path string) ([]string, *Stat, <-chan Event, error) {
  881. if err := validatePath(path, false); err != nil {
  882. return nil, nil, nil, err
  883. }
  884. var ech <-chan Event
  885. res := &getChildren2Response{}
  886. _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) {
  887. if err == nil {
  888. ech = c.addWatcher(path, watchTypeChild)
  889. }
  890. })
  891. if err != nil {
  892. return nil, nil, nil, err
  893. }
  894. return res.Children, &res.Stat, ech, err
  895. }
  896. func (c *Conn) Get(path string) ([]byte, *Stat, error) {
  897. if err := validatePath(path, false); err != nil {
  898. return nil, nil, err
  899. }
  900. res := &getDataResponse{}
  901. _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: false}, res, nil)
  902. return res.Data, &res.Stat, err
  903. }
  904. // GetW returns the contents of a znode and sets a watch
  905. func (c *Conn) GetW(path string) ([]byte, *Stat, <-chan Event, error) {
  906. if err := validatePath(path, false); err != nil {
  907. return nil, nil, nil, err
  908. }
  909. var ech <-chan Event
  910. res := &getDataResponse{}
  911. _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) {
  912. if err == nil {
  913. ech = c.addWatcher(path, watchTypeData)
  914. }
  915. })
  916. if err != nil {
  917. return nil, nil, nil, err
  918. }
  919. return res.Data, &res.Stat, ech, err
  920. }
  921. func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) {
  922. if err := validatePath(path, false); err != nil {
  923. return nil, err
  924. }
  925. res := &setDataResponse{}
  926. _, err := c.request(opSetData, &SetDataRequest{path, data, version}, res, nil)
  927. return &res.Stat, err
  928. }
  929. func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, error) {
  930. if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil {
  931. return "", err
  932. }
  933. res := &createResponse{}
  934. _, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil)
  935. return res.Path, err
  936. }
  937. // CreateProtectedEphemeralSequential fixes a race condition if the server crashes
  938. // after it creates the node. On reconnect the session may still be valid so the
  939. // ephemeral node still exists. Therefore, on reconnect we need to check if a node
  940. // with a GUID generated on create exists.
  941. func (c *Conn) CreateProtectedEphemeralSequential(path string, data []byte, acl []ACL) (string, error) {
  942. if err := validatePath(path, true); err != nil {
  943. return "", err
  944. }
  945. var guid [16]byte
  946. _, err := io.ReadFull(rand.Reader, guid[:16])
  947. if err != nil {
  948. return "", err
  949. }
  950. guidStr := fmt.Sprintf("%x", guid)
  951. parts := strings.Split(path, "/")
  952. parts[len(parts)-1] = fmt.Sprintf("%s%s-%s", protectedPrefix, guidStr, parts[len(parts)-1])
  953. rootPath := strings.Join(parts[:len(parts)-1], "/")
  954. protectedPath := strings.Join(parts, "/")
  955. var newPath string
  956. for i := 0; i < 3; i++ {
  957. newPath, err = c.Create(protectedPath, data, FlagEphemeral|FlagSequence, acl)
  958. switch err {
  959. case ErrSessionExpired:
  960. // No need to search for the node since it can't exist. Just try again.
  961. case ErrConnectionClosed:
  962. children, _, err := c.Children(rootPath)
  963. if err != nil {
  964. return "", err
  965. }
  966. for _, p := range children {
  967. parts := strings.Split(p, "/")
  968. if pth := parts[len(parts)-1]; strings.HasPrefix(pth, protectedPrefix) {
  969. if g := pth[len(protectedPrefix) : len(protectedPrefix)+32]; g == guidStr {
  970. return rootPath + "/" + p, nil
  971. }
  972. }
  973. }
  974. case nil:
  975. return newPath, nil
  976. default:
  977. return "", err
  978. }
  979. }
  980. return "", err
  981. }
  982. func (c *Conn) Delete(path string, version int32) error {
  983. if err := validatePath(path, false); err != nil {
  984. return err
  985. }
  986. _, err := c.request(opDelete, &DeleteRequest{path, version}, &deleteResponse{}, nil)
  987. return err
  988. }
  989. func (c *Conn) Exists(path string) (bool, *Stat, error) {
  990. if err := validatePath(path, false); err != nil {
  991. return false, nil, err
  992. }
  993. res := &existsResponse{}
  994. _, err := c.request(opExists, &existsRequest{Path: path, Watch: false}, res, nil)
  995. exists := true
  996. if err == ErrNoNode {
  997. exists = false
  998. err = nil
  999. }
  1000. return exists, &res.Stat, err
  1001. }
  1002. func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) {
  1003. if err := validatePath(path, false); err != nil {
  1004. return false, nil, nil, err
  1005. }
  1006. var ech <-chan Event
  1007. res := &existsResponse{}
  1008. _, err := c.request(opExists, &existsRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) {
  1009. if err == nil {
  1010. ech = c.addWatcher(path, watchTypeData)
  1011. } else if err == ErrNoNode {
  1012. ech = c.addWatcher(path, watchTypeExist)
  1013. }
  1014. })
  1015. exists := true
  1016. if err == ErrNoNode {
  1017. exists = false
  1018. err = nil
  1019. }
  1020. if err != nil {
  1021. return false, nil, nil, err
  1022. }
  1023. return exists, &res.Stat, ech, err
  1024. }
  1025. func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) {
  1026. if err := validatePath(path, false); err != nil {
  1027. return nil, nil, err
  1028. }
  1029. res := &getAclResponse{}
  1030. _, err := c.request(opGetAcl, &getAclRequest{Path: path}, res, nil)
  1031. return res.Acl, &res.Stat, err
  1032. }
  1033. func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) {
  1034. if err := validatePath(path, false); err != nil {
  1035. return nil, err
  1036. }
  1037. res := &setAclResponse{}
  1038. _, err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res, nil)
  1039. return &res.Stat, err
  1040. }
  1041. func (c *Conn) Sync(path string) (string, error) {
  1042. if err := validatePath(path, false); err != nil {
  1043. return "", err
  1044. }
  1045. res := &syncResponse{}
  1046. _, err := c.request(opSync, &syncRequest{Path: path}, res, nil)
  1047. return res.Path, err
  1048. }
  1049. type MultiResponse struct {
  1050. Stat *Stat
  1051. String string
  1052. Error error
  1053. }
  1054. // Multi executes multiple ZooKeeper operations or none of them. The provided
  1055. // ops must be one of *CreateRequest, *DeleteRequest, *SetDataRequest, or
  1056. // *CheckVersionRequest.
  1057. func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) {
  1058. req := &multiRequest{
  1059. Ops: make([]multiRequestOp, 0, len(ops)),
  1060. DoneHeader: multiHeader{Type: -1, Done: true, Err: -1},
  1061. }
  1062. for _, op := range ops {
  1063. var opCode int32
  1064. switch op.(type) {
  1065. case *CreateRequest:
  1066. opCode = opCreate
  1067. case *SetDataRequest:
  1068. opCode = opSetData
  1069. case *DeleteRequest:
  1070. opCode = opDelete
  1071. case *CheckVersionRequest:
  1072. opCode = opCheck
  1073. default:
  1074. return nil, fmt.Errorf("unknown operation type %T", op)
  1075. }
  1076. req.Ops = append(req.Ops, multiRequestOp{multiHeader{opCode, false, -1}, op})
  1077. }
  1078. res := &multiResponse{}
  1079. _, err := c.request(opMulti, req, res, nil)
  1080. mr := make([]MultiResponse, len(res.Ops))
  1081. for i, op := range res.Ops {
  1082. mr[i] = MultiResponse{Stat: op.Stat, String: op.String, Error: op.Err.toError()}
  1083. }
  1084. return mr, err
  1085. }
  1086. // IncrementalReconfig is the zookeeper reconfiguration api that allows adding and removing servers
  1087. // by lists of members.
  1088. // Return the new configuration stats.
  1089. func (c *Conn) IncrementalReconfig(joining, leaving []string, version int64) (*Stat, error) {
  1090. // TODO: validate the shape of the member string to give early feedback.
  1091. request := &reconfigRequest{
  1092. JoiningServers: []byte(strings.Join(joining, ",")),
  1093. LeavingServers: []byte(strings.Join(leaving, ",")),
  1094. CurConfigId: version,
  1095. }
  1096. return c.internalReconfig(request)
  1097. }
  1098. // Reconfig is the non-incremental update functionality for Zookeeper where the list preovided
  1099. // is the entire new member list.
  1100. // the optional version allows for conditional reconfigurations, -1 ignores the condition.
  1101. func (c *Conn) Reconfig(members []string, version int64) (*Stat, error) {
  1102. request := &reconfigRequest{
  1103. NewMembers: []byte(strings.Join(members, ",")),
  1104. CurConfigId: version,
  1105. }
  1106. return c.internalReconfig(request)
  1107. }
  1108. func (c *Conn) internalReconfig(request *reconfigRequest) (*Stat, error) {
  1109. response := &reconfigReponse{}
  1110. _, err := c.request(opReconfig, request, response, nil)
  1111. return &response.Stat, err
  1112. }
  1113. // Server returns the current or last-connected server name.
  1114. func (c *Conn) Server() string {
  1115. c.serverMu.Lock()
  1116. defer c.serverMu.Unlock()
  1117. return c.server
  1118. }