fixes #5. Merge branch 'feature/5-type-extensibility'

This commit is contained in:
Brian J. Cardiff 2016-06-23 15:12:20 -03:00
commit 22db7d1043
8 changed files with 373 additions and 57 deletions

View file

@ -0,0 +1,299 @@
require "./spec_helper"
module GenericResultSet
@index = 0
def move_next
@index = 0
true
end
def column_count : Int32
@row.size
end
def column_name(index : Int32) : String
index.to_s
end
def column_type(index : Int32)
@row[index].class
end
{% for t in DB::TYPES %}
# Reads the next column as a nillable {{t}}.
def read?(t : {{t}}.class) : {{t}}?
read_and_move_next_column as {{t}}?
end
{% end %}
def read_and_move_next_column
@index += 1
@row[@index - 1]
end
end
class FooValue
def initialize(@value : Int32)
end
def value
@value
end
end
class FooDriver < DB::Driver
alias Any = DB::Any | FooValue
@@row = [] of Any
def self.fake_row=(row : Array(Any))
@@row = row
end
def self.fake_row
@@row
end
def build_connection(db : DB::Database) : DB::Connection
FooConnection.new(db)
end
class FooConnection < DB::Connection
def build_statement(query)
FooStatement.new(self)
end
end
class FooStatement < DB::Statement
protected def perform_query(args : Enumerable) : DB::ResultSet
args.each { |arg| process_arg arg }
FooResultSet.new(self, FooDriver.fake_row)
end
protected def perform_exec(args : Enumerable) : DB::ExecResult
args.each { |arg| process_arg arg }
DB::ExecResult.new 0i64, 0i64
end
private def process_arg(value : FooDriver::Any)
end
private def process_arg(value)
raise "#{self.class} does not support #{value.class} params"
end
end
class FooResultSet < DB::ResultSet
include GenericResultSet
def initialize(statement, @row : Array(FooDriver::Any))
super(statement)
end
def read?(t : FooValue.class) : FooValue?
read_and_move_next_column.as(FooValue?)
end
end
end
DB.register_driver "foo", FooDriver
class BarValue
getter value
def initialize(@value : Int32)
end
end
class BarDriver < DB::Driver
alias Any = DB::Any | BarValue
@@row = [] of Any
def self.fake_row=(row : Array(Any))
@@row = row
end
def self.fake_row
@@row
end
def build_connection(db : DB::Database) : DB::Connection
BarConnection.new(db)
end
class BarConnection < DB::Connection
def build_statement(query)
BarStatement.new(self)
end
end
class BarStatement < DB::Statement
protected def perform_query(args : Enumerable) : DB::ResultSet
args.each { |arg| process_arg arg }
BarResultSet.new(self, BarDriver.fake_row)
end
protected def perform_exec(args : Enumerable) : DB::ExecResult
args.each { |arg| process_arg arg }
DB::ExecResult.new 0i64, 0i64
end
private def process_arg(value : BarDriver::Any)
end
private def process_arg(value)
raise "#{self.class} does not support #{value.class} params"
end
end
class BarResultSet < DB::ResultSet
include GenericResultSet
def initialize(statement, @row : Array(BarDriver::Any))
super(statement)
end
def read?(t : BarValue.class) : BarValue?
read_and_move_next_column.as(BarValue?)
end
end
end
DB.register_driver "bar", BarDriver
describe DB do
it "should be able to register multiple drivers" do
DB.open("foo://host").driver.should be_a(FooDriver)
DB.open("bar://host").driver.should be_a(BarDriver)
end
it "Foo and Bar drivers should return fake_row" do
with_witness do |w|
DB.open("foo://host") do |db|
# TODO somehow FooValue.new(99) is needed otherwise the read_object assertion fail
FooDriver.fake_row = [1, "string", FooValue.new(3), FooValue.new(99)] of FooDriver::Any
db.query "query" do |rs|
w.check
rs.move_next
rs.read?(Int32).should eq(1)
rs.read?(String).should eq("string")
rs.read(FooValue).value.should eq(3)
end
end
end
with_witness do |w|
DB.open("bar://host") do |db|
# TODO somehow BarValue.new(99) is needed otherwise the read_object assertion fail
BarDriver.fake_row = [BarValue.new(4), "lorem", 1.0, BarValue.new(99)] of BarDriver::Any
db.query "query" do |rs|
w.check
rs.move_next
rs.read(BarValue).value.should eq(4)
rs.read?(String).should eq("lorem")
rs.read?(Float64).should eq(1.0)
end
end
end
end
it "Foo and Bar drivers should not implement each other read" do
with_witness do |w|
DB.open("foo://host") do |db|
FooDriver.fake_row = [1] of FooDriver::Any
db.query "query" do |rs|
rs.move_next
expect_raises Exception, "read?(t : BarValue) is not implemented in FooDriver::FooResultSet" do
w.check
rs.read(BarValue)
end
end
end
end
with_witness do |w|
DB.open("bar://host") do |db|
BarDriver.fake_row = [1] of BarDriver::Any
db.query "query" do |rs|
rs.move_next
expect_raises Exception, "read?(t : FooValue) is not implemented in BarDriver::BarResultSet" do
w.check
rs.read(FooValue)
end
end
end
end
end
it "allow custom types to be used as arguments for query" do
DB.open("foo://host") do |db|
FooDriver.fake_row = [1, "string"] of FooDriver::Any
db.query "query" { }
db.query "query", 1 { }
db.query "query", 1, "string" { }
db.query("query", Bytes.new(4)) { }
db.query("query", 1, "string", FooValue.new(5)) { }
db.query "query", [1, "string", FooValue.new(5)] { }
db.query("query").close
db.query("query", 1).close
db.query("query", 1, "string").close
db.query("query", Bytes.new(4)).close
db.query("query", 1, "string", FooValue.new(5)).close
db.query("query", [1, "string", FooValue.new(5)]).close
end
DB.open("bar://host") do |db|
BarDriver.fake_row = [1, "string"] of BarDriver::Any
db.query "query" { }
db.query "query", 1 { }
db.query "query", 1, "string" { }
db.query("query", Bytes.new(4)) { }
db.query("query", 1, "string", BarValue.new(5)) { }
db.query "query", [1, "string", BarValue.new(5)] { }
db.query("query").close
db.query("query", 1).close
db.query("query", 1, "string").close
db.query("query", Bytes.new(4)).close
db.query("query", 1, "string", BarValue.new(5)).close
db.query("query", [1, "string", BarValue.new(5)]).close
end
end
it "allow custom types to be used as arguments for exec" do
DB.open("foo://host") do |db|
FooDriver.fake_row = [1, "string"] of FooDriver::Any
db.exec("query")
db.exec("query", 1)
db.exec("query", 1, "string")
db.exec("query", Bytes.new(4))
db.exec("query", 1, "string", FooValue.new(5))
db.exec("query", [1, "string", FooValue.new(5)])
end
DB.open("bar://host") do |db|
BarDriver.fake_row = [1, "string"] of BarDriver::Any
db.exec("query")
db.exec("query", 1)
db.exec("query", 1, "string")
db.exec("query", Bytes.new(4))
db.exec("query", 1, "string", BarValue.new(5))
db.exec("query", [1, "string", BarValue.new(5)])
end
end
it "Foo and Bar drivers should not implement each other params" do
DB.open("foo://host") do |db|
expect_raises Exception, "FooDriver::FooStatement does not support BarValue params" do
db.exec("query", [BarValue.new(5)])
end
end
DB.open("bar://host") do |db|
expect_raises Exception, "BarDriver::BarStatement does not support FooValue params" do
db.exec("query", [FooValue.new(5)])
end
end
end
end

View file

@ -42,23 +42,31 @@ class DummyDriver < DB::Driver
super(connection) super(connection)
end end
protected def perform_query(args : Slice(DB::Any)) protected def perform_query(args : Enumerable)
set_params args set_params args
DummyResultSet.new self, @query DummyResultSet.new self, @query
end end
protected def perform_exec(args : Slice(DB::Any)) protected def perform_exec(args : Enumerable)
set_params args set_params args
DB::ExecResult.new 0, 0_i64 DB::ExecResult.new 0i64, 0_i64
end end
private def set_params(args) private def set_params(args)
@params.clear @params.clear
args.each_with_index do |arg, index| args.each_with_index do |arg, index|
@params[index] = arg set_param(index, arg)
end end
end end
private def set_param(index, value : DB::Any)
@params[index] = value
end
private def set_param(index, value)
raise "not implemented for #{value.class}"
end
protected def do_close protected def do_close
super super
end end
@ -70,7 +78,7 @@ class DummyDriver < DB::Driver
@values : Array(String)? @values : Array(String)?
@@last_result_set : self? @@last_result_set : self?
@@next_column_type : Nil.class | String.class | Int32.class | Int64.class | Float32.class | Float64.class | Slice(UInt8).class @@next_column_type : Nil.class | String.class | Int32.class | Int64.class | Float32.class | Float64.class | Bytes.class
def initialize(statement, query) def initialize(statement, query)
super(statement) super(statement)
@ -114,12 +122,16 @@ class DummyDriver < DB::Driver
return nil if n == "NULL" return nil if n == "NULL"
if n == "?" if n == "?"
return @statement.params[0] return (@statement.as(DummyStatement)).params[0]
end end
return n return n
end end
def read?(t : Nil.class)
read?.as(Nil)
end
def read?(t : String.class) def read?(t : String.class)
read?.try &.to_s read?.try &.to_s
end end
@ -140,17 +152,17 @@ class DummyDriver < DB::Driver
read?(String).try &.to_f64 read?(String).try &.to_f64
end end
def read?(t : Slice(UInt8).class) def read?(t : Bytes.class)
value = read? value = read?
if value.is_a?(Nil) if value.is_a?(Nil)
value value
elsif value.is_a?(String) elsif value.is_a?(String)
ary = value.bytes ary = value.bytes
Slice.new(ary.to_unsafe, ary.size) Slice.new(ary.to_unsafe, ary.size)
elsif value.is_a?(Slice(UInt8)) elsif value.is_a?(Bytes)
value value
else else
raise "#{value} is not convertible to Slice(UInt8)" raise "#{value} is not convertible to Bytes"
end end
end end
end end

View file

@ -101,9 +101,9 @@ describe DummyDriver do
db.query("az,AZ") do |rs| db.query("az,AZ") do |rs|
rs.move_next rs.move_next
ary = [97u8, 122u8] ary = [97u8, 122u8]
rs.read(Slice(UInt8)).should eq(Slice.new(ary.to_unsafe, ary.size)) rs.read(Bytes).should eq(Bytes.new(ary.to_unsafe, ary.size))
ary = [65u8, 90u8] ary = [65u8, 90u8]
rs.read(Slice(UInt8)).should eq(Slice.new(ary.to_unsafe, ary.size)) rs.read(Bytes).should eq(Bytes.new(ary.to_unsafe, ary.size))
end end
end end
end end
@ -134,9 +134,9 @@ describe DummyDriver do
it "executes and selects blob" do it "executes and selects blob" do
with_dummy do |db| with_dummy do |db|
ary = UInt8[0x53, 0x51, 0x4C] ary = UInt8[0x53, 0x51, 0x4C]
slice = Slice.new(ary.to_unsafe, ary.size) slice = Bytes.new(ary.to_unsafe, ary.size)
DummyDriver::DummyResultSet.next_column_type = typeof(slice) DummyDriver::DummyResultSet.next_column_type = typeof(slice)
(db.scalar("?", slice) as Slice(UInt8)).to_a.should eq(ary) (db.scalar("?", slice).as(Bytes)).to_a.should eq(ary)
end end
end end
end end

View file

@ -9,7 +9,7 @@ describe DB::Statement do
it "should initialize positional params in query" do it "should initialize positional params in query" do
with_dummy do |db| with_dummy do |db|
stmt = db.prepare("the query") stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
stmt.query "a", 1, nil stmt.query "a", 1, nil
stmt.params[0].should eq("a") stmt.params[0].should eq("a")
stmt.params[1].should eq(1) stmt.params[1].should eq(1)
@ -19,7 +19,7 @@ describe DB::Statement do
it "should initialize positional params in query with array" do it "should initialize positional params in query with array" do
with_dummy do |db| with_dummy do |db|
stmt = db.prepare("the query") stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
stmt.query ["a", 1, nil] stmt.query ["a", 1, nil]
stmt.params[0].should eq("a") stmt.params[0].should eq("a")
stmt.params[1].should eq(1) stmt.params[1].should eq(1)
@ -29,7 +29,7 @@ describe DB::Statement do
it "should initialize positional params in exec" do it "should initialize positional params in exec" do
with_dummy do |db| with_dummy do |db|
stmt = db.prepare("the query") stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
stmt.exec "a", 1, nil stmt.exec "a", 1, nil
stmt.params[0].should eq("a") stmt.params[0].should eq("a")
stmt.params[1].should eq(1) stmt.params[1].should eq(1)
@ -39,7 +39,7 @@ describe DB::Statement do
it "should initialize positional params in exec with array" do it "should initialize positional params in exec with array" do
with_dummy do |db| with_dummy do |db|
stmt = db.prepare("the query") stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
stmt.exec ["a", 1, nil] stmt.exec ["a", 1, nil]
stmt.params[0].should eq("a") stmt.params[0].should eq("a")
stmt.params[1].should eq(1) stmt.params[1].should eq(1)
@ -49,7 +49,7 @@ describe DB::Statement do
it "should initialize positional params in scalar" do it "should initialize positional params in scalar" do
with_dummy do |db| with_dummy do |db|
stmt = db.prepare("the query") stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
stmt.scalar "a", 1, nil stmt.scalar "a", 1, nil
stmt.params[0].should eq("a") stmt.params[0].should eq("a")
stmt.params[1].should eq(1) stmt.params[1].should eq(1)

View file

@ -70,13 +70,15 @@ module DB
# Types supported to interface with database driver. # Types supported to interface with database driver.
# These can be used in any `ResultSet#read` or any `Database#query` related # These can be used in any `ResultSet#read` or any `Database#query` related
# method to be used as query parameters # method to be used as query parameters
TYPES = [String, Int32, Int64, Float32, Float64, Slice(UInt8)] TYPES = [Nil, String, Int32, Int64, Float32, Float64, Bytes]
# See `DB::TYPES` in `DB`. `Any` is a nillable version of the union of all types in `DB::TYPES` # See `DB::TYPES` in `DB`. `Any` is a union of all types in `DB::TYPES`
alias Any = Nil | String | Int32 | Int64 | Float32 | Float64 | Slice(UInt8) {% begin %}
alias Any = Union({{*TYPES}})
{% end %}
# Result of a `#exec` statement. # Result of a `#exec` statement.
record ExecResult, rows_affected : Int32, last_insert_id : Int64 record ExecResult, rows_affected : Int64, last_insert_id : Int64
# :nodoc: # :nodoc:
def self.driver_class(driver_name) : Driver.class def self.driver_class(driver_name) : Driver.class

View file

@ -40,7 +40,7 @@ module DB
end end
end end
# Performs the `query` discarding any response # Performs the `query` and returns an `ExecResult`
def exec(query, *args) def exec(query, *args)
prepare(query).exec(*args) prepare(query).exec(*args)
end end

View file

@ -16,8 +16,8 @@ module DB
# ### Note to implementors # ### Note to implementors
# #
# 1. Override `#move_next` to move to the next row. # 1. Override `#move_next` to move to the next row.
# 2. Override `#read?(t)` for all `t` in `DB::TYPES`. # 2. Override `#read?(t)` for all `t` in `DB::TYPES` and any other types the driver should handle.
# 3. (Optional) Override `#read(t)` for all `t` in `DB::TYPES`. # 3. (Optional) Override `#read(t)` for all `t` in `DB::TYPES` and any other.
# 4. Override `#column_count`, `#column_name`. # 4. Override `#column_count`, `#column_name`.
# 5. Override `#column_type`. It must return a type in `DB::TYPES`. # 5. Override `#column_type`. It must return a type in `DB::TYPES`.
abstract class ResultSet abstract class ResultSet
@ -59,17 +59,24 @@ module DB
# The result is one of `DB::TYPES`. # The result is one of `DB::TYPES`.
abstract def column_type(index : Int32) abstract def column_type(index : Int32)
def read(t)
read?(t).not_nil!
end
# Reads the next column as a Nil.
def read(t : Nil.class) : Nil
read?(Nil)
end
def read?(t)
raise "read?(t : #{t}) is not implemented in #{self.class}"
end
# list datatypes that must be supported form the driver # list datatypes that must be supported form the driver
# users will call read(String) or read?(String) for nillables # users will call read(String) or read?(String) for nillables
{% for t in DB::TYPES %} {% for t in DB::TYPES %}
# Reads the next column as a nillable {{t}}. # Reads the next column as a nillable {{t}}.
abstract def read?(t : {{t}}.class) : {{t}}? abstract def read?(t : {{t}}.class) : {{t}}?
# Reads the next column as a {{t}}.
def read(t : {{t}}.class) : {{t}}
read?({{t}}).not_nil!
end
{% end %} {% end %}
# def read_blob # def read_blob

View file

@ -27,18 +27,18 @@ module DB
# See `QueryMethods#exec` # See `QueryMethods#exec`
def exec def exec
perform_exec_and_release(Slice(Any).new(0)) # no overload matches ... with types Slice(NoReturn) perform_exec_and_release(Slice(Any).new(0))
end end
# See `QueryMethods#exec` # See `QueryMethods#exec`
def exec(args : Enumerable(Any)) def exec(args : Array)
perform_exec_and_release(args.to_a.to_unsafe.to_slice(args.size)) perform_exec_and_release(args)
end end
# See `QueryMethods#exec` # See `QueryMethods#exec`
def exec(*args) def exec(*args)
# TODO better way to do it # TODO better way to do it
perform_exec_and_release(args.to_a.to_unsafe.to_slice(args.size)) perform_exec_and_release(args)
end end
# See `QueryMethods#scalar` # See `QueryMethods#scalar`
@ -57,8 +57,8 @@ module DB
return rs.read?(Float32) return rs.read?(Float32)
when Float64.class when Float64.class
return rs.read?(Float64) return rs.read?(Float64)
when Slice(UInt8).class when Bytes.class
return rs.read?(Slice(UInt8)) return rs.read?(Bytes)
when Nil.class when Nil.class
return rs.read?(Int32) return rs.read?(Int32)
else else
@ -71,13 +71,23 @@ module DB
end end
# See `QueryMethods#query` # See `QueryMethods#query`
def query(*args) def query
perform_query *args perform_query Slice(Any).new(0)
end
# See `QueryMethods#query`
def query(args : Array)
perform_query args
end end
# See `QueryMethods#query` # See `QueryMethods#query`
def query(*args) def query(*args)
perform_query(*args).tap do |rs| perform_query args
end
# See `QueryMethods#query`
def query(*args)
query(*args).tap do |rs|
begin begin
yield rs yield rs
ensure ensure
@ -86,27 +96,13 @@ module DB
end end
end end
private def perform_query : ResultSet private def perform_exec_and_release(args : Enumerable) : ExecResult
perform_query(Slice(Any).new(0)) # no overload matches ... with types Slice(NoReturn)
end
private def perform_query(args : Enumerable(Any)) : ResultSet
# TODO better way to do it
perform_query(args.to_a.to_unsafe.to_slice(args.size))
end
private def perform_query(*args) : ResultSet
# TODO better way to do it
perform_query(args.to_a.to_unsafe.to_slice(args.size))
end
private def perform_exec_and_release(args : Slice(Any)) : ExecResult
perform_exec(args).tap do perform_exec(args).tap do
release_connection release_connection
end end
end end
protected abstract def perform_query(args : Slice(Any)) : ResultSet protected abstract def perform_query(args : Enumerable) : ResultSet
protected abstract def perform_exec(args : Slice(Any)) : ExecResult protected abstract def perform_exec(args : Enumerable) : ExecResult
end end
end end