diff --git a/spec/std/db/driver_spec.cr b/spec/std/db/driver_spec.cr index f2ef75e..36d0bea 100644 --- a/spec/std/db/driver_spec.cr +++ b/spec/std/db/driver_spec.cr @@ -8,8 +8,9 @@ describe DB::Driver do end it "should instantiate driver with options" do - driver = DB.driver "dummy", {"host": "localhost", "port": "1027"} - driver.options["host"].should eq("localhost") - driver.options["port"].should eq("1027") + db = DB.open "dummy", {"host": "localhost", "port": "1027"} + db.driver_class.should eq(DummyDriver) + db.options["host"].should eq("localhost") + db.options["port"].should eq("1027") end end diff --git a/spec/std/db/dummy_driver.cr b/spec/std/db/dummy_driver.cr index f92f989..2aac908 100644 --- a/spec/std/db/dummy_driver.cr +++ b/spec/std/db/dummy_driver.cr @@ -18,7 +18,7 @@ class DummyDriver < DB::Driver super(statement) end - def has_next + def move_next @iterator.next.tap do |n| return false if n.is_a?(Iterator::Stop) @values = n.each @@ -26,12 +26,27 @@ class DummyDriver < DB::Driver end end - def read_string - @values.not_nil!.next as String + def read?(t : String.class) + n = @values.not_nil!.next + raise "end of row" if n.is_a?(Iterator::Stop) + return nil if n == "NULL" + return n as String end - def read_u_int64 - read_string.to_u64 + def read?(t : Int32.class) + read?(String).try &.to_i32 + end + + def read?(t : Int64.class) + read?(String).try &.to_i64 + end + + def read?(t : Float32.class) + read?(String).try &.to_f23 + end + + def read?(t : Float64.class) + read?(String).try &.to_f64 end end end @@ -39,5 +54,5 @@ end DB.register_driver "dummy", DummyDriver def get_dummy - DB.driver "dummy", {} of String => String + DB.open "dummy", {} of String => String end diff --git a/spec/std/db/dummy_driver_spec.cr b/spec/std/db/dummy_driver_spec.cr index d8ff216..f910301 100644 --- a/spec/std/db/dummy_driver_spec.cr +++ b/spec/std/db/dummy_driver_spec.cr @@ -17,50 +17,60 @@ describe DummyDriver do it "should enumerate records by spaces" do result_set = get_dummy.prepare("").exec - result_set.has_next.should be_false + result_set.move_next.should be_false result_set = get_dummy.prepare("a,b").exec - result_set.has_next.should be_true - result_set.has_next.should be_false + result_set.move_next.should be_true + result_set.move_next.should be_false result_set = get_dummy.prepare("a,b 1,2").exec - result_set.has_next.should be_true - result_set.has_next.should be_true - result_set.has_next.should be_false + result_set.move_next.should be_true + result_set.move_next.should be_true + result_set.move_next.should be_false result_set = get_dummy.prepare("a,b 1,2 c,d").exec - result_set.has_next.should be_true - result_set.has_next.should be_true - result_set.has_next.should be_true - result_set.has_next.should be_false + result_set.move_next.should be_true + result_set.move_next.should be_true + result_set.move_next.should be_true + result_set.move_next.should be_false end it "should enumerate string fields" do result_set = get_dummy.prepare("a,b 1,2").exec - result_set.has_next + result_set.move_next result_set.read(String).should eq("a") result_set.read(String).should eq("b") - result_set.has_next + result_set.move_next result_set.read(String).should eq("1") result_set.read(String).should eq("2") end - it "should enumerate uint64 fields" do + it "should enumerate nil fields" do + result_set = get_dummy.prepare("a,NULL 1,NULL").exec + result_set.move_next + result_set.read?(String).should eq("a") + result_set.read?(String).should be_nil + result_set.move_next + result_set.read?(Int64).should eq(1) + result_set.read?(Int64).should be_nil + end + + it "should enumerate int64 fields" do result_set = get_dummy.prepare("3,4 1,2").exec - result_set.has_next - result_set.read(UInt64).should eq(3) - result_set.read(UInt64).should eq(4) - result_set.has_next - result_set.read(UInt64).should eq(1) - result_set.read(UInt64).should eq(2) + result_set.move_next + result_set.read(Int64).should eq(3i64) + result_set.read(Int64).should eq(4i64) + result_set.move_next + result_set.read(Int64).should eq(1i64) + result_set.read(Int64).should eq(2i64) end it "should enumerate records using each" do - nums = [] of UInt64 + nums = [] of Int32 result_set = get_dummy.prepare("3,4 1,2").exec result_set.each do - nums << result_set.read(UInt64) - nums << result_set.read(UInt64) + nums << result_set.read(Int32) + nums << result_set.read(Int32) end nums.should eq([3, 4, 1, 2]) diff --git a/src/db/database.cr b/src/db/database.cr new file mode 100644 index 0000000..25a171f --- /dev/null +++ b/src/db/database.cr @@ -0,0 +1,20 @@ +module DB + # Acts as an entry point for database access. + # Offers a com + class Database + getter driver_class + getter options + + def initialize(@driver_class, @options) + @driver = @driver_class.new(@options) + end + + def prepare(query) + @driver.prepare(query) + end + + def exec(query, *args) + prepare(query).exec(*args) + end + end +end diff --git a/src/db/db.cr b/src/db/db.cr index 608dc8a..77cea9d 100644 --- a/src/db/db.cr +++ b/src/db/db.cr @@ -1,4 +1,6 @@ module DB + TYPES = [String, Int32, Int64, Float32, Float64] + def self.driver_class(name) # : Driver.class @@drivers.not_nil![name] end @@ -8,11 +10,12 @@ module DB @@drivers.not_nil![name] = klass end - def self.driver(name, options) - driver_class(name).new(options) + def self.open(name, options) + Database.new(driver_class(name), options) end end +require "./database" require "./driver" require "./statement" require "./result_set" diff --git a/src/db/result_set.cr b/src/db/result_set.cr index 147f8e3..ed396c5 100644 --- a/src/db/result_set.cr +++ b/src/db/result_set.cr @@ -6,32 +6,21 @@ module DB end def each - while has_next + while move_next yield end end - abstract def has_next : Bool - - # def read(t : T.class) : T - # end + abstract def move_next : Bool # list datatypes that must be supported form the driver - # implementors will override read_string - # users will call read(String) due to overloads read(T) will be a T - # TODO: unable to write unions (nillables) - {% for t in [String, UInt64] %} + # users will call read(String) or read?(String) for nillables + {% for t in DB::TYPES %} + abstract def read?(t : {{t}}.class) : {{t}}? + def read(t : {{t}}.class) : {{t}} - read_{{t.name.underscore}} + read?({{t}}).not_nil! end - - protected abstract def read_{{t.name.underscore}} : {{t}} {% end %} - - # def read(t : String.class) : String - # read_string - # end - # - # protected abstract def read_string : String end end diff --git a/src/db/statement.cr b/src/db/statement.cr index fc62c03..3bf6ff5 100644 --- a/src/db/statement.cr +++ b/src/db/statement.cr @@ -1,5 +1,7 @@ module DB abstract class Statement + getter driver + def initialize(@driver) end