diff --git a/spec/database_spec.cr b/spec/database_spec.cr index 5ed460b..21d9a28 100644 --- a/spec/database_spec.cr +++ b/spec/database_spec.cr @@ -47,4 +47,20 @@ describe DB::Database do db.prepare("query3").should be_a(DB::PoolStatement) end end + + it "should return same statement in pool per query" do + with_dummy do |db| + stmt = db.prepare("query1") + db.prepare("query2").should_not eq(stmt) + db.prepare("query1").should eq(stmt) + end + end + + it "should close pool statements when closing db" do + stmt = uninitialized DB::PoolStatement + with_dummy do |db| + stmt = db.prepare("query1") + end + stmt.closed?.should be_true + end end diff --git a/src/db.cr b/src/db.cr index 6446c98..f474cef 100644 --- a/src/db.cr +++ b/src/db.cr @@ -119,6 +119,7 @@ module DB end require "./db/pool" +require "./db/string_key_cache" require "./db/query_methods" require "./db/disposable" require "./db/database" diff --git a/src/db/connection.cr b/src/db/connection.cr index 112f2ef..89cfe02 100644 --- a/src/db/connection.cr +++ b/src/db/connection.cr @@ -20,25 +20,20 @@ module DB # :nodoc: getter database - @statements_cache = {} of String => Statement + @statements_cache = StringKeyCache(Statement).new def initialize(@database : Database) end # :nodoc: def prepare(query) : Statement - stmt = @statements_cache.fetch(query, nil) - stmt = @statements_cache[query] = build_statement(query) unless stmt - - stmt + @statements_cache.fetch(query) { build_statement(query) } end abstract def build_statement(query) : Statement protected def do_close - @statements_cache.each do |_, stmt| - stmt.close - end + @statements_cache.each_value &.close @statements_cache.clear end end diff --git a/src/db/database.cr b/src/db/database.cr index d7486fd..d277bb2 100644 --- a/src/db/database.cr +++ b/src/db/database.cr @@ -24,6 +24,7 @@ module DB @pool : Pool(Connection) @setup_connection : Connection -> Nil + @statements_cache = StringKeyCache(PoolStatement).new # :nodoc: def initialize(@driver : Driver, @uri : URI) @@ -48,14 +49,15 @@ module DB # Closes all connection to the database. def close + @statements_cache.each_value &.close + @statements_cache.clear + @pool.close end # :nodoc: def prepare(query) - # TODO query based cache for pool statement - # TODO clear PoolStatements when closing the DB - PoolStatement.new self, query + @statements_cache.fetch(query) { PoolStatement.new(self, query) } end # :nodoc: diff --git a/src/db/string_key_cache.cr b/src/db/string_key_cache.cr new file mode 100644 index 0000000..f2cae62 --- /dev/null +++ b/src/db/string_key_cache.cr @@ -0,0 +1,21 @@ +module DB + class StringKeyCache(T) + @cache = {} of String => T + + def fetch(key : String) : T + value = @cache.fetch(key, nil) + value = @cache[key] = yield unless value + value + end + + def each_value + @cache.each do |_, value| + yield value + end + end + + def clear + @cache.clear + end + end +end