From 683e6bdfa7af624beb19c4989fa6f0210097ab89 Mon Sep 17 00:00:00 2001 From: "Brian J. Cardiff" Date: Sat, 30 Jan 2016 19:46:43 -0300 Subject: [PATCH] major db refactor for a better api. `#query`, `#exec`, `#scalar`, `#scalar?` as main query methods from `Database` blocks overrides that ensure statements are closed. --- spec/std/db/db_spec.cr | 53 ++++++++ spec/std/db/driver_spec.cr | 16 --- spec/std/db/dummy_driver.cr | 58 +++++++-- spec/std/db/dummy_driver_spec.cr | 205 +++++++++++++++++++------------ spec/std/db/result_set_spec.cr | 20 +++ spec/std/db/statement_spec.cr | 181 +++++++++++++++++++++++++++ src/db/connection.cr | 25 ++++ src/db/database.cr | 64 +++++++--- src/db/db.cr | 9 ++ src/db/driver.cr | 2 +- src/db/result_set.cr | 5 +- src/db/statement.cr | 75 ++++++++--- 12 files changed, 574 insertions(+), 139 deletions(-) create mode 100644 spec/std/db/db_spec.cr delete mode 100644 spec/std/db/driver_spec.cr create mode 100644 spec/std/db/result_set_spec.cr create mode 100644 spec/std/db/statement_spec.cr create mode 100644 src/db/connection.cr diff --git a/spec/std/db/db_spec.cr b/spec/std/db/db_spec.cr new file mode 100644 index 0000000..89764f0 --- /dev/null +++ b/spec/std/db/db_spec.cr @@ -0,0 +1,53 @@ +require "spec" +require "db" +require "./dummy_driver" + +describe DB do + it "should get driver class by name" do + DB.driver_class("dummy").should eq(DummyDriver) + end + + it "should instantiate driver with options" do + db = DB.open "dummy", {"host": "localhost", "port": "1027"} + db.driver_class.should eq(DummyDriver) + db.options["host"].should eq("localhost") + db.options["port"].should eq("1027") + end + + it "should create a connection and close it" do + cnn = nil + DB.open "dummy", {"host": "localhost"} do |db| + cnn = db.connection + end + + cnn.should be_a(DummyDriver::DummyConnection) + cnn.not_nil!.closed?.should be_true + end + + it "query should close statement" do + with_witness do |w| + with_dummy do |db| + db.query "1,2" do + break + end + + w.check + db.connection.last_statement.closed?.should be_true + end + end + end + + it "exec should close statement" do + with_dummy do |db| + db.exec "" + db.connection.last_statement.closed?.should be_true + end + end + + it "scalar should close statement" do + with_dummy do |db| + db.scalar "1" + db.connection.last_statement.closed?.should be_true + end + end +end diff --git a/spec/std/db/driver_spec.cr b/spec/std/db/driver_spec.cr deleted file mode 100644 index 36d0bea..0000000 --- a/spec/std/db/driver_spec.cr +++ /dev/null @@ -1,16 +0,0 @@ -require "spec" -require "db" -require "./dummy_driver" - -describe DB::Driver do - it "should get driver class by name" do - DB.driver_class("dummy").should eq(DummyDriver) - end - - it "should instantiate driver with options" do - db = DB.open "dummy", {"host": "localhost", "port": "1027"} - db.driver_class.should eq(DummyDriver) - db.options["host"].should eq("localhost") - db.options["port"].should eq("1027") - end -end diff --git a/spec/std/db/dummy_driver.cr b/spec/std/db/dummy_driver.cr index 24b76c0..0f41006 100644 --- a/spec/std/db/dummy_driver.cr +++ b/spec/std/db/dummy_driver.cr @@ -1,6 +1,19 @@ +require "spec" + class DummyDriver < DB::Driver - def prepare(query) - DummyStatement.new(self, query.split.map { |r| r.split ',' }) + def build_connection + DummyConnection.new(options) + end + + class DummyConnection < DB::Connection + getter! last_statement + + def prepare(query) + @last_statement = DummyStatement.new(self, query.split.map { |r| r.split ',' }) + end + + def perform_close + end end class DummyStatement < DB::Statement @@ -10,6 +23,10 @@ class DummyDriver < DB::Driver super(driver) end + protected def begin_parameters + @params = Hash(Int32 | String, DB::Any?).new + end + protected def add_parameter(index : Int32, value) params[index] = value end @@ -18,11 +35,7 @@ class DummyDriver < DB::Driver params[":#{name}"] = value end - protected def before_execute - @params = Hash(Int32 | String, DB::Any).new - end - - protected def execute + protected def perform DummyResultSet.new self, @items.each end end @@ -48,6 +61,10 @@ class DummyDriver < DB::Driver "c#{index}" end + def column_type(index : Int32) + String + end + private def read? : DB::Any? n = @values.not_nil!.next raise "end of row" if n.is_a?(Iterator::Stop) @@ -91,8 +108,10 @@ class DummyDriver < DB::Driver elsif value.is_a?(String) ary = value.bytes Slice.new(ary.to_unsafe, ary.size) + elsif value.is_a?(Slice(UInt8)) + value else - value as Slice(UInt8) + raise "#{value} is not convertible to Slice(UInt8)" end end end @@ -100,6 +119,25 @@ end DB.register_driver "dummy", DummyDriver -def get_dummy - DB.open "dummy", {} of String => String +class Witness + getter count + + def initialize(@count) + end + + def check + @count -= 1 + end +end + +def with_witness(count = 1) + w = Witness.new(count) + yield w + w.count.should eq(0), "The expected coverage was unmet" +end + +def with_dummy + DB.open "dummy", {} of String => String do |db| + yield db + end end diff --git a/spec/std/db/dummy_driver_spec.cr b/spec/std/db/dummy_driver_spec.cr index 6ba3994..4454564 100644 --- a/spec/std/db/dummy_driver_spec.cr +++ b/spec/std/db/dummy_driver_spec.cr @@ -3,118 +3,169 @@ require "db" require "./dummy_driver" describe DummyDriver do - it "should return statements" do - get_dummy.prepare("the query").should be_a(DB::Statement) + it "with_dummy executes the block with a database" do + with_witness do |w| + with_dummy do |db| + w.check + db.should be_a(DB::Database) + end + end end describe DummyDriver::DummyStatement do - it "exec should return a result_set" do - statement = get_dummy.prepare("a,b 1,2") - result_set = statement.exec - result_set.should be_a(DB::ResultSet) - result_set.statement.should be(statement) + it "should enumerate split rows by spaces" do + with_dummy do |db| + rs = db.query("") + rs.move_next.should be_false + rs.close + + rs = db.query("a,b") + rs.move_next.should be_true + rs.move_next.should be_false + rs.close + + rs = db.query("a,b 1,2") + rs.move_next.should be_true + rs.move_next.should be_true + rs.move_next.should be_false + rs.close + + rs = db.query("a,b 1,2 c,d") + rs.move_next.should be_true + rs.move_next.should be_true + rs.move_next.should be_true + rs.move_next.should be_false + rs.close + end end - it "should enumerate records by spaces" do - result_set = get_dummy.prepare("").exec - result_set.move_next.should be_false + it "should query with block shuold executes always" do + with_witness do |w| + with_dummy do |db| + db.query "" do |rs| + w.check + end + end + end - result_set = get_dummy.prepare("a,b").exec - result_set.move_next.should be_true - result_set.move_next.should be_false - - result_set = get_dummy.prepare("a,b 1,2").exec - result_set.move_next.should be_true - result_set.move_next.should be_true - result_set.move_next.should be_false - - result_set = get_dummy.prepare("a,b 1,2 c,d").exec - result_set.move_next.should be_true - result_set.move_next.should be_true - result_set.move_next.should be_true - result_set.move_next.should be_false + with_witness do |w| + with_dummy do |db| + db.query "lorem ipsum" do |rs| + w.check + end + end + end end it "should enumerate string fields" do - result_set = get_dummy.prepare("a,b 1,2").exec - result_set.move_next - result_set.read(String).should eq("a") - result_set.read(String).should eq("b") - result_set.move_next - result_set.read(String).should eq("1") - result_set.read(String).should eq("2") + with_dummy do |db| + db.query "a,b 1,2" do |rs| + rs.move_next + rs.read(String).should eq("a") + rs.read(String).should eq("b") + rs.move_next + rs.read(String).should eq("1") + rs.read(String).should eq("2") + end + end end it "should enumerate nil fields" do - result_set = get_dummy.prepare("a,NULL 1,NULL").exec - result_set.move_next - result_set.read?(String).should eq("a") - result_set.read?(String).should be_nil - result_set.move_next - result_set.read?(Int64).should eq(1) - result_set.read?(Int64).should be_nil + with_dummy do |db| + db.query "a,NULL 1,NULL" do |rs| + rs.move_next + rs.read?(String).should eq("a") + rs.read?(String).should be_nil + rs.move_next + rs.read?(Int64).should eq(1) + rs.read?(Int64).should be_nil + end + end end it "should enumerate int64 fields" do - result_set = get_dummy.prepare("3,4 1,2").exec - result_set.move_next - result_set.read(Int64).should eq(3i64) - result_set.read(Int64).should eq(4i64) - result_set.move_next - result_set.read(Int64).should eq(1i64) - result_set.read(Int64).should eq(2i64) + with_dummy do |db| + db.query "3,4 1,2" do |rs| + rs.move_next + rs.read(Int64).should eq(3i64) + rs.read(Int64).should eq(4i64) + rs.move_next + rs.read(Int64).should eq(1i64) + rs.read(Int64).should eq(2i64) + end + end end it "should enumerate blob fields" do - result_set = get_dummy.prepare("az,AZ").exec - result_set.move_next - ary = [97u8, 122u8] - result_set.read(Slice(UInt8)).should eq(Slice.new(ary.to_unsafe, ary.size)) - ary = [65u8, 90u8] - result_set.read(Slice(UInt8)).should eq(Slice.new(ary.to_unsafe, ary.size)) + with_dummy do |db| + db.query("az,AZ") do |rs| + rs.move_next + ary = [97u8, 122u8] + rs.read(Slice(UInt8)).should eq(Slice.new(ary.to_unsafe, ary.size)) + ary = [65u8, 90u8] + rs.read(Slice(UInt8)).should eq(Slice.new(ary.to_unsafe, ary.size)) + end + end end - it "should enumerate records using each" do - nums = [] of Int32 - result_set = get_dummy.prepare("3,4 1,2").exec - result_set.each do - nums << result_set.read(Int32) - nums << result_set.read(Int32) + it "should get Int32 scalars by default" do + with_dummy do |db| + db.scalar("1").should be_a(Int32) + db.scalar?("1").should be_a(Int32) + db.scalar?("NULL").should be_nil end + end - nums.should eq([3, 4, 1, 2]) + it "should get String scalars" do + with_dummy do |db| + db.scalar(String, "foo").should eq("foo") + end end {% for value in [1, 1_i64, "hello", 1.5, 1.5_f32] %} - it "should set arguments for {{value.id}}" do - result_set = get_dummy.exec "?", {{value}} - result_set.move_next.should be_true - result_set.read(typeof({{value}})).should eq({{value}}) + it "numeric scalars of type of {{value.id}} should return value or nil" do + with_dummy do |db| + db.scalar(typeof({{value}}), "#{{{value}}}").should eq({{value}}) + db.scalar?(typeof({{value}}), "#{{{value}}}").should eq({{value}}) + db.scalar?(typeof({{value}}), "NULL").should be_nil + end + end + + it "should set positional arguments for {{value.id}}" do + with_dummy do |db| + db.scalar(typeof({{value}}), "?", {{value}}).should eq({{value}}) + end end it "should set arguments by symbol for {{value.id}}" do - result_set = get_dummy.exec ":once :twice", {once: {{value}}, twice: {{value + value}} } - result_set.move_next.should be_true - result_set.read(typeof({{value}})).should eq({{value}}) - result_set.move_next.should be_true - result_set.read(typeof({{value}})).should eq({{value + value}}) + with_dummy do |db| + db.query ":once :twice", {once: {{value}}, twice: {{value + value}} } do |rs| + rs.move_next.should be_true + rs.read(typeof({{value}})).should eq({{value}}) + rs.move_next.should be_true + rs.read(typeof({{value}})).should eq({{value + value}}) + end + end end it "should set arguments by string for {{value.id}}" do - result_set = get_dummy.exec ":once :twice", {"once": {{value}}, "twice": {{value + value}} } - result_set.move_next.should be_true - result_set.read(typeof({{value}})).should eq({{value}}) - result_set.move_next.should be_true - result_set.read(typeof({{value}})).should eq({{value + value}}) + with_dummy do |db| + db.query ":once :twice", {"once": {{value}}, "twice": {{value + value}} } do |rs| + rs.move_next.should be_true + rs.read(typeof({{value}})).should eq({{value}}) + rs.move_next.should be_true + rs.read(typeof({{value}})).should eq({{value + value}}) + end + end end {% end %} it "executes and selects blob" do - ary = UInt8[0x53, 0x51, 0x4C] - slice = Slice.new(ary.to_unsafe, ary.size) - result_set = get_dummy.exec "?", slice - result_set.move_next - result_set.read(Slice(UInt8)).to_a.should eq(ary) + with_dummy do |db| + ary = UInt8[0x53, 0x51, 0x4C] + slice = Slice.new(ary.to_unsafe, ary.size) + db.scalar(Slice(UInt8), "?", slice).to_a.should eq(ary) + end end end end diff --git a/spec/std/db/result_set_spec.cr b/spec/std/db/result_set_spec.cr new file mode 100644 index 0000000..eabfb7f --- /dev/null +++ b/spec/std/db/result_set_spec.cr @@ -0,0 +1,20 @@ +require "spec" +require "db" +require "./dummy_driver" + +describe DB::ResultSet do + it "should enumerate records using each" do + nums = [] of Int32 + + with_dummy do |db| + db.query "3,4 1,2" do |rs| + rs.each do + nums << rs.read(Int32) + nums << rs.read(Int32) + end + end + end + + nums.should eq([3, 4, 1, 2]) + end +end diff --git a/spec/std/db/statement_spec.cr b/spec/std/db/statement_spec.cr new file mode 100644 index 0000000..9442482 --- /dev/null +++ b/spec/std/db/statement_spec.cr @@ -0,0 +1,181 @@ +require "spec" +require "db" +require "./dummy_driver" + +describe DB::Statement do + it "should prepare statements" do + with_dummy do |db| + db.prepare("the query").should be_a(DB::Statement) + end + end + + it "should initialize positional params in query" do + with_dummy do |db| + stmt = db.prepare("the query") + stmt.query "a", 1, nil + stmt.params[1].should eq("a") + stmt.params[2].should eq(1) + stmt.params[3].should eq(nil) + end + end + + it "should initialize symbol named params in query" do + with_dummy do |db| + stmt = db.prepare("the query") + stmt.query({a: "a", b: 1, c: nil}) + stmt.params[":a"].should eq("a") + stmt.params[":b"].should eq(1) + stmt.params[":c"].should eq(nil) + end + end + + it "should initialize string named params in query" do + with_dummy do |db| + stmt = db.prepare("the query") + stmt.query({"a": "a", "b": 1, "c": nil}) + stmt.params[":a"].should eq("a") + stmt.params[":b"].should eq(1) + stmt.params[":c"].should eq(nil) + end + end + + it "should initialize positional params in exec" do + with_dummy do |db| + stmt = db.prepare("the query") + stmt.exec "a", 1, nil + stmt.params[1].should eq("a") + stmt.params[2].should eq(1) + stmt.params[3].should eq(nil) + end + end + + it "should initialize symbol named params in exec" do + with_dummy do |db| + stmt = db.prepare("the query") + stmt.exec({a: "a", b: 1, c: nil}) + stmt.params[":a"].should eq("a") + stmt.params[":b"].should eq(1) + stmt.params[":c"].should eq(nil) + end + end + + it "should initialize string named params in exec" do + with_dummy do |db| + stmt = db.prepare("the query") + stmt.exec({"a": "a", "b": 1, "c": nil}) + stmt.params[":a"].should eq("a") + stmt.params[":b"].should eq(1) + stmt.params[":c"].should eq(nil) + end + end + + it "should initialize positional params in scalar" do + with_dummy do |db| + stmt = db.prepare("the query") + stmt.scalar String, "a", 1, nil + stmt.params[1].should eq("a") + stmt.params[2].should eq(1) + stmt.params[3].should eq(nil) + end + end + + it "should initialize symbol named params in scalar" do + with_dummy do |db| + stmt = db.prepare("the query") + stmt.scalar(String, {a: "a", b: 1, c: nil}) + stmt.params[":a"].should eq("a") + stmt.params[":b"].should eq(1) + stmt.params[":c"].should eq(nil) + end + end + + it "should initialize string named params in scalar" do + with_dummy do |db| + stmt = db.prepare("the query") + stmt.scalar(String, {"a": "a", "b": 1, "c": nil}) + stmt.params[":a"].should eq("a") + stmt.params[":b"].should eq(1) + stmt.params[":c"].should eq(nil) + end + end + + it "should initialize positional params in scalar?" do + with_dummy do |db| + stmt = db.prepare("the query") + stmt.scalar? String, "a", 1, nil + stmt.params[1].should eq("a") + stmt.params[2].should eq(1) + stmt.params[3].should eq(nil) + end + end + + it "should initialize symbol named params in scalar?" do + with_dummy do |db| + stmt = db.prepare("the query") + stmt.scalar?(String, {a: "a", b: 1, c: nil}) + stmt.params[":a"].should eq("a") + stmt.params[":b"].should eq(1) + stmt.params[":c"].should eq(nil) + end + end + + it "should initialize string named params in scalar?" do + with_dummy do |db| + stmt = db.prepare("the query") + stmt.scalar?(String, {"a": "a", "b": 1, "c": nil}) + stmt.params[":a"].should eq("a") + stmt.params[":b"].should eq(1) + stmt.params[":c"].should eq(nil) + end + end + + it "query with block should not close statement" do + with_dummy do |db| + stmt = db.prepare "3,4 1,2" + stmt.query + stmt.closed?.should be_false + end + end + + it "query with block should close statement" do + with_dummy do |db| + stmt = db.prepare "3,4 1,2" + stmt.query do |rs| + end + stmt.closed?.should be_true + end + end + + it "query should close statement" do + with_dummy do |db| + stmt = db.prepare "3,4 1,2" + stmt.query do |rs| + end + stmt.closed?.should be_true + end + end + + it "scalar should close statement" do + with_dummy do |db| + stmt = db.prepare "3,4 1,2" + stmt.scalar + stmt.closed?.should be_true + end + end + + it "scalar should close statement" do + with_dummy do |db| + stmt = db.prepare "3,4 1,2" + stmt.scalar? + stmt.closed?.should be_true + end + end + + it "exec should close statement" do + with_dummy do |db| + stmt = db.prepare "3,4 1,2" + stmt.exec + stmt.closed?.should be_true + end + end +end diff --git a/src/db/connection.cr b/src/db/connection.cr new file mode 100644 index 0000000..2b515d5 --- /dev/null +++ b/src/db/connection.cr @@ -0,0 +1,25 @@ +module DB + abstract class Connection + getter options + + def initialize(@options) + @closed = false + end + + # Closes this connection. + def close + raise "Connection already closed" if @closed + @closed = true + perform_close + end + + # Returns `true` if this statement is closed. See `#close`. + def closed? + @closed + end + + abstract def prepare(query) : Statement + + protected abstract def perform_close + end +end diff --git a/src/db/database.cr b/src/db/database.cr index 3d8641d..466035b 100644 --- a/src/db/database.cr +++ b/src/db/database.cr @@ -1,43 +1,69 @@ module DB # Acts as an entry point for database access. - # Offers a com + # Currently it creates a single connection to the database. + # Eventually a connection pool will be handled. + # + # It should be created from DB module. See `DB.open`. class Database getter driver_class getter options def initialize(@driver_class, @options) @driver = @driver_class.new(@options) + @connection = @driver.build_connection end - # :nodoc: + # Closes all connection to the database + def close + @connection.close + end + + # Returns a `Connection` to the database + def connection + @connection + end + + # Prepares a `Statement`. The Statement must be closed explicitly + # after is not longer in use. + # + # Usually `#exec`, `#query` or `#scalar` should be used. def prepare(query) - @driver.prepare(query) + connection.prepare(query) + end + + def query(query, *args) + prepare(query).query(*args) + end + + def query(query, *args) + # CHECK prepare(query).query(*args, &block) + query(query, *args).tap do |rs| + begin + yield rs + ensure + rs.close + end + end end - # :nodoc: def exec(query, *args) prepare(query).exec(*args) end - def exec_non_query(query, *args) - exec_query(query) do |result_set| - result_set.move_next - end + def scalar(query, *args) + prepare(query).scalar(*args) end - # :nodoc: - def exec_query(query, *args) - result_set = exec(query, *args) - yield result_set - result_set.close + def scalar(t, query, *args) + prepare(query).scalar(t, *args) end - def exec_query_each(query, *args) - exec_query(query) do |result_set| - result_set.each do - yield result_set - end - end + def scalar?(query, *args) + prepare(query).scalar?(*args) + end + + def scalar?(t, query, *args) + prepare(query).scalar?(t, *args) end end end diff --git a/src/db/db.cr b/src/db/db.cr index 52b0f9a..74acd3b 100644 --- a/src/db/db.cr +++ b/src/db/db.cr @@ -2,6 +2,7 @@ module DB TYPES = [String, Int32, Int64, Float32, Float64, Slice(UInt8)] alias Any = String | Int32 | Int64 | Float32 | Float64 | Slice(UInt8) + # :nodoc: def self.driver_class(name) # : Driver.class @@drivers.not_nil![name] end @@ -14,9 +15,17 @@ module DB def self.open(name, options) Database.new(driver_class(name), options) end + + def self.open(name, options, &block) + open(name, options).tap do |db| + yield db + db.close + end + end end require "./database" require "./driver" +require "./connection" require "./statement" require "./result_set" diff --git a/src/db/driver.cr b/src/db/driver.cr index 7a9a72c..1c9ac1d 100644 --- a/src/db/driver.cr +++ b/src/db/driver.cr @@ -5,6 +5,6 @@ module DB def initialize(@options) end - abstract def prepare(query) : Statement + abstract def build_connection : Connection end end diff --git a/src/db/result_set.cr b/src/db/result_set.cr index 4958fc7..202ad57 100644 --- a/src/db/result_set.cr +++ b/src/db/result_set.cr @@ -17,10 +17,11 @@ module DB abstract def move_next : Bool + # TODO def empty? : Bool, handle internally with move_next (?) + abstract def column_count : Int32 abstract def column_name(index : Int32) : String - - # abstract def column_type(index : Int32) + abstract def column_type(index : Int32) # list datatypes that must be supported form the driver # users will call read(String) or read?(String) for nillables diff --git a/src/db/statement.cr b/src/db/statement.cr index a5e58ca..2c16a99 100644 --- a/src/db/statement.cr +++ b/src/db/statement.cr @@ -1,23 +1,71 @@ module DB abstract class Statement - getter driver + getter connection - def initialize(@driver) + def initialize(@connection) @closed = false end - def exec(*args) : ResultSet - exec args + def exec(*args) + execute(*args).close end - def exec(arg : Slice(UInt8)) - before_execute + def scalar(*args) + scalar(Int32, *args) + end + + # t in DB::TYPES + def scalar(t, *args) + query(*args) do |rs| + rs.each do + return rs.read(t) + end + end + + raise "unreachable" + end + + def scalar?(*args) + scalar?(Int32, *args) + end + + # t in DB::TYPES + def scalar?(t, *args) + query(*args) do |rs| + rs.each do + return rs.read?(t) + end + end + + raise "unreachable" + end + + def query(*args) + execute *args + end + + def query(*args) + execute(*args).tap do |rs| + begin + yield rs + ensure + rs.close + end + end + end + + private def execute(*args) : ResultSet + execute args + end + + private def execute(arg : Slice(UInt8)) + begin_parameters add_parameter 1, arg - execute + perform end - def exec(args : Enumerable) - before_execute + private def execute(args : Enumerable) + begin_parameters args.each_with_index(1) do |arg, index| if arg.is_a?(Hash) arg.each do |key, value| @@ -27,10 +75,7 @@ module DB add_parameter index, arg end end - execute - end - - protected def before_execute + perform end # Closes this statement. @@ -46,10 +91,12 @@ module DB end # 1-based positional arguments + protected def begin_parameters + end protected abstract def add_parameter(index : Int32, value) protected abstract def add_parameter(name : String, value) - protected abstract def execute : ResultSet + protected abstract def perform : ResultSet protected def on_close end end