args.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. // Copyright 2018 Huan Du. All rights reserved.
  2. // Licensed under the MIT license that can be found in the LICENSE file.
  3. package sqlbuilder
  4. import (
  5. "bytes"
  6. "database/sql"
  7. "fmt"
  8. "sort"
  9. "strconv"
  10. "strings"
  11. )
  12. // Args stores arguments associated with a SQL.
  13. type Args struct {
  14. // The default flavor used by `Args#Compile`
  15. Flavor Flavor
  16. args []interface{}
  17. namedArgs map[string]int
  18. sqlNamedArgs map[string]int
  19. onlyNamed bool
  20. }
  21. func init() {
  22. // Predefine some $n args to avoid additional memory allocation.
  23. predefinedArgs = make([]string, 0, maxPredefinedArgs)
  24. for i := 0; i < maxPredefinedArgs; i++ {
  25. predefinedArgs = append(predefinedArgs, fmt.Sprintf("$%v", i))
  26. }
  27. }
  28. const maxPredefinedArgs = 64
  29. var predefinedArgs []string
  30. // Add adds an arg to Args and returns a placeholder.
  31. func (args *Args) Add(arg interface{}) string {
  32. idx := args.add(arg)
  33. if idx < maxPredefinedArgs {
  34. return predefinedArgs[idx]
  35. }
  36. return fmt.Sprintf("$%v", idx)
  37. }
  38. func (args *Args) add(arg interface{}) int {
  39. idx := len(args.args)
  40. switch a := arg.(type) {
  41. case sql.NamedArg:
  42. if args.sqlNamedArgs == nil {
  43. args.sqlNamedArgs = map[string]int{}
  44. }
  45. if p, ok := args.sqlNamedArgs[a.Name]; ok {
  46. arg = args.args[p]
  47. break
  48. }
  49. args.sqlNamedArgs[a.Name] = idx
  50. case namedArgs:
  51. if args.namedArgs == nil {
  52. args.namedArgs = map[string]int{}
  53. }
  54. if p, ok := args.namedArgs[a.name]; ok {
  55. arg = args.args[p]
  56. break
  57. }
  58. // Find out the real arg and add it to args.
  59. idx = args.add(a.arg)
  60. args.namedArgs[a.name] = idx
  61. return idx
  62. }
  63. args.args = append(args.args, arg)
  64. return idx
  65. }
  66. // Compile compiles builder's format to standard sql and returns associated args.
  67. //
  68. // The format string uses a special syntax to represent arguments.
  69. //
  70. // $? refers successive arguments passed in the call. It works similar as `%v` in `fmt.Sprintf`.
  71. // $0 $1 ... $n refers nth-argument passed in the call. Next $? will use arguments n+1.
  72. // ${name} refers a named argument created by `Named` with `name`.
  73. // $$ is a "$" string.
  74. func (args *Args) Compile(format string, intialValue ...interface{}) (query string, values []interface{}) {
  75. return args.CompileWithFlavor(format, args.Flavor, intialValue...)
  76. }
  77. // CompileWithFlavor compiles builder's format to standard sql with flavor and returns associated args.
  78. //
  79. // See doc for `Compile` to learn details.
  80. func (args *Args) CompileWithFlavor(format string, flavor Flavor, intialValue ...interface{}) (query string, values []interface{}) {
  81. buf := &bytes.Buffer{}
  82. idx := strings.IndexRune(format, '$')
  83. offset := 0
  84. values = intialValue
  85. if flavor == invalidFlavor {
  86. flavor = DefaultFlavor
  87. }
  88. for idx >= 0 && len(format) > 0 {
  89. if idx > 0 {
  90. buf.WriteString(format[:idx])
  91. }
  92. format = format[idx+1:]
  93. // Treat the $ at the end of format is a normal $ rune.
  94. if len(format) == 0 {
  95. buf.WriteRune('$')
  96. break
  97. }
  98. if r := format[0]; r == '$' {
  99. buf.WriteRune('$')
  100. format = format[1:]
  101. } else if r == '{' {
  102. format, values = args.compileNamed(buf, flavor, format, values)
  103. } else if !args.onlyNamed && '0' <= r && r <= '9' {
  104. format, values, offset = args.compileDigits(buf, flavor, format, values, offset)
  105. } else if !args.onlyNamed && r == '?' {
  106. format, values, offset = args.compileSuccessive(buf, flavor, format[1:], values, offset)
  107. } else {
  108. // For unknown $ expression format, treat it as a normal $ rune.
  109. buf.WriteRune('$')
  110. }
  111. idx = strings.IndexRune(format, '$')
  112. }
  113. if len(format) > 0 {
  114. buf.WriteString(format)
  115. }
  116. query = buf.String()
  117. if len(args.sqlNamedArgs) > 0 {
  118. // Stabilize the sequence to make it easier to write test cases.
  119. ints := make([]int, 0, len(args.sqlNamedArgs))
  120. for _, p := range args.sqlNamedArgs {
  121. ints = append(ints, p)
  122. }
  123. sort.Ints(ints)
  124. for _, i := range ints {
  125. values = append(values, args.args[i])
  126. }
  127. }
  128. return
  129. }
  130. func (args *Args) compileNamed(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}) (string, []interface{}) {
  131. i := 1
  132. for ; i < len(format) && format[i] != '}'; i++ {
  133. // Nothing.
  134. }
  135. // Invalid $ format. Ignore it.
  136. if i == len(format) {
  137. return format, values
  138. }
  139. name := format[1:i]
  140. format = format[i+1:]
  141. if p, ok := args.namedArgs[name]; ok {
  142. format, values, _ = args.compileSuccessive(buf, flavor, format, values, p)
  143. }
  144. return format, values
  145. }
  146. func (args *Args) compileDigits(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
  147. i := 1
  148. for ; i < len(format) && '0' <= format[i] && format[i] <= '9'; i++ {
  149. // Nothing.
  150. }
  151. digits := format[:i]
  152. format = format[i:]
  153. if pointer, err := strconv.Atoi(digits); err == nil {
  154. return args.compileSuccessive(buf, flavor, format, values, pointer)
  155. }
  156. return format, values, offset
  157. }
  158. func (args *Args) compileSuccessive(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) {
  159. if offset >= len(args.args) {
  160. return format, values, offset
  161. }
  162. arg := args.args[offset]
  163. values = args.compileArg(buf, flavor, values, arg)
  164. return format, values, offset + 1
  165. }
  166. func (args *Args) compileArg(buf *bytes.Buffer, flavor Flavor, values []interface{}, arg interface{}) []interface{} {
  167. switch a := arg.(type) {
  168. case Builder:
  169. var s string
  170. s, values = a.BuildWithFlavor(flavor, values...)
  171. buf.WriteString(s)
  172. case sql.NamedArg:
  173. buf.WriteRune('@')
  174. buf.WriteString(a.Name)
  175. case rawArgs:
  176. buf.WriteString(a.expr)
  177. case listArgs:
  178. if len(a.args) > 0 {
  179. values = args.compileArg(buf, flavor, values, a.args[0])
  180. }
  181. for i := 1; i < len(a.args); i++ {
  182. buf.WriteString(", ")
  183. values = args.compileArg(buf, flavor, values, a.args[i])
  184. }
  185. default:
  186. switch flavor {
  187. case MySQL:
  188. buf.WriteRune('?')
  189. case PostgreSQL:
  190. fmt.Fprintf(buf, "$%v", len(values)+1)
  191. default:
  192. panic(fmt.Errorf("Args.CompileWithFlavor: invalid flavor %v (%v)", flavor, int(flavor)))
  193. }
  194. values = append(values, arg)
  195. }
  196. return values
  197. }