conn_go18.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. package pq
  2. import (
  3. "context"
  4. "database/sql"
  5. "database/sql/driver"
  6. "fmt"
  7. "io"
  8. "io/ioutil"
  9. "sync/atomic"
  10. "time"
  11. )
  12. // Implement the "QueryerContext" interface
  13. func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
  14. list := make([]driver.Value, len(args))
  15. for i, nv := range args {
  16. list[i] = nv.Value
  17. }
  18. finish := cn.watchCancel(ctx)
  19. r, err := cn.query(query, list)
  20. if err != nil {
  21. if finish != nil {
  22. finish()
  23. }
  24. return nil, err
  25. }
  26. r.finish = finish
  27. return r, nil
  28. }
  29. // Implement the "ExecerContext" interface
  30. func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
  31. list := make([]driver.Value, len(args))
  32. for i, nv := range args {
  33. list[i] = nv.Value
  34. }
  35. if finish := cn.watchCancel(ctx); finish != nil {
  36. defer finish()
  37. }
  38. return cn.Exec(query, list)
  39. }
  40. // Implement the "ConnBeginTx" interface
  41. func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
  42. var mode string
  43. switch sql.IsolationLevel(opts.Isolation) {
  44. case sql.LevelDefault:
  45. // Don't touch mode: use the server's default
  46. case sql.LevelReadUncommitted:
  47. mode = " ISOLATION LEVEL READ UNCOMMITTED"
  48. case sql.LevelReadCommitted:
  49. mode = " ISOLATION LEVEL READ COMMITTED"
  50. case sql.LevelRepeatableRead:
  51. mode = " ISOLATION LEVEL REPEATABLE READ"
  52. case sql.LevelSerializable:
  53. mode = " ISOLATION LEVEL SERIALIZABLE"
  54. default:
  55. return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation)
  56. }
  57. if opts.ReadOnly {
  58. mode += " READ ONLY"
  59. } else {
  60. mode += " READ WRITE"
  61. }
  62. tx, err := cn.begin(mode)
  63. if err != nil {
  64. return nil, err
  65. }
  66. cn.txnFinish = cn.watchCancel(ctx)
  67. return tx, nil
  68. }
  69. func (cn *conn) Ping(ctx context.Context) error {
  70. if finish := cn.watchCancel(ctx); finish != nil {
  71. defer finish()
  72. }
  73. rows, err := cn.simpleQuery(";")
  74. if err != nil {
  75. return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger
  76. }
  77. rows.Close()
  78. return nil
  79. }
  80. func (cn *conn) watchCancel(ctx context.Context) func() {
  81. if done := ctx.Done(); done != nil {
  82. finished := make(chan struct{}, 1)
  83. go func() {
  84. select {
  85. case <-done:
  86. select {
  87. case finished <- struct{}{}:
  88. default:
  89. // We raced with the finish func, let the next query handle this with the
  90. // context.
  91. return
  92. }
  93. // Set the connection state to bad so it does not get reused.
  94. cn.setBad()
  95. // At this point the function level context is canceled,
  96. // so it must not be used for the additional network
  97. // request to cancel the query.
  98. // Create a new context to pass into the dial.
  99. ctxCancel, cancel := context.WithTimeout(context.Background(), time.Second*10)
  100. defer cancel()
  101. _ = cn.cancel(ctxCancel)
  102. case <-finished:
  103. }
  104. }()
  105. return func() {
  106. select {
  107. case <-finished:
  108. cn.setBad()
  109. cn.Close()
  110. case finished <- struct{}{}:
  111. }
  112. }
  113. }
  114. return nil
  115. }
  116. func (cn *conn) cancel(ctx context.Context) error {
  117. c, err := dial(ctx, cn.dialer, cn.opts)
  118. if err != nil {
  119. return err
  120. }
  121. defer c.Close()
  122. {
  123. bad := &atomic.Value{}
  124. bad.Store(false)
  125. can := conn{
  126. c: c,
  127. bad: bad,
  128. }
  129. err = can.ssl(cn.opts)
  130. if err != nil {
  131. return err
  132. }
  133. w := can.writeBuf(0)
  134. w.int32(80877102) // cancel request code
  135. w.int32(cn.processID)
  136. w.int32(cn.secretKey)
  137. if err := can.sendStartupPacket(w); err != nil {
  138. return err
  139. }
  140. }
  141. // Read until EOF to ensure that the server received the cancel.
  142. {
  143. _, err := io.Copy(ioutil.Discard, c)
  144. return err
  145. }
  146. }