1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032 |
- package pq
- import (
- "bufio"
- "context"
- "crypto/md5"
- "crypto/sha256"
- "database/sql"
- "database/sql/driver"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "net"
- "os"
- "os/user"
- "path"
- "path/filepath"
- "strconv"
- "strings"
- "sync/atomic"
- "time"
- "unicode"
- "github.com/lib/pq/oid"
- "github.com/lib/pq/scram"
- )
- var (
- ErrNotSupported = errors.New("pq: Unsupported command")
- ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
- ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
- ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less")
- ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly")
- errUnexpectedReady = errors.New("unexpected ReadyForQuery")
- errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
- errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
- )
- var (
- _ driver.Driver = Driver{}
- )
- type Driver struct{}
- func (d Driver) Open(name string) (driver.Conn, error) {
- return Open(name)
- }
- func init() {
- sql.Register("postgres", &Driver{})
- }
- type parameterStatus struct {
-
-
- serverVersion int
-
-
- currentLocation *time.Location
- }
- type transactionStatus byte
- const (
- txnStatusIdle transactionStatus = 'I'
- txnStatusIdleInTransaction transactionStatus = 'T'
- txnStatusInFailedTransaction transactionStatus = 'E'
- )
- func (s transactionStatus) String() string {
- switch s {
- case txnStatusIdle:
- return "idle"
- case txnStatusIdleInTransaction:
- return "idle in transaction"
- case txnStatusInFailedTransaction:
- return "in a failed transaction"
- default:
- errorf("unknown transactionStatus %d", s)
- }
- panic("not reached")
- }
- type Dialer interface {
- Dial(network, address string) (net.Conn, error)
- DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
- }
- type DialerContext interface {
- DialContext(ctx context.Context, network, address string) (net.Conn, error)
- }
- type defaultDialer struct {
- d net.Dialer
- }
- func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
- return d.d.Dial(network, address)
- }
- func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
- ctx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- return d.DialContext(ctx, network, address)
- }
- func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
- return d.d.DialContext(ctx, network, address)
- }
- type conn struct {
- c net.Conn
- buf *bufio.Reader
- namei int
- scratch [512]byte
- txnStatus transactionStatus
- txnFinish func()
-
- dialer Dialer
- opts values
-
- processID int
- secretKey int
- parameterStatus parameterStatus
- saveMessageType byte
- saveMessageBuffer []byte
-
-
- bad *atomic.Value
-
-
-
- disablePreparedBinaryResult bool
-
-
- binaryParameters bool
-
- inCopy bool
-
- noticeHandler func(*Error)
-
- notificationHandler func(*Notification)
-
- gss GSS
- }
- func (cn *conn) handleDriverSettings(o values) (err error) {
- boolSetting := func(key string, val *bool) error {
- if value, ok := o[key]; ok {
- if value == "yes" {
- *val = true
- } else if value == "no" {
- *val = false
- } else {
- return fmt.Errorf("unrecognized value %q for %s", value, key)
- }
- }
- return nil
- }
- err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
- if err != nil {
- return err
- }
- return boolSetting("binary_parameters", &cn.binaryParameters)
- }
- func (cn *conn) handlePgpass(o values) {
-
- if _, ok := o["password"]; ok {
- return
- }
- filename := os.Getenv("PGPASSFILE")
- if filename == "" {
-
-
-
- userHome := os.Getenv("HOME")
- if userHome == "" {
- user, err := user.Current()
- if err != nil {
- return
- }
- userHome = user.HomeDir
- }
- filename = filepath.Join(userHome, ".pgpass")
- }
- fileinfo, err := os.Stat(filename)
- if err != nil {
- return
- }
- mode := fileinfo.Mode()
- if mode&(0x77) != 0 {
-
- return
- }
- file, err := os.Open(filename)
- if err != nil {
- return
- }
- defer file.Close()
- scanner := bufio.NewScanner(io.Reader(file))
- hostname := o["host"]
- ntw, _ := network(o)
- port := o["port"]
- db := o["dbname"]
- username := o["user"]
-
- getFields := func(s string) []string {
- fs := make([]string, 0, 5)
- f := make([]rune, 0, len(s))
- var esc bool
- for _, c := range s {
- switch {
- case esc:
- f = append(f, c)
- esc = false
- case c == '\\':
- esc = true
- case c == ':':
- fs = append(fs, string(f))
- f = f[:0]
- default:
- f = append(f, c)
- }
- }
- return append(fs, string(f))
- }
- for scanner.Scan() {
- line := scanner.Text()
- if len(line) == 0 || line[0] == '#' {
- continue
- }
- split := getFields(line)
- if len(split) != 5 {
- continue
- }
- if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
- o["password"] = split[4]
- return
- }
- }
- }
- func (cn *conn) writeBuf(b byte) *writeBuf {
- cn.scratch[0] = b
- return &writeBuf{
- buf: cn.scratch[:5],
- pos: 1,
- }
- }
- func Open(dsn string) (_ driver.Conn, err error) {
- return DialOpen(defaultDialer{}, dsn)
- }
- func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
- c, err := NewConnector(dsn)
- if err != nil {
- return nil, err
- }
- c.dialer = d
- return c.open(context.Background())
- }
- func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
-
-
-
-
- defer errRecoverNoErrBadConn(&err)
- o := c.opts
- bad := &atomic.Value{}
- bad.Store(false)
- cn = &conn{
- opts: o,
- dialer: c.dialer,
- bad: bad,
- }
- err = cn.handleDriverSettings(o)
- if err != nil {
- return nil, err
- }
- cn.handlePgpass(o)
- cn.c, err = dial(ctx, c.dialer, o)
- if err != nil {
- return nil, err
- }
- err = cn.ssl(o)
- if err != nil {
- if cn.c != nil {
- cn.c.Close()
- }
- return nil, err
- }
-
- panicking := true
- defer func() {
- if panicking {
- cn.c.Close()
- }
- }()
- cn.buf = bufio.NewReader(cn.c)
- cn.startup(o)
-
- if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
- err = cn.c.SetDeadline(time.Time{})
- }
- panicking = false
- return cn, err
- }
- func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
- network, address := network(o)
-
- if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
- seconds, err := strconv.ParseInt(timeout, 10, 0)
- if err != nil {
- return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
- }
- duration := time.Duration(seconds) * time.Second
-
-
-
-
- deadline := time.Now().Add(duration)
- var conn net.Conn
- if dctx, ok := d.(DialerContext); ok {
- ctx, cancel := context.WithTimeout(ctx, duration)
- defer cancel()
- conn, err = dctx.DialContext(ctx, network, address)
- } else {
- conn, err = d.DialTimeout(network, address, duration)
- }
- if err != nil {
- return nil, err
- }
- err = conn.SetDeadline(deadline)
- return conn, err
- }
- if dctx, ok := d.(DialerContext); ok {
- return dctx.DialContext(ctx, network, address)
- }
- return d.Dial(network, address)
- }
- func network(o values) (string, string) {
- host := o["host"]
- if strings.HasPrefix(host, "/") {
- sockPath := path.Join(host, ".s.PGSQL."+o["port"])
- return "unix", sockPath
- }
- return "tcp", net.JoinHostPort(host, o["port"])
- }
- type values map[string]string
- type scanner struct {
- s []rune
- i int
- }
- func newScanner(s string) *scanner {
- return &scanner{[]rune(s), 0}
- }
- func (s *scanner) Next() (rune, bool) {
- if s.i >= len(s.s) {
- return 0, false
- }
- r := s.s[s.i]
- s.i++
- return r, true
- }
- func (s *scanner) SkipSpaces() (rune, bool) {
- r, ok := s.Next()
- for unicode.IsSpace(r) && ok {
- r, ok = s.Next()
- }
- return r, ok
- }
- func parseOpts(name string, o values) error {
- s := newScanner(name)
- for {
- var (
- keyRunes, valRunes []rune
- r rune
- ok bool
- )
- if r, ok = s.SkipSpaces(); !ok {
- break
- }
-
- for !unicode.IsSpace(r) && r != '=' {
- keyRunes = append(keyRunes, r)
- if r, ok = s.Next(); !ok {
- break
- }
- }
-
- if r != '=' {
- r, ok = s.SkipSpaces()
- }
-
- if r != '=' || !ok {
- return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
- }
-
- if r, ok = s.SkipSpaces(); !ok {
-
- o[string(keyRunes)] = ""
- break
- }
- if r != '\'' {
- for !unicode.IsSpace(r) {
- if r == '\\' {
- if r, ok = s.Next(); !ok {
- return fmt.Errorf(`missing character after backslash`)
- }
- }
- valRunes = append(valRunes, r)
- if r, ok = s.Next(); !ok {
- break
- }
- }
- } else {
- quote:
- for {
- if r, ok = s.Next(); !ok {
- return fmt.Errorf(`unterminated quoted string literal in connection string`)
- }
- switch r {
- case '\'':
- break quote
- case '\\':
- r, _ = s.Next()
- fallthrough
- default:
- valRunes = append(valRunes, r)
- }
- }
- }
- o[string(keyRunes)] = string(valRunes)
- }
- return nil
- }
- func (cn *conn) isInTransaction() bool {
- return cn.txnStatus == txnStatusIdleInTransaction ||
- cn.txnStatus == txnStatusInFailedTransaction
- }
- func (cn *conn) setBad() {
- if cn.bad != nil {
- cn.bad.Store(true)
- }
- }
- func (cn *conn) getBad() bool {
- if cn.bad != nil {
- return cn.bad.Load().(bool)
- }
- return false
- }
- func (cn *conn) checkIsInTransaction(intxn bool) {
- if cn.isInTransaction() != intxn {
- cn.setBad()
- errorf("unexpected transaction status %v", cn.txnStatus)
- }
- }
- func (cn *conn) Begin() (_ driver.Tx, err error) {
- return cn.begin("")
- }
- func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
- if cn.getBad() {
- return nil, driver.ErrBadConn
- }
- defer cn.errRecover(&err)
- cn.checkIsInTransaction(false)
- _, commandTag, err := cn.simpleExec("BEGIN" + mode)
- if err != nil {
- return nil, err
- }
- if commandTag != "BEGIN" {
- cn.setBad()
- return nil, fmt.Errorf("unexpected command tag %s", commandTag)
- }
- if cn.txnStatus != txnStatusIdleInTransaction {
- cn.setBad()
- return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
- }
- return cn, nil
- }
- func (cn *conn) closeTxn() {
- if finish := cn.txnFinish; finish != nil {
- finish()
- }
- }
- func (cn *conn) Commit() (err error) {
- defer cn.closeTxn()
- if cn.getBad() {
- return driver.ErrBadConn
- }
- defer cn.errRecover(&err)
- cn.checkIsInTransaction(true)
-
-
-
-
-
-
- if cn.txnStatus == txnStatusInFailedTransaction {
- if err := cn.rollback(); err != nil {
- return err
- }
- return ErrInFailedTransaction
- }
- _, commandTag, err := cn.simpleExec("COMMIT")
- if err != nil {
- if cn.isInTransaction() {
- cn.setBad()
- }
- return err
- }
- if commandTag != "COMMIT" {
- cn.setBad()
- return fmt.Errorf("unexpected command tag %s", commandTag)
- }
- cn.checkIsInTransaction(false)
- return nil
- }
- func (cn *conn) Rollback() (err error) {
- defer cn.closeTxn()
- if cn.getBad() {
- return driver.ErrBadConn
- }
- defer cn.errRecover(&err)
- return cn.rollback()
- }
- func (cn *conn) rollback() (err error) {
- cn.checkIsInTransaction(true)
- _, commandTag, err := cn.simpleExec("ROLLBACK")
- if err != nil {
- if cn.isInTransaction() {
- cn.setBad()
- }
- return err
- }
- if commandTag != "ROLLBACK" {
- return fmt.Errorf("unexpected command tag %s", commandTag)
- }
- cn.checkIsInTransaction(false)
- return nil
- }
- func (cn *conn) gname() string {
- cn.namei++
- return strconv.FormatInt(int64(cn.namei), 10)
- }
- func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
- b := cn.writeBuf('Q')
- b.string(q)
- cn.send(b)
- for {
- t, r := cn.recv1()
- switch t {
- case 'C':
- res, commandTag = cn.parseComplete(r.string())
- case 'Z':
- cn.processReadyForQuery(r)
- if res == nil && err == nil {
- err = errUnexpectedReady
- }
-
- return
- case 'E':
- err = parseError(r)
- case 'I':
- res = emptyRows
- case 'T', 'D':
-
- default:
- cn.setBad()
- errorf("unknown response for simple query: %q", t)
- }
- }
- }
- func (cn *conn) simpleQuery(q string) (res *rows, err error) {
- defer cn.errRecover(&err)
- b := cn.writeBuf('Q')
- b.string(q)
- cn.send(b)
- for {
- t, r := cn.recv1()
- switch t {
- case 'C', 'I':
-
-
-
-
- if err != nil {
- cn.setBad()
- errorf("unexpected message %q in simple query execution", t)
- }
- if res == nil {
- res = &rows{
- cn: cn,
- }
- }
-
-
-
- if t == 'C' {
- res.result, res.tag = cn.parseComplete(r.string())
- if res.colNames != nil {
- return
- }
- }
- res.done = true
- case 'Z':
- cn.processReadyForQuery(r)
-
- return
- case 'E':
- res = nil
- err = parseError(r)
- case 'D':
- if res == nil {
- cn.setBad()
- errorf("unexpected DataRow in simple query execution")
- }
-
- cn.saveMessage(t, r)
- return
- case 'T':
-
-
- res = &rows{cn: cn}
- res.rowsHeader = parsePortalRowDescribe(r)
-
-
- default:
- cn.setBad()
- errorf("unknown response for simple query: %q", t)
- }
- }
- }
- type noRows struct{}
- var emptyRows noRows
- var _ driver.Result = noRows{}
- func (noRows) LastInsertId() (int64, error) {
- return 0, errNoLastInsertID
- }
- func (noRows) RowsAffected() (int64, error) {
- return 0, errNoRowsAffected
- }
- func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) {
- if len(colTyps) == 0 {
- return nil, colFmtDataAllText
- }
- colFmts = make([]format, len(colTyps))
- if forceText {
- return colFmts, colFmtDataAllText
- }
- allBinary := true
- allText := true
- for i, t := range colTyps {
- switch t.OID {
-
-
-
- case oid.T_bytea:
- fallthrough
- case oid.T_int8:
- fallthrough
- case oid.T_int4:
- fallthrough
- case oid.T_int2:
- fallthrough
- case oid.T_uuid:
- colFmts[i] = formatBinary
- allText = false
- default:
- allBinary = false
- }
- }
- if allBinary {
- return colFmts, colFmtDataAllBinary
- } else if allText {
- return colFmts, colFmtDataAllText
- } else {
- colFmtData = make([]byte, 2+len(colFmts)*2)
- binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
- for i, v := range colFmts {
- binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
- }
- return colFmts, colFmtData
- }
- }
- func (cn *conn) prepareTo(q, stmtName string) *stmt {
- st := &stmt{cn: cn, name: stmtName}
- b := cn.writeBuf('P')
- b.string(st.name)
- b.string(q)
- b.int16(0)
- b.next('D')
- b.byte('S')
- b.string(st.name)
- b.next('S')
- cn.send(b)
- cn.readParseResponse()
- st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
- st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
- cn.readReadyForQuery()
- return st
- }
- func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
- if cn.getBad() {
- return nil, driver.ErrBadConn
- }
- defer cn.errRecover(&err)
- if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
- s, err := cn.prepareCopyIn(q)
- if err == nil {
- cn.inCopy = true
- }
- return s, err
- }
- return cn.prepareTo(q, cn.gname()), nil
- }
- func (cn *conn) Close() (err error) {
-
- defer cn.errRecover(&err)
-
-
- defer func() {
- cerr := cn.c.Close()
- if err == nil {
- err = cerr
- }
- }()
-
-
- return cn.sendSimpleMessage('X')
- }
- func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
- return cn.query(query, args)
- }
- func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
- if cn.getBad() {
- return nil, driver.ErrBadConn
- }
- if cn.inCopy {
- return nil, errCopyInProgress
- }
- defer cn.errRecover(&err)
-
-
- if len(args) == 0 {
- return cn.simpleQuery(query)
- }
- if cn.binaryParameters {
- cn.sendBinaryModeQuery(query, args)
- cn.readParseResponse()
- cn.readBindResponse()
- rows := &rows{cn: cn}
- rows.rowsHeader = cn.readPortalDescribeResponse()
- cn.postExecuteWorkaround()
- return rows, nil
- }
- st := cn.prepareTo(query, "")
- st.exec(args)
- return &rows{
- cn: cn,
- rowsHeader: st.rowsHeader,
- }, nil
- }
- func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
- if cn.getBad() {
- return nil, driver.ErrBadConn
- }
- defer cn.errRecover(&err)
-
-
- if len(args) == 0 {
-
- r, _, err := cn.simpleExec(query)
- return r, err
- }
- if cn.binaryParameters {
- cn.sendBinaryModeQuery(query, args)
- cn.readParseResponse()
- cn.readBindResponse()
- cn.readPortalDescribeResponse()
- cn.postExecuteWorkaround()
- res, _, err = cn.readExecuteResponse("Execute")
- return res, err
- }
-
-
-
- st := cn.prepareTo(query, "")
- r, err := st.Exec(args)
- if err != nil {
- panic(err)
- }
- return r, err
- }
- type safeRetryError struct {
- Err error
- }
- func (se *safeRetryError) Error() string {
- return se.Err.Error()
- }
- func (cn *conn) send(m *writeBuf) {
- n, err := cn.c.Write(m.wrap())
- if err != nil {
- if n == 0 {
- err = &safeRetryError{Err: err}
- }
- panic(err)
- }
- }
- func (cn *conn) sendStartupPacket(m *writeBuf) error {
- _, err := cn.c.Write((m.wrap())[1:])
- return err
- }
- func (cn *conn) sendSimpleMessage(typ byte) (err error) {
- _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
- return err
- }
- func (cn *conn) saveMessage(typ byte, buf *readBuf) {
- if cn.saveMessageType != 0 {
- cn.setBad()
- errorf("unexpected saveMessageType %d", cn.saveMessageType)
- }
- cn.saveMessageType = typ
- cn.saveMessageBuffer = *buf
- }
- func (cn *conn) recvMessage(r *readBuf) (byte, error) {
-
- if cn.saveMessageType != 0 {
- t := cn.saveMessageType
- *r = cn.saveMessageBuffer
- cn.saveMessageType = 0
- cn.saveMessageBuffer = nil
- return t, nil
- }
- x := cn.scratch[:5]
- _, err := io.ReadFull(cn.buf, x)
- if err != nil {
- return 0, err
- }
-
- t := x[0]
- n := int(binary.BigEndian.Uint32(x[1:])) - 4
- var y []byte
- if n <= len(cn.scratch) {
- y = cn.scratch[:n]
- } else {
- y = make([]byte, n)
- }
- _, err = io.ReadFull(cn.buf, y)
- if err != nil {
- return 0, err
- }
- *r = y
- return t, nil
- }
- func (cn *conn) recv() (t byte, r *readBuf) {
- for {
- var err error
- r = &readBuf{}
- t, err = cn.recvMessage(r)
- if err != nil {
- panic(err)
- }
- switch t {
- case 'E':
- panic(parseError(r))
- case 'N':
- if n := cn.noticeHandler; n != nil {
- n(parseError(r))
- }
- case 'A':
- if n := cn.notificationHandler; n != nil {
- n(recvNotification(r))
- }
- default:
- return
- }
- }
- }
- func (cn *conn) recv1Buf(r *readBuf) byte {
- for {
- t, err := cn.recvMessage(r)
- if err != nil {
- panic(err)
- }
- switch t {
- case 'A':
- if n := cn.notificationHandler; n != nil {
- n(recvNotification(r))
- }
- case 'N':
- if n := cn.noticeHandler; n != nil {
- n(parseError(r))
- }
- case 'S':
- cn.processParameterStatus(r)
- default:
- return t
- }
- }
- }
- func (cn *conn) recv1() (t byte, r *readBuf) {
- r = &readBuf{}
- t = cn.recv1Buf(r)
- return t, r
- }
- func (cn *conn) ssl(o values) error {
- upgrade, err := ssl(o)
- if err != nil {
- return err
- }
- if upgrade == nil {
-
- return nil
- }
- w := cn.writeBuf(0)
- w.int32(80877103)
- if err = cn.sendStartupPacket(w); err != nil {
- return err
- }
- b := cn.scratch[:1]
- _, err = io.ReadFull(cn.c, b)
- if err != nil {
- return err
- }
- if b[0] != 'S' {
- return ErrSSLNotSupported
- }
- cn.c, err = upgrade(cn.c)
- return err
- }
- func isDriverSetting(key string) bool {
- switch key {
- case "host", "port":
- return true
- case "password":
- return true
- case "sslmode", "sslcert", "sslkey", "sslrootcert":
- return true
- case "fallback_application_name":
- return true
- case "connect_timeout":
- return true
- case "disable_prepared_binary_result":
- return true
- case "binary_parameters":
- return true
- case "krbsrvname":
- return true
- case "krbspn":
- return true
- default:
- return false
- }
- }
- func (cn *conn) startup(o values) {
- w := cn.writeBuf(0)
- w.int32(196608)
-
-
-
-
- for k, v := range o {
- if isDriverSetting(k) {
-
- continue
- }
-
-
- if k == "dbname" {
- k = "database"
- }
- w.string(k)
- w.string(v)
- }
- w.string("")
- if err := cn.sendStartupPacket(w); err != nil {
- panic(err)
- }
- for {
- t, r := cn.recv()
- switch t {
- case 'K':
- cn.processBackendKeyData(r)
- case 'S':
- cn.processParameterStatus(r)
- case 'R':
- cn.auth(r, o)
- case 'Z':
- cn.processReadyForQuery(r)
- return
- default:
- errorf("unknown response for startup: %q", t)
- }
- }
- }
- func (cn *conn) auth(r *readBuf, o values) {
- switch code := r.int32(); code {
- case 0:
-
- case 3:
- w := cn.writeBuf('p')
- w.string(o["password"])
- cn.send(w)
- t, r := cn.recv()
- if t != 'R' {
- errorf("unexpected password response: %q", t)
- }
- if r.int32() != 0 {
- errorf("unexpected authentication response: %q", t)
- }
- case 5:
- s := string(r.next(4))
- w := cn.writeBuf('p')
- w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
- cn.send(w)
- t, r := cn.recv()
- if t != 'R' {
- errorf("unexpected password response: %q", t)
- }
- if r.int32() != 0 {
- errorf("unexpected authentication response: %q", t)
- }
- case 7:
- if newGss == nil {
- errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)")
- }
- cli, err := newGss()
- if err != nil {
- errorf("kerberos error: %s", err.Error())
- }
- var token []byte
- if spn, ok := o["krbspn"]; ok {
-
- token, err = cli.GetInitTokenFromSpn(spn)
- } else {
-
- service := "postgres"
- if val, ok := o["krbsrvname"]; ok {
- service = val
- }
- token, err = cli.GetInitToken(o["host"], service)
- }
- if err != nil {
- errorf("failed to get Kerberos ticket: %q", err)
- }
- w := cn.writeBuf('p')
- w.bytes(token)
- cn.send(w)
-
- cn.gss = cli
- case 8:
- if cn.gss == nil {
- errorf("GSSAPI protocol error")
- }
- b := []byte(*r)
- done, tokOut, err := cn.gss.Continue(b)
- if err == nil && !done {
- w := cn.writeBuf('p')
- w.bytes(tokOut)
- cn.send(w)
- }
-
-
- case 10:
- sc := scram.NewClient(sha256.New, o["user"], o["password"])
- sc.Step(nil)
- if sc.Err() != nil {
- errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
- }
- scOut := sc.Out()
- w := cn.writeBuf('p')
- w.string("SCRAM-SHA-256")
- w.int32(len(scOut))
- w.bytes(scOut)
- cn.send(w)
- t, r := cn.recv()
- if t != 'R' {
- errorf("unexpected password response: %q", t)
- }
- if r.int32() != 11 {
- errorf("unexpected authentication response: %q", t)
- }
- nextStep := r.next(len(*r))
- sc.Step(nextStep)
- if sc.Err() != nil {
- errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
- }
- scOut = sc.Out()
- w = cn.writeBuf('p')
- w.bytes(scOut)
- cn.send(w)
- t, r = cn.recv()
- if t != 'R' {
- errorf("unexpected password response: %q", t)
- }
- if r.int32() != 12 {
- errorf("unexpected authentication response: %q", t)
- }
- nextStep = r.next(len(*r))
- sc.Step(nextStep)
- if sc.Err() != nil {
- errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
- }
- default:
- errorf("unknown authentication response: %d", code)
- }
- }
- type format int
- const formatText format = 0
- const formatBinary format = 1
- var colFmtDataAllBinary = []byte{0, 1, 0, 1}
- var colFmtDataAllText = []byte{0, 0}
- type stmt struct {
- cn *conn
- name string
- rowsHeader
- colFmtData []byte
- paramTyps []oid.Oid
- closed bool
- }
- func (st *stmt) Close() (err error) {
- if st.closed {
- return nil
- }
- if st.cn.getBad() {
- return driver.ErrBadConn
- }
- defer st.cn.errRecover(&err)
- w := st.cn.writeBuf('C')
- w.byte('S')
- w.string(st.name)
- st.cn.send(w)
- st.cn.send(st.cn.writeBuf('S'))
- t, _ := st.cn.recv1()
- if t != '3' {
- st.cn.setBad()
- errorf("unexpected close response: %q", t)
- }
- st.closed = true
- t, r := st.cn.recv1()
- if t != 'Z' {
- st.cn.setBad()
- errorf("expected ready for query, but got: %q", t)
- }
- st.cn.processReadyForQuery(r)
- return nil
- }
- func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
- if st.cn.getBad() {
- return nil, driver.ErrBadConn
- }
- defer st.cn.errRecover(&err)
- st.exec(v)
- return &rows{
- cn: st.cn,
- rowsHeader: st.rowsHeader,
- }, nil
- }
- func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
- if st.cn.getBad() {
- return nil, driver.ErrBadConn
- }
- defer st.cn.errRecover(&err)
- st.exec(v)
- res, _, err = st.cn.readExecuteResponse("simple query")
- return res, err
- }
- func (st *stmt) exec(v []driver.Value) {
- if len(v) >= 65536 {
- errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
- }
- if len(v) != len(st.paramTyps) {
- errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
- }
- cn := st.cn
- w := cn.writeBuf('B')
- w.byte(0)
- w.string(st.name)
- if cn.binaryParameters {
- cn.sendBinaryParameters(w, v)
- } else {
- w.int16(0)
- w.int16(len(v))
- for i, x := range v {
- if x == nil {
- w.int32(-1)
- } else {
- b := encode(&cn.parameterStatus, x, st.paramTyps[i])
- w.int32(len(b))
- w.bytes(b)
- }
- }
- }
- w.bytes(st.colFmtData)
- w.next('E')
- w.byte(0)
- w.int32(0)
- w.next('S')
- cn.send(w)
- cn.readBindResponse()
- cn.postExecuteWorkaround()
- }
- func (st *stmt) NumInput() int {
- return len(st.paramTyps)
- }
- func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
- commandsWithAffectedRows := []string{
- "SELECT ",
-
- "UPDATE ",
- "DELETE ",
- "FETCH ",
- "MOVE ",
- "COPY ",
- }
- var affectedRows *string
- for _, tag := range commandsWithAffectedRows {
- if strings.HasPrefix(commandTag, tag) {
- t := commandTag[len(tag):]
- affectedRows = &t
- commandTag = tag[:len(tag)-1]
- break
- }
- }
-
-
-
-
- if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
- parts := strings.Split(commandTag, " ")
- if len(parts) != 3 {
- cn.setBad()
- errorf("unexpected INSERT command tag %s", commandTag)
- }
- affectedRows = &parts[len(parts)-1]
- commandTag = "INSERT"
- }
-
- if affectedRows == nil {
- return driver.RowsAffected(0), commandTag
- }
- n, err := strconv.ParseInt(*affectedRows, 10, 64)
- if err != nil {
- cn.setBad()
- errorf("could not parse commandTag: %s", err)
- }
- return driver.RowsAffected(n), commandTag
- }
- type rowsHeader struct {
- colNames []string
- colTyps []fieldDesc
- colFmts []format
- }
- type rows struct {
- cn *conn
- finish func()
- rowsHeader
- done bool
- rb readBuf
- result driver.Result
- tag string
- next *rowsHeader
- }
- func (rs *rows) Close() error {
- if finish := rs.finish; finish != nil {
- defer finish()
- }
-
- for {
- err := rs.Next(nil)
- switch err {
- case nil:
- case io.EOF:
-
-
-
- if rs.done {
- return nil
- }
- default:
- return err
- }
- }
- }
- func (rs *rows) Columns() []string {
- return rs.colNames
- }
- func (rs *rows) Result() driver.Result {
- if rs.result == nil {
- return emptyRows
- }
- return rs.result
- }
- func (rs *rows) Tag() string {
- return rs.tag
- }
- func (rs *rows) Next(dest []driver.Value) (err error) {
- if rs.done {
- return io.EOF
- }
- conn := rs.cn
- if conn.getBad() {
- return driver.ErrBadConn
- }
- defer conn.errRecover(&err)
- for {
- t := conn.recv1Buf(&rs.rb)
- switch t {
- case 'E':
- err = parseError(&rs.rb)
- case 'C', 'I':
- if t == 'C' {
- rs.result, rs.tag = conn.parseComplete(rs.rb.string())
- }
- continue
- case 'Z':
- conn.processReadyForQuery(&rs.rb)
- rs.done = true
- if err != nil {
- return err
- }
- return io.EOF
- case 'D':
- n := rs.rb.int16()
- if err != nil {
- conn.setBad()
- errorf("unexpected DataRow after error %s", err)
- }
- if n < len(dest) {
- dest = dest[:n]
- }
- for i := range dest {
- l := rs.rb.int32()
- if l == -1 {
- dest[i] = nil
- continue
- }
- dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
- }
- return
- case 'T':
- next := parsePortalRowDescribe(&rs.rb)
- rs.next = &next
- return io.EOF
- default:
- errorf("unexpected message after execute: %q", t)
- }
- }
- }
- func (rs *rows) HasNextResultSet() bool {
- hasNext := rs.next != nil && !rs.done
- return hasNext
- }
- func (rs *rows) NextResultSet() error {
- if rs.next == nil {
- return io.EOF
- }
- rs.rowsHeader = *rs.next
- rs.next = nil
- return nil
- }
- func QuoteIdentifier(name string) string {
- end := strings.IndexRune(name, 0)
- if end > -1 {
- name = name[:end]
- }
- return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
- }
- func QuoteLiteral(literal string) string {
-
-
-
-
-
-
- literal = strings.Replace(literal, `'`, `''`, -1)
-
-
-
-
-
- if strings.Contains(literal, `\`) {
- literal = strings.Replace(literal, `\`, `\\`, -1)
- literal = ` E'` + literal + `'`
- } else {
-
- literal = `'` + literal + `'`
- }
- return literal
- }
- func md5s(s string) string {
- h := md5.New()
- h.Write([]byte(s))
- return fmt.Sprintf("%x", h.Sum(nil))
- }
- func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
-
-
-
- var paramFormats []int
- for i, x := range args {
- _, ok := x.([]byte)
- if ok {
- if paramFormats == nil {
- paramFormats = make([]int, len(args))
- }
- paramFormats[i] = 1
- }
- }
- if paramFormats == nil {
- b.int16(0)
- } else {
- b.int16(len(paramFormats))
- for _, x := range paramFormats {
- b.int16(x)
- }
- }
- b.int16(len(args))
- for _, x := range args {
- if x == nil {
- b.int32(-1)
- } else {
- datum := binaryEncode(&cn.parameterStatus, x)
- b.int32(len(datum))
- b.bytes(datum)
- }
- }
- }
- func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
- if len(args) >= 65536 {
- errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
- }
- b := cn.writeBuf('P')
- b.byte(0)
- b.string(query)
- b.int16(0)
- b.next('B')
- b.int16(0)
- cn.sendBinaryParameters(b, args)
- b.bytes(colFmtDataAllText)
- b.next('D')
- b.byte('P')
- b.byte(0)
- b.next('E')
- b.byte(0)
- b.int32(0)
- b.next('S')
- cn.send(b)
- }
- func (cn *conn) processParameterStatus(r *readBuf) {
- var err error
- param := r.string()
- switch param {
- case "server_version":
- var major1 int
- var major2 int
- var minor int
- _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
- if err == nil {
- cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
- }
- case "TimeZone":
- cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
- if err != nil {
- cn.parameterStatus.currentLocation = nil
- }
- default:
-
- }
- }
- func (cn *conn) processReadyForQuery(r *readBuf) {
- cn.txnStatus = transactionStatus(r.byte())
- }
- func (cn *conn) readReadyForQuery() {
- t, r := cn.recv1()
- switch t {
- case 'Z':
- cn.processReadyForQuery(r)
- return
- default:
- cn.setBad()
- errorf("unexpected message %q; expected ReadyForQuery", t)
- }
- }
- func (cn *conn) processBackendKeyData(r *readBuf) {
- cn.processID = r.int32()
- cn.secretKey = r.int32()
- }
- func (cn *conn) readParseResponse() {
- t, r := cn.recv1()
- switch t {
- case '1':
- return
- case 'E':
- err := parseError(r)
- cn.readReadyForQuery()
- panic(err)
- default:
- cn.setBad()
- errorf("unexpected Parse response %q", t)
- }
- }
- func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) {
- for {
- t, r := cn.recv1()
- switch t {
- case 't':
- nparams := r.int16()
- paramTyps = make([]oid.Oid, nparams)
- for i := range paramTyps {
- paramTyps[i] = r.oid()
- }
- case 'n':
- return paramTyps, nil, nil
- case 'T':
- colNames, colTyps = parseStatementRowDescribe(r)
- return paramTyps, colNames, colTyps
- case 'E':
- err := parseError(r)
- cn.readReadyForQuery()
- panic(err)
- default:
- cn.setBad()
- errorf("unexpected Describe statement response %q", t)
- }
- }
- }
- func (cn *conn) readPortalDescribeResponse() rowsHeader {
- t, r := cn.recv1()
- switch t {
- case 'T':
- return parsePortalRowDescribe(r)
- case 'n':
- return rowsHeader{}
- case 'E':
- err := parseError(r)
- cn.readReadyForQuery()
- panic(err)
- default:
- cn.setBad()
- errorf("unexpected Describe response %q", t)
- }
- panic("not reached")
- }
- func (cn *conn) readBindResponse() {
- t, r := cn.recv1()
- switch t {
- case '2':
- return
- case 'E':
- err := parseError(r)
- cn.readReadyForQuery()
- panic(err)
- default:
- cn.setBad()
- errorf("unexpected Bind response %q", t)
- }
- }
- func (cn *conn) postExecuteWorkaround() {
-
-
-
-
-
-
-
-
-
- for {
- t, r := cn.recv1()
- switch t {
- case 'E':
- err := parseError(r)
- cn.readReadyForQuery()
- panic(err)
- case 'C', 'D', 'I':
-
- cn.saveMessage(t, r)
- return
- default:
- cn.setBad()
- errorf("unexpected message during extended query execution: %q", t)
- }
- }
- }
- func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
- for {
- t, r := cn.recv1()
- switch t {
- case 'C':
- if err != nil {
- cn.setBad()
- errorf("unexpected CommandComplete after error %s", err)
- }
- res, commandTag = cn.parseComplete(r.string())
- case 'Z':
- cn.processReadyForQuery(r)
- if res == nil && err == nil {
- err = errUnexpectedReady
- }
- return res, commandTag, err
- case 'E':
- err = parseError(r)
- case 'T', 'D', 'I':
- if err != nil {
- cn.setBad()
- errorf("unexpected %q after error %s", t, err)
- }
- if t == 'I' {
- res = emptyRows
- }
-
- default:
- cn.setBad()
- errorf("unknown %s response: %q", protocolState, t)
- }
- }
- }
- func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
- n := r.int16()
- colNames = make([]string, n)
- colTyps = make([]fieldDesc, n)
- for i := range colNames {
- colNames[i] = r.string()
- r.next(6)
- colTyps[i].OID = r.oid()
- colTyps[i].Len = r.int16()
- colTyps[i].Mod = r.int32()
-
- r.next(2)
- }
- return
- }
- func parsePortalRowDescribe(r *readBuf) rowsHeader {
- n := r.int16()
- colNames := make([]string, n)
- colFmts := make([]format, n)
- colTyps := make([]fieldDesc, n)
- for i := range colNames {
- colNames[i] = r.string()
- r.next(6)
- colTyps[i].OID = r.oid()
- colTyps[i].Len = r.int16()
- colTyps[i].Mod = r.int32()
- colFmts[i] = format(r.int16())
- }
- return rowsHeader{
- colNames: colNames,
- colFmts: colFmts,
- colTyps: colTyps,
- }
- }
- func parseEnviron(env []string) (out map[string]string) {
- out = make(map[string]string)
- for _, v := range env {
- parts := strings.SplitN(v, "=", 2)
- accrue := func(keyname string) {
- out[keyname] = parts[1]
- }
- unsupported := func() {
- panic(fmt.Sprintf("setting %v not supported", parts[0]))
- }
-
-
-
-
-
-
- switch parts[0] {
- case "PGHOST":
- accrue("host")
- case "PGHOSTADDR":
- unsupported()
- case "PGPORT":
- accrue("port")
- case "PGDATABASE":
- accrue("dbname")
- case "PGUSER":
- accrue("user")
- case "PGPASSWORD":
- accrue("password")
- case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
- unsupported()
- case "PGOPTIONS":
- accrue("options")
- case "PGAPPNAME":
- accrue("application_name")
- case "PGSSLMODE":
- accrue("sslmode")
- case "PGSSLCERT":
- accrue("sslcert")
- case "PGSSLKEY":
- accrue("sslkey")
- case "PGSSLROOTCERT":
- accrue("sslrootcert")
- case "PGREQUIRESSL", "PGSSLCRL":
- unsupported()
- case "PGREQUIREPEER":
- unsupported()
- case "PGKRBSRVNAME", "PGGSSLIB":
- unsupported()
- case "PGCONNECT_TIMEOUT":
- accrue("connect_timeout")
- case "PGCLIENTENCODING":
- accrue("client_encoding")
- case "PGDATESTYLE":
- accrue("datestyle")
- case "PGTZ":
- accrue("timezone")
- case "PGGEQO":
- accrue("geqo")
- case "PGSYSCONFDIR", "PGLOCALEDIR":
- unsupported()
- }
- }
- return out
- }
- func isUTF8(name string) bool {
-
- s := strings.Map(alnumLowerASCII, name)
- return s == "utf8" || s == "unicode"
- }
- func alnumLowerASCII(ch rune) rune {
- if 'A' <= ch && ch <= 'Z' {
- return ch + ('a' - 'A')
- }
- if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
- return ch
- }
- return -1
- }
|