123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418 |
- // Copyright 2018 Huan Du. All rights reserved.
- // Licensed under the MIT license that can be found in the LICENSE file.
- package sqlbuilder
- import (
- "fmt"
- "strconv"
- "time"
- "unicode"
- "unicode/utf8"
- "unsafe"
- )
- // mysqlInterpolate parses query and replace all "?" with encoded args.
- // If there are more "?" than len(args), returns ErrMissingArgs.
- // Otherwise, if there are less "?" than len(args), the redundant args are omitted.
- func mysqlInterpolate(query string, args ...interface{}) (string, error) {
- // Roughly estimate the size to avoid useless memory allocation and copy.
- buf := make([]byte, 0, len(query)+len(args)*20)
- var quote rune
- var err error
- cnt := 0
- max := len(args)
- escaping := false
- offset := 0
- target := query
- r, sz := utf8.DecodeRuneInString(target)
- for ; sz != 0; r, sz = utf8.DecodeRuneInString(target) {
- offset += sz
- target = query[offset:]
- if escaping {
- escaping = false
- continue
- }
- switch r {
- case '?':
- if quote != 0 {
- continue
- }
- if cnt >= max {
- return "", ErrInterpolateMissingArgs
- }
- buf = append(buf, query[:offset-sz]...)
- buf, err = encodeValue(buf, args[cnt], MySQL)
- if err != nil {
- return "", err
- }
- query = target
- offset = 0
- cnt++
- case '\'':
- if quote == '\'' {
- quote = 0
- continue
- }
- if quote == 0 {
- quote = '\''
- }
- case '"':
- if quote == '"' {
- quote = 0
- continue
- }
- if quote == 0 {
- quote = '"'
- }
- case '`':
- if quote == '`' {
- quote = 0
- continue
- }
- if quote == 0 {
- quote = '`'
- }
- case '\\':
- if quote != 0 {
- escaping = true
- }
- }
- }
- buf = append(buf, query...)
- return *(*string)(unsafe.Pointer(&buf)), nil
- }
- // postgresqlInterpolate parses query and replace all "$*" with encoded args.
- // If there are more "$*" than len(args), returns ErrMissingArgs.
- // Otherwise, if there are less "$*" than len(args), the redundant args are omitted.
- func postgresqlInterpolate(query string, args ...interface{}) (string, error) {
- // Roughly estimate the size to avoid useless memory allocation and copy.
- buf := make([]byte, 0, len(query)+len(args)*20)
- var quote rune
- var dollarQuote string
- var err error
- var idx int64
- max := len(args)
- escaping := false
- offset := 0
- target := query
- r, sz := utf8.DecodeRuneInString(target)
- for ; sz != 0; r, sz = utf8.DecodeRuneInString(target) {
- offset += sz
- target = query[offset:]
- if escaping {
- escaping = false
- continue
- }
- switch r {
- case '$':
- if quote != 0 {
- if quote != '$' {
- continue
- }
- // Try to find the end of dollar quote.
- pos := offset
- for r, sz = utf8.DecodeRuneInString(target); sz != 0 && r != '$'; r, sz = utf8.DecodeRuneInString(target) {
- pos += sz
- target = query[pos:]
- }
- if sz == 0 {
- break
- }
- if r == '$' {
- dq := query[offset : pos+sz]
- offset = pos
- target = query[offset:]
- if dq == dollarQuote {
- quote = 0
- dollarQuote = ""
- offset += sz
- target = query[offset:]
- }
- continue
- }
- continue
- }
- oldSz := sz
- pos := offset
- r, sz = utf8.DecodeRuneInString(target)
- if '1' <= r && r <= '9' {
- // A placeholder is found.
- pos += sz
- target = query[pos:]
- for r, sz = utf8.DecodeRuneInString(target); sz != 0 && '0' <= r && r <= '9'; r, sz = utf8.DecodeRuneInString(target) {
- pos += sz
- target = query[pos:]
- }
- idx, err = strconv.ParseInt(query[offset:pos], 10, strconv.IntSize)
- if err != nil {
- return "", err
- }
- if int(idx) >= max+1 {
- return "", ErrInterpolateMissingArgs
- }
- buf = append(buf, query[:offset-oldSz]...)
- buf, err = encodeValue(buf, args[idx-1], PostgreSQL)
- if err != nil {
- return "", err
- }
- query = target
- offset = 0
- if sz == 0 {
- break
- }
- continue
- }
- // Try to find the beginning of dollar quote.
- for ; sz != 0 && r != '$' && unicode.IsLetter(r); r, sz = utf8.DecodeRuneInString(target) {
- pos += sz
- target = query[pos:]
- }
- if sz == 0 {
- break
- }
- if !unicode.IsLetter(r) && r != '$' {
- continue
- }
- pos += sz
- quote = '$'
- dollarQuote = query[offset:pos]
- offset = pos
- target = query[offset:]
- case '\'':
- if quote == '\'' {
- // PostgreSQL uses two single quotes to represent one single quote.
- r, sz = utf8.DecodeRuneInString(target)
- if r == '\'' {
- offset += sz
- target = query[offset:]
- continue
- }
- quote = 0
- continue
- }
- if quote == 0 {
- quote = '\''
- }
- case '"':
- if quote == '"' {
- quote = 0
- continue
- }
- if quote == 0 {
- quote = '"'
- }
- case '\\':
- if quote == '\'' || quote == '"' {
- escaping = true
- }
- }
- }
- buf = append(buf, query...)
- return *(*string)(unsafe.Pointer(&buf)), nil
- }
- func encodeValue(buf []byte, arg interface{}, flavor Flavor) ([]byte, error) {
- switch v := arg.(type) {
- case nil:
- buf = append(buf, "NULL"...)
- case bool:
- if v {
- buf = append(buf, "TRUE"...)
- } else {
- buf = append(buf, "FALSE"...)
- }
- case int:
- buf = strconv.AppendInt(buf, int64(v), 10)
- case int8:
- buf = strconv.AppendInt(buf, int64(v), 10)
- case int16:
- buf = strconv.AppendInt(buf, int64(v), 10)
- case int32:
- buf = strconv.AppendInt(buf, int64(v), 10)
- case int64:
- buf = strconv.AppendInt(buf, v, 10)
- case uint:
- buf = strconv.AppendUint(buf, uint64(v), 10)
- case uint8:
- buf = strconv.AppendUint(buf, uint64(v), 10)
- case uint16:
- buf = strconv.AppendUint(buf, uint64(v), 10)
- case uint32:
- buf = strconv.AppendUint(buf, uint64(v), 10)
- case uint64:
- buf = strconv.AppendUint(buf, v, 10)
- case float32:
- buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32)
- case float64:
- buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
- case []byte:
- if v == nil {
- buf = append(buf, "NULL"...)
- break
- }
- switch flavor {
- case MySQL:
- buf = append(buf, "_binary"...)
- buf = quoteStringValue(buf, *(*string)(unsafe.Pointer(&v)), flavor)
- case PostgreSQL:
- hex := make([]byte, 0, 2)
- buf = append(buf, "E'\\\\x"...)
- for _, b := range v {
- runes := strconv.AppendInt(hex, int64(b), 16)
- buf = append(buf, byte(unicode.ToUpper(rune(runes[0]))))
- buf = append(buf, byte(unicode.ToUpper(rune(runes[1]))))
- }
- buf = append(buf, "'::bytea"...)
- }
- case string:
- buf = quoteStringValue(buf, v, flavor)
- case time.Time:
- if v.IsZero() {
- buf = append(buf, "'0000-00-00'"...)
- break
- }
- // In SQL standard, the precision of fractional seconds in time literal is up to 6 digits.
- // Round up v.
- v = v.Add(500 * time.Nanosecond)
- buf = append(buf, '\'')
- switch flavor {
- case MySQL:
- buf = append(buf, v.Format("2006-01-02 15:04:05.999999")...)
- case PostgreSQL:
- buf = append(buf, v.Format("2006-01-02 15:04:05.999999 MST")...)
- }
- buf = append(buf, '\'')
- case fmt.Stringer:
- buf = quoteStringValue(buf, v.String(), flavor)
- default:
- return nil, ErrInterpolateUnsupportedArgs
- }
- return buf, nil
- }
- func quoteStringValue(buf []byte, s string, flavor Flavor) []byte {
- if flavor == PostgreSQL {
- buf = append(buf, 'E')
- }
- buf = append(buf, '\'')
- r, sz := utf8.DecodeRuneInString(s)
- for ; sz != 0; r, sz = utf8.DecodeRuneInString(s) {
- switch r {
- case '\x00':
- buf = append(buf, "\\0"...)
- case '\b':
- buf = append(buf, "\\b"...)
- case '\n':
- buf = append(buf, "\\n"...)
- case '\r':
- buf = append(buf, "\\r"...)
- case '\t':
- buf = append(buf, "\\t"...)
- case '\x1a':
- buf = append(buf, "\\Z"...)
- case '\'':
- buf = append(buf, "\\'"...)
- case '"':
- buf = append(buf, "\\\""...)
- case '\\':
- buf = append(buf, "\\\\"...)
- default:
- buf = append(buf, s[:sz]...)
- }
- s = s[sz:]
- }
- buf = append(buf, '\'')
- return buf
- }
|