diff --git a/spec/statement_spec.cr b/spec/statement_spec.cr index a080b35..f70937e 100644 --- a/spec/statement_spec.cr +++ b/spec/statement_spec.cr @@ -51,6 +51,25 @@ describe DB::Statement do end end + it "allows no arguments" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.query + stmt.params.should be_empty + end + end + + it "concatenate arguments" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.query 1, 2, args: ["a", [1, nil]] + stmt.params[0].should eq(1) + stmt.params[1].should eq(2) + stmt.params[2].should eq("a") + stmt.params[3].should eq([1, nil]) + end + end + it "should initialize positional params in query with array" do with_dummy_connection do |cnn| stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) @@ -89,6 +108,25 @@ describe DB::Statement do end end + it "allows no arguments" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.exec + stmt.params.should be_empty + end + end + + it "concatenate arguments" do + with_dummy_connection do |cnn| + stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) + stmt.exec 1, 2, args: ["a", [1, nil]] + stmt.params[0].should eq(1) + stmt.params[1].should eq(2) + stmt.params[2].should eq("a") + stmt.params[3].should eq([1, nil]) + end + end + it "should initialize positional params in scalar" do with_dummy_connection do |cnn| stmt = cnn.prepared("the query").as(DummyDriver::DummyStatement) diff --git a/src/db.cr b/src/db.cr index 915e7b8..bbe0e3d 100644 --- a/src/db.cr +++ b/src/db.cr @@ -181,6 +181,7 @@ end require "./db/pool" require "./db/string_key_cache" +require "./db/enumerable_concat" require "./db/query_methods" require "./db/session_methods" require "./db/disposable" diff --git a/src/db/enumerable_concat.cr b/src/db/enumerable_concat.cr new file mode 100644 index 0000000..3f70b0d --- /dev/null +++ b/src/db/enumerable_concat.cr @@ -0,0 +1,38 @@ +module DB + # :nodoc: + struct EnumerableConcat(S, T, U) + include Enumerable(S) + + def initialize(@e1 : T, @e2 : U) + end + + def each + if e1 = @e1 + @e1.each do |e| + yield e + end + end + if e2 = @e2 + e2.each do |e| + yield e + end + end + end + + # returns given `e1 : T` an `Enumerable(T')` and `e2 : U` an `Enumerable(U') | Nil` + # it retuns and `Enumerable(T' | U')` that enumerates the elements of `e1` + # and, later, the elements of `e2`. + def self.build(e1 : T, e2 : U) + return e1 if e2.nil? || e2.empty? + return e2 if e1.nil? || e1.empty? + EnumerableConcat(Union(typeof(sample(e1)), typeof(sample(e2))), T, U).new(e1, e2) + end + + private def self.sample(c : Enumerable?) + c.not_nil!.each do |e| + return e + end + raise "" + end + end +end diff --git a/src/db/statement.cr b/src/db/statement.cr index 3eae1f6..b1a8173 100644 --- a/src/db/statement.cr +++ b/src/db/statement.cr @@ -65,7 +65,7 @@ module DB # See `QueryMethods#exec` def exec(*t_args, args : Array? = nil) : DB::ExecResult - perform_exec_and_release(args || t_args) + perform_exec_and_release(EnumerableConcat.build(t_args, args)) end # See `QueryMethods#query` @@ -75,7 +75,7 @@ module DB # See `QueryMethods#query` def query(*t_args, args : Array? = nil) : DB::ResultSet - perform_query_with_rescue(args || t_args) + perform_query_with_rescue(EnumerableConcat.build(t_args, args)) end private def perform_exec_and_release(args : Enumerable) : ExecResult