fixes #5. Merge branch 'feature/5-type-extensibility'
This commit is contained in:
commit
22db7d1043
|
@ -0,0 +1,299 @@
|
|||
require "./spec_helper"
|
||||
|
||||
module GenericResultSet
|
||||
@index = 0
|
||||
|
||||
def move_next
|
||||
@index = 0
|
||||
true
|
||||
end
|
||||
|
||||
def column_count : Int32
|
||||
@row.size
|
||||
end
|
||||
|
||||
def column_name(index : Int32) : String
|
||||
index.to_s
|
||||
end
|
||||
|
||||
def column_type(index : Int32)
|
||||
@row[index].class
|
||||
end
|
||||
|
||||
{% for t in DB::TYPES %}
|
||||
# Reads the next column as a nillable {{t}}.
|
||||
def read?(t : {{t}}.class) : {{t}}?
|
||||
read_and_move_next_column as {{t}}?
|
||||
end
|
||||
{% end %}
|
||||
|
||||
def read_and_move_next_column
|
||||
@index += 1
|
||||
@row[@index - 1]
|
||||
end
|
||||
end
|
||||
|
||||
class FooValue
|
||||
def initialize(@value : Int32)
|
||||
end
|
||||
|
||||
def value
|
||||
@value
|
||||
end
|
||||
end
|
||||
|
||||
class FooDriver < DB::Driver
|
||||
alias Any = DB::Any | FooValue
|
||||
@@row = [] of Any
|
||||
|
||||
def self.fake_row=(row : Array(Any))
|
||||
@@row = row
|
||||
end
|
||||
|
||||
def self.fake_row
|
||||
@@row
|
||||
end
|
||||
|
||||
def build_connection(db : DB::Database) : DB::Connection
|
||||
FooConnection.new(db)
|
||||
end
|
||||
|
||||
class FooConnection < DB::Connection
|
||||
def build_statement(query)
|
||||
FooStatement.new(self)
|
||||
end
|
||||
end
|
||||
|
||||
class FooStatement < DB::Statement
|
||||
protected def perform_query(args : Enumerable) : DB::ResultSet
|
||||
args.each { |arg| process_arg arg }
|
||||
FooResultSet.new(self, FooDriver.fake_row)
|
||||
end
|
||||
|
||||
protected def perform_exec(args : Enumerable) : DB::ExecResult
|
||||
args.each { |arg| process_arg arg }
|
||||
DB::ExecResult.new 0i64, 0i64
|
||||
end
|
||||
|
||||
private def process_arg(value : FooDriver::Any)
|
||||
end
|
||||
|
||||
private def process_arg(value)
|
||||
raise "#{self.class} does not support #{value.class} params"
|
||||
end
|
||||
end
|
||||
|
||||
class FooResultSet < DB::ResultSet
|
||||
include GenericResultSet
|
||||
|
||||
def initialize(statement, @row : Array(FooDriver::Any))
|
||||
super(statement)
|
||||
end
|
||||
|
||||
def read?(t : FooValue.class) : FooValue?
|
||||
read_and_move_next_column.as(FooValue?)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
DB.register_driver "foo", FooDriver
|
||||
|
||||
class BarValue
|
||||
getter value
|
||||
|
||||
def initialize(@value : Int32)
|
||||
end
|
||||
end
|
||||
|
||||
class BarDriver < DB::Driver
|
||||
alias Any = DB::Any | BarValue
|
||||
@@row = [] of Any
|
||||
|
||||
def self.fake_row=(row : Array(Any))
|
||||
@@row = row
|
||||
end
|
||||
|
||||
def self.fake_row
|
||||
@@row
|
||||
end
|
||||
|
||||
def build_connection(db : DB::Database) : DB::Connection
|
||||
BarConnection.new(db)
|
||||
end
|
||||
|
||||
class BarConnection < DB::Connection
|
||||
def build_statement(query)
|
||||
BarStatement.new(self)
|
||||
end
|
||||
end
|
||||
|
||||
class BarStatement < DB::Statement
|
||||
protected def perform_query(args : Enumerable) : DB::ResultSet
|
||||
args.each { |arg| process_arg arg }
|
||||
BarResultSet.new(self, BarDriver.fake_row)
|
||||
end
|
||||
|
||||
protected def perform_exec(args : Enumerable) : DB::ExecResult
|
||||
args.each { |arg| process_arg arg }
|
||||
DB::ExecResult.new 0i64, 0i64
|
||||
end
|
||||
|
||||
private def process_arg(value : BarDriver::Any)
|
||||
end
|
||||
|
||||
private def process_arg(value)
|
||||
raise "#{self.class} does not support #{value.class} params"
|
||||
end
|
||||
end
|
||||
|
||||
class BarResultSet < DB::ResultSet
|
||||
include GenericResultSet
|
||||
|
||||
def initialize(statement, @row : Array(BarDriver::Any))
|
||||
super(statement)
|
||||
end
|
||||
|
||||
def read?(t : BarValue.class) : BarValue?
|
||||
read_and_move_next_column.as(BarValue?)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
DB.register_driver "bar", BarDriver
|
||||
|
||||
describe DB do
|
||||
it "should be able to register multiple drivers" do
|
||||
DB.open("foo://host").driver.should be_a(FooDriver)
|
||||
DB.open("bar://host").driver.should be_a(BarDriver)
|
||||
end
|
||||
|
||||
it "Foo and Bar drivers should return fake_row" do
|
||||
with_witness do |w|
|
||||
DB.open("foo://host") do |db|
|
||||
# TODO somehow FooValue.new(99) is needed otherwise the read_object assertion fail
|
||||
FooDriver.fake_row = [1, "string", FooValue.new(3), FooValue.new(99)] of FooDriver::Any
|
||||
db.query "query" do |rs|
|
||||
w.check
|
||||
rs.move_next
|
||||
rs.read?(Int32).should eq(1)
|
||||
rs.read?(String).should eq("string")
|
||||
rs.read(FooValue).value.should eq(3)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
with_witness do |w|
|
||||
DB.open("bar://host") do |db|
|
||||
# TODO somehow BarValue.new(99) is needed otherwise the read_object assertion fail
|
||||
BarDriver.fake_row = [BarValue.new(4), "lorem", 1.0, BarValue.new(99)] of BarDriver::Any
|
||||
db.query "query" do |rs|
|
||||
w.check
|
||||
rs.move_next
|
||||
rs.read(BarValue).value.should eq(4)
|
||||
rs.read?(String).should eq("lorem")
|
||||
rs.read?(Float64).should eq(1.0)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
it "Foo and Bar drivers should not implement each other read" do
|
||||
with_witness do |w|
|
||||
DB.open("foo://host") do |db|
|
||||
FooDriver.fake_row = [1] of FooDriver::Any
|
||||
db.query "query" do |rs|
|
||||
rs.move_next
|
||||
expect_raises Exception, "read?(t : BarValue) is not implemented in FooDriver::FooResultSet" do
|
||||
w.check
|
||||
rs.read(BarValue)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
with_witness do |w|
|
||||
DB.open("bar://host") do |db|
|
||||
BarDriver.fake_row = [1] of BarDriver::Any
|
||||
db.query "query" do |rs|
|
||||
rs.move_next
|
||||
expect_raises Exception, "read?(t : FooValue) is not implemented in BarDriver::BarResultSet" do
|
||||
w.check
|
||||
rs.read(FooValue)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
it "allow custom types to be used as arguments for query" do
|
||||
DB.open("foo://host") do |db|
|
||||
FooDriver.fake_row = [1, "string"] of FooDriver::Any
|
||||
db.query "query" { }
|
||||
db.query "query", 1 { }
|
||||
db.query "query", 1, "string" { }
|
||||
db.query("query", Bytes.new(4)) { }
|
||||
db.query("query", 1, "string", FooValue.new(5)) { }
|
||||
db.query "query", [1, "string", FooValue.new(5)] { }
|
||||
|
||||
db.query("query").close
|
||||
db.query("query", 1).close
|
||||
db.query("query", 1, "string").close
|
||||
db.query("query", Bytes.new(4)).close
|
||||
db.query("query", 1, "string", FooValue.new(5)).close
|
||||
db.query("query", [1, "string", FooValue.new(5)]).close
|
||||
end
|
||||
|
||||
DB.open("bar://host") do |db|
|
||||
BarDriver.fake_row = [1, "string"] of BarDriver::Any
|
||||
db.query "query" { }
|
||||
db.query "query", 1 { }
|
||||
db.query "query", 1, "string" { }
|
||||
db.query("query", Bytes.new(4)) { }
|
||||
db.query("query", 1, "string", BarValue.new(5)) { }
|
||||
db.query "query", [1, "string", BarValue.new(5)] { }
|
||||
|
||||
db.query("query").close
|
||||
db.query("query", 1).close
|
||||
db.query("query", 1, "string").close
|
||||
db.query("query", Bytes.new(4)).close
|
||||
db.query("query", 1, "string", BarValue.new(5)).close
|
||||
db.query("query", [1, "string", BarValue.new(5)]).close
|
||||
end
|
||||
end
|
||||
|
||||
it "allow custom types to be used as arguments for exec" do
|
||||
DB.open("foo://host") do |db|
|
||||
FooDriver.fake_row = [1, "string"] of FooDriver::Any
|
||||
db.exec("query")
|
||||
db.exec("query", 1)
|
||||
db.exec("query", 1, "string")
|
||||
db.exec("query", Bytes.new(4))
|
||||
db.exec("query", 1, "string", FooValue.new(5))
|
||||
db.exec("query", [1, "string", FooValue.new(5)])
|
||||
end
|
||||
|
||||
DB.open("bar://host") do |db|
|
||||
BarDriver.fake_row = [1, "string"] of BarDriver::Any
|
||||
db.exec("query")
|
||||
db.exec("query", 1)
|
||||
db.exec("query", 1, "string")
|
||||
db.exec("query", Bytes.new(4))
|
||||
db.exec("query", 1, "string", BarValue.new(5))
|
||||
db.exec("query", [1, "string", BarValue.new(5)])
|
||||
end
|
||||
end
|
||||
|
||||
it "Foo and Bar drivers should not implement each other params" do
|
||||
DB.open("foo://host") do |db|
|
||||
expect_raises Exception, "FooDriver::FooStatement does not support BarValue params" do
|
||||
db.exec("query", [BarValue.new(5)])
|
||||
end
|
||||
end
|
||||
|
||||
DB.open("bar://host") do |db|
|
||||
expect_raises Exception, "BarDriver::BarStatement does not support FooValue params" do
|
||||
db.exec("query", [FooValue.new(5)])
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -42,23 +42,31 @@ class DummyDriver < DB::Driver
|
|||
super(connection)
|
||||
end
|
||||
|
||||
protected def perform_query(args : Slice(DB::Any))
|
||||
protected def perform_query(args : Enumerable)
|
||||
set_params args
|
||||
DummyResultSet.new self, @query
|
||||
end
|
||||
|
||||
protected def perform_exec(args : Slice(DB::Any))
|
||||
protected def perform_exec(args : Enumerable)
|
||||
set_params args
|
||||
DB::ExecResult.new 0, 0_i64
|
||||
DB::ExecResult.new 0i64, 0_i64
|
||||
end
|
||||
|
||||
private def set_params(args)
|
||||
@params.clear
|
||||
args.each_with_index do |arg, index|
|
||||
@params[index] = arg
|
||||
set_param(index, arg)
|
||||
end
|
||||
end
|
||||
|
||||
private def set_param(index, value : DB::Any)
|
||||
@params[index] = value
|
||||
end
|
||||
|
||||
private def set_param(index, value)
|
||||
raise "not implemented for #{value.class}"
|
||||
end
|
||||
|
||||
protected def do_close
|
||||
super
|
||||
end
|
||||
|
@ -70,7 +78,7 @@ class DummyDriver < DB::Driver
|
|||
@values : Array(String)?
|
||||
|
||||
@@last_result_set : self?
|
||||
@@next_column_type : Nil.class | String.class | Int32.class | Int64.class | Float32.class | Float64.class | Slice(UInt8).class
|
||||
@@next_column_type : Nil.class | String.class | Int32.class | Int64.class | Float32.class | Float64.class | Bytes.class
|
||||
|
||||
def initialize(statement, query)
|
||||
super(statement)
|
||||
|
@ -114,12 +122,16 @@ class DummyDriver < DB::Driver
|
|||
return nil if n == "NULL"
|
||||
|
||||
if n == "?"
|
||||
return @statement.params[0]
|
||||
return (@statement.as(DummyStatement)).params[0]
|
||||
end
|
||||
|
||||
return n
|
||||
end
|
||||
|
||||
def read?(t : Nil.class)
|
||||
read?.as(Nil)
|
||||
end
|
||||
|
||||
def read?(t : String.class)
|
||||
read?.try &.to_s
|
||||
end
|
||||
|
@ -140,17 +152,17 @@ class DummyDriver < DB::Driver
|
|||
read?(String).try &.to_f64
|
||||
end
|
||||
|
||||
def read?(t : Slice(UInt8).class)
|
||||
def read?(t : Bytes.class)
|
||||
value = read?
|
||||
if value.is_a?(Nil)
|
||||
value
|
||||
elsif value.is_a?(String)
|
||||
ary = value.bytes
|
||||
Slice.new(ary.to_unsafe, ary.size)
|
||||
elsif value.is_a?(Slice(UInt8))
|
||||
elsif value.is_a?(Bytes)
|
||||
value
|
||||
else
|
||||
raise "#{value} is not convertible to Slice(UInt8)"
|
||||
raise "#{value} is not convertible to Bytes"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -101,9 +101,9 @@ describe DummyDriver do
|
|||
db.query("az,AZ") do |rs|
|
||||
rs.move_next
|
||||
ary = [97u8, 122u8]
|
||||
rs.read(Slice(UInt8)).should eq(Slice.new(ary.to_unsafe, ary.size))
|
||||
rs.read(Bytes).should eq(Bytes.new(ary.to_unsafe, ary.size))
|
||||
ary = [65u8, 90u8]
|
||||
rs.read(Slice(UInt8)).should eq(Slice.new(ary.to_unsafe, ary.size))
|
||||
rs.read(Bytes).should eq(Bytes.new(ary.to_unsafe, ary.size))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -134,9 +134,9 @@ describe DummyDriver do
|
|||
it "executes and selects blob" do
|
||||
with_dummy do |db|
|
||||
ary = UInt8[0x53, 0x51, 0x4C]
|
||||
slice = Slice.new(ary.to_unsafe, ary.size)
|
||||
slice = Bytes.new(ary.to_unsafe, ary.size)
|
||||
DummyDriver::DummyResultSet.next_column_type = typeof(slice)
|
||||
(db.scalar("?", slice) as Slice(UInt8)).to_a.should eq(ary)
|
||||
(db.scalar("?", slice).as(Bytes)).to_a.should eq(ary)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -9,7 +9,7 @@ describe DB::Statement do
|
|||
|
||||
it "should initialize positional params in query" do
|
||||
with_dummy do |db|
|
||||
stmt = db.prepare("the query")
|
||||
stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
|
||||
stmt.query "a", 1, nil
|
||||
stmt.params[0].should eq("a")
|
||||
stmt.params[1].should eq(1)
|
||||
|
@ -19,7 +19,7 @@ describe DB::Statement do
|
|||
|
||||
it "should initialize positional params in query with array" do
|
||||
with_dummy do |db|
|
||||
stmt = db.prepare("the query")
|
||||
stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
|
||||
stmt.query ["a", 1, nil]
|
||||
stmt.params[0].should eq("a")
|
||||
stmt.params[1].should eq(1)
|
||||
|
@ -29,7 +29,7 @@ describe DB::Statement do
|
|||
|
||||
it "should initialize positional params in exec" do
|
||||
with_dummy do |db|
|
||||
stmt = db.prepare("the query")
|
||||
stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
|
||||
stmt.exec "a", 1, nil
|
||||
stmt.params[0].should eq("a")
|
||||
stmt.params[1].should eq(1)
|
||||
|
@ -39,7 +39,7 @@ describe DB::Statement do
|
|||
|
||||
it "should initialize positional params in exec with array" do
|
||||
with_dummy do |db|
|
||||
stmt = db.prepare("the query")
|
||||
stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
|
||||
stmt.exec ["a", 1, nil]
|
||||
stmt.params[0].should eq("a")
|
||||
stmt.params[1].should eq(1)
|
||||
|
@ -49,7 +49,7 @@ describe DB::Statement do
|
|||
|
||||
it "should initialize positional params in scalar" do
|
||||
with_dummy do |db|
|
||||
stmt = db.prepare("the query")
|
||||
stmt = db.prepare("the query").as(DummyDriver::DummyStatement)
|
||||
stmt.scalar "a", 1, nil
|
||||
stmt.params[0].should eq("a")
|
||||
stmt.params[1].should eq(1)
|
||||
|
|
10
src/db.cr
10
src/db.cr
|
@ -70,13 +70,15 @@ module DB
|
|||
# Types supported to interface with database driver.
|
||||
# These can be used in any `ResultSet#read` or any `Database#query` related
|
||||
# method to be used as query parameters
|
||||
TYPES = [String, Int32, Int64, Float32, Float64, Slice(UInt8)]
|
||||
TYPES = [Nil, String, Int32, Int64, Float32, Float64, Bytes]
|
||||
|
||||
# See `DB::TYPES` in `DB`. `Any` is a nillable version of the union of all types in `DB::TYPES`
|
||||
alias Any = Nil | String | Int32 | Int64 | Float32 | Float64 | Slice(UInt8)
|
||||
# See `DB::TYPES` in `DB`. `Any` is a union of all types in `DB::TYPES`
|
||||
{% begin %}
|
||||
alias Any = Union({{*TYPES}})
|
||||
{% end %}
|
||||
|
||||
# Result of a `#exec` statement.
|
||||
record ExecResult, rows_affected : Int32, last_insert_id : Int64
|
||||
record ExecResult, rows_affected : Int64, last_insert_id : Int64
|
||||
|
||||
# :nodoc:
|
||||
def self.driver_class(driver_name) : Driver.class
|
||||
|
|
|
@ -40,7 +40,7 @@ module DB
|
|||
end
|
||||
end
|
||||
|
||||
# Performs the `query` discarding any response
|
||||
# Performs the `query` and returns an `ExecResult`
|
||||
def exec(query, *args)
|
||||
prepare(query).exec(*args)
|
||||
end
|
||||
|
|
|
@ -16,8 +16,8 @@ module DB
|
|||
# ### Note to implementors
|
||||
#
|
||||
# 1. Override `#move_next` to move to the next row.
|
||||
# 2. Override `#read?(t)` for all `t` in `DB::TYPES`.
|
||||
# 3. (Optional) Override `#read(t)` for all `t` in `DB::TYPES`.
|
||||
# 2. Override `#read?(t)` for all `t` in `DB::TYPES` and any other types the driver should handle.
|
||||
# 3. (Optional) Override `#read(t)` for all `t` in `DB::TYPES` and any other.
|
||||
# 4. Override `#column_count`, `#column_name`.
|
||||
# 5. Override `#column_type`. It must return a type in `DB::TYPES`.
|
||||
abstract class ResultSet
|
||||
|
@ -59,17 +59,24 @@ module DB
|
|||
# The result is one of `DB::TYPES`.
|
||||
abstract def column_type(index : Int32)
|
||||
|
||||
def read(t)
|
||||
read?(t).not_nil!
|
||||
end
|
||||
|
||||
# Reads the next column as a Nil.
|
||||
def read(t : Nil.class) : Nil
|
||||
read?(Nil)
|
||||
end
|
||||
|
||||
def read?(t)
|
||||
raise "read?(t : #{t}) is not implemented in #{self.class}"
|
||||
end
|
||||
|
||||
# list datatypes that must be supported form the driver
|
||||
# users will call read(String) or read?(String) for nillables
|
||||
|
||||
{% for t in DB::TYPES %}
|
||||
# Reads the next column as a nillable {{t}}.
|
||||
abstract def read?(t : {{t}}.class) : {{t}}?
|
||||
|
||||
# Reads the next column as a {{t}}.
|
||||
def read(t : {{t}}.class) : {{t}}
|
||||
read?({{t}}).not_nil!
|
||||
end
|
||||
{% end %}
|
||||
|
||||
# def read_blob
|
||||
|
|
|
@ -27,18 +27,18 @@ module DB
|
|||
|
||||
# See `QueryMethods#exec`
|
||||
def exec
|
||||
perform_exec_and_release(Slice(Any).new(0)) # no overload matches ... with types Slice(NoReturn)
|
||||
perform_exec_and_release(Slice(Any).new(0))
|
||||
end
|
||||
|
||||
# See `QueryMethods#exec`
|
||||
def exec(args : Enumerable(Any))
|
||||
perform_exec_and_release(args.to_a.to_unsafe.to_slice(args.size))
|
||||
def exec(args : Array)
|
||||
perform_exec_and_release(args)
|
||||
end
|
||||
|
||||
# See `QueryMethods#exec`
|
||||
def exec(*args)
|
||||
# TODO better way to do it
|
||||
perform_exec_and_release(args.to_a.to_unsafe.to_slice(args.size))
|
||||
perform_exec_and_release(args)
|
||||
end
|
||||
|
||||
# See `QueryMethods#scalar`
|
||||
|
@ -57,8 +57,8 @@ module DB
|
|||
return rs.read?(Float32)
|
||||
when Float64.class
|
||||
return rs.read?(Float64)
|
||||
when Slice(UInt8).class
|
||||
return rs.read?(Slice(UInt8))
|
||||
when Bytes.class
|
||||
return rs.read?(Bytes)
|
||||
when Nil.class
|
||||
return rs.read?(Int32)
|
||||
else
|
||||
|
@ -71,13 +71,23 @@ module DB
|
|||
end
|
||||
|
||||
# See `QueryMethods#query`
|
||||
def query(*args)
|
||||
perform_query *args
|
||||
def query
|
||||
perform_query Slice(Any).new(0)
|
||||
end
|
||||
|
||||
# See `QueryMethods#query`
|
||||
def query(args : Array)
|
||||
perform_query args
|
||||
end
|
||||
|
||||
# See `QueryMethods#query`
|
||||
def query(*args)
|
||||
perform_query(*args).tap do |rs|
|
||||
perform_query args
|
||||
end
|
||||
|
||||
# See `QueryMethods#query`
|
||||
def query(*args)
|
||||
query(*args).tap do |rs|
|
||||
begin
|
||||
yield rs
|
||||
ensure
|
||||
|
@ -86,27 +96,13 @@ module DB
|
|||
end
|
||||
end
|
||||
|
||||
private def perform_query : ResultSet
|
||||
perform_query(Slice(Any).new(0)) # no overload matches ... with types Slice(NoReturn)
|
||||
end
|
||||
|
||||
private def perform_query(args : Enumerable(Any)) : ResultSet
|
||||
# TODO better way to do it
|
||||
perform_query(args.to_a.to_unsafe.to_slice(args.size))
|
||||
end
|
||||
|
||||
private def perform_query(*args) : ResultSet
|
||||
# TODO better way to do it
|
||||
perform_query(args.to_a.to_unsafe.to_slice(args.size))
|
||||
end
|
||||
|
||||
private def perform_exec_and_release(args : Slice(Any)) : ExecResult
|
||||
private def perform_exec_and_release(args : Enumerable) : ExecResult
|
||||
perform_exec(args).tap do
|
||||
release_connection
|
||||
end
|
||||
end
|
||||
|
||||
protected abstract def perform_query(args : Slice(Any)) : ResultSet
|
||||
protected abstract def perform_exec(args : Slice(Any)) : ExecResult
|
||||
protected abstract def perform_query(args : Enumerable) : ResultSet
|
||||
protected abstract def perform_exec(args : Enumerable) : ExecResult
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue