From 9c88f718e825130ef542c672b5c76e2c80a959de Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Tue, 28 Jun 2016 14:02:08 -0300 Subject: [PATCH] Require ResultSet to just implement `read`, optionally implementing `read(T.class)`. Fixes #5 --- spec/custom_drivers_types_spec.cr | 33 ++---- spec/dummy_driver.cr | 46 +++----- spec/dummy_driver_spec.cr | 128 +++++++++++++++++++--- src/db.cr | 1 + src/db/error.cr | 4 + src/db/query_methods.cr | 173 ++++++++++++++++++++++++++++-- src/db/result_set.cr | 38 ++++--- src/db/statement.cr | 2 +- 8 files changed, 329 insertions(+), 96 deletions(-) create mode 100644 src/db/error.cr diff --git a/spec/custom_drivers_types_spec.cr b/spec/custom_drivers_types_spec.cr index 16e88c3..f0fc8f1 100644 --- a/spec/custom_drivers_types_spec.cr +++ b/spec/custom_drivers_types_spec.cr @@ -16,18 +16,7 @@ module GenericResultSet index.to_s end - def column_type(index : Int32) - @row[index].class - end - - {% for t in DB::TYPES %} - # Reads the next column as a nillable {{t}}. - def read?(t : {{t}}.class) : {{t}}? - read_and_move_next_column as {{t}}? - end - {% end %} - - def read_and_move_next_column + def read @index += 1 @row[@index - 1] end @@ -89,10 +78,6 @@ class FooDriver < DB::Driver def initialize(statement, @row : Array(FooDriver::Any)) super(statement) end - - def read?(t : FooValue.class) : FooValue? - read_and_move_next_column.as(FooValue?) - end end end @@ -152,10 +137,6 @@ class BarDriver < DB::Driver def initialize(statement, @row : Array(BarDriver::Any)) super(statement) end - - def read?(t : BarValue.class) : BarValue? - read_and_move_next_column.as(BarValue?) - end end end @@ -174,8 +155,8 @@ describe DB do db.query "query" do |rs| w.check rs.move_next - rs.read?(Int32).should eq(1) - rs.read?(String).should eq("string") + rs.read(Int32).should eq(1) + rs.read(String).should eq("string") rs.read(FooValue).value.should eq(3) end end @@ -188,8 +169,8 @@ describe DB do w.check rs.move_next rs.read(BarValue).value.should eq(4) - rs.read?(String).should eq("lorem") - rs.read?(Float64).should eq(1.0) + rs.read(String).should eq("lorem") + rs.read(Float64).should eq(1.0) end end end @@ -208,7 +189,7 @@ describe DB do FooDriver.fake_row = [1] of FooDriver::Any db.query "query" do |rs| rs.move_next - expect_raises Exception, "read?(t : BarValue) is not implemented in FooDriver::FooResultSet" do + expect_raises(TypeCastError) do w.check rs.read(BarValue) end @@ -221,7 +202,7 @@ describe DB do BarDriver.fake_row = [1] of BarDriver::Any db.query "query" do |rs| rs.move_next - expect_raises Exception, "read?(t : FooValue) is not implemented in BarDriver::BarResultSet" do + expect_raises(TypeCastError) do w.check rs.read(FooValue) end diff --git a/spec/dummy_driver.cr b/spec/dummy_driver.cr index 9933604..9bf09f0 100644 --- a/spec/dummy_driver.cr +++ b/spec/dummy_driver.cr @@ -73,12 +73,10 @@ class DummyDriver < DB::Driver end class DummyResultSet < DB::ResultSet - @@next_column_type = String @top_values : Array(Array(String)) @values : Array(String)? @@last_result_set : self? - @@next_column_type : Nil.class | String.class | Int32.class | Int64.class | Float32.class | Float64.class | Bytes.class def initialize(statement, query) super(statement) @@ -108,15 +106,7 @@ class DummyDriver < DB::Driver "c#{index}" end - def column_type(index : Int32) - @@next_column_type - end - - def self.next_column_type=(value) - @@next_column_type = value - end - - private def read? : DB::Any? + def read n = @values.not_nil!.shift? raise "end of row" if n.is_a?(Nil) return nil if n == "NULL" @@ -128,38 +118,36 @@ class DummyDriver < DB::Driver return n end - def read?(t : Nil.class) - read?.as(Nil) + def read(t : String.class) + read.to_s end - def read?(t : String.class) - read?.try &.to_s + def read(t : String?.class) + read.try &.to_s end - def read?(t : Int32.class) - read?(String).try &.to_i32 + def read(t : Int32.class) + read(String).to_i32 end - def read?(t : Int64.class) - read?(String).try &.to_i64 + def read(t : Int64.class) + read(String).to_i64 end - def read?(t : Float32.class) - read?(String).try &.to_f32 + def read(t : Float32.class) + read(String).to_f32 end - def read?(t : Float64.class) - read?(String).try &.to_f64 + def read(t : Float64.class) + read(String).to_f64 end - def read?(t : Bytes.class) - value = read? - if value.is_a?(Nil) - value - elsif value.is_a?(String) + def read(t : Bytes.class) + case value = read + when String ary = value.bytes Slice.new(ary.to_unsafe, ary.size) - elsif value.is_a?(Bytes) + when Bytes value else raise "#{value} is not convertible to Bytes" diff --git a/spec/dummy_driver_spec.cr b/spec/dummy_driver_spec.cr index 3eb2220..328b7c8 100644 --- a/spec/dummy_driver_spec.cr +++ b/spec/dummy_driver_spec.cr @@ -74,11 +74,11 @@ describe DummyDriver do with_dummy do |db| db.query "a,NULL 1,NULL" do |rs| rs.move_next - rs.read?(String).should eq("a") - rs.read?(String).should be_nil + rs.read(String).should eq("a") + rs.read(String | Nil).should be_nil rs.move_next - rs.read?(Int64).should eq(1) - rs.read?(Int64).should be_nil + rs.read(Int64).should eq(1) + rs.read(Int64 | Nil).should be_nil end end end @@ -96,6 +96,116 @@ describe DummyDriver do end end + describe "query one" do + it "queries" do + with_dummy do |db| + db.query_one("3,4", &.read(Int64, Int64)).should eq({3i64, 4i64}) + end + end + + it "raises if more than one row" do + with_dummy do |db| + expect_raises(DB::Error, "more than one row") do + db.query_one("3,4 5,6") { } + end + end + end + + it "raises if no rows" do + with_dummy do |db| + expect_raises(DB::Error, "no rows") do + db.query_one("") { } + end + end + end + + it "with as" do + with_dummy do |db| + db.query_one("3,4", as: {Int64, Int64}).should eq({3i64, 4i64}) + end + end + + it "with as, just one" do + with_dummy do |db| + db.query_one("3", as: Int64).should eq(3i64) + end + end + end + + describe "query one?" do + it "queries" do + with_dummy do |db| + value = db.query_one?("3,4", &.read(Int64, Int64)) + value.should eq({3i64, 4i64}) + value.should be_a(Tuple(Int64, Int64)?) + end + end + + it "raises if more than one row" do + with_dummy do |db| + expect_raises(DB::Error, "more than one row") do + db.query_one?("3,4 5,6") { } + end + end + end + + it "returns nil if no rows" do + with_dummy do |db| + db.query_one?("") { fail("block shouldn't be invoked") }.should be_nil + end + end + + it "with as" do + with_dummy do |db| + value = db.query_one?("3,4", as: {Int64, Int64}) + value.should be_a(Tuple(Int64, Int64)?) + value.should eq({3i64, 4i64}) + end + end + + it "with as, just one" do + with_dummy do |db| + value = db.query_one?("3", as: Int64) + value.should be_a(Int64?) + value.should eq(3i64) + end + end + end + + describe "query all" do + it "queries" do + with_dummy do |db| + ary = db.query_all "3,4 1,2", &.read(Int64, Int64) + ary.should eq([{3, 4}, {1, 2}]) + end + end + + it "queries with as" do + with_dummy do |db| + ary = db.query_all "3,4 1,2", as: {Int64, Int64} + ary.should eq([{3, 4}, {1, 2}]) + end + end + + it "queries with as, just one" do + with_dummy do |db| + ary = db.query_all "3 1", as: Int64 + ary.should eq([3, 1]) + end + end + end + + it "reads multiple values" do + with_dummy do |db| + db.query "3,4 1,2" do |rs| + rs.move_next + rs.read(Int64, Int64).should eq({3i64, 4i64}) + rs.move_next + rs.read(Int64, Int64).should eq({1i64, 2i64}) + end + end + end + it "should enumerate blob fields" do with_dummy do |db| db.query("az,AZ") do |rs| @@ -110,22 +220,13 @@ describe DummyDriver do it "should get Nil scalars" do with_dummy do |db| - DummyDriver::DummyResultSet.next_column_type = Nil db.scalar("NULL").should be_nil end end {% for value in [1, 1_i64, "hello", 1.5, 1.5_f32] %} - it "numeric scalars of type of {{value.id}} should return value or nil" do - with_dummy do |db| - DummyDriver::DummyResultSet.next_column_type = typeof({{value}}) - db.scalar("#{{{value}}}").should eq({{value}}) - end - end - it "should set positional arguments for {{value.id}}" do with_dummy do |db| - DummyDriver::DummyResultSet.next_column_type = typeof({{value}}) db.scalar("?", {{value}}).should eq({{value}}) end end @@ -135,7 +236,6 @@ describe DummyDriver do with_dummy do |db| ary = UInt8[0x53, 0x51, 0x4C] slice = Bytes.new(ary.to_unsafe, ary.size) - DummyDriver::DummyResultSet.next_column_type = typeof(slice) (db.scalar("?", slice).as(Bytes)).to_a.should eq(ary) end end diff --git a/src/db.cr b/src/db.cr index 75e8c13..1e0e2eb 100644 --- a/src/db.cr +++ b/src/db.cr @@ -126,3 +126,4 @@ require "./db/driver" require "./db/connection" require "./db/statement" require "./db/result_set" +require "./db/error" diff --git a/src/db/error.cr b/src/db/error.cr new file mode 100644 index 0000000..595c2f1 --- /dev/null +++ b/src/db/error.cr @@ -0,0 +1,4 @@ +module DB + class Error < Exception + end +end diff --git a/src/db/query_methods.cr b/src/db/query_methods.cr index 3047f30..572ff20 100644 --- a/src/db/query_methods.cr +++ b/src/db/query_methods.cr @@ -21,27 +21,190 @@ module DB # :nodoc: abstract def prepare(query) : Statement - # Returns a `ResultSet` for the `query`. + # Executes a *query* and returns a `ResultSet` with the results. # The `ResultSet` must be closed manually. + # + # ``` + # result = db.query "select name from contacts where id = ?", 10 + # begin + # if result.move_next + # id = result.read(Int32) + # end + # ensure + # result.close + # end + # ``` def query(query, *args) prepare query, &.query(*args) end - # Yields a `ResultSet` for the `query`. + # Executes a *query* and yields a `ResultSet` with the results. # The `ResultSet` is closed automatically. + # + # ``` + # db.query("select name from contacts where age > ?", 18) do |rs| + # rs.each do + # name = rs.read(String) + # end + # end + # ``` def query(query, *args) # CHECK prepare(query).query(*args, &block) rs = query(query, *args) yield rs ensure rs.close end + # Executes a *query* that expects a single row and yields a `ResultSet` + # positioned at that first row. + # + # The given block must not invoke `move_next` on the yielded result set. + # + # Raises `DB::Error` if there were no rows, or if there were more than one row. + # + # ``` + # name = db.query_one "select name from contacts where id = ?", 18, &.read(String) + # ``` + def query_one(query, *args, &block : ResultSet -> U) : U + query(query, *args) do |rs| + raise DB::Error.new("no rows") unless rs.move_next + + value = yield rs + raise DB::Error.new("more than one row") if rs.move_next + return value + end + end + + # Executes a *query* that expects a single row and returns it + # as a tuple of the given *types*. + # + # Raises `DB::Error` if there were no rows, or if there were more than one row. + # + # ``` + # db.query_one "select name, age from contacts where id = ?", 1, as: {String, Int32} + # ``` + def query_one(query, *args, as types : Tuple) + query_one(query, *args) do |rs| + rs.read(*types) + end + end + + # Executes a *query* that expects a single row + # and returns the first column's value as the given *type*. + # + # Raises `DB::Error` if there were no rows, or if there were more than one row. + # + # ``` + # db.query_one "select name from contacts where id = ?", 1, as: String + # ``` + def query_one(query, *args, as type : Class) + query_one(query, *args) do |rs| + rs.read(type) + end + end + + # Executes a *query* that expects at most a single row and yields a `ResultSet` + # positioned at that first row. + # + # Returns `nil`, not invoking the block, if there were no rows. + # + # Raises `DB::Error` if there were more than one row + # (this ends up invoking the block once). + # + # ``` + # name = db.query_one? "select name from contacts where id = ?", 18, &.read(String) + # typeof(name) # => String | Nil + # ``` + def query_one?(query, *args, &block : ResultSet -> U) : U? + query(query, *args) do |rs| + return nil unless rs.move_next + + value = yield rs + raise DB::Error.new("more than one row") if rs.move_next + return value + end + end + + # Executes a *query* that expects a single row and returns it + # as a tuple of the given *types*. + # + # Returns `nil` if there were no rows. + # + # Raises `DB::Error` if there were more than one row. + # + # ``` + # result = db.query_one? "select name, age from contacts where id = ?", 1, as: {String, Int32} + # typeof(result) # => Tuple(String, Int32) | Nil + # ``` + def query_one?(query, *args, as types : Tuple) + query_one?(query, *args) do |rs| + rs.read(*types) + end + end + + # Executes a *query* that expects a single row + # and returns the first column's value as the given *type*. + # + # Returns `nil` if there were no rows. + # + # Raises `DB::Error` if there were more than one row. + # + # ``` + # name = db.query_one? "select name from contacts where id = ?", 1, as: String + # typeof(name) # => String? + # ``` + def query_one?(query, *args, as type : Class) + query_one?(query, *args) do |rs| + rs.read(type) + end + end + + # Executes a *query* and yield a `ResultSet` positioned at the beginning + # of each row, returning an array of the values of the blocks. + # + # ``` + # names = db.query_all "select name from contacts", &.read(String) + # ``` + def query_all(query, *args, &block : ResultSet -> U) : Array(U) + ary = [] of U + query(query, *args) do |rs| + rs.each do + ary.push(yield rs) + end + end + ary + end + + # Executes a *query* and returns an array where each row is + # read as a tuple of the given *types*. + # + # ``` + # contacts = db.query_all "select name, age from contactas", as: {String, Int32} + # ``` + def query_all(query, *args, as types : Tuple) + query_all(query, *args) do |rs| + rs.read(*types) + end + end + + # Executes a *query* and returns an array where there first + # column's value of each row is read as the given *type*. + # + # ``` + # names = db.query_all "select name from contactas", as: String + # ``` + def query_all(query, *args, as type : Class) + query_all(query, *args) do |rs| + rs.read(type) + end + end + # Performs the `query` and returns an `ExecResult` def exec(query, *args) prepare query, &.exec(*args) end - # Performs the `query` and returns a single scalar `DB::Any` value - # puts db.scalar("SELECT MAX(name)") as String # => (a String) + # Performs the `query` and returns a single scalar value + # puts db.scalar("SELECT MAX(name)").as(String) # => (a String) def scalar(query, *args) prepare query, &.scalar(*args) end @@ -55,7 +218,5 @@ module DB raise ex end end - - # TODO add query_row end end diff --git a/src/db/result_set.cr b/src/db/result_set.cr index 012c99e..1b048ae 100644 --- a/src/db/result_set.cr +++ b/src/db/result_set.cr @@ -16,10 +16,9 @@ module DB # ### Note to implementors # # 1. Override `#move_next` to move to the next row. - # 2. Override `#read?(t)` for all `t` in `DB::TYPES` and any other types the driver should handle. - # 3. (Optional) Override `#read(t)` for all `t` in `DB::TYPES` and any other. + # 2. Override `#read` returning the next value in the row. + # 3. (Optional) Override `#read(t)` for some types `t` for which custom logic other than a simple cast is needed. # 4. Override `#column_count`, `#column_name`. - # 5. Override `#column_type`. It must return a type in `DB::TYPES`. abstract class ResultSet include Disposable @@ -55,30 +54,29 @@ module DB # Returns the name of the column in `index` 0-based position. abstract def column_name(index : Int32) : String - # Returns the type of the column in `index` 0-based position. - # The result is one of `DB::TYPES`. - abstract def column_type(index : Int32) + # Reads the next column value + abstract def read - def read(t) - read?(t).not_nil! + # Reads the next column value as a **type** + def read(type : T.class) : T + read.as(T) end - # Reads the next column as a Nil. - def read(t : Nil.class) : Nil - read?(Nil) + # Reads the next columns and returns a tuple of the values. + def read(*types : Class) + internal_read(*types) end - def read?(t) - raise "read?(t : #{t}) is not implemented in #{self.class}" + private def internal_read(*types : *T) + {% begin %} + Tuple.new( + {% for type in T %} + read({{type.instance}}), + {% end %} + ) + {% end %} end - # list datatypes that must be supported form the driver - # users will call read(String) or read?(String) for nillables - {% for t in DB::TYPES %} - # Reads the next column as a nillable {{t}}. - abstract def read?(t : {{t}}.class) : {{t}}? - {% end %} - # def read_blob # yield ... io .... # end diff --git a/src/db/statement.cr b/src/db/statement.cr index 94d7afa..55120b4 100644 --- a/src/db/statement.cr +++ b/src/db/statement.cr @@ -45,7 +45,7 @@ module DB def scalar(*args) query(*args) do |rs| rs.each do - return rs.read?(rs.column_type(0)) + return rs.read end end