Statement#exec and #query require named argument for array values (#110)

This change allows to use an array as single argument for #exec and
 #query and #scalar methods. Before it was shadowed by the *args splat overload.
This commit is contained in:
Johannes Müller 2019-09-20 22:23:09 +02:00 committed by Brian J. Cardiff
parent af6d837bcd
commit b3898ae2a2
9 changed files with 168 additions and 88 deletions

View file

@ -227,14 +227,14 @@ describe DB do
db.query "query", 1, "string" { } db.query "query", 1, "string" { }
db.query("query", Bytes.new(4)) { } db.query("query", Bytes.new(4)) { }
db.query("query", 1, "string", FooValue.new(5)) { } db.query("query", 1, "string", FooValue.new(5)) { }
db.query "query", [1, "string", FooValue.new(5)] { } db.query "query", args: [1, "string", FooValue.new(5)] { }
db.query("query").close db.query("query").close
db.query("query", 1).close db.query("query", 1).close
db.query("query", 1, "string").close db.query("query", 1, "string").close
db.query("query", Bytes.new(4)).close db.query("query", Bytes.new(4)).close
db.query("query", 1, "string", FooValue.new(5)).close db.query("query", 1, "string", FooValue.new(5)).close
db.query("query", [1, "string", FooValue.new(5)]).close db.query("query", args: [1, "string", FooValue.new(5)]).close
end end
DB.open("bar://host") do |db| DB.open("bar://host") do |db|
@ -244,14 +244,14 @@ describe DB do
db.query "query", 1, "string" { } db.query "query", 1, "string" { }
db.query("query", Bytes.new(4)) { } db.query("query", Bytes.new(4)) { }
db.query("query", 1, "string", BarValue.new(5)) { } db.query("query", 1, "string", BarValue.new(5)) { }
db.query "query", [1, "string", BarValue.new(5)] { } db.query "query", args: [1, "string", BarValue.new(5)] { }
db.query("query").close db.query("query").close
db.query("query", 1).close db.query("query", 1).close
db.query("query", 1, "string").close db.query("query", 1, "string").close
db.query("query", Bytes.new(4)).close db.query("query", Bytes.new(4)).close
db.query("query", 1, "string", BarValue.new(5)).close db.query("query", 1, "string", BarValue.new(5)).close
db.query("query", [1, "string", BarValue.new(5)]).close db.query("query", args: [1, "string", BarValue.new(5)]).close
end end
end end
@ -263,7 +263,7 @@ describe DB do
db.exec("query", 1, "string") db.exec("query", 1, "string")
db.exec("query", Bytes.new(4)) db.exec("query", Bytes.new(4))
db.exec("query", 1, "string", FooValue.new(5)) db.exec("query", 1, "string", FooValue.new(5))
db.exec("query", [1, "string", FooValue.new(5)]) db.exec("query", args: [1, "string", FooValue.new(5)])
end end
DB.open("bar://host") do |db| DB.open("bar://host") do |db|
@ -273,20 +273,20 @@ describe DB do
db.exec("query", 1, "string") db.exec("query", 1, "string")
db.exec("query", Bytes.new(4)) db.exec("query", Bytes.new(4))
db.exec("query", 1, "string", BarValue.new(5)) db.exec("query", 1, "string", BarValue.new(5))
db.exec("query", [1, "string", BarValue.new(5)]) db.exec("query", args: [1, "string", BarValue.new(5)])
end end
end end
it "Foo and Bar drivers should not implement each other params" do it "Foo and Bar drivers should not implement each other params" do
DB.open("foo://host") do |db| DB.open("foo://host") do |db|
expect_raises Exception, "FooDriver::FooStatement does not support BarValue params" do expect_raises Exception, "FooDriver::FooStatement does not support BarValue params" do
db.exec("query", [BarValue.new(5)]) db.exec("query", args: [BarValue.new(5)])
end end
end end
DB.open("bar://host") do |db| DB.open("bar://host") do |db|
expect_raises Exception, "BarDriver::BarStatement does not support FooValue params" do expect_raises Exception, "BarDriver::BarStatement does not support FooValue params" do
db.exec("query", [FooValue.new(5)]) db.exec("query", args: [FooValue.new(5)])
end end
end end
end end

View file

@ -97,7 +97,7 @@ class DummyDriver < DB::Driver
property params property params
def initialize(connection, @query : String, @prepared : Bool) def initialize(connection, @query : String, @prepared : Bool)
@params = Hash(Int32 | String, DB::Any).new @params = Hash(Int32 | String, DB::Any | Array(DB::Any)).new
super(connection) super(connection)
raise DB::Error.new(query) if query == "syntax error" raise DB::Error.new(query) if query == "syntax error"
end end
@ -126,6 +126,10 @@ class DummyDriver < DB::Driver
@params[index] = value @params[index] = value
end end
private def set_param(index, value : Array)
@params[index] = value.map(&.as(DB::Any))
end
private def set_param(index, value) private def set_param(index, value)
raise "not implemented for #{value.class}" raise "not implemented for #{value.class}"
end end

View file

@ -43,10 +43,37 @@ describe DB::Statement do
end end
end end
it "should initialize positional params in query with array" do it "accepts array as single argument" do
with_dummy_connection do |cnn| with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.query ["a", 1, nil] stmt.query ["a", 1, nil]
stmt.params[0].should eq(["a", 1, nil])
end
end
it "allows no arguments" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.query
stmt.params.should be_empty
end
end
it "concatenate arguments" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.query 1, 2, args: ["a", [1, nil]]
stmt.params[0].should eq(1)
stmt.params[1].should eq(2)
stmt.params[2].should eq("a")
stmt.params[3].should eq([1, nil])
end
end
it "should initialize positional params in query with array" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.query args: ["a", 1, nil]
stmt.params[0].should eq("a") stmt.params[0].should eq("a")
stmt.params[1].should eq(1) stmt.params[1].should eq(1)
stmt.params[2].should eq(nil) stmt.params[2].should eq(nil)
@ -63,16 +90,43 @@ describe DB::Statement do
end end
end end
it "should initialize positional params in exec with array" do it "accepts array as single argument" do
with_dummy_connection do |cnn| with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.exec ["a", 1, nil] stmt.exec ["a", 1, nil]
stmt.params[0].should eq(["a", 1, nil])
end
end
it "should initialize positional params in exec with array" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.exec args: ["a", 1, nil]
stmt.params[0].should eq("a") stmt.params[0].should eq("a")
stmt.params[1].should eq(1) stmt.params[1].should eq(1)
stmt.params[2].should eq(nil) stmt.params[2].should eq(nil)
end end
end end
it "allows no arguments" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.exec
stmt.params.should be_empty
end
end
it "concatenate arguments" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.exec 1, 2, args: ["a", [1, nil]]
stmt.params[0].should eq(1)
stmt.params[1].should eq(2)
stmt.params[2].should eq("a")
stmt.params[3].should eq([1, nil])
end
end
it "should initialize positional params in scalar" do it "should initialize positional params in scalar" do
with_dummy_connection do |cnn| with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)

View file

@ -181,6 +181,7 @@ end
require "./db/pool" require "./db/pool"
require "./db/string_key_cache" require "./db/string_key_cache"
require "./db/enumerable_concat"
require "./db/query_methods" require "./db/query_methods"
require "./db/session_methods" require "./db/session_methods"
require "./db/disposable" require "./db/disposable"

View file

@ -0,0 +1,38 @@
module DB
# :nodoc:
struct EnumerableConcat(S, T, U)
include Enumerable(S)
def initialize(@e1 : T, @e2 : U)
end
def each
if e1 = @e1
@e1.each do |e|
yield e
end
end
if e2 = @e2
e2.each do |e|
yield e
end
end
end
# returns given `e1 : T` an `Enumerable(T')` and `e2 : U` an `Enumerable(U') | Nil`
# it retuns and `Enumerable(T' | U')` that enumerates the elements of `e1`
# and, later, the elements of `e2`.
def self.build(e1 : T, e2 : U)
return e1 if e2.nil? || e2.empty?
return e2 if e1.nil? || e1.empty?
EnumerableConcat(Union(typeof(sample(e1)), typeof(sample(e2))), T, U).new(e1, e2)
end
private def self.sample(c : Enumerable?)
c.not_nil!.each do |e|
return e
end
raise ""
end
end
end

View file

@ -15,13 +15,8 @@ module DB
end end
# See `QueryMethods#exec` # See `QueryMethods#exec`
def exec(*args) : ExecResult def exec(*args_, args : Array? = nil) : ExecResult
statement_with_retry &.exec(*args) statement_with_retry &.exec(*args_, args: args)
end
# See `QueryMethods#exec`
def exec(args : Array) : ExecResult
statement_with_retry &.exec(args)
end end
# See `QueryMethods#query` # See `QueryMethods#query`
@ -30,18 +25,13 @@ module DB
end end
# See `QueryMethods#query` # See `QueryMethods#query`
def query(*args) : ResultSet def query(*args_, args : Array? = nil) : ResultSet
statement_with_retry &.query(*args) statement_with_retry &.query(*args_, args: args)
end
# See `QueryMethods#query`
def query(args : Array) : ResultSet
statement_with_retry &.query(args)
end end
# See `QueryMethods#scalar` # See `QueryMethods#scalar`
def scalar(*args) def scalar(*args_, args : Array? = nil)
statement_with_retry &.scalar(*args) statement_with_retry &.scalar(*args_, args: args)
end end
# builds a statement over a real connection # builds a statement over a real connection

View file

@ -7,10 +7,11 @@ module DB
# 2. `#scalar` reads a single value of the response. A union of possible values is returned. # 2. `#scalar` reads a single value of the response. A union of possible values is returned.
# 3. `#query` returns a `ResultSet` that allows iteration over the rows in the response and column information. # 3. `#query` returns a `ResultSet` that allows iteration over the rows in the response and column information.
# #
# Arguments can be passed by position # Arguments can be passed by position or as an array.
# #
# ``` # ```
# db.query("SELECT name FROM ... WHERE age > ?", age) # db.query("SELECT name FROM ... WHERE age > ?", age)
# db.query("SELECT name FROM ... WHERE age > ?", args: [age])
# ``` # ```
# #
# Convention of mapping how arguments are mapped to the query depends on each driver. # Convention of mapping how arguments are mapped to the query depends on each driver.
@ -34,8 +35,15 @@ module DB
# result.close # result.close
# end # end
# ``` # ```
def query(query, *args) #
build(query).query(*args) # Note: to use a dynamic list length of arguments use `args:` keyword argument.
#
# ```
# result = db.query "select name from contacts where id = ?", args: [10]
# ```
#
def query(query, *args_, args : Array? = nil)
build(query).query(*args_, args: args)
end end
# Executes a *query* and yields a `ResultSet` with the results. # Executes a *query* and yields a `ResultSet` with the results.
@ -48,9 +56,9 @@ module DB
# end # end
# end # end
# ``` # ```
def query(query, *args) def query(query, *args_, args : Array? = nil)
# CHECK build(query).query(*args, &block) # CHECK build(query).query(*args, &block)
rs = query(query, *args) rs = query(query, *args_, args: args)
yield rs ensure rs.close yield rs ensure rs.close
end end
@ -64,8 +72,8 @@ module DB
# ``` # ```
# name = db.query_one "select name from contacts where id = ?", 18, &.read(String) # name = db.query_one "select name from contacts where id = ?", 18, &.read(String)
# ``` # ```
def query_one(query, *args, &block : ResultSet -> U) : U forall U def query_one(query, *args_, args : Array? = nil, &block : ResultSet -> U) : U forall U
query(query, *args) do |rs| query(query, *args_, args: args) do |rs|
raise DB::Error.new("no rows") unless rs.move_next raise DB::Error.new("no rows") unless rs.move_next
value = yield rs value = yield rs
@ -82,8 +90,8 @@ module DB
# ``` # ```
# db.query_one "select name, age from contacts where id = ?", 1, as: {String, Int32} # db.query_one "select name, age from contacts where id = ?", 1, as: {String, Int32}
# ``` # ```
def query_one(query, *args, as types : Tuple) def query_one(query, *args_, args : Array? = nil, as types : Tuple)
query_one(query, *args) do |rs| query_one(query, *args_, args: args) do |rs|
rs.read(*types) rs.read(*types)
end end
end end
@ -97,8 +105,8 @@ module DB
# ``` # ```
# db.query_one "select name, age from contacts where id = ?", 1, as: {name: String, age: Int32} # db.query_one "select name, age from contacts where id = ?", 1, as: {name: String, age: Int32}
# ``` # ```
def query_one(query, *args, as types : NamedTuple) def query_one(query, *args_, args : Array? = nil, as types : NamedTuple)
query_one(query, *args) do |rs| query_one(query, *args_, args: args) do |rs|
rs.read(**types) rs.read(**types)
end end
end end
@ -111,8 +119,8 @@ module DB
# ``` # ```
# db.query_one "select name from contacts where id = ?", 1, as: String # db.query_one "select name from contacts where id = ?", 1, as: String
# ``` # ```
def query_one(query, *args, as type : Class) def query_one(query, *args_, args : Array? = nil, as type : Class)
query_one(query, *args) do |rs| query_one(query, *args_, args: args) do |rs|
rs.read(type) rs.read(type)
end end
end end
@ -129,8 +137,8 @@ module DB
# name = db.query_one? "select name from contacts where id = ?", 18, &.read(String) # name = db.query_one? "select name from contacts where id = ?", 18, &.read(String)
# typeof(name) # => String | Nil # typeof(name) # => String | Nil
# ``` # ```
def query_one?(query, *args, &block : ResultSet -> U) : U? forall U def query_one?(query, *args_, args : Array? = nil, &block : ResultSet -> U) : U? forall U
query(query, *args) do |rs| query(query, *args_, args: args) do |rs|
return nil unless rs.move_next return nil unless rs.move_next
value = yield rs value = yield rs
@ -150,8 +158,8 @@ module DB
# result = db.query_one? "select name, age from contacts where id = ?", 1, as: {String, Int32} # result = db.query_one? "select name, age from contacts where id = ?", 1, as: {String, Int32}
# typeof(result) # => Tuple(String, Int32) | Nil # typeof(result) # => Tuple(String, Int32) | Nil
# ``` # ```
def query_one?(query, *args, as types : Tuple) def query_one?(query, *args_, args : Array? = nil, as types : Tuple)
query_one?(query, *args) do |rs| query_one?(query, *args_, args: args) do |rs|
rs.read(*types) rs.read(*types)
end end
end end
@ -168,8 +176,8 @@ module DB
# result = db.query_one? "select name, age from contacts where id = ?", 1, as: {age: String, name: Int32} # result = db.query_one? "select name, age from contacts where id = ?", 1, as: {age: String, name: Int32}
# typeof(result) # => NamedTuple(age: String, name: Int32) | Nil # typeof(result) # => NamedTuple(age: String, name: Int32) | Nil
# ``` # ```
def query_one?(query, *args, as types : NamedTuple) def query_one?(query, *args_, args : Array? = nil, as types : NamedTuple)
query_one?(query, *args) do |rs| query_one?(query, *args_, args: args) do |rs|
rs.read(**types) rs.read(**types)
end end
end end
@ -185,8 +193,8 @@ module DB
# name = db.query_one? "select name from contacts where id = ?", 1, as: String # name = db.query_one? "select name from contacts where id = ?", 1, as: String
# typeof(name) # => String? # typeof(name) # => String?
# ``` # ```
def query_one?(query, *args, as type : Class) def query_one?(query, *args_, args : Array? = nil, as type : Class)
query_one?(query, *args) do |rs| query_one?(query, *args_, args: args) do |rs|
rs.read(type) rs.read(type)
end end
end end
@ -197,9 +205,9 @@ module DB
# ``` # ```
# names = db.query_all "select name from contacts", &.read(String) # names = db.query_all "select name from contacts", &.read(String)
# ``` # ```
def query_all(query, *args, &block : ResultSet -> U) : Array(U) forall U def query_all(query, *args_, args : Array? = nil, &block : ResultSet -> U) : Array(U) forall U
ary = [] of U ary = [] of U
query_each(query, *args) do |rs| query_each(query, *args_, args: args) do |rs|
ary.push(yield rs) ary.push(yield rs)
end end
ary ary
@ -211,8 +219,8 @@ module DB
# ``` # ```
# contacts = db.query_all "select name, age from contacts", as: {String, Int32} # contacts = db.query_all "select name, age from contacts", as: {String, Int32}
# ``` # ```
def query_all(query, *args, as types : Tuple) def query_all(query, *args_, args : Array? = nil, as types : Tuple)
query_all(query, *args) do |rs| query_all(query, *args_, args: args) do |rs|
rs.read(*types) rs.read(*types)
end end
end end
@ -224,8 +232,8 @@ module DB
# ``` # ```
# contacts = db.query_all "select name, age from contacts", as: {name: String, age: Int32} # contacts = db.query_all "select name, age from contacts", as: {name: String, age: Int32}
# ``` # ```
def query_all(query, *args, as types : NamedTuple) def query_all(query, *args_, args : Array? = nil, as types : NamedTuple)
query_all(query, *args) do |rs| query_all(query, *args_, args: args) do |rs|
rs.read(**types) rs.read(**types)
end end
end end
@ -236,8 +244,8 @@ module DB
# ``` # ```
# names = db.query_all "select name from contacts", as: String # names = db.query_all "select name from contacts", as: String
# ``` # ```
def query_all(query, *args, as type : Class) def query_all(query, *args_, args : Array? = nil, as type : Class)
query_all(query, *args) do |rs| query_all(query, *args_, args: args) do |rs|
rs.read(type) rs.read(type)
end end
end end
@ -250,8 +258,8 @@ module DB
# puts rs.read(String) # puts rs.read(String)
# end # end
# ``` # ```
def query_each(query, *args) def query_each(query, *args_, args : Array? = nil)
query(query, *args) do |rs| query(query, *args_, args: args) do |rs|
rs.each do rs.each do
yield rs yield rs
end end
@ -259,8 +267,8 @@ module DB
end end
# Performs the `query` and returns an `ExecResult` # Performs the `query` and returns an `ExecResult`
def exec(query, *args) def exec(query, *args_, args : Array? = nil)
build(query).exec(*args) build(query).exec(*args_, args: args)
end end
# Performs the `query` and returns a single scalar value # Performs the `query` and returns a single scalar value
@ -268,8 +276,8 @@ module DB
# ``` # ```
# puts db.scalar("SELECT MAX(name)").as(String) # => (a String) # puts db.scalar("SELECT MAX(name)").as(String) # => (a String)
# ``` # ```
def scalar(query, *args) def scalar(query, *args_, args : Array? = nil)
build(query).scalar(*args) build(query).scalar(*args_, args: args)
end end
end end
end end

View file

@ -8,8 +8,8 @@ module DB
end end
# See `QueryMethods#scalar` # See `QueryMethods#scalar`
def scalar(*args) def scalar(*args_, args : Array? = nil)
query(*args) do |rs| query(*args_, args: args) do |rs|
rs.each do rs.each do
return rs.read return rs.read
end end
@ -19,24 +19,20 @@ module DB
end end
# See `QueryMethods#query` # See `QueryMethods#query`
def query(*args) def query(*args_, args : Array? = nil)
rs = query(*args) rs = query(*args_, args: args)
yield rs ensure rs.close yield rs ensure rs.close
end end
# See `QueryMethods#exec` # See `QueryMethods#exec`
abstract def exec : ExecResult abstract def exec : ExecResult
# See `QueryMethods#exec` # See `QueryMethods#exec`
abstract def exec(*args) : ExecResult abstract def exec(*args_, args : Array? = nil) : ExecResult
# See `QueryMethods#exec`
abstract def exec(args : Array) : ExecResult
# See `QueryMethods#query` # See `QueryMethods#query`
abstract def query : ResultSet abstract def query : ResultSet
# See `QueryMethods#query` # See `QueryMethods#query`
abstract def query(*args) : ResultSet abstract def query(*args_, args : Array? = nil) : ResultSet
# See `QueryMethods#query`
abstract def query(args : Array) : ResultSet
end end
# Represents a query in a `Connection`. # Represents a query in a `Connection`.
@ -68,14 +64,8 @@ module DB
end end
# See `QueryMethods#exec` # See `QueryMethods#exec`
def exec(args : Array) : DB::ExecResult def exec(*args_, args : Array? = nil) : DB::ExecResult
perform_exec_and_release(args) perform_exec_and_release(EnumerableConcat.build(args_, args))
end
# See `QueryMethods#exec`
def exec(*args)
# TODO better way to do it
perform_exec_and_release(args)
end end
# See `QueryMethods#query` # See `QueryMethods#query`
@ -84,13 +74,8 @@ module DB
end end
# See `QueryMethods#query` # See `QueryMethods#query`
def query(args : Array) : DB::ResultSet def query(*args_, args : Array? = nil) : DB::ResultSet
perform_query_with_rescue args perform_query_with_rescue(EnumerableConcat.build(args_, args))
end
# See `QueryMethods#query`
def query(*args)
perform_query_with_rescue args
end end
private def perform_exec_and_release(args : Enumerable) : ExecResult private def perform_exec_and_release(args : Enumerable) : ExecResult

View file

@ -154,7 +154,7 @@ module DB
end end
it "executes with bind #{value_desc} as array" do |db| it "executes with bind #{value_desc} as array" do |db|
db.scalar(select_scalar(param(1), sql_type), [value]).should eq(value) db.scalar(select_scalar(param(1), sql_type), args: [value]).should eq(value)
end end
it "select #{value_desc} as literal" do |db| it "select #{value_desc} as literal" do |db|