Statement#exec and #query require named argument for array values (#110)

This change allows to use an array as single argument for #exec and
 #query and #scalar methods. Before it was shadowed by the *args splat overload.
This commit is contained in:
Johannes Müller 2019-09-20 22:23:09 +02:00 committed by Brian J. Cardiff
parent af6d837bcd
commit b3898ae2a2
9 changed files with 168 additions and 88 deletions

View File

@ -227,14 +227,14 @@ describe DB do
db.query "query", 1, "string" { }
db.query("query", Bytes.new(4)) { }
db.query("query", 1, "string", FooValue.new(5)) { }
db.query "query", [1, "string", FooValue.new(5)] { }
db.query "query", args: [1, "string", FooValue.new(5)] { }
db.query("query").close
db.query("query", 1).close
db.query("query", 1, "string").close
db.query("query", Bytes.new(4)).close
db.query("query", 1, "string", FooValue.new(5)).close
db.query("query", [1, "string", FooValue.new(5)]).close
db.query("query", args: [1, "string", FooValue.new(5)]).close
end
DB.open("bar://host") do |db|
@ -244,14 +244,14 @@ describe DB do
db.query "query", 1, "string" { }
db.query("query", Bytes.new(4)) { }
db.query("query", 1, "string", BarValue.new(5)) { }
db.query "query", [1, "string", BarValue.new(5)] { }
db.query "query", args: [1, "string", BarValue.new(5)] { }
db.query("query").close
db.query("query", 1).close
db.query("query", 1, "string").close
db.query("query", Bytes.new(4)).close
db.query("query", 1, "string", BarValue.new(5)).close
db.query("query", [1, "string", BarValue.new(5)]).close
db.query("query", args: [1, "string", BarValue.new(5)]).close
end
end
@ -263,7 +263,7 @@ describe DB do
db.exec("query", 1, "string")
db.exec("query", Bytes.new(4))
db.exec("query", 1, "string", FooValue.new(5))
db.exec("query", [1, "string", FooValue.new(5)])
db.exec("query", args: [1, "string", FooValue.new(5)])
end
DB.open("bar://host") do |db|
@ -273,20 +273,20 @@ describe DB do
db.exec("query", 1, "string")
db.exec("query", Bytes.new(4))
db.exec("query", 1, "string", BarValue.new(5))
db.exec("query", [1, "string", BarValue.new(5)])
db.exec("query", args: [1, "string", BarValue.new(5)])
end
end
it "Foo and Bar drivers should not implement each other params" do
DB.open("foo://host") do |db|
expect_raises Exception, "FooDriver::FooStatement does not support BarValue params" do
db.exec("query", [BarValue.new(5)])
db.exec("query", args: [BarValue.new(5)])
end
end
DB.open("bar://host") do |db|
expect_raises Exception, "BarDriver::BarStatement does not support FooValue params" do
db.exec("query", [FooValue.new(5)])
db.exec("query", args: [FooValue.new(5)])
end
end
end

View File

@ -97,7 +97,7 @@ class DummyDriver < DB::Driver
property params
def initialize(connection, @query : String, @prepared : Bool)
@params = Hash(Int32 | String, DB::Any).new
@params = Hash(Int32 | String, DB::Any | Array(DB::Any)).new
super(connection)
raise DB::Error.new(query) if query == "syntax error"
end
@ -126,6 +126,10 @@ class DummyDriver < DB::Driver
@params[index] = value
end
private def set_param(index, value : Array)
@params[index] = value.map(&.as(DB::Any))
end
private def set_param(index, value)
raise "not implemented for #{value.class}"
end

View File

@ -43,10 +43,37 @@ describe DB::Statement do
end
end
it "should initialize positional params in query with array" do
it "accepts array as single argument" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.query ["a", 1, nil]
stmt.params[0].should eq(["a", 1, nil])
end
end
it "allows no arguments" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.query
stmt.params.should be_empty
end
end
it "concatenate arguments" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.query 1, 2, args: ["a", [1, nil]]
stmt.params[0].should eq(1)
stmt.params[1].should eq(2)
stmt.params[2].should eq("a")
stmt.params[3].should eq([1, nil])
end
end
it "should initialize positional params in query with array" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.query args: ["a", 1, nil]
stmt.params[0].should eq("a")
stmt.params[1].should eq(1)
stmt.params[2].should eq(nil)
@ -63,16 +90,43 @@ describe DB::Statement do
end
end
it "should initialize positional params in exec with array" do
it "accepts array as single argument" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.exec ["a", 1, nil]
stmt.params[0].should eq(["a", 1, nil])
end
end
it "should initialize positional params in exec with array" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.exec args: ["a", 1, nil]
stmt.params[0].should eq("a")
stmt.params[1].should eq(1)
stmt.params[2].should eq(nil)
end
end
it "allows no arguments" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.exec
stmt.params.should be_empty
end
end
it "concatenate arguments" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
stmt.exec 1, 2, args: ["a", [1, nil]]
stmt.params[0].should eq(1)
stmt.params[1].should eq(2)
stmt.params[2].should eq("a")
stmt.params[3].should eq([1, nil])
end
end
it "should initialize positional params in scalar" do
with_dummy_connection do |cnn|
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)

View File

@ -181,6 +181,7 @@ end
require "./db/pool"
require "./db/string_key_cache"
require "./db/enumerable_concat"
require "./db/query_methods"
require "./db/session_methods"
require "./db/disposable"

View File

@ -0,0 +1,38 @@
module DB
# :nodoc:
struct EnumerableConcat(S, T, U)
include Enumerable(S)
def initialize(@e1 : T, @e2 : U)
end
def each
if e1 = @e1
@e1.each do |e|
yield e
end
end
if e2 = @e2
e2.each do |e|
yield e
end
end
end
# returns given `e1 : T` an `Enumerable(T')` and `e2 : U` an `Enumerable(U') | Nil`
# it retuns and `Enumerable(T' | U')` that enumerates the elements of `e1`
# and, later, the elements of `e2`.
def self.build(e1 : T, e2 : U)
return e1 if e2.nil? || e2.empty?
return e2 if e1.nil? || e1.empty?
EnumerableConcat(Union(typeof(sample(e1)), typeof(sample(e2))), T, U).new(e1, e2)
end
private def self.sample(c : Enumerable?)
c.not_nil!.each do |e|
return e
end
raise ""
end
end
end

View File

@ -15,13 +15,8 @@ module DB
end
# See `QueryMethods#exec`
def exec(*args) : ExecResult
statement_with_retry &.exec(*args)
end
# See `QueryMethods#exec`
def exec(args : Array) : ExecResult
statement_with_retry &.exec(args)
def exec(*args_, args : Array? = nil) : ExecResult
statement_with_retry &.exec(*args_, args: args)
end
# See `QueryMethods#query`
@ -30,18 +25,13 @@ module DB
end
# See `QueryMethods#query`
def query(*args) : ResultSet
statement_with_retry &.query(*args)
end
# See `QueryMethods#query`
def query(args : Array) : ResultSet
statement_with_retry &.query(args)
def query(*args_, args : Array? = nil) : ResultSet
statement_with_retry &.query(*args_, args: args)
end
# See `QueryMethods#scalar`
def scalar(*args)
statement_with_retry &.scalar(*args)
def scalar(*args_, args : Array? = nil)
statement_with_retry &.scalar(*args_, args: args)
end
# builds a statement over a real connection

View File

@ -7,10 +7,11 @@ module DB
# 2. `#scalar` reads a single value of the response. A union of possible values is returned.
# 3. `#query` returns a `ResultSet` that allows iteration over the rows in the response and column information.
#
# Arguments can be passed by position
# Arguments can be passed by position or as an array.
#
# ```
# db.query("SELECT name FROM ... WHERE age > ?", age)
# db.query("SELECT name FROM ... WHERE age > ?", args: [age])
# ```
#
# Convention of mapping how arguments are mapped to the query depends on each driver.
@ -34,8 +35,15 @@ module DB
# result.close
# end
# ```
def query(query, *args)
build(query).query(*args)
#
# Note: to use a dynamic list length of arguments use `args:` keyword argument.
#
# ```
# result = db.query "select name from contacts where id = ?", args: [10]
# ```
#
def query(query, *args_, args : Array? = nil)
build(query).query(*args_, args: args)
end
# Executes a *query* and yields a `ResultSet` with the results.
@ -48,9 +56,9 @@ module DB
# end
# end
# ```
def query(query, *args)
def query(query, *args_, args : Array? = nil)
# CHECK build(query).query(*args, &block)
rs = query(query, *args)
rs = query(query, *args_, args: args)
yield rs ensure rs.close
end
@ -64,8 +72,8 @@ module DB
# ```
# name = db.query_one "select name from contacts where id = ?", 18, &.read(String)
# ```
def query_one(query, *args, &block : ResultSet -> U) : U forall U
query(query, *args) do |rs|
def query_one(query, *args_, args : Array? = nil, &block : ResultSet -> U) : U forall U
query(query, *args_, args: args) do |rs|
raise DB::Error.new("no rows") unless rs.move_next
value = yield rs
@ -82,8 +90,8 @@ module DB
# ```
# db.query_one "select name, age from contacts where id = ?", 1, as: {String, Int32}
# ```
def query_one(query, *args, as types : Tuple)
query_one(query, *args) do |rs|
def query_one(query, *args_, args : Array? = nil, as types : Tuple)
query_one(query, *args_, args: args) do |rs|
rs.read(*types)
end
end
@ -97,8 +105,8 @@ module DB
# ```
# db.query_one "select name, age from contacts where id = ?", 1, as: {name: String, age: Int32}
# ```
def query_one(query, *args, as types : NamedTuple)
query_one(query, *args) do |rs|
def query_one(query, *args_, args : Array? = nil, as types : NamedTuple)
query_one(query, *args_, args: args) do |rs|
rs.read(**types)
end
end
@ -111,8 +119,8 @@ module DB
# ```
# db.query_one "select name from contacts where id = ?", 1, as: String
# ```
def query_one(query, *args, as type : Class)
query_one(query, *args) do |rs|
def query_one(query, *args_, args : Array? = nil, as type : Class)
query_one(query, *args_, args: args) do |rs|
rs.read(type)
end
end
@ -129,8 +137,8 @@ module DB
# name = db.query_one? "select name from contacts where id = ?", 18, &.read(String)
# typeof(name) # => String | Nil
# ```
def query_one?(query, *args, &block : ResultSet -> U) : U? forall U
query(query, *args) do |rs|
def query_one?(query, *args_, args : Array? = nil, &block : ResultSet -> U) : U? forall U
query(query, *args_, args: args) do |rs|
return nil unless rs.move_next
value = yield rs
@ -150,8 +158,8 @@ module DB
# result = db.query_one? "select name, age from contacts where id = ?", 1, as: {String, Int32}
# typeof(result) # => Tuple(String, Int32) | Nil
# ```
def query_one?(query, *args, as types : Tuple)
query_one?(query, *args) do |rs|
def query_one?(query, *args_, args : Array? = nil, as types : Tuple)
query_one?(query, *args_, args: args) do |rs|
rs.read(*types)
end
end
@ -168,8 +176,8 @@ module DB
# result = db.query_one? "select name, age from contacts where id = ?", 1, as: {age: String, name: Int32}
# typeof(result) # => NamedTuple(age: String, name: Int32) | Nil
# ```
def query_one?(query, *args, as types : NamedTuple)
query_one?(query, *args) do |rs|
def query_one?(query, *args_, args : Array? = nil, as types : NamedTuple)
query_one?(query, *args_, args: args) do |rs|
rs.read(**types)
end
end
@ -185,8 +193,8 @@ module DB
# name = db.query_one? "select name from contacts where id = ?", 1, as: String
# typeof(name) # => String?
# ```
def query_one?(query, *args, as type : Class)
query_one?(query, *args) do |rs|
def query_one?(query, *args_, args : Array? = nil, as type : Class)
query_one?(query, *args_, args: args) do |rs|
rs.read(type)
end
end
@ -197,9 +205,9 @@ module DB
# ```
# names = db.query_all "select name from contacts", &.read(String)
# ```
def query_all(query, *args, &block : ResultSet -> U) : Array(U) forall U
def query_all(query, *args_, args : Array? = nil, &block : ResultSet -> U) : Array(U) forall U
ary = [] of U
query_each(query, *args) do |rs|
query_each(query, *args_, args: args) do |rs|
ary.push(yield rs)
end
ary
@ -211,8 +219,8 @@ module DB
# ```
# contacts = db.query_all "select name, age from contacts", as: {String, Int32}
# ```
def query_all(query, *args, as types : Tuple)
query_all(query, *args) do |rs|
def query_all(query, *args_, args : Array? = nil, as types : Tuple)
query_all(query, *args_, args: args) do |rs|
rs.read(*types)
end
end
@ -224,8 +232,8 @@ module DB
# ```
# contacts = db.query_all "select name, age from contacts", as: {name: String, age: Int32}
# ```
def query_all(query, *args, as types : NamedTuple)
query_all(query, *args) do |rs|
def query_all(query, *args_, args : Array? = nil, as types : NamedTuple)
query_all(query, *args_, args: args) do |rs|
rs.read(**types)
end
end
@ -236,8 +244,8 @@ module DB
# ```
# names = db.query_all "select name from contacts", as: String
# ```
def query_all(query, *args, as type : Class)
query_all(query, *args) do |rs|
def query_all(query, *args_, args : Array? = nil, as type : Class)
query_all(query, *args_, args: args) do |rs|
rs.read(type)
end
end
@ -250,8 +258,8 @@ module DB
# puts rs.read(String)
# end
# ```
def query_each(query, *args)
query(query, *args) do |rs|
def query_each(query, *args_, args : Array? = nil)
query(query, *args_, args: args) do |rs|
rs.each do
yield rs
end
@ -259,8 +267,8 @@ module DB
end
# Performs the `query` and returns an `ExecResult`
def exec(query, *args)
build(query).exec(*args)
def exec(query, *args_, args : Array? = nil)
build(query).exec(*args_, args: args)
end
# Performs the `query` and returns a single scalar value
@ -268,8 +276,8 @@ module DB
# ```
# puts db.scalar("SELECT MAX(name)").as(String) # => (a String)
# ```
def scalar(query, *args)
build(query).scalar(*args)
def scalar(query, *args_, args : Array? = nil)
build(query).scalar(*args_, args: args)
end
end
end

View File

@ -8,8 +8,8 @@ module DB
end
# See `QueryMethods#scalar`
def scalar(*args)
query(*args) do |rs|
def scalar(*args_, args : Array? = nil)
query(*args_, args: args) do |rs|
rs.each do
return rs.read
end
@ -19,24 +19,20 @@ module DB
end
# See `QueryMethods#query`
def query(*args)
rs = query(*args)
def query(*args_, args : Array? = nil)
rs = query(*args_, args: args)
yield rs ensure rs.close
end
# See `QueryMethods#exec`
abstract def exec : ExecResult
# See `QueryMethods#exec`
abstract def exec(*args) : ExecResult
# See `QueryMethods#exec`
abstract def exec(args : Array) : ExecResult
abstract def exec(*args_, args : Array? = nil) : ExecResult
# See `QueryMethods#query`
abstract def query : ResultSet
# See `QueryMethods#query`
abstract def query(*args) : ResultSet
# See `QueryMethods#query`
abstract def query(args : Array) : ResultSet
abstract def query(*args_, args : Array? = nil) : ResultSet
end
# Represents a query in a `Connection`.
@ -68,14 +64,8 @@ module DB
end
# See `QueryMethods#exec`
def exec(args : Array) : DB::ExecResult
perform_exec_and_release(args)
end
# See `QueryMethods#exec`
def exec(*args)
# TODO better way to do it
perform_exec_and_release(args)
def exec(*args_, args : Array? = nil) : DB::ExecResult
perform_exec_and_release(EnumerableConcat.build(args_, args))
end
# See `QueryMethods#query`
@ -84,13 +74,8 @@ module DB
end
# See `QueryMethods#query`
def query(args : Array) : DB::ResultSet
perform_query_with_rescue args
end
# See `QueryMethods#query`
def query(*args)
perform_query_with_rescue args
def query(*args_, args : Array? = nil) : DB::ResultSet
perform_query_with_rescue(EnumerableConcat.build(args_, args))
end
private def perform_exec_and_release(args : Enumerable) : ExecResult

View File

@ -154,7 +154,7 @@ module DB
end
it "executes with bind #{value_desc} as array" do |db|
db.scalar(select_scalar(param(1), sql_type), [value]).should eq(value)
db.scalar(select_scalar(param(1), sql_type), args: [value]).should eq(value)
end
it "select #{value_desc} as literal" do |db|