diff --git a/spec/custom_drivers_types_spec.cr b/spec/custom_drivers_types_spec.cr index 9f028d3..6adc60c 100644 --- a/spec/custom_drivers_types_spec.cr +++ b/spec/custom_drivers_types_spec.cr @@ -68,7 +68,7 @@ class FooDriver < DB::Driver end class FooStatement < DB::Statement - protected def perform_query(args : Slice(DB::Any)) : DB::ResultSet + protected def perform_query(args : Enumerable) : DB::ResultSet GenericResultSet(Any).new(self, FooDriver.fake_row) end @@ -110,7 +110,7 @@ class BarDriver < DB::Driver end class BarStatement < DB::Statement - protected def perform_query(args : Slice(DB::Any)) : DB::ResultSet + protected def perform_query(args : Enumerable) : DB::ResultSet GenericResultSet(Any).new(self, BarDriver.fake_row) end @@ -157,4 +157,26 @@ describe DB do end end end + + it "allow custom types to be used as arguments for query" do + DB.open("foo://host") do |db| + FooDriver.fake_row = [1, "string"] of FooDriver::Any + db.query "query" { } + db.query "query", 1 { } + db.query "query", 1, "string" { } + db.query("query", Slice(UInt8).new(4)) { } + db.query("query", 1, "string", FooValue.new(5)) { } + db.query "query", [1, "string", FooValue.new(5)] { } + end + + DB.open("bar://host") do |db| + BarDriver.fake_row = [1, "string"] of BarDriver::Any + db.query "query" { } + db.query "query", 1 { } + db.query "query", 1, "string" { } + db.query("query", Slice(UInt8).new(4)) { } + db.query("query", 1, "string", BarValue.new(5)) { } + db.query "query", [1, "string", FooValue.new(5)] { } + end + end end diff --git a/spec/dummy_driver.cr b/spec/dummy_driver.cr index cd59642..8872cae 100644 --- a/spec/dummy_driver.cr +++ b/spec/dummy_driver.cr @@ -42,7 +42,7 @@ class DummyDriver < DB::Driver super(connection) end - protected def perform_query(args : Slice(DB::Any)) + protected def perform_query(args : Enumerable) set_params args DummyResultSet.new self, @query end diff --git a/src/db/statement.cr b/src/db/statement.cr index 82e7fbc..0ad66ea 100644 --- a/src/db/statement.cr +++ b/src/db/statement.cr @@ -71,13 +71,13 @@ module DB end # See `QueryMethods#query` - def query(*args) - perform_query *args + def query + perform_query Slice(Any).new(0) end # See `QueryMethods#query` - def query(*args) - perform_query(*args).tap do |rs| + def query + perform_query(Slice(Any).new(0)).tap do |rs| begin yield rs ensure @@ -86,18 +86,36 @@ module DB end end - private def perform_query : ResultSet - perform_query(Slice(Any).new(0)) # no overload matches ... with types Slice(NoReturn) + # See `QueryMethods#query` + def query(args : Array) + perform_query args end - private def perform_query(args : Enumerable(Any)) : ResultSet - # TODO better way to do it - perform_query(args.to_a.to_unsafe.to_slice(args.size)) + # See `QueryMethods#query` + def query(args : Array) + perform_query(args).tap do |rs| + begin + yield rs + ensure + rs.close + end + end end - private def perform_query(*args) : ResultSet - # TODO better way to do it - perform_query(args.to_a.to_unsafe.to_slice(args.size)) + # See `QueryMethods#query` + def query(*args) + perform_query args + end + + # See `QueryMethods#query` + def query(*args) + perform_query(args).tap do |rs| + begin + yield rs + ensure + rs.close + end + end end private def perform_exec_and_release(args : Slice(Any)) : ExecResult @@ -106,7 +124,7 @@ module DB end end - protected abstract def perform_query(args : Slice(Any)) : ResultSet + protected abstract def perform_query(args : Enumerable) : ResultSet protected abstract def perform_exec(args : Slice(Any)) : ExecResult end end