diff --git a/spec/custom_drivers_types_spec.cr b/spec/custom_drivers_types_spec.cr index 5655dac..c372ad3 100644 --- a/spec/custom_drivers_types_spec.cr +++ b/spec/custom_drivers_types_spec.cr @@ -20,6 +20,10 @@ module GenericResultSet @index += 1 @row[@index - 1] end + + def next_column_index : Int32 + @index + end end class FooValue @@ -197,7 +201,7 @@ describe DB do FooDriver.fake_row = [1] of FooDriver::Any db.query "query" do |rs| rs.move_next - expect_raises(Exception, "FooResultSet#read returned a Int32. A BarValue was expected.") do + expect_raises(DB::ColumnTypeMismatchError, "In FooDriver::FooResultSet#read the column 0 returned a Int32 but a BarValue was expected.") do w.check rs.read(BarValue) end @@ -210,7 +214,7 @@ describe DB do BarDriver.fake_row = [1] of BarDriver::Any db.query "query" do |rs| rs.move_next - expect_raises(Exception, "BarResultSet#read returned a Int32. A FooValue was expected.") do + expect_raises(DB::ColumnTypeMismatchError, "In BarDriver::BarResultSet#read the column 0 returned a Int32 but a FooValue was expected.") do w.check rs.read(FooValue) end diff --git a/spec/dummy_driver.cr b/spec/dummy_driver.cr index 6cdbead..aa40790 100644 --- a/spec/dummy_driver.cr +++ b/spec/dummy_driver.cr @@ -190,6 +190,10 @@ class DummyDriver < DB::Driver return n end + def next_column_index : Int32 + @column_count - @values.not_nil!.size + end + def read(t : String.class) read.to_s end diff --git a/src/db/error.cr b/src/db/error.cr index d3c5862..6e072c9 100644 --- a/src/db/error.cr +++ b/src/db/error.cr @@ -46,4 +46,17 @@ module DB # Raised when a scalar query returns no results. class NoResultsError < Error end + + # Raised when the type returned for the column value + # does not match the type expected. + class ColumnTypeMismatchError < Error + getter column_index : Int32 + getter column_name : String + getter column_type : String + getter expected_type : String + + def initialize(*, context : String, @column_index : Int32, @column_name : String, @column_type : String, @expected_type : String) + super("In #{context} the column #{column_name} returned a #{column_type} but a #{expected_type} was expected.") + end + end end diff --git a/src/db/result_set.cr b/src/db/result_set.cr index b2bd722..7520251 100644 --- a/src/db/result_set.cr +++ b/src/db/result_set.cr @@ -69,6 +69,11 @@ module DB # Reads the next column value abstract def read + # Returns the column index that corresponds to the next `#read`. + # + # If the last column of the current row has been read, it must return `#column_count`. + abstract def next_column_index : Int32 + # Reads the next columns and maps them to a class def read(type : DB::Mappable.class) type.new(self) @@ -76,11 +81,18 @@ module DB # Reads the next column value as a **type** def read(type : T.class) : T forall T + col_index = next_column_index value = read if value.is_a?(T) value else - raise "#{self.class}#read returned a #{value.class}. A #{T} was expected." + raise DB::ColumnTypeMismatchError.new( + context: "#{self.class}#read", + column_index: col_index, + column_name: column_name(col_index), + column_type: value.class.to_s, + expected_type: T.to_s + ) end end diff --git a/src/spec.cr b/src/spec.cr index 225a419..c69b43c 100644 --- a/src/spec.cr +++ b/src/spec.cr @@ -289,6 +289,62 @@ module DB ages.should eq([10, 20, 30]) end + it "next_column_index" do |db| + db.exec sql_create_table_person + db.exec sql_insert_person, "foo", 10 + db.exec sql_insert_person, "bar", 20 + + db.query sql_select_person do |rs| + rs.move_next + rs.next_column_index.should eq(0) + rs.read(String) + rs.next_column_index.should eq(1) + rs.read(Int32) + rs.next_column_index.should eq(2) + + rs.move_next + rs.next_column_index.should eq(0) + rs.read(String) + rs.next_column_index.should eq(1) + rs.read(Int32) + rs.next_column_index.should eq(2) + end + end + + it "next_column_index when ColumnTypeMismatchError" do |db| + db.exec sql_create_table_person + db.exec sql_insert_person, "foo", 10 + db.exec sql_insert_person, "bar", 20 + + db.query sql_select_person do |rs| + rs.move_next + rs.next_column_index.should eq(0) + ex = expect_raises(ColumnTypeMismatchError) { rs.read(Int32) } + ex.column_index.should eq(0) + ex.column_name.should eq("name") + # NOTE: sqlite currently returns Int64 due to how Int32 is implemented + ex.column_type.should match(/String/) + # NOTE: pg currently returns Slice(UInt8) | String due to how String is implemented + ex.expected_type.should match(/Int/) + rs.next_column_index.should eq(1) + ex = expect_raises(ColumnTypeMismatchError) { rs.read(String) } + ex.column_index.should eq(1) + ex.column_name.should eq("age") + # NOTE: sqlite returns Int64 + ex.column_type.should match(/Int/) + # NOTE: pg currently returns Slice(UInt8) | String due to how String is implemented + ex.expected_type.should match(/String/) + rs.next_column_index.should eq(2) + + rs.move_next + rs.next_column_index.should eq(0) + expect_raises(ColumnTypeMismatchError) { rs.read(Int32) } + rs.next_column_index.should eq(1) + expect_raises(ColumnTypeMismatchError) { rs.read(String) } + rs.next_column_index.should eq(2) + end + end + # describe "transactions" do it "transactions: can read inside transaction and rollback after" do |db| db.exec sql_create_table_person