interpolate.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. // Copyright 2018 Huan Du. All rights reserved.
  2. // Licensed under the MIT license that can be found in the LICENSE file.
  3. package sqlbuilder
  4. import (
  5. "fmt"
  6. "strconv"
  7. "time"
  8. "unicode"
  9. "unicode/utf8"
  10. "unsafe"
  11. )
  12. // mysqlInterpolate parses query and replace all "?" with encoded args.
  13. // If there are more "?" than len(args), returns ErrMissingArgs.
  14. // Otherwise, if there are less "?" than len(args), the redundant args are omitted.
  15. func mysqlInterpolate(query string, args ...interface{}) (string, error) {
  16. // Roughly estimate the size to avoid useless memory allocation and copy.
  17. buf := make([]byte, 0, len(query)+len(args)*20)
  18. var quote rune
  19. var err error
  20. cnt := 0
  21. max := len(args)
  22. escaping := false
  23. offset := 0
  24. target := query
  25. r, sz := utf8.DecodeRuneInString(target)
  26. for ; sz != 0; r, sz = utf8.DecodeRuneInString(target) {
  27. offset += sz
  28. target = query[offset:]
  29. if escaping {
  30. escaping = false
  31. continue
  32. }
  33. switch r {
  34. case '?':
  35. if quote != 0 {
  36. continue
  37. }
  38. if cnt >= max {
  39. return "", ErrInterpolateMissingArgs
  40. }
  41. buf = append(buf, query[:offset-sz]...)
  42. buf, err = encodeValue(buf, args[cnt], MySQL)
  43. if err != nil {
  44. return "", err
  45. }
  46. query = target
  47. offset = 0
  48. cnt++
  49. case '\'':
  50. if quote == '\'' {
  51. quote = 0
  52. continue
  53. }
  54. if quote == 0 {
  55. quote = '\''
  56. }
  57. case '"':
  58. if quote == '"' {
  59. quote = 0
  60. continue
  61. }
  62. if quote == 0 {
  63. quote = '"'
  64. }
  65. case '`':
  66. if quote == '`' {
  67. quote = 0
  68. continue
  69. }
  70. if quote == 0 {
  71. quote = '`'
  72. }
  73. case '\\':
  74. if quote != 0 {
  75. escaping = true
  76. }
  77. }
  78. }
  79. buf = append(buf, query...)
  80. return *(*string)(unsafe.Pointer(&buf)), nil
  81. }
  82. // postgresqlInterpolate parses query and replace all "$*" with encoded args.
  83. // If there are more "$*" than len(args), returns ErrMissingArgs.
  84. // Otherwise, if there are less "$*" than len(args), the redundant args are omitted.
  85. func postgresqlInterpolate(query string, args ...interface{}) (string, error) {
  86. // Roughly estimate the size to avoid useless memory allocation and copy.
  87. buf := make([]byte, 0, len(query)+len(args)*20)
  88. var quote rune
  89. var dollarQuote string
  90. var err error
  91. var idx int64
  92. max := len(args)
  93. escaping := false
  94. offset := 0
  95. target := query
  96. r, sz := utf8.DecodeRuneInString(target)
  97. for ; sz != 0; r, sz = utf8.DecodeRuneInString(target) {
  98. offset += sz
  99. target = query[offset:]
  100. if escaping {
  101. escaping = false
  102. continue
  103. }
  104. switch r {
  105. case '$':
  106. if quote != 0 {
  107. if quote != '$' {
  108. continue
  109. }
  110. // Try to find the end of dollar quote.
  111. pos := offset
  112. for r, sz = utf8.DecodeRuneInString(target); sz != 0 && r != '$'; r, sz = utf8.DecodeRuneInString(target) {
  113. pos += sz
  114. target = query[pos:]
  115. }
  116. if sz == 0 {
  117. break
  118. }
  119. if r == '$' {
  120. dq := query[offset : pos+sz]
  121. offset = pos
  122. target = query[offset:]
  123. if dq == dollarQuote {
  124. quote = 0
  125. dollarQuote = ""
  126. offset += sz
  127. target = query[offset:]
  128. }
  129. continue
  130. }
  131. continue
  132. }
  133. oldSz := sz
  134. pos := offset
  135. r, sz = utf8.DecodeRuneInString(target)
  136. if '1' <= r && r <= '9' {
  137. // A placeholder is found.
  138. pos += sz
  139. target = query[pos:]
  140. for r, sz = utf8.DecodeRuneInString(target); sz != 0 && '0' <= r && r <= '9'; r, sz = utf8.DecodeRuneInString(target) {
  141. pos += sz
  142. target = query[pos:]
  143. }
  144. idx, err = strconv.ParseInt(query[offset:pos], 10, strconv.IntSize)
  145. if err != nil {
  146. return "", err
  147. }
  148. if int(idx) >= max+1 {
  149. return "", ErrInterpolateMissingArgs
  150. }
  151. buf = append(buf, query[:offset-oldSz]...)
  152. buf, err = encodeValue(buf, args[idx-1], PostgreSQL)
  153. if err != nil {
  154. return "", err
  155. }
  156. query = target
  157. offset = 0
  158. if sz == 0 {
  159. break
  160. }
  161. continue
  162. }
  163. // Try to find the beginning of dollar quote.
  164. for ; sz != 0 && r != '$' && unicode.IsLetter(r); r, sz = utf8.DecodeRuneInString(target) {
  165. pos += sz
  166. target = query[pos:]
  167. }
  168. if sz == 0 {
  169. break
  170. }
  171. if !unicode.IsLetter(r) && r != '$' {
  172. continue
  173. }
  174. pos += sz
  175. quote = '$'
  176. dollarQuote = query[offset:pos]
  177. offset = pos
  178. target = query[offset:]
  179. case '\'':
  180. if quote == '\'' {
  181. // PostgreSQL uses two single quotes to represent one single quote.
  182. r, sz = utf8.DecodeRuneInString(target)
  183. if r == '\'' {
  184. offset += sz
  185. target = query[offset:]
  186. continue
  187. }
  188. quote = 0
  189. continue
  190. }
  191. if quote == 0 {
  192. quote = '\''
  193. }
  194. case '"':
  195. if quote == '"' {
  196. quote = 0
  197. continue
  198. }
  199. if quote == 0 {
  200. quote = '"'
  201. }
  202. case '\\':
  203. if quote == '\'' || quote == '"' {
  204. escaping = true
  205. }
  206. }
  207. }
  208. buf = append(buf, query...)
  209. return *(*string)(unsafe.Pointer(&buf)), nil
  210. }
  211. func encodeValue(buf []byte, arg interface{}, flavor Flavor) ([]byte, error) {
  212. switch v := arg.(type) {
  213. case nil:
  214. buf = append(buf, "NULL"...)
  215. case bool:
  216. if v {
  217. buf = append(buf, "TRUE"...)
  218. } else {
  219. buf = append(buf, "FALSE"...)
  220. }
  221. case int:
  222. buf = strconv.AppendInt(buf, int64(v), 10)
  223. case int8:
  224. buf = strconv.AppendInt(buf, int64(v), 10)
  225. case int16:
  226. buf = strconv.AppendInt(buf, int64(v), 10)
  227. case int32:
  228. buf = strconv.AppendInt(buf, int64(v), 10)
  229. case int64:
  230. buf = strconv.AppendInt(buf, v, 10)
  231. case uint:
  232. buf = strconv.AppendUint(buf, uint64(v), 10)
  233. case uint8:
  234. buf = strconv.AppendUint(buf, uint64(v), 10)
  235. case uint16:
  236. buf = strconv.AppendUint(buf, uint64(v), 10)
  237. case uint32:
  238. buf = strconv.AppendUint(buf, uint64(v), 10)
  239. case uint64:
  240. buf = strconv.AppendUint(buf, v, 10)
  241. case float32:
  242. buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32)
  243. case float64:
  244. buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
  245. case []byte:
  246. if v == nil {
  247. buf = append(buf, "NULL"...)
  248. break
  249. }
  250. switch flavor {
  251. case MySQL:
  252. buf = append(buf, "_binary"...)
  253. buf = quoteStringValue(buf, *(*string)(unsafe.Pointer(&v)), flavor)
  254. case PostgreSQL:
  255. hex := make([]byte, 0, 2)
  256. buf = append(buf, "E'\\\\x"...)
  257. for _, b := range v {
  258. runes := strconv.AppendInt(hex, int64(b), 16)
  259. buf = append(buf, byte(unicode.ToUpper(rune(runes[0]))))
  260. buf = append(buf, byte(unicode.ToUpper(rune(runes[1]))))
  261. }
  262. buf = append(buf, "'::bytea"...)
  263. }
  264. case string:
  265. buf = quoteStringValue(buf, v, flavor)
  266. case time.Time:
  267. if v.IsZero() {
  268. buf = append(buf, "'0000-00-00'"...)
  269. break
  270. }
  271. // In SQL standard, the precision of fractional seconds in time literal is up to 6 digits.
  272. // Round up v.
  273. v = v.Add(500 * time.Nanosecond)
  274. buf = append(buf, '\'')
  275. switch flavor {
  276. case MySQL:
  277. buf = append(buf, v.Format("2006-01-02 15:04:05.999999")...)
  278. case PostgreSQL:
  279. buf = append(buf, v.Format("2006-01-02 15:04:05.999999 MST")...)
  280. }
  281. buf = append(buf, '\'')
  282. case fmt.Stringer:
  283. buf = quoteStringValue(buf, v.String(), flavor)
  284. default:
  285. return nil, ErrInterpolateUnsupportedArgs
  286. }
  287. return buf, nil
  288. }
  289. func quoteStringValue(buf []byte, s string, flavor Flavor) []byte {
  290. if flavor == PostgreSQL {
  291. buf = append(buf, 'E')
  292. }
  293. buf = append(buf, '\'')
  294. r, sz := utf8.DecodeRuneInString(s)
  295. for ; sz != 0; r, sz = utf8.DecodeRuneInString(s) {
  296. switch r {
  297. case '\x00':
  298. buf = append(buf, "\\0"...)
  299. case '\b':
  300. buf = append(buf, "\\b"...)
  301. case '\n':
  302. buf = append(buf, "\\n"...)
  303. case '\r':
  304. buf = append(buf, "\\r"...)
  305. case '\t':
  306. buf = append(buf, "\\t"...)
  307. case '\x1a':
  308. buf = append(buf, "\\Z"...)
  309. case '\'':
  310. buf = append(buf, "\\'"...)
  311. case '"':
  312. buf = append(buf, "\\\""...)
  313. case '\\':
  314. buf = append(buf, "\\\\"...)
  315. default:
  316. buf = append(buf, s[:sz]...)
  317. }
  318. s = s[sz:]
  319. }
  320. buf = append(buf, '\'')
  321. return buf
  322. }