diff --git a/spec/custom_drivers_types_spec.cr b/spec/custom_drivers_types_spec.cr index 4524b68..42153ff 100644 --- a/spec/custom_drivers_types_spec.cr +++ b/spec/custom_drivers_types_spec.cr @@ -1,10 +1,7 @@ require "./spec_helper" -class GenericResultSet(T) < DB::ResultSet - def initialize(statement, @row : Array(T)) - super(statement) - @index = 0 - end +module GenericResultSet + @index = 0 def move_next @index = 0 @@ -26,11 +23,11 @@ class GenericResultSet(T) < DB::ResultSet {% for t in DB::TYPES %} # Reads the next column as a nillable {{t}}. def read?(t : {{t}}.class) : {{t}}? - read_object as {{t}}? + read_and_move_next_column as {{t}}? end {% end %} - def read_object + def read_and_move_next_column @index += 1 @row[@index - 1] end @@ -69,13 +66,25 @@ class FooDriver < DB::Driver class FooStatement < DB::Statement protected def perform_query(args : Enumerable) : DB::ResultSet - GenericResultSet(Any).new(self, FooDriver.fake_row) + FooResultSet.new(self, FooDriver.fake_row) end protected def perform_exec(args : Enumerable) : DB::ExecResult DB::ExecResult.new 0, 0i64 end end + + class FooResultSet < DB::ResultSet + include GenericResultSet + + def initialize(statement, @row : Array(FooDriver::Any)) + super(statement) + end + + def read?(t : FooValue.class) : FooValue? + read_and_move_next_column.as(FooValue?) + end + end end DB.register_driver "foo", FooDriver @@ -111,13 +120,25 @@ class BarDriver < DB::Driver class BarStatement < DB::Statement protected def perform_query(args : Enumerable) : DB::ResultSet - GenericResultSet(Any).new(self, BarDriver.fake_row) + BarResultSet.new(self, BarDriver.fake_row) end protected def perform_exec(args : Enumerable) : DB::ExecResult DB::ExecResult.new 0, 0i64 end end + + class BarResultSet < DB::ResultSet + include GenericResultSet + + def initialize(statement, @row : Array(BarDriver::Any)) + super(statement) + end + + def read?(t : BarValue.class) : BarValue? + read_and_move_next_column.as(BarValue?) + end + end end DB.register_driver "bar", BarDriver @@ -138,7 +159,7 @@ describe DB do rs.move_next rs.read?(Int32).should eq(1) rs.read?(String).should eq("string") - (rs.read_object.as(FooValue)).value.should eq(3) + rs.read(FooValue).value.should eq(3) end end end @@ -150,7 +171,7 @@ describe DB do db.query "query" do |rs| w.check rs.move_next - (rs.read_object.as(BarValue)).value.should eq(4) + rs.read(BarValue).value.should eq(4) rs.read?(String).should eq("lorem") rs.read?(Float64).should eq(1.0) end @@ -158,6 +179,34 @@ describe DB do end end + it "Foo and Bar drivers should not implement each other read" do + with_witness do |w| + DB.open("foo://host") do |db| + FooDriver.fake_row = [1] of FooDriver::Any + db.query "query" do |rs| + rs.move_next + expect_raises Exception, "read?(t : BarValue) is not implemented in FooDriver::FooResultSet" do + w.check + rs.read(BarValue) + end + end + end + end + + with_witness do |w| + DB.open("bar://host") do |db| + BarDriver.fake_row = [1] of BarDriver::Any + db.query "query" do |rs| + rs.move_next + expect_raises Exception, "read?(t : FooValue) is not implemented in BarDriver::BarResultSet" do + w.check + rs.read(FooValue) + end + end + end + end + end + it "allow custom types to be used as arguments for query" do DB.open("foo://host") do |db| FooDriver.fake_row = [1, "string"] of FooDriver::Any diff --git a/src/db/result_set.cr b/src/db/result_set.cr index ca7300d..012c99e 100644 --- a/src/db/result_set.cr +++ b/src/db/result_set.cr @@ -16,14 +16,10 @@ module DB # ### Note to implementors # # 1. Override `#move_next` to move to the next row. - # 2. Override `#read?(t)` for all `t` in `DB::TYPES`. - # 3. (Optional) Override `#read(t)` for all `t` in `DB::TYPES`. + # 2. Override `#read?(t)` for all `t` in `DB::TYPES` and any other types the driver should handle. + # 3. (Optional) Override `#read(t)` for all `t` in `DB::TYPES` and any other. # 4. Override `#column_count`, `#column_name`. # 5. Override `#column_type`. It must return a type in `DB::TYPES`. - # 6. Override `#read_object` to return other data types not included in `DB::TYPES`. This - # will create a union type, so user will be forced to cast result type. Usually `#read` - # should be used to avoid unnecesary intermediate union type values. Calling `#read_object` - # should also move to the next column. abstract class ResultSet include Disposable @@ -63,28 +59,26 @@ module DB # The result is one of `DB::TYPES`. abstract def column_type(index : Int32) - # list datatypes that must be supported form the driver - # users will call read(String) or read?(String) for nillables - - {% for t in DB::TYPES %} - # Reads the next column as a nillable {{t}}. - abstract def read?(t : {{t}}.class) : {{t}}? - - # Reads the next column as a {{t}}. - def read(t : {{t}}.class) : {{t}} - read?({{t}}).not_nil! - end - {% end %} + def read(t) + read?(t).not_nil! + end # Reads the next column as a Nil. def read(t : Nil.class) : Nil read?(Nil) end - def read_object - raise "Not implemented" + def read?(t) + raise "read?(t : #{t}) is not implemented in #{self.class}" end + # list datatypes that must be supported form the driver + # users will call read(String) or read?(String) for nillables + {% for t in DB::TYPES %} + # Reads the next column as a nillable {{t}}. + abstract def read?(t : {{t}}.class) : {{t}}? + {% end %} + # def read_blob # yield ... io .... # end