diff --git a/spec/custom_drivers_types_spec.cr b/spec/custom_drivers_types_spec.cr new file mode 100644 index 0000000..2929d0f --- /dev/null +++ b/spec/custom_drivers_types_spec.cr @@ -0,0 +1,299 @@ +require "./spec_helper" + +module GenericResultSet + @index = 0 + + def move_next + @index = 0 + true + end + + def column_count : Int32 + @row.size + end + + def column_name(index : Int32) : String + index.to_s + end + + def column_type(index : Int32) + @row[index].class + end + + {% for t in DB::TYPES %} + # Reads the next column as a nillable {{t}}. + def read?(t : {{t}}.class) : {{t}}? + read_and_move_next_column as {{t}}? + end + {% end %} + + def read_and_move_next_column + @index += 1 + @row[@index - 1] + end +end + +class FooValue + def initialize(@value : Int32) + end + + def value + @value + end +end + +class FooDriver < DB::Driver + alias Any = DB::Any | FooValue + @@row = [] of Any + + def self.fake_row=(row : Array(Any)) + @@row = row + end + + def self.fake_row + @@row + end + + def build_connection(db : DB::Database) : DB::Connection + FooConnection.new(db) + end + + class FooConnection < DB::Connection + def build_statement(query) + FooStatement.new(self) + end + end + + class FooStatement < DB::Statement + protected def perform_query(args : Enumerable) : DB::ResultSet + args.each { |arg| process_arg arg } + FooResultSet.new(self, FooDriver.fake_row) + end + + protected def perform_exec(args : Enumerable) : DB::ExecResult + args.each { |arg| process_arg arg } + DB::ExecResult.new 0i64, 0i64 + end + + private def process_arg(value : FooDriver::Any) + end + + private def process_arg(value) + raise "#{self.class} does not support #{value.class} params" + end + end + + class FooResultSet < DB::ResultSet + include GenericResultSet + + def initialize(statement, @row : Array(FooDriver::Any)) + super(statement) + end + + def read?(t : FooValue.class) : FooValue? + read_and_move_next_column.as(FooValue?) + end + end +end + +DB.register_driver "foo", FooDriver + +class BarValue + getter value + + def initialize(@value : Int32) + end +end + +class BarDriver < DB::Driver + alias Any = DB::Any | BarValue + @@row = [] of Any + + def self.fake_row=(row : Array(Any)) + @@row = row + end + + def self.fake_row + @@row + end + + def build_connection(db : DB::Database) : DB::Connection + BarConnection.new(db) + end + + class BarConnection < DB::Connection + def build_statement(query) + BarStatement.new(self) + end + end + + class BarStatement < DB::Statement + protected def perform_query(args : Enumerable) : DB::ResultSet + args.each { |arg| process_arg arg } + BarResultSet.new(self, BarDriver.fake_row) + end + + protected def perform_exec(args : Enumerable) : DB::ExecResult + args.each { |arg| process_arg arg } + DB::ExecResult.new 0i64, 0i64 + end + + private def process_arg(value : BarDriver::Any) + end + + private def process_arg(value) + raise "#{self.class} does not support #{value.class} params" + end + end + + class BarResultSet < DB::ResultSet + include GenericResultSet + + def initialize(statement, @row : Array(BarDriver::Any)) + super(statement) + end + + def read?(t : BarValue.class) : BarValue? + read_and_move_next_column.as(BarValue?) + end + end +end + +DB.register_driver "bar", BarDriver + +describe DB do + it "should be able to register multiple drivers" do + DB.open("foo://host").driver.should be_a(FooDriver) + DB.open("bar://host").driver.should be_a(BarDriver) + end + + it "Foo and Bar drivers should return fake_row" do + with_witness do |w| + DB.open("foo://host") do |db| + # TODO somehow FooValue.new(99) is needed otherwise the read_object assertion fail + FooDriver.fake_row = [1, "string", FooValue.new(3), FooValue.new(99)] of FooDriver::Any + db.query "query" do |rs| + w.check + rs.move_next + rs.read?(Int32).should eq(1) + rs.read?(String).should eq("string") + rs.read(FooValue).value.should eq(3) + end + end + end + + with_witness do |w| + DB.open("bar://host") do |db| + # TODO somehow BarValue.new(99) is needed otherwise the read_object assertion fail + BarDriver.fake_row = [BarValue.new(4), "lorem", 1.0, BarValue.new(99)] of BarDriver::Any + db.query "query" do |rs| + w.check + rs.move_next + rs.read(BarValue).value.should eq(4) + rs.read?(String).should eq("lorem") + rs.read?(Float64).should eq(1.0) + end + end + end + end + + it "Foo and Bar drivers should not implement each other read" do + with_witness do |w| + DB.open("foo://host") do |db| + FooDriver.fake_row = [1] of FooDriver::Any + db.query "query" do |rs| + rs.move_next + expect_raises Exception, "read?(t : BarValue) is not implemented in FooDriver::FooResultSet" do + w.check + rs.read(BarValue) + end + end + end + end + + with_witness do |w| + DB.open("bar://host") do |db| + BarDriver.fake_row = [1] of BarDriver::Any + db.query "query" do |rs| + rs.move_next + expect_raises Exception, "read?(t : FooValue) is not implemented in BarDriver::BarResultSet" do + w.check + rs.read(FooValue) + end + end + end + end + end + + it "allow custom types to be used as arguments for query" do + DB.open("foo://host") do |db| + FooDriver.fake_row = [1, "string"] of FooDriver::Any + db.query "query" { } + db.query "query", 1 { } + db.query "query", 1, "string" { } + db.query("query", Bytes.new(4)) { } + db.query("query", 1, "string", FooValue.new(5)) { } + db.query "query", [1, "string", FooValue.new(5)] { } + + db.query("query").close + db.query("query", 1).close + db.query("query", 1, "string").close + db.query("query", Bytes.new(4)).close + db.query("query", 1, "string", FooValue.new(5)).close + db.query("query", [1, "string", FooValue.new(5)]).close + end + + DB.open("bar://host") do |db| + BarDriver.fake_row = [1, "string"] of BarDriver::Any + db.query "query" { } + db.query "query", 1 { } + db.query "query", 1, "string" { } + db.query("query", Bytes.new(4)) { } + db.query("query", 1, "string", BarValue.new(5)) { } + db.query "query", [1, "string", BarValue.new(5)] { } + + db.query("query").close + db.query("query", 1).close + db.query("query", 1, "string").close + db.query("query", Bytes.new(4)).close + db.query("query", 1, "string", BarValue.new(5)).close + db.query("query", [1, "string", BarValue.new(5)]).close + end + end + + it "allow custom types to be used as arguments for exec" do + DB.open("foo://host") do |db| + FooDriver.fake_row = [1, "string"] of FooDriver::Any + db.exec("query") + db.exec("query", 1) + db.exec("query", 1, "string") + db.exec("query", Bytes.new(4)) + db.exec("query", 1, "string", FooValue.new(5)) + db.exec("query", [1, "string", FooValue.new(5)]) + end + + DB.open("bar://host") do |db| + BarDriver.fake_row = [1, "string"] of BarDriver::Any + db.exec("query") + db.exec("query", 1) + db.exec("query", 1, "string") + db.exec("query", Bytes.new(4)) + db.exec("query", 1, "string", BarValue.new(5)) + db.exec("query", [1, "string", BarValue.new(5)]) + end + end + + it "Foo and Bar drivers should not implement each other params" do + DB.open("foo://host") do |db| + expect_raises Exception, "FooDriver::FooStatement does not support BarValue params" do + db.exec("query", [BarValue.new(5)]) + end + end + + DB.open("bar://host") do |db| + expect_raises Exception, "BarDriver::BarStatement does not support FooValue params" do + db.exec("query", [FooValue.new(5)]) + end + end + end +end diff --git a/spec/dummy_driver.cr b/spec/dummy_driver.cr index b57e5d4..9933604 100644 --- a/spec/dummy_driver.cr +++ b/spec/dummy_driver.cr @@ -42,23 +42,31 @@ class DummyDriver < DB::Driver super(connection) end - protected def perform_query(args : Slice(DB::Any)) + protected def perform_query(args : Enumerable) set_params args DummyResultSet.new self, @query end - protected def perform_exec(args : Slice(DB::Any)) + protected def perform_exec(args : Enumerable) set_params args - DB::ExecResult.new 0, 0_i64 + DB::ExecResult.new 0i64, 0_i64 end private def set_params(args) @params.clear args.each_with_index do |arg, index| - @params[index] = arg + set_param(index, arg) end end + private def set_param(index, value : DB::Any) + @params[index] = value + end + + private def set_param(index, value) + raise "not implemented for #{value.class}" + end + protected def do_close super end @@ -70,7 +78,7 @@ class DummyDriver < DB::Driver @values : Array(String)? @@last_result_set : self? - @@next_column_type : Nil.class | String.class | Int32.class | Int64.class | Float32.class | Float64.class | Slice(UInt8).class + @@next_column_type : Nil.class | String.class | Int32.class | Int64.class | Float32.class | Float64.class | Bytes.class def initialize(statement, query) super(statement) @@ -114,12 +122,16 @@ class DummyDriver < DB::Driver return nil if n == "NULL" if n == "?" - return @statement.params[0] + return (@statement.as(DummyStatement)).params[0] end return n end + def read?(t : Nil.class) + read?.as(Nil) + end + def read?(t : String.class) read?.try &.to_s end @@ -140,17 +152,17 @@ class DummyDriver < DB::Driver read?(String).try &.to_f64 end - def read?(t : Slice(UInt8).class) + def read?(t : Bytes.class) value = read? if value.is_a?(Nil) value elsif value.is_a?(String) ary = value.bytes Slice.new(ary.to_unsafe, ary.size) - elsif value.is_a?(Slice(UInt8)) + elsif value.is_a?(Bytes) value else - raise "#{value} is not convertible to Slice(UInt8)" + raise "#{value} is not convertible to Bytes" end end end diff --git a/spec/dummy_driver_spec.cr b/spec/dummy_driver_spec.cr index d9bf76d..3eb2220 100644 --- a/spec/dummy_driver_spec.cr +++ b/spec/dummy_driver_spec.cr @@ -101,9 +101,9 @@ describe DummyDriver do db.query("az,AZ") do |rs| rs.move_next ary = [97u8, 122u8] - rs.read(Slice(UInt8)).should eq(Slice.new(ary.to_unsafe, ary.size)) + rs.read(Bytes).should eq(Bytes.new(ary.to_unsafe, ary.size)) ary = [65u8, 90u8] - rs.read(Slice(UInt8)).should eq(Slice.new(ary.to_unsafe, ary.size)) + rs.read(Bytes).should eq(Bytes.new(ary.to_unsafe, ary.size)) end end end @@ -134,9 +134,9 @@ describe DummyDriver do it "executes and selects blob" do with_dummy do |db| ary = UInt8[0x53, 0x51, 0x4C] - slice = Slice.new(ary.to_unsafe, ary.size) + slice = Bytes.new(ary.to_unsafe, ary.size) DummyDriver::DummyResultSet.next_column_type = typeof(slice) - (db.scalar("?", slice) as Slice(UInt8)).to_a.should eq(ary) + (db.scalar("?", slice).as(Bytes)).to_a.should eq(ary) end end end diff --git a/spec/statement_spec.cr b/spec/statement_spec.cr index 5232621..d362ebf 100644 --- a/spec/statement_spec.cr +++ b/spec/statement_spec.cr @@ -9,7 +9,7 @@ describe DB::Statement do it "should initialize positional params in query" do with_dummy do |db| - stmt = db.prepare("the query") + stmt = db.prepare("the query").as(DummyDriver::DummyStatement) stmt.query "a", 1, nil stmt.params[0].should eq("a") stmt.params[1].should eq(1) @@ -19,7 +19,7 @@ describe DB::Statement do it "should initialize positional params in query with array" do with_dummy do |db| - stmt = db.prepare("the query") + stmt = db.prepare("the query").as(DummyDriver::DummyStatement) stmt.query ["a", 1, nil] stmt.params[0].should eq("a") stmt.params[1].should eq(1) @@ -29,7 +29,7 @@ describe DB::Statement do it "should initialize positional params in exec" do with_dummy do |db| - stmt = db.prepare("the query") + stmt = db.prepare("the query").as(DummyDriver::DummyStatement) stmt.exec "a", 1, nil stmt.params[0].should eq("a") stmt.params[1].should eq(1) @@ -39,7 +39,7 @@ describe DB::Statement do it "should initialize positional params in exec with array" do with_dummy do |db| - stmt = db.prepare("the query") + stmt = db.prepare("the query").as(DummyDriver::DummyStatement) stmt.exec ["a", 1, nil] stmt.params[0].should eq("a") stmt.params[1].should eq(1) @@ -49,7 +49,7 @@ describe DB::Statement do it "should initialize positional params in scalar" do with_dummy do |db| - stmt = db.prepare("the query") + stmt = db.prepare("the query").as(DummyDriver::DummyStatement) stmt.scalar "a", 1, nil stmt.params[0].should eq("a") stmt.params[1].should eq(1) diff --git a/src/db.cr b/src/db.cr index 65799fe..499f799 100644 --- a/src/db.cr +++ b/src/db.cr @@ -70,13 +70,15 @@ module DB # Types supported to interface with database driver. # These can be used in any `ResultSet#read` or any `Database#query` related # method to be used as query parameters - TYPES = [String, Int32, Int64, Float32, Float64, Slice(UInt8)] + TYPES = [Nil, String, Int32, Int64, Float32, Float64, Bytes] - # See `DB::TYPES` in `DB`. `Any` is a nillable version of the union of all types in `DB::TYPES` - alias Any = Nil | String | Int32 | Int64 | Float32 | Float64 | Slice(UInt8) + # See `DB::TYPES` in `DB`. `Any` is a union of all types in `DB::TYPES` + {% begin %} + alias Any = Union({{*TYPES}}) + {% end %} # Result of a `#exec` statement. - record ExecResult, rows_affected : Int32, last_insert_id : Int64 + record ExecResult, rows_affected : Int64, last_insert_id : Int64 # :nodoc: def self.driver_class(driver_name) : Driver.class diff --git a/src/db/query_methods.cr b/src/db/query_methods.cr index 085ee43..517eee9 100644 --- a/src/db/query_methods.cr +++ b/src/db/query_methods.cr @@ -40,7 +40,7 @@ module DB end end - # Performs the `query` discarding any response + # Performs the `query` and returns an `ExecResult` def exec(query, *args) prepare(query).exec(*args) end diff --git a/src/db/result_set.cr b/src/db/result_set.cr index 8dbaa54..012c99e 100644 --- a/src/db/result_set.cr +++ b/src/db/result_set.cr @@ -16,8 +16,8 @@ module DB # ### Note to implementors # # 1. Override `#move_next` to move to the next row. - # 2. Override `#read?(t)` for all `t` in `DB::TYPES`. - # 3. (Optional) Override `#read(t)` for all `t` in `DB::TYPES`. + # 2. Override `#read?(t)` for all `t` in `DB::TYPES` and any other types the driver should handle. + # 3. (Optional) Override `#read(t)` for all `t` in `DB::TYPES` and any other. # 4. Override `#column_count`, `#column_name`. # 5. Override `#column_type`. It must return a type in `DB::TYPES`. abstract class ResultSet @@ -59,17 +59,24 @@ module DB # The result is one of `DB::TYPES`. abstract def column_type(index : Int32) + def read(t) + read?(t).not_nil! + end + + # Reads the next column as a Nil. + def read(t : Nil.class) : Nil + read?(Nil) + end + + def read?(t) + raise "read?(t : #{t}) is not implemented in #{self.class}" + end + # list datatypes that must be supported form the driver # users will call read(String) or read?(String) for nillables - {% for t in DB::TYPES %} # Reads the next column as a nillable {{t}}. abstract def read?(t : {{t}}.class) : {{t}}? - - # Reads the next column as a {{t}}. - def read(t : {{t}}.class) : {{t}} - read?({{t}}).not_nil! - end {% end %} # def read_blob diff --git a/src/db/statement.cr b/src/db/statement.cr index 82e7fbc..0991019 100644 --- a/src/db/statement.cr +++ b/src/db/statement.cr @@ -27,18 +27,18 @@ module DB # See `QueryMethods#exec` def exec - perform_exec_and_release(Slice(Any).new(0)) # no overload matches ... with types Slice(NoReturn) + perform_exec_and_release(Slice(Any).new(0)) end # See `QueryMethods#exec` - def exec(args : Enumerable(Any)) - perform_exec_and_release(args.to_a.to_unsafe.to_slice(args.size)) + def exec(args : Array) + perform_exec_and_release(args) end # See `QueryMethods#exec` def exec(*args) # TODO better way to do it - perform_exec_and_release(args.to_a.to_unsafe.to_slice(args.size)) + perform_exec_and_release(args) end # See `QueryMethods#scalar` @@ -57,8 +57,8 @@ module DB return rs.read?(Float32) when Float64.class return rs.read?(Float64) - when Slice(UInt8).class - return rs.read?(Slice(UInt8)) + when Bytes.class + return rs.read?(Bytes) when Nil.class return rs.read?(Int32) else @@ -71,13 +71,23 @@ module DB end # See `QueryMethods#query` - def query(*args) - perform_query *args + def query + perform_query Slice(Any).new(0) + end + + # See `QueryMethods#query` + def query(args : Array) + perform_query args end # See `QueryMethods#query` def query(*args) - perform_query(*args).tap do |rs| + perform_query args + end + + # See `QueryMethods#query` + def query(*args) + query(*args).tap do |rs| begin yield rs ensure @@ -86,27 +96,13 @@ module DB end end - private def perform_query : ResultSet - perform_query(Slice(Any).new(0)) # no overload matches ... with types Slice(NoReturn) - end - - private def perform_query(args : Enumerable(Any)) : ResultSet - # TODO better way to do it - perform_query(args.to_a.to_unsafe.to_slice(args.size)) - end - - private def perform_query(*args) : ResultSet - # TODO better way to do it - perform_query(args.to_a.to_unsafe.to_slice(args.size)) - end - - private def perform_exec_and_release(args : Slice(Any)) : ExecResult + private def perform_exec_and_release(args : Enumerable) : ExecResult perform_exec(args).tap do release_connection end end - protected abstract def perform_query(args : Slice(Any)) : ResultSet - protected abstract def perform_exec(args : Slice(Any)) : ExecResult + protected abstract def perform_query(args : Enumerable) : ResultSet + protected abstract def perform_exec(args : Enumerable) : ExecResult end end