diff --git a/spec/custom_drivers_types_spec.cr b/spec/custom_drivers_types_spec.cr index 21d1668..0b25a6a 100644 --- a/spec/custom_drivers_types_spec.cr +++ b/spec/custom_drivers_types_spec.cr @@ -227,14 +227,14 @@ describe DB do db.query "query", 1, "string" { } db.query("query", Bytes.new(4)) { } db.query("query", 1, "string", FooValue.new(5)) { } - db.query "query", [1, "string", FooValue.new(5)] { } + db.query "query", args: [1, "string", FooValue.new(5)] { } db.query("query").close db.query("query", 1).close db.query("query", 1, "string").close db.query("query", Bytes.new(4)).close db.query("query", 1, "string", FooValue.new(5)).close - db.query("query", [1, "string", FooValue.new(5)]).close + db.query("query", args: [1, "string", FooValue.new(5)]).close end DB.open("bar://host") do |db| @@ -244,14 +244,14 @@ describe DB do db.query "query", 1, "string" { } db.query("query", Bytes.new(4)) { } db.query("query", 1, "string", BarValue.new(5)) { } - db.query "query", [1, "string", BarValue.new(5)] { } + db.query "query", args: [1, "string", BarValue.new(5)] { } db.query("query").close db.query("query", 1).close db.query("query", 1, "string").close db.query("query", Bytes.new(4)).close db.query("query", 1, "string", BarValue.new(5)).close - db.query("query", [1, "string", BarValue.new(5)]).close + db.query("query", args: [1, "string", BarValue.new(5)]).close end end @@ -263,7 +263,7 @@ describe DB do db.exec("query", 1, "string") db.exec("query", Bytes.new(4)) db.exec("query", 1, "string", FooValue.new(5)) - db.exec("query", [1, "string", FooValue.new(5)]) + db.exec("query", args: [1, "string", FooValue.new(5)]) end DB.open("bar://host") do |db| @@ -273,20 +273,20 @@ describe DB do db.exec("query", 1, "string") db.exec("query", Bytes.new(4)) db.exec("query", 1, "string", BarValue.new(5)) - db.exec("query", [1, "string", BarValue.new(5)]) + db.exec("query", args: [1, "string", BarValue.new(5)]) end end it "Foo and Bar drivers should not implement each other params" do DB.open("foo://host") do |db| expect_raises Exception, "FooDriver::FooStatement does not support BarValue params" do - db.exec("query", [BarValue.new(5)]) + db.exec("query", args: [BarValue.new(5)]) end end DB.open("bar://host") do |db| expect_raises Exception, "BarDriver::BarStatement does not support FooValue params" do - db.exec("query", [FooValue.new(5)]) + db.exec("query", args: [FooValue.new(5)]) end end end diff --git a/spec/dummy_driver.cr b/spec/dummy_driver.cr index 830d371..0158770 100644 --- a/spec/dummy_driver.cr +++ b/spec/dummy_driver.cr @@ -97,7 +97,7 @@ class DummyDriver < DB::Driver property params def initialize(connection, @query : String, @prepared : Bool) - @params = Hash(Int32 | String, DB::Any).new + @params = Hash(Int32 | String, DB::Any | Array(DB::Any)).new super(connection) raise DB::Error.new(query) if query == "syntax error" end @@ -126,6 +126,10 @@ class DummyDriver < DB::Driver @params[index] = value end + private def set_param(index, value : Array) + @params[index] = value.map(&.as(DB::Any)) + end + private def set_param(index, value) raise "not implemented for #{value.class}" end diff --git a/spec/statement_spec.cr b/spec/statement_spec.cr index fcdd05a..f70937e 100644 --- a/spec/statement_spec.cr +++ b/spec/statement_spec.cr @@ -43,10 +43,37 @@ describe DB::Statement do end end - it "should initialize positional params in query with array" do + it "accepts array as single argument" do with_dummy_connection do |cnn| stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) stmt.query ["a", 1, nil] + stmt.params[0].should eq(["a", 1, nil]) + end + end + + it "allows no arguments" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.query + stmt.params.should be_empty + end + end + + it "concatenate arguments" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.query 1, 2, args: ["a", [1, nil]] + stmt.params[0].should eq(1) + stmt.params[1].should eq(2) + stmt.params[2].should eq("a") + stmt.params[3].should eq([1, nil]) + end + end + + it "should initialize positional params in query with array" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.query args: ["a", 1, nil] stmt.params[0].should eq("a") stmt.params[1].should eq(1) stmt.params[2].should eq(nil) @@ -63,16 +90,43 @@ describe DB::Statement do end end - it "should initialize positional params in exec with array" do + it "accepts array as single argument" do with_dummy_connection do |cnn| stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) stmt.exec ["a", 1, nil] + stmt.params[0].should eq(["a", 1, nil]) + end + end + + it "should initialize positional params in exec with array" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.exec args: ["a", 1, nil] stmt.params[0].should eq("a") stmt.params[1].should eq(1) stmt.params[2].should eq(nil) end end + it "allows no arguments" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.exec + stmt.params.should be_empty + end + end + + it "concatenate arguments" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.exec 1, 2, args: ["a", [1, nil]] + stmt.params[0].should eq(1) + stmt.params[1].should eq(2) + stmt.params[2].should eq("a") + stmt.params[3].should eq([1, nil]) + end + end + it "should initialize positional params in scalar" do with_dummy_connection do |cnn| stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) diff --git a/src/db.cr b/src/db.cr index 915e7b8..bbe0e3d 100644 --- a/src/db.cr +++ b/src/db.cr @@ -181,6 +181,7 @@ end require "./db/pool" require "./db/string_key_cache" +require "./db/enumerable_concat" require "./db/query_methods" require "./db/session_methods" require "./db/disposable" diff --git a/src/db/enumerable_concat.cr b/src/db/enumerable_concat.cr new file mode 100644 index 0000000..3f70b0d --- /dev/null +++ b/src/db/enumerable_concat.cr @@ -0,0 +1,38 @@ +module DB + # :nodoc: + struct EnumerableConcat(S, T, U) + include Enumerable(S) + + def initialize(@e1 : T, @e2 : U) + end + + def each + if e1 = @e1 + @e1.each do |e| + yield e + end + end + if e2 = @e2 + e2.each do |e| + yield e + end + end + end + + # returns given `e1 : T` an `Enumerable(T')` and `e2 : U` an `Enumerable(U') | Nil` + # it retuns and `Enumerable(T' | U')` that enumerates the elements of `e1` + # and, later, the elements of `e2`. + def self.build(e1 : T, e2 : U) + return e1 if e2.nil? || e2.empty? + return e2 if e1.nil? || e1.empty? + EnumerableConcat(Union(typeof(sample(e1)), typeof(sample(e2))), T, U).new(e1, e2) + end + + private def self.sample(c : Enumerable?) + c.not_nil!.each do |e| + return e + end + raise "" + end + end +end diff --git a/src/db/pool_statement.cr b/src/db/pool_statement.cr index 668ce2b..f0fc2b2 100644 --- a/src/db/pool_statement.cr +++ b/src/db/pool_statement.cr @@ -15,13 +15,8 @@ module DB end # See `QueryMethods#exec` - def exec(*args) : ExecResult - statement_with_retry &.exec(*args) - end - - # See `QueryMethods#exec` - def exec(args : Array) : ExecResult - statement_with_retry &.exec(args) + def exec(*args_, args : Array? = nil) : ExecResult + statement_with_retry &.exec(*args_, args: args) end # See `QueryMethods#query` @@ -30,18 +25,13 @@ module DB end # See `QueryMethods#query` - def query(*args) : ResultSet - statement_with_retry &.query(*args) - end - - # See `QueryMethods#query` - def query(args : Array) : ResultSet - statement_with_retry &.query(args) + def query(*args_, args : Array? = nil) : ResultSet + statement_with_retry &.query(*args_, args: args) end # See `QueryMethods#scalar` - def scalar(*args) - statement_with_retry &.scalar(*args) + def scalar(*args_, args : Array? = nil) + statement_with_retry &.scalar(*args_, args: args) end # builds a statement over a real connection diff --git a/src/db/query_methods.cr b/src/db/query_methods.cr index 9676256..9016acc 100644 --- a/src/db/query_methods.cr +++ b/src/db/query_methods.cr @@ -7,10 +7,11 @@ module DB # 2. `#scalar` reads a single value of the response. A union of possible values is returned. # 3. `#query` returns a `ResultSet` that allows iteration over the rows in the response and column information. # - # Arguments can be passed by position + # Arguments can be passed by position or as an array. # # ``` # db.query("SELECT name FROM ... WHERE age > ?", age) + # db.query("SELECT name FROM ... WHERE age > ?", args: [age]) # ``` # # Convention of mapping how arguments are mapped to the query depends on each driver. @@ -34,8 +35,15 @@ module DB # result.close # end # ``` - def query(query, *args) - build(query).query(*args) + # + # Note: to use a dynamic list length of arguments use `args:` keyword argument. + # + # ``` + # result = db.query "select name from contacts where id = ?", args: [10] + # ``` + # + def query(query, *args_, args : Array? = nil) + build(query).query(*args_, args: args) end # Executes a *query* and yields a `ResultSet` with the results. @@ -48,9 +56,9 @@ module DB # end # end # ``` - def query(query, *args) + def query(query, *args_, args : Array? = nil) # CHECK build(query).query(*args, &block) - rs = query(query, *args) + rs = query(query, *args_, args: args) yield rs ensure rs.close end @@ -64,8 +72,8 @@ module DB # ``` # name = db.query_one "select name from contacts where id = ?", 18, &.read(String) # ``` - def query_one(query, *args, &block : ResultSet -> U) : U forall U - query(query, *args) do |rs| + def query_one(query, *args_, args : Array? = nil, &block : ResultSet -> U) : U forall U + query(query, *args_, args: args) do |rs| raise DB::Error.new("no rows") unless rs.move_next value = yield rs @@ -82,8 +90,8 @@ module DB # ``` # db.query_one "select name, age from contacts where id = ?", 1, as: {String, Int32} # ``` - def query_one(query, *args, as types : Tuple) - query_one(query, *args) do |rs| + def query_one(query, *args_, args : Array? = nil, as types : Tuple) + query_one(query, *args_, args: args) do |rs| rs.read(*types) end end @@ -97,8 +105,8 @@ module DB # ``` # db.query_one "select name, age from contacts where id = ?", 1, as: {name: String, age: Int32} # ``` - def query_one(query, *args, as types : NamedTuple) - query_one(query, *args) do |rs| + def query_one(query, *args_, args : Array? = nil, as types : NamedTuple) + query_one(query, *args_, args: args) do |rs| rs.read(**types) end end @@ -111,8 +119,8 @@ module DB # ``` # db.query_one "select name from contacts where id = ?", 1, as: String # ``` - def query_one(query, *args, as type : Class) - query_one(query, *args) do |rs| + def query_one(query, *args_, args : Array? = nil, as type : Class) + query_one(query, *args_, args: args) do |rs| rs.read(type) end end @@ -129,8 +137,8 @@ module DB # name = db.query_one? "select name from contacts where id = ?", 18, &.read(String) # typeof(name) # => String | Nil # ``` - def query_one?(query, *args, &block : ResultSet -> U) : U? forall U - query(query, *args) do |rs| + def query_one?(query, *args_, args : Array? = nil, &block : ResultSet -> U) : U? forall U + query(query, *args_, args: args) do |rs| return nil unless rs.move_next value = yield rs @@ -150,8 +158,8 @@ module DB # result = db.query_one? "select name, age from contacts where id = ?", 1, as: {String, Int32} # typeof(result) # => Tuple(String, Int32) | Nil # ``` - def query_one?(query, *args, as types : Tuple) - query_one?(query, *args) do |rs| + def query_one?(query, *args_, args : Array? = nil, as types : Tuple) + query_one?(query, *args_, args: args) do |rs| rs.read(*types) end end @@ -168,8 +176,8 @@ module DB # result = db.query_one? "select name, age from contacts where id = ?", 1, as: {age: String, name: Int32} # typeof(result) # => NamedTuple(age: String, name: Int32) | Nil # ``` - def query_one?(query, *args, as types : NamedTuple) - query_one?(query, *args) do |rs| + def query_one?(query, *args_, args : Array? = nil, as types : NamedTuple) + query_one?(query, *args_, args: args) do |rs| rs.read(**types) end end @@ -185,8 +193,8 @@ module DB # name = db.query_one? "select name from contacts where id = ?", 1, as: String # typeof(name) # => String? # ``` - def query_one?(query, *args, as type : Class) - query_one?(query, *args) do |rs| + def query_one?(query, *args_, args : Array? = nil, as type : Class) + query_one?(query, *args_, args: args) do |rs| rs.read(type) end end @@ -197,9 +205,9 @@ module DB # ``` # names = db.query_all "select name from contacts", &.read(String) # ``` - def query_all(query, *args, &block : ResultSet -> U) : Array(U) forall U + def query_all(query, *args_, args : Array? = nil, &block : ResultSet -> U) : Array(U) forall U ary = [] of U - query_each(query, *args) do |rs| + query_each(query, *args_, args: args) do |rs| ary.push(yield rs) end ary @@ -211,8 +219,8 @@ module DB # ``` # contacts = db.query_all "select name, age from contacts", as: {String, Int32} # ``` - def query_all(query, *args, as types : Tuple) - query_all(query, *args) do |rs| + def query_all(query, *args_, args : Array? = nil, as types : Tuple) + query_all(query, *args_, args: args) do |rs| rs.read(*types) end end @@ -224,8 +232,8 @@ module DB # ``` # contacts = db.query_all "select name, age from contacts", as: {name: String, age: Int32} # ``` - def query_all(query, *args, as types : NamedTuple) - query_all(query, *args) do |rs| + def query_all(query, *args_, args : Array? = nil, as types : NamedTuple) + query_all(query, *args_, args: args) do |rs| rs.read(**types) end end @@ -236,8 +244,8 @@ module DB # ``` # names = db.query_all "select name from contacts", as: String # ``` - def query_all(query, *args, as type : Class) - query_all(query, *args) do |rs| + def query_all(query, *args_, args : Array? = nil, as type : Class) + query_all(query, *args_, args: args) do |rs| rs.read(type) end end @@ -250,8 +258,8 @@ module DB # puts rs.read(String) # end # ``` - def query_each(query, *args) - query(query, *args) do |rs| + def query_each(query, *args_, args : Array? = nil) + query(query, *args_, args: args) do |rs| rs.each do yield rs end @@ -259,8 +267,8 @@ module DB end # Performs the `query` and returns an `ExecResult` - def exec(query, *args) - build(query).exec(*args) + def exec(query, *args_, args : Array? = nil) + build(query).exec(*args_, args: args) end # Performs the `query` and returns a single scalar value @@ -268,8 +276,8 @@ module DB # ``` # puts db.scalar("SELECT MAX(name)").as(String) # => (a String) # ``` - def scalar(query, *args) - build(query).scalar(*args) + def scalar(query, *args_, args : Array? = nil) + build(query).scalar(*args_, args: args) end end end diff --git a/src/db/statement.cr b/src/db/statement.cr index 0ab4be3..5c3f5e3 100644 --- a/src/db/statement.cr +++ b/src/db/statement.cr @@ -8,8 +8,8 @@ module DB end # See `QueryMethods#scalar` - def scalar(*args) - query(*args) do |rs| + def scalar(*args_, args : Array? = nil) + query(*args_, args: args) do |rs| rs.each do return rs.read end @@ -19,24 +19,20 @@ module DB end # See `QueryMethods#query` - def query(*args) - rs = query(*args) + def query(*args_, args : Array? = nil) + rs = query(*args_, args: args) yield rs ensure rs.close end # See `QueryMethods#exec` abstract def exec : ExecResult # See `QueryMethods#exec` - abstract def exec(*args) : ExecResult - # See `QueryMethods#exec` - abstract def exec(args : Array) : ExecResult + abstract def exec(*args_, args : Array? = nil) : ExecResult # See `QueryMethods#query` abstract def query : ResultSet # See `QueryMethods#query` - abstract def query(*args) : ResultSet - # See `QueryMethods#query` - abstract def query(args : Array) : ResultSet + abstract def query(*args_, args : Array? = nil) : ResultSet end # Represents a query in a `Connection`. @@ -68,14 +64,8 @@ module DB end # See `QueryMethods#exec` - def exec(args : Array) : DB::ExecResult - perform_exec_and_release(args) - end - - # See `QueryMethods#exec` - def exec(*args) - # TODO better way to do it - perform_exec_and_release(args) + def exec(*args_, args : Array? = nil) : DB::ExecResult + perform_exec_and_release(EnumerableConcat.build(args_, args)) end # See `QueryMethods#query` @@ -84,13 +74,8 @@ module DB end # See `QueryMethods#query` - def query(args : Array) : DB::ResultSet - perform_query_with_rescue args - end - - # See `QueryMethods#query` - def query(*args) - perform_query_with_rescue args + def query(*args_, args : Array? = nil) : DB::ResultSet + perform_query_with_rescue(EnumerableConcat.build(args_, args)) end private def perform_exec_and_release(args : Enumerable) : ExecResult diff --git a/src/spec.cr b/src/spec.cr index 65d12c3..b8c64c2 100644 --- a/src/spec.cr +++ b/src/spec.cr @@ -154,7 +154,7 @@ module DB end it "executes with bind #{value_desc} as array" do |db| - db.scalar(select_scalar(param(1), sql_type), [value]).should eq(value) + db.scalar(select_scalar(param(1), sql_type), args: [value]).should eq(value) end it "select #{value_desc} as literal" do |db|