diff --git a/src/db.cr b/src/db.cr index e5fd6eb..1cc52f5 100644 --- a/src/db.cr +++ b/src/db.cr @@ -193,6 +193,7 @@ module DB end require "./db/pool" +require "./db/thread_local_pool" require "./db/string_key_cache" require "./db/enumerable_concat" require "./db/query_methods" diff --git a/src/db/database.cr b/src/db/database.cr index 16e1e7b..b007e3d 100644 --- a/src/db/database.cr +++ b/src/db/database.cr @@ -36,7 +36,11 @@ module DB getter pool @connection_options : Connection::Options - @pool : Pool(Connection) + {% if flag?(:preview_mt) %} + @pool : Pool(Connection) | ThreadLocalPool(Connection) + {% else %} + @pool : Pool(Connection) + {% end %} @setup_connection : Connection -> Nil # Initialize a database with the specified options and connection factory. @@ -44,14 +48,26 @@ module DB def initialize(connection_options : Connection::Options, pool_options : Pool::Options, &factory : -> Connection) @connection_options = connection_options @setup_connection = ->(conn : Connection) {} + @pool = uninitialized Pool(Connection) # in order to use self in the factory proc - @pool = Pool(Connection).new(pool_options) { - conn = factory.call + wrapped_factory = ->{ + conn = factory.call.as(Connection) conn.auto_release = false conn.context = self @setup_connection.call conn conn } + + {% if flag?(:preview_mt) %} + @pool = + if pool_options.thread_local_pool + ThreadLocalPool(Connection).new(pool_options, &wrapped_factory) + else + @pool = Pool(Connection).new(pool_options, &wrapped_factory) + end + {% else %} + @pool = Pool(Connection).new(pool_options, &wrapped_factory) + {% end %} end def prepared_statements? : Bool diff --git a/src/db/pool.cr b/src/db/pool.cr index 378e303..e7d349a 100644 --- a/src/db/pool.cr +++ b/src/db/pool.cr @@ -16,7 +16,9 @@ module DB # maximum amount of retry attempts to reconnect to the db. See `Pool#retry` retry_attempts : Int32 = 1, # seconds to wait before a retry attempt - retry_delay : Float64 = 0.2 do + retry_delay : Float64 = 0.2, + # there will be a connection pool per thread (only in multi-threaded) + thread_local_pool : Bool = false do def self.from_http_params(params : HTTP::Params, default = Options.new) Options.new( initial_pool_size: params.fetch("initial_pool_size", default.initial_pool_size).to_i, @@ -25,6 +27,7 @@ module DB checkout_timeout: params.fetch("checkout_timeout", default.checkout_timeout).to_f, retry_attempts: params.fetch("retry_attempts", default.retry_attempts).to_i, retry_delay: params.fetch("retry_delay", default.retry_delay).to_f, + thread_local_pool: DB.fetch_bool(params, "thread_local_pool", default.thread_local_pool), ) end end diff --git a/src/db/thread_local_pool.cr b/src/db/thread_local_pool.cr new file mode 100644 index 0000000..e71541b --- /dev/null +++ b/src/db/thread_local_pool.cr @@ -0,0 +1,68 @@ +module DB + class ThreadLocalPool(T) + @pools = Crystal::ThreadLocalValue(DB::Pool(T)).new + + def initialize(@pool_options : DB::Pool::Options = DB::Pool::Options.new, &@factory : -> T) + end + + def close : Nil + end + + def stats + raise "not implemented" + end + + def checkout : T + pool.checkout + end + + def checkout(&block : T ->) + pool.checkout do |resource| + yield resource + end + end + + def release(resource : T) + pool.release(resource) + end + + def retry + pool.retry do + yield + end + end + + def delete(resource : T) + pool.delete(resource) + end + + def each_resource + each_pool do |p| + p.each_resource do |conn| + yield conn + end + end + end + + def is_available?(resource : T) + each_pool do |p| + return true if p.is_available?(resource) + end + false + end + + private def pool + @pools.get do + DB::Pool.new(@pool_options, &@factory) + end + end + + private def each_pool + @pools.@mutex.sync do + @pools.@values.each_value do |p| + yield p + end + end + end + end +end