mirror of
https://gitea.invidious.io/iv-org/shard-crystal-db.git
synced 2024-08-15 00:53:32 +00:00
Statement#exec and #query require named argument for array values (#110)
This change allows to use an array as single argument for #exec and #query and #scalar methods. Before it was shadowed by the *args splat overload.
This commit is contained in:
parent
af6d837bcd
commit
b3898ae2a2
9 changed files with 168 additions and 88 deletions
|
@ -227,14 +227,14 @@ describe DB do
|
|||
db.query "query", 1, "string" { }
|
||||
db.query("query", Bytes.new(4)) { }
|
||||
db.query("query", 1, "string", FooValue.new(5)) { }
|
||||
db.query "query", [1, "string", FooValue.new(5)] { }
|
||||
db.query "query", args: [1, "string", FooValue.new(5)] { }
|
||||
|
||||
db.query("query").close
|
||||
db.query("query", 1).close
|
||||
db.query("query", 1, "string").close
|
||||
db.query("query", Bytes.new(4)).close
|
||||
db.query("query", 1, "string", FooValue.new(5)).close
|
||||
db.query("query", [1, "string", FooValue.new(5)]).close
|
||||
db.query("query", args: [1, "string", FooValue.new(5)]).close
|
||||
end
|
||||
|
||||
DB.open("bar://host") do |db|
|
||||
|
@ -244,14 +244,14 @@ describe DB do
|
|||
db.query "query", 1, "string" { }
|
||||
db.query("query", Bytes.new(4)) { }
|
||||
db.query("query", 1, "string", BarValue.new(5)) { }
|
||||
db.query "query", [1, "string", BarValue.new(5)] { }
|
||||
db.query "query", args: [1, "string", BarValue.new(5)] { }
|
||||
|
||||
db.query("query").close
|
||||
db.query("query", 1).close
|
||||
db.query("query", 1, "string").close
|
||||
db.query("query", Bytes.new(4)).close
|
||||
db.query("query", 1, "string", BarValue.new(5)).close
|
||||
db.query("query", [1, "string", BarValue.new(5)]).close
|
||||
db.query("query", args: [1, "string", BarValue.new(5)]).close
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -263,7 +263,7 @@ describe DB do
|
|||
db.exec("query", 1, "string")
|
||||
db.exec("query", Bytes.new(4))
|
||||
db.exec("query", 1, "string", FooValue.new(5))
|
||||
db.exec("query", [1, "string", FooValue.new(5)])
|
||||
db.exec("query", args: [1, "string", FooValue.new(5)])
|
||||
end
|
||||
|
||||
DB.open("bar://host") do |db|
|
||||
|
@ -273,20 +273,20 @@ describe DB do
|
|||
db.exec("query", 1, "string")
|
||||
db.exec("query", Bytes.new(4))
|
||||
db.exec("query", 1, "string", BarValue.new(5))
|
||||
db.exec("query", [1, "string", BarValue.new(5)])
|
||||
db.exec("query", args: [1, "string", BarValue.new(5)])
|
||||
end
|
||||
end
|
||||
|
||||
it "Foo and Bar drivers should not implement each other params" do
|
||||
DB.open("foo://host") do |db|
|
||||
expect_raises Exception, "FooDriver::FooStatement does not support BarValue params" do
|
||||
db.exec("query", [BarValue.new(5)])
|
||||
db.exec("query", args: [BarValue.new(5)])
|
||||
end
|
||||
end
|
||||
|
||||
DB.open("bar://host") do |db|
|
||||
expect_raises Exception, "BarDriver::BarStatement does not support FooValue params" do
|
||||
db.exec("query", [FooValue.new(5)])
|
||||
db.exec("query", args: [FooValue.new(5)])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -97,7 +97,7 @@ class DummyDriver < DB::Driver
|
|||
property params
|
||||
|
||||
def initialize(connection, @query : String, @prepared : Bool)
|
||||
@params = Hash(Int32 | String, DB::Any).new
|
||||
@params = Hash(Int32 | String, DB::Any | Array(DB::Any)).new
|
||||
super(connection)
|
||||
raise DB::Error.new(query) if query == "syntax error"
|
||||
end
|
||||
|
@ -126,6 +126,10 @@ class DummyDriver < DB::Driver
|
|||
@params[index] = value
|
||||
end
|
||||
|
||||
private def set_param(index, value : Array)
|
||||
@params[index] = value.map(&.as(DB::Any))
|
||||
end
|
||||
|
||||
private def set_param(index, value)
|
||||
raise "not implemented for #{value.class}"
|
||||
end
|
||||
|
|
|
@ -43,10 +43,37 @@ describe DB::Statement do
|
|||
end
|
||||
end
|
||||
|
||||
it "should initialize positional params in query with array" do
|
||||
it "accepts array as single argument" do
|
||||
with_dummy_connection do |cnn|
|
||||
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
|
||||
stmt.query ["a", 1, nil]
|
||||
stmt.params[0].should eq(["a", 1, nil])
|
||||
end
|
||||
end
|
||||
|
||||
it "allows no arguments" do
|
||||
with_dummy_connection do |cnn|
|
||||
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
|
||||
stmt.query
|
||||
stmt.params.should be_empty
|
||||
end
|
||||
end
|
||||
|
||||
it "concatenate arguments" do
|
||||
with_dummy_connection do |cnn|
|
||||
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
|
||||
stmt.query 1, 2, args: ["a", [1, nil]]
|
||||
stmt.params[0].should eq(1)
|
||||
stmt.params[1].should eq(2)
|
||||
stmt.params[2].should eq("a")
|
||||
stmt.params[3].should eq([1, nil])
|
||||
end
|
||||
end
|
||||
|
||||
it "should initialize positional params in query with array" do
|
||||
with_dummy_connection do |cnn|
|
||||
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
|
||||
stmt.query args: ["a", 1, nil]
|
||||
stmt.params[0].should eq("a")
|
||||
stmt.params[1].should eq(1)
|
||||
stmt.params[2].should eq(nil)
|
||||
|
@ -63,16 +90,43 @@ describe DB::Statement do
|
|||
end
|
||||
end
|
||||
|
||||
it "should initialize positional params in exec with array" do
|
||||
it "accepts array as single argument" do
|
||||
with_dummy_connection do |cnn|
|
||||
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
|
||||
stmt.exec ["a", 1, nil]
|
||||
stmt.params[0].should eq(["a", 1, nil])
|
||||
end
|
||||
end
|
||||
|
||||
it "should initialize positional params in exec with array" do
|
||||
with_dummy_connection do |cnn|
|
||||
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
|
||||
stmt.exec args: ["a", 1, nil]
|
||||
stmt.params[0].should eq("a")
|
||||
stmt.params[1].should eq(1)
|
||||
stmt.params[2].should eq(nil)
|
||||
end
|
||||
end
|
||||
|
||||
it "allows no arguments" do
|
||||
with_dummy_connection do |cnn|
|
||||
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
|
||||
stmt.exec
|
||||
stmt.params.should be_empty
|
||||
end
|
||||
end
|
||||
|
||||
it "concatenate arguments" do
|
||||
with_dummy_connection do |cnn|
|
||||
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
|
||||
stmt.exec 1, 2, args: ["a", [1, nil]]
|
||||
stmt.params[0].should eq(1)
|
||||
stmt.params[1].should eq(2)
|
||||
stmt.params[2].should eq("a")
|
||||
stmt.params[3].should eq([1, nil])
|
||||
end
|
||||
end
|
||||
|
||||
it "should initialize positional params in scalar" do
|
||||
with_dummy_connection do |cnn|
|
||||
stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement)
|
||||
|
|
|
@ -181,6 +181,7 @@ end
|
|||
|
||||
require "./db/pool"
|
||||
require "./db/string_key_cache"
|
||||
require "./db/enumerable_concat"
|
||||
require "./db/query_methods"
|
||||
require "./db/session_methods"
|
||||
require "./db/disposable"
|
||||
|
|
38
src/db/enumerable_concat.cr
Normal file
38
src/db/enumerable_concat.cr
Normal file
|
@ -0,0 +1,38 @@
|
|||
module DB
|
||||
# :nodoc:
|
||||
struct EnumerableConcat(S, T, U)
|
||||
include Enumerable(S)
|
||||
|
||||
def initialize(@e1 : T, @e2 : U)
|
||||
end
|
||||
|
||||
def each
|
||||
if e1 = @e1
|
||||
@e1.each do |e|
|
||||
yield e
|
||||
end
|
||||
end
|
||||
if e2 = @e2
|
||||
e2.each do |e|
|
||||
yield e
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# returns given `e1 : T` an `Enumerable(T')` and `e2 : U` an `Enumerable(U') | Nil`
|
||||
# it retuns and `Enumerable(T' | U')` that enumerates the elements of `e1`
|
||||
# and, later, the elements of `e2`.
|
||||
def self.build(e1 : T, e2 : U)
|
||||
return e1 if e2.nil? || e2.empty?
|
||||
return e2 if e1.nil? || e1.empty?
|
||||
EnumerableConcat(Union(typeof(sample(e1)), typeof(sample(e2))), T, U).new(e1, e2)
|
||||
end
|
||||
|
||||
private def self.sample(c : Enumerable?)
|
||||
c.not_nil!.each do |e|
|
||||
return e
|
||||
end
|
||||
raise ""
|
||||
end
|
||||
end
|
||||
end
|
|
@ -15,13 +15,8 @@ module DB
|
|||
end
|
||||
|
||||
# See `QueryMethods#exec`
|
||||
def exec(*args) : ExecResult
|
||||
statement_with_retry &.exec(*args)
|
||||
end
|
||||
|
||||
# See `QueryMethods#exec`
|
||||
def exec(args : Array) : ExecResult
|
||||
statement_with_retry &.exec(args)
|
||||
def exec(*args_, args : Array? = nil) : ExecResult
|
||||
statement_with_retry &.exec(*args_, args: args)
|
||||
end
|
||||
|
||||
# See `QueryMethods#query`
|
||||
|
@ -30,18 +25,13 @@ module DB
|
|||
end
|
||||
|
||||
# See `QueryMethods#query`
|
||||
def query(*args) : ResultSet
|
||||
statement_with_retry &.query(*args)
|
||||
end
|
||||
|
||||
# See `QueryMethods#query`
|
||||
def query(args : Array) : ResultSet
|
||||
statement_with_retry &.query(args)
|
||||
def query(*args_, args : Array? = nil) : ResultSet
|
||||
statement_with_retry &.query(*args_, args: args)
|
||||
end
|
||||
|
||||
# See `QueryMethods#scalar`
|
||||
def scalar(*args)
|
||||
statement_with_retry &.scalar(*args)
|
||||
def scalar(*args_, args : Array? = nil)
|
||||
statement_with_retry &.scalar(*args_, args: args)
|
||||
end
|
||||
|
||||
# builds a statement over a real connection
|
||||
|
|
|
@ -7,10 +7,11 @@ module DB
|
|||
# 2. `#scalar` reads a single value of the response. A union of possible values is returned.
|
||||
# 3. `#query` returns a `ResultSet` that allows iteration over the rows in the response and column information.
|
||||
#
|
||||
# Arguments can be passed by position
|
||||
# Arguments can be passed by position or as an array.
|
||||
#
|
||||
# ```
|
||||
# db.query("SELECT name FROM ... WHERE age > ?", age)
|
||||
# db.query("SELECT name FROM ... WHERE age > ?", args: [age])
|
||||
# ```
|
||||
#
|
||||
# Convention of mapping how arguments are mapped to the query depends on each driver.
|
||||
|
@ -34,8 +35,15 @@ module DB
|
|||
# result.close
|
||||
# end
|
||||
# ```
|
||||
def query(query, *args)
|
||||
build(query).query(*args)
|
||||
#
|
||||
# Note: to use a dynamic list length of arguments use `args:` keyword argument.
|
||||
#
|
||||
# ```
|
||||
# result = db.query "select name from contacts where id = ?", args: [10]
|
||||
# ```
|
||||
#
|
||||
def query(query, *args_, args : Array? = nil)
|
||||
build(query).query(*args_, args: args)
|
||||
end
|
||||
|
||||
# Executes a *query* and yields a `ResultSet` with the results.
|
||||
|
@ -48,9 +56,9 @@ module DB
|
|||
# end
|
||||
# end
|
||||
# ```
|
||||
def query(query, *args)
|
||||
def query(query, *args_, args : Array? = nil)
|
||||
# CHECK build(query).query(*args, &block)
|
||||
rs = query(query, *args)
|
||||
rs = query(query, *args_, args: args)
|
||||
yield rs ensure rs.close
|
||||
end
|
||||
|
||||
|
@ -64,8 +72,8 @@ module DB
|
|||
# ```
|
||||
# name = db.query_one "select name from contacts where id = ?", 18, &.read(String)
|
||||
# ```
|
||||
def query_one(query, *args, &block : ResultSet -> U) : U forall U
|
||||
query(query, *args) do |rs|
|
||||
def query_one(query, *args_, args : Array? = nil, &block : ResultSet -> U) : U forall U
|
||||
query(query, *args_, args: args) do |rs|
|
||||
raise DB::Error.new("no rows") unless rs.move_next
|
||||
|
||||
value = yield rs
|
||||
|
@ -82,8 +90,8 @@ module DB
|
|||
# ```
|
||||
# db.query_one "select name, age from contacts where id = ?", 1, as: {String, Int32}
|
||||
# ```
|
||||
def query_one(query, *args, as types : Tuple)
|
||||
query_one(query, *args) do |rs|
|
||||
def query_one(query, *args_, args : Array? = nil, as types : Tuple)
|
||||
query_one(query, *args_, args: args) do |rs|
|
||||
rs.read(*types)
|
||||
end
|
||||
end
|
||||
|
@ -97,8 +105,8 @@ module DB
|
|||
# ```
|
||||
# db.query_one "select name, age from contacts where id = ?", 1, as: {name: String, age: Int32}
|
||||
# ```
|
||||
def query_one(query, *args, as types : NamedTuple)
|
||||
query_one(query, *args) do |rs|
|
||||
def query_one(query, *args_, args : Array? = nil, as types : NamedTuple)
|
||||
query_one(query, *args_, args: args) do |rs|
|
||||
rs.read(**types)
|
||||
end
|
||||
end
|
||||
|
@ -111,8 +119,8 @@ module DB
|
|||
# ```
|
||||
# db.query_one "select name from contacts where id = ?", 1, as: String
|
||||
# ```
|
||||
def query_one(query, *args, as type : Class)
|
||||
query_one(query, *args) do |rs|
|
||||
def query_one(query, *args_, args : Array? = nil, as type : Class)
|
||||
query_one(query, *args_, args: args) do |rs|
|
||||
rs.read(type)
|
||||
end
|
||||
end
|
||||
|
@ -129,8 +137,8 @@ module DB
|
|||
# name = db.query_one? "select name from contacts where id = ?", 18, &.read(String)
|
||||
# typeof(name) # => String | Nil
|
||||
# ```
|
||||
def query_one?(query, *args, &block : ResultSet -> U) : U? forall U
|
||||
query(query, *args) do |rs|
|
||||
def query_one?(query, *args_, args : Array? = nil, &block : ResultSet -> U) : U? forall U
|
||||
query(query, *args_, args: args) do |rs|
|
||||
return nil unless rs.move_next
|
||||
|
||||
value = yield rs
|
||||
|
@ -150,8 +158,8 @@ module DB
|
|||
# result = db.query_one? "select name, age from contacts where id = ?", 1, as: {String, Int32}
|
||||
# typeof(result) # => Tuple(String, Int32) | Nil
|
||||
# ```
|
||||
def query_one?(query, *args, as types : Tuple)
|
||||
query_one?(query, *args) do |rs|
|
||||
def query_one?(query, *args_, args : Array? = nil, as types : Tuple)
|
||||
query_one?(query, *args_, args: args) do |rs|
|
||||
rs.read(*types)
|
||||
end
|
||||
end
|
||||
|
@ -168,8 +176,8 @@ module DB
|
|||
# result = db.query_one? "select name, age from contacts where id = ?", 1, as: {age: String, name: Int32}
|
||||
# typeof(result) # => NamedTuple(age: String, name: Int32) | Nil
|
||||
# ```
|
||||
def query_one?(query, *args, as types : NamedTuple)
|
||||
query_one?(query, *args) do |rs|
|
||||
def query_one?(query, *args_, args : Array? = nil, as types : NamedTuple)
|
||||
query_one?(query, *args_, args: args) do |rs|
|
||||
rs.read(**types)
|
||||
end
|
||||
end
|
||||
|
@ -185,8 +193,8 @@ module DB
|
|||
# name = db.query_one? "select name from contacts where id = ?", 1, as: String
|
||||
# typeof(name) # => String?
|
||||
# ```
|
||||
def query_one?(query, *args, as type : Class)
|
||||
query_one?(query, *args) do |rs|
|
||||
def query_one?(query, *args_, args : Array? = nil, as type : Class)
|
||||
query_one?(query, *args_, args: args) do |rs|
|
||||
rs.read(type)
|
||||
end
|
||||
end
|
||||
|
@ -197,9 +205,9 @@ module DB
|
|||
# ```
|
||||
# names = db.query_all "select name from contacts", &.read(String)
|
||||
# ```
|
||||
def query_all(query, *args, &block : ResultSet -> U) : Array(U) forall U
|
||||
def query_all(query, *args_, args : Array? = nil, &block : ResultSet -> U) : Array(U) forall U
|
||||
ary = [] of U
|
||||
query_each(query, *args) do |rs|
|
||||
query_each(query, *args_, args: args) do |rs|
|
||||
ary.push(yield rs)
|
||||
end
|
||||
ary
|
||||
|
@ -211,8 +219,8 @@ module DB
|
|||
# ```
|
||||
# contacts = db.query_all "select name, age from contacts", as: {String, Int32}
|
||||
# ```
|
||||
def query_all(query, *args, as types : Tuple)
|
||||
query_all(query, *args) do |rs|
|
||||
def query_all(query, *args_, args : Array? = nil, as types : Tuple)
|
||||
query_all(query, *args_, args: args) do |rs|
|
||||
rs.read(*types)
|
||||
end
|
||||
end
|
||||
|
@ -224,8 +232,8 @@ module DB
|
|||
# ```
|
||||
# contacts = db.query_all "select name, age from contacts", as: {name: String, age: Int32}
|
||||
# ```
|
||||
def query_all(query, *args, as types : NamedTuple)
|
||||
query_all(query, *args) do |rs|
|
||||
def query_all(query, *args_, args : Array? = nil, as types : NamedTuple)
|
||||
query_all(query, *args_, args: args) do |rs|
|
||||
rs.read(**types)
|
||||
end
|
||||
end
|
||||
|
@ -236,8 +244,8 @@ module DB
|
|||
# ```
|
||||
# names = db.query_all "select name from contacts", as: String
|
||||
# ```
|
||||
def query_all(query, *args, as type : Class)
|
||||
query_all(query, *args) do |rs|
|
||||
def query_all(query, *args_, args : Array? = nil, as type : Class)
|
||||
query_all(query, *args_, args: args) do |rs|
|
||||
rs.read(type)
|
||||
end
|
||||
end
|
||||
|
@ -250,8 +258,8 @@ module DB
|
|||
# puts rs.read(String)
|
||||
# end
|
||||
# ```
|
||||
def query_each(query, *args)
|
||||
query(query, *args) do |rs|
|
||||
def query_each(query, *args_, args : Array? = nil)
|
||||
query(query, *args_, args: args) do |rs|
|
||||
rs.each do
|
||||
yield rs
|
||||
end
|
||||
|
@ -259,8 +267,8 @@ module DB
|
|||
end
|
||||
|
||||
# Performs the `query` and returns an `ExecResult`
|
||||
def exec(query, *args)
|
||||
build(query).exec(*args)
|
||||
def exec(query, *args_, args : Array? = nil)
|
||||
build(query).exec(*args_, args: args)
|
||||
end
|
||||
|
||||
# Performs the `query` and returns a single scalar value
|
||||
|
@ -268,8 +276,8 @@ module DB
|
|||
# ```
|
||||
# puts db.scalar("SELECT MAX(name)").as(String) # => (a String)
|
||||
# ```
|
||||
def scalar(query, *args)
|
||||
build(query).scalar(*args)
|
||||
def scalar(query, *args_, args : Array? = nil)
|
||||
build(query).scalar(*args_, args: args)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -8,8 +8,8 @@ module DB
|
|||
end
|
||||
|
||||
# See `QueryMethods#scalar`
|
||||
def scalar(*args)
|
||||
query(*args) do |rs|
|
||||
def scalar(*args_, args : Array? = nil)
|
||||
query(*args_, args: args) do |rs|
|
||||
rs.each do
|
||||
return rs.read
|
||||
end
|
||||
|
@ -19,24 +19,20 @@ module DB
|
|||
end
|
||||
|
||||
# See `QueryMethods#query`
|
||||
def query(*args)
|
||||
rs = query(*args)
|
||||
def query(*args_, args : Array? = nil)
|
||||
rs = query(*args_, args: args)
|
||||
yield rs ensure rs.close
|
||||
end
|
||||
|
||||
# See `QueryMethods#exec`
|
||||
abstract def exec : ExecResult
|
||||
# See `QueryMethods#exec`
|
||||
abstract def exec(*args) : ExecResult
|
||||
# See `QueryMethods#exec`
|
||||
abstract def exec(args : Array) : ExecResult
|
||||
abstract def exec(*args_, args : Array? = nil) : ExecResult
|
||||
|
||||
# See `QueryMethods#query`
|
||||
abstract def query : ResultSet
|
||||
# See `QueryMethods#query`
|
||||
abstract def query(*args) : ResultSet
|
||||
# See `QueryMethods#query`
|
||||
abstract def query(args : Array) : ResultSet
|
||||
abstract def query(*args_, args : Array? = nil) : ResultSet
|
||||
end
|
||||
|
||||
# Represents a query in a `Connection`.
|
||||
|
@ -68,14 +64,8 @@ module DB
|
|||
end
|
||||
|
||||
# See `QueryMethods#exec`
|
||||
def exec(args : Array) : DB::ExecResult
|
||||
perform_exec_and_release(args)
|
||||
end
|
||||
|
||||
# See `QueryMethods#exec`
|
||||
def exec(*args)
|
||||
# TODO better way to do it
|
||||
perform_exec_and_release(args)
|
||||
def exec(*args_, args : Array? = nil) : DB::ExecResult
|
||||
perform_exec_and_release(EnumerableConcat.build(args_, args))
|
||||
end
|
||||
|
||||
# See `QueryMethods#query`
|
||||
|
@ -84,13 +74,8 @@ module DB
|
|||
end
|
||||
|
||||
# See `QueryMethods#query`
|
||||
def query(args : Array) : DB::ResultSet
|
||||
perform_query_with_rescue args
|
||||
end
|
||||
|
||||
# See `QueryMethods#query`
|
||||
def query(*args)
|
||||
perform_query_with_rescue args
|
||||
def query(*args_, args : Array? = nil) : DB::ResultSet
|
||||
perform_query_with_rescue(EnumerableConcat.build(args_, args))
|
||||
end
|
||||
|
||||
private def perform_exec_and_release(args : Enumerable) : ExecResult
|
||||
|
|
|
@ -154,7 +154,7 @@ module DB
|
|||
end
|
||||
|
||||
it "executes with bind #{value_desc} as array" do |db|
|
||||
db.scalar(select_scalar(param(1), sql_type), [value]).should eq(value)
|
||||
db.scalar(select_scalar(param(1), sql_type), args: [value]).should eq(value)
|
||||
end
|
||||
|
||||
it "select #{value_desc} as literal" do |db|
|
||||
|
|
Loading…
Reference in a new issue