diff --git a/spec/custom_drivers_types_spec.cr b/spec/custom_drivers_types_spec.cr index 902ec43..7a7f8a1 100644 --- a/spec/custom_drivers_types_spec.cr +++ b/spec/custom_drivers_types_spec.cr @@ -36,6 +36,15 @@ class FooValue end class FooDriver < DB::Driver + class FooConnectionBuilder < DB::ConnectionBuilder + def initialize(@options : DB::Connection::Options) + end + + def build : DB::Connection + FooConnection.new(@options) + end + end + alias Any = DB::Any | FooValue @@row = [] of Any @@ -47,10 +56,9 @@ class FooDriver < DB::Driver @@row end - def connection_builder(uri : URI) : Proc(DB::Connection) + def connection_builder(uri : URI) : DB::ConnectionBuilder params = HTTP::Params.parse(uri.query || "") - options = connection_options(params) - ->{ FooConnection.new(options).as(DB::Connection) } + FooConnectionBuilder.new(connection_options(params)) end class FooConnection < DB::Connection @@ -101,6 +109,15 @@ class BarValue end class BarDriver < DB::Driver + class BarConnectionBuilder < DB::ConnectionBuilder + def initialize(@options : DB::Connection::Options) + end + + def build : DB::Connection + BarConnection.new(@options) + end + end + alias Any = DB::Any | BarValue @@row = [] of Any @@ -112,10 +129,9 @@ class BarDriver < DB::Driver @@row end - def connection_builder(uri : URI) : Proc(DB::Connection) + def connection_builder(uri : URI) : DB::ConnectionBuilder params = HTTP::Params.parse(uri.query || "") - options = connection_options(params) - ->{ BarConnection.new(options).as(DB::Connection) } + BarConnectionBuilder.new(connection_options(params)) end class BarConnection < DB::Connection diff --git a/spec/dummy_driver.cr b/spec/dummy_driver.cr index 4528fd9..85f947f 100644 --- a/spec/dummy_driver.cr +++ b/spec/dummy_driver.cr @@ -2,10 +2,18 @@ require "spec" require "../src/db" class DummyDriver < DB::Driver - def connection_builder(uri : URI) : Proc(DB::Connection) + class DummyConnectionBuilder < DB::ConnectionBuilder + def initialize(@options : DB::Connection::Options) + end + + def build : DB::Connection + DummyConnection.new(@options) + end + end + + def connection_builder(uri : URI) : DB::ConnectionBuilder params = HTTP::Params.parse(uri.query || "") - options = connection_options(params) - ->{ DummyConnection.new(options).as(DB::Connection) } + DummyConnectionBuilder.new(connection_options(params)) end class DummyConnection < DB::Connection diff --git a/src/db.cr b/src/db.cr index 5697155..5dc5fff 100644 --- a/src/db.cr +++ b/src/db.cr @@ -156,7 +156,8 @@ module DB params = HTTP::Params.parse(uri.query || "") connection_options = driver.connection_options(params) pool_options = driver.pool_options(params) - factory = driver.connection_builder(uri) + builder = driver.connection_builder(uri) + factory = ->{ builder.build } Database.new(connection_options, pool_options, &factory) end @@ -165,7 +166,7 @@ module DB end private def self.build_connection(uri : URI) - build_driver(uri).connection_builder(uri).call + build_driver(uri).connection_builder(uri).build end private def self.build_driver(uri : URI) @@ -193,6 +194,7 @@ require "./db/enumerable_concat" require "./db/query_methods" require "./db/session_methods" require "./db/disposable" +require "./db/connection_builder" require "./db/driver" require "./db/statement" require "./db/begin_transaction" diff --git a/src/db/connection_builder.cr b/src/db/connection_builder.cr new file mode 100644 index 0000000..c1fd70e --- /dev/null +++ b/src/db/connection_builder.cr @@ -0,0 +1,8 @@ +module DB + # A connection factory with a specific configuration. + # + # See `Driver#connection_builder`. + abstract class ConnectionBuilder + abstract def build : Connection + end +end diff --git a/src/db/driver.cr b/src/db/driver.cr index ed517f5..afce958 100644 --- a/src/db/driver.cr +++ b/src/db/driver.cr @@ -35,7 +35,7 @@ module DB # # NOTE: For implementors *uri* should be parsed once. If all the options # are sound a factory Proc is returned. - abstract def connection_builder(uri : URI) : Proc(Connection) + abstract def connection_builder(uri : URI) : ConnectionBuilder def connection_options(params : HTTP::Params) : Connection::Options Connection::Options.from_http_params(params)