fixes #5. Merge branch 'feature/5-type-extensibility'

This commit is contained in:
Brian J. Cardiff 2016-06-23 15:12:20 -03:00
commit 22db7d1043
8 changed files with 373 additions and 57 deletions

View File

@ -0,0 +1,299 @@
require "./spec_helper"
module GenericResultSet
@index = 0
def move_next
@index = 0
true
end
def column_count : Int32
@row.size
end
def column_name(index : Int32) : String
index.to_s
end
def column_type(index : Int32)
@row[index].class
end
{% for t in DB::TYPES %}
# Reads the next column as a nillable {{t}}.
def read?(t : {{t}}.class) : {{t}}?
read_and_move_next_column as {{t}}?
end
{% end %}
def read_and_move_next_column
@index += 1
@row[@index - 1]
end
end
class FooValue
def initialize(@value : Int32)
end
def value
@value
end
end
class FooDriver < DB::Driver
alias Any = DB::Any | FooValue
@@row = [] of Any
def self.fake_row=(row : Array(Any))
@@row = row
end
def self.fake_row
@@row
end
def build_connection(db : DB::Database) : DB::Connection
FooConnection.new(db)
end
class FooConnection < DB::Connection
def build_statement(query)
FooStatement.new(self)
end
end
class FooStatement < DB::Statement
protected def perform_query(args : Enumerable) : DB::ResultSet
args.each { |arg| process_arg arg }
FooResultSet.new(self, FooDriver.fake_row)
end
protected def perform_exec(args : Enumerable) : DB::ExecResult
args.each { |arg| process_arg arg }
DB::ExecResult.new 0i64, 0i64
end
private def process_arg(value : FooDriver::Any)
end
private def process_arg(value)
raise "#{self.class} does not support #{value.class} params"
end
end
class FooResultSet < DB::ResultSet
include GenericResultSet
def initialize(statement, @row : Array(FooDriver::Any))
super(statement)
end
def read?(t : FooValue.class) : FooValue?
read_and_move_next_column.as(FooValue?)
end
end
end
DB.register_driver "foo", FooDriver
class BarValue
getter value
def initialize(@value : Int32)
end
end
class BarDriver < DB::Driver
alias Any = DB::Any | BarValue
@@row = [] of Any
def self.fake_row=(row : Array(Any))
@@row = row
end
def self.fake_row
@@row
end
def build_connection(db : DB::Database) : DB::Connection
BarConnection.new(db)
end
class BarConnection < DB::Connection
def build_statement(query)
BarStatement.new(self)
end
end
class BarStatement < DB::Statement
protected def perform_query(args : Enumerable) : DB::ResultSet
args.each { |arg| process_arg arg }
BarResultSet.new(self, BarDriver.fake_row)
end
protected def perform_exec(args : Enumerable) : DB::ExecResult
args.each { |arg| process_arg arg }
DB::ExecResult.new 0i64, 0i64
end
private def process_arg(value : BarDriver::Any)
end
private def process_arg(value)
raise "#{self.class} does not support #{value.class} params"
end
end
class BarResultSet < DB::ResultSet
include GenericResultSet
def initialize(statement, @row : Array(BarDriver::Any))
super(statement)
end
def read?(t : BarValue.class) : BarValue?
read_and_move_next_column.as(BarValue?)
end
end
end
DB.register_driver "bar", BarDriver
describe DB do
it "should be able to register multiple drivers" do
DB.open("foo://host").driver.should be_a(FooDriver)
DB.open("bar://host").driver.should be_a(BarDriver)
end
it "Foo and Bar drivers should return fake_row" do
with_witness do |w|
DB.open("foo://host") do |db|
# TODO somehow FooValue.new(99) is needed otherwise the read_object assertion fail
FooDriver.fake_row = [1, "string", FooValue.new(3), FooValue.new(99)] of FooDriver::Any
db.query "query" do |rs|
w.check
rs.move_next
rs.read?(Int32).should eq(1)
rs.read?(String).should eq("string")
rs.read(FooValue).value.should eq(3)
end
end
end
with_witness do |w|
DB.open("bar://host") do |db|
# TODO somehow BarValue.new(99) is needed otherwise the read_object assertion fail
BarDriver.fake_row = [BarValue.new(4), "lorem", 1.0, BarValue.new(99)] of BarDriver::Any
db.query "query" do |rs|
w.check
rs.move_next
rs.read(BarValue).value.should eq(4)
rs.read?(String).should eq("lorem")
rs.read?(Float64).should eq(1.0)
end
end
end
end
it "Foo and Bar drivers should not implement each other read" do
with_witness do |w|
DB.open("foo://host") do |db|
FooDriver.fake_row = [1] of FooDriver::Any
db.query "query" do |rs|
rs.move_next
expect_raises Exception, "read?(t : BarValue) is not implemented in FooDriver::FooResultSet" do
w.check
rs.read(BarValue)
end
end
end
end
with_witness do |w|
DB.open("bar://host") do |db|
BarDriver.fake_row = [1] of BarDriver::Any
db.query "query" do |rs|
rs.move_next
expect_raises Exception, "read?(t : FooValue) is not implemented in BarDriver::BarResultSet" do
w.check
rs.read(FooValue)
end
end
end
end
end
it "allow custom types to be used as arguments for query" do
DB.open("foo://host") do |db|
FooDriver.fake_row = [1, "string"] of FooDriver::Any
db.query "query" { }
db.query "query", 1 { }
db.query "query", 1, "string" { }
db.query("query", Bytes.new(4)) { }
db.query("query", 1, "string", FooValue.new(5)) { }
db.query "query", [1, "string", FooValue.new(5)] { }
db.query("query").close
db.query("query", 1).close
db.query("query", 1, "string").close
db.query("query", Bytes.new(4)).close
db.query("query", 1, "string", FooValue.new(5)).close
db.query("query", [1, "string", FooValue.new(5)]).close
end
DB.open("bar://host") do |db|
BarDriver.fake_row = [1, "string"] of BarDriver::Any
db.query "query" { }
db.query "query", 1 { }
db.query "query", 1, "string" { }
db.query("query", Bytes.new(4)) { }
db.query("query", 1, "string", BarValue.new(5)) { }
db.query "query", [1, "string", BarValue.new(5)] { }
db.query("query").close
db.query("query", 1).close
db.query("query", 1, "string").close
db.query("query", Bytes.new(4)).close
db.query("query", 1, "string", BarValue.new(5)).close
db.query("query", [1, "string", BarValue.new(5)]).close
end
end
it "allow custom types to be used as arguments for exec" do
DB.open("foo://host") do |db|
FooDriver.fake_row = [1, "string"] of FooDriver::Any
db.exec("query")
db.exec("query", 1)
db.exec("query", 1, "string")
db.exec("query", Bytes.new(4))
db.exec("query", 1, "string", FooValue.new(5))
db.exec("query", [1, "string", FooValue.new(5)])
end
DB.open("bar://host") do |db|
BarDriver.fake_row = [1, "string"] of BarDriver::Any
db.exec("query")
db.exec("query", 1)
db.exec("query", 1, "string")
db.exec("query", Bytes.new(4))
db.exec("query", 1, "string", BarValue.new(5))
db.exec("query", [1, "string", BarValue.new(5)])
end
end
it "Foo and Bar drivers should not implement each other params" do
DB.open("foo://host") do |db|
expect_raises Exception, "FooDriver::FooStatement does not support BarValue params" do
db.exec("query", [BarValue.new(5)])
end
end
DB.open("bar://host") do |db|
expect_raises Exception, "BarDriver::BarStatement does not support FooValue params" do
db.exec("query", [FooValue.new(5)])
end
end
end
end

View File

@ -42,23 +42,31 @@ class DummyDriver < DB::Driver
super(connection)
end
protected def perform_query(args : Slice(DB::Any))
protected def perform_query(args : Enumerable)
set_params args
DummyResultSet.new self, @query
end
protected def perform_exec(args : Slice(DB::Any))
protected def perform_exec(args : Enumerable)
set_params args
DB::ExecResult.new 0, 0_i64
DB::ExecResult.new 0i64, 0_i64
end
private def set_params(args)
@params.clear
args.each_with_index do |arg, index|
@params[index] = arg
set_param(index, arg)
end
end
private def set_param(index, value : DB::Any)
@params[index] = value
end
private def set_param(index, value)
raise "not implemented for #{value.class}"
end
protected def do_close
super
end
@ -70,7 +78,7 @@ class DummyDriver < DB::Driver
@values : Array(String)?
@@last_result_set : self?
@@next_column_type : Nil.class | String.class | Int32.class | Int64.class | Float32.class | Float64.class | Slice(UInt8).class
@@next_column_type : Nil.class | String.class | Int32.class | Int64.class | Float32.class | Float64.class | Bytes.class
def initialize(statement, query)
super(statement)
@ -114,12 +122,16 @@ class DummyDriver < DB::Driver
return nil if n == "NULL"
if n == "?"
return @statement.params[0]
return (@statement.as(DummyStatement)).params[0]
end
return n
end
def read?(t : Nil.class)
read?.as(Nil)
end
def read?(t : String.class)
read?.try &.to_s
end
@ -140,17 +152,17 @@ class DummyDriver < DB::Driver
read?(String).try &.to_f64
end
def read?(t : Slice(UInt8).class)
def read?(t : Bytes.class)
value = read?
if value.is_a?(Nil)
value
elsif value.is_a?(String)
ary = value.bytes
Slice.new(ary.to_unsafe, ary.size)
elsif value.is_a?(Slice(UInt8))
elsif value.is_a?(Bytes)
value
else
raise "#{value} is not convertible to Slice(UInt8)"
raise "#{value} is not convertible to Bytes"
end
end
end

View File

@ -101,9 +101,9 @@ describe DummyDriver do
db.query("az,AZ") do |rs|
rs.move_next
ary = [97u8, 122u8]
rs.read(Slice(UInt8)).should eq(Slice.new(ary.to_unsafe, ary.size))
rs.read(Bytes).should eq(Bytes.new(ary.to_unsafe, ary.size))
ary = [65u8, 90u8]
rs.read(Slice(UInt8)).should eq(Slice.new(ary.to_unsafe, ary.size))
rs.read(Bytes).should eq(Bytes.new(ary.to_unsafe, ary.size))
end
end
end
@ -134,9 +134,9 @@ describe DummyDriver do
it "executes and selects blob" do
with_dummy do |db|
ary = UInt8[0x53, 0x51, 0x4C]
slice = Slice.new(ary.to_unsafe, ary.size)
slice = Bytes.new(ary.to_unsafe, ary.size)
DummyDriver::DummyResultSet.next_column_type = typeof(slice)
(db.scalar("?", slice) as Slice(UInt8)).to_a.should eq(ary)
(db.scalar("?", slice).as(Bytes)).to_a.should eq(ary)
end
end
end

View File

@ -9,7 +9,7 @@ describe DB::Statement do
it "should initialize positional params in query" do
with_dummy do |db|
stmt = db.prepare("the query")
stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
stmt.query "a", 1, nil
stmt.params[0].should eq("a")
stmt.params[1].should eq(1)
@ -19,7 +19,7 @@ describe DB::Statement do
it "should initialize positional params in query with array" do
with_dummy do |db|
stmt = db.prepare("the query")
stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
stmt.query ["a", 1, nil]
stmt.params[0].should eq("a")
stmt.params[1].should eq(1)
@ -29,7 +29,7 @@ describe DB::Statement do
it "should initialize positional params in exec" do
with_dummy do |db|
stmt = db.prepare("the query")
stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
stmt.exec "a", 1, nil
stmt.params[0].should eq("a")
stmt.params[1].should eq(1)
@ -39,7 +39,7 @@ describe DB::Statement do
it "should initialize positional params in exec with array" do
with_dummy do |db|
stmt = db.prepare("the query")
stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
stmt.exec ["a", 1, nil]
stmt.params[0].should eq("a")
stmt.params[1].should eq(1)
@ -49,7 +49,7 @@ describe DB::Statement do
it "should initialize positional params in scalar" do
with_dummy do |db|
stmt = db.prepare("the query")
stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
stmt.scalar "a", 1, nil
stmt.params[0].should eq("a")
stmt.params[1].should eq(1)

View File

@ -70,13 +70,15 @@ module DB
# Types supported to interface with database driver.
# These can be used in any `ResultSet#read` or any `Database#query` related
# method to be used as query parameters
TYPES = [String, Int32, Int64, Float32, Float64, Slice(UInt8)]
TYPES = [Nil, String, Int32, Int64, Float32, Float64, Bytes]
# See `DB::TYPES` in `DB`. `Any` is a nillable version of the union of all types in `DB::TYPES`
alias Any = Nil | String | Int32 | Int64 | Float32 | Float64 | Slice(UInt8)
# See `DB::TYPES` in `DB`. `Any` is a union of all types in `DB::TYPES`
{% begin %}
alias Any = Union({{*TYPES}})
{% end %}
# Result of a `#exec` statement.
record ExecResult, rows_affected : Int32, last_insert_id : Int64
record ExecResult, rows_affected : Int64, last_insert_id : Int64
# :nodoc:
def self.driver_class(driver_name) : Driver.class

View File

@ -40,7 +40,7 @@ module DB
end
end
# Performs the `query` discarding any response
# Performs the `query` and returns an `ExecResult`
def exec(query, *args)
prepare(query).exec(*args)
end

View File

@ -16,8 +16,8 @@ module DB
# ### Note to implementors
#
# 1. Override `#move_next` to move to the next row.
# 2. Override `#read?(t)` for all `t` in `DB::TYPES`.
# 3. (Optional) Override `#read(t)` for all `t` in `DB::TYPES`.
# 2. Override `#read?(t)` for all `t` in `DB::TYPES` and any other types the driver should handle.
# 3. (Optional) Override `#read(t)` for all `t` in `DB::TYPES` and any other.
# 4. Override `#column_count`, `#column_name`.
# 5. Override `#column_type`. It must return a type in `DB::TYPES`.
abstract class ResultSet
@ -59,17 +59,24 @@ module DB
# The result is one of `DB::TYPES`.
abstract def column_type(index : Int32)
def read(t)
read?(t).not_nil!
end
# Reads the next column as a Nil.
def read(t : Nil.class) : Nil
read?(Nil)
end
def read?(t)
raise "read?(t : #{t}) is not implemented in #{self.class}"
end
# list datatypes that must be supported form the driver
# users will call read(String) or read?(String) for nillables
{% for t in DB::TYPES %}
# Reads the next column as a nillable {{t}}.
abstract def read?(t : {{t}}.class) : {{t}}?
# Reads the next column as a {{t}}.
def read(t : {{t}}.class) : {{t}}
read?({{t}}).not_nil!
end
{% end %}
# def read_blob

View File

@ -27,18 +27,18 @@ module DB
# See `QueryMethods#exec`
def exec
perform_exec_and_release(Slice(Any).new(0)) # no overload matches ... with types Slice(NoReturn)
perform_exec_and_release(Slice(Any).new(0))
end
# See `QueryMethods#exec`
def exec(args : Enumerable(Any))
perform_exec_and_release(args.to_a.to_unsafe.to_slice(args.size))
def exec(args : Array)
perform_exec_and_release(args)
end
# See `QueryMethods#exec`
def exec(*args)
# TODO better way to do it
perform_exec_and_release(args.to_a.to_unsafe.to_slice(args.size))
perform_exec_and_release(args)
end
# See `QueryMethods#scalar`
@ -57,8 +57,8 @@ module DB
return rs.read?(Float32)
when Float64.class
return rs.read?(Float64)
when Slice(UInt8).class
return rs.read?(Slice(UInt8))
when Bytes.class
return rs.read?(Bytes)
when Nil.class
return rs.read?(Int32)
else
@ -71,13 +71,23 @@ module DB
end
# See `QueryMethods#query`
def query(*args)
perform_query *args
def query
perform_query Slice(Any).new(0)
end
# See `QueryMethods#query`
def query(args : Array)
perform_query args
end
# See `QueryMethods#query`
def query(*args)
perform_query(*args).tap do |rs|
perform_query args
end
# See `QueryMethods#query`
def query(*args)
query(*args).tap do |rs|
begin
yield rs
ensure
@ -86,27 +96,13 @@ module DB
end
end
private def perform_query : ResultSet
perform_query(Slice(Any).new(0)) # no overload matches ... with types Slice(NoReturn)
end
private def perform_query(args : Enumerable(Any)) : ResultSet
# TODO better way to do it
perform_query(args.to_a.to_unsafe.to_slice(args.size))
end
private def perform_query(*args) : ResultSet
# TODO better way to do it
perform_query(args.to_a.to_unsafe.to_slice(args.size))
end
private def perform_exec_and_release(args : Slice(Any)) : ExecResult
private def perform_exec_and_release(args : Enumerable) : ExecResult
perform_exec(args).tap do
release_connection
end
end
protected abstract def perform_query(args : Slice(Any)) : ResultSet
protected abstract def perform_exec(args : Slice(Any)) : ExecResult
protected abstract def perform_query(args : Enumerable) : ResultSet
protected abstract def perform_exec(args : Enumerable) : ExecResult
end
end