diff --git a/shard.yml b/shard.yml index 82ccb5c..ec568a5 100644 --- a/shard.yml +++ b/shard.yml @@ -4,11 +4,11 @@ version: 0.19.0 dependencies: db: github: crystal-lang/crystal-db - version: ~> 0.11.0 + version: ~> 0.12.0 authors: - Ary Borenszweig - - Brian J. Cardiff + - Brian J. Cardiff crystal: ">= 1.0.0, < 2.0.0" diff --git a/spec/db_spec.cr b/spec/db_spec.cr index f3675a0..c130070 100644 --- a/spec/db_spec.cr +++ b/spec/db_spec.cr @@ -13,7 +13,7 @@ private def cast_if_blob(expr, sql_type) end end -DB::DriverSpecs(DB::Any).run do +DB::DriverSpecs(DB::Any).run do |ctx| support_unprepared false before do @@ -104,7 +104,7 @@ DB::DriverSpecs(DB::Any).run do db.exec %(insert into a (i, str) values (23, "bai bai");) 2.times do |i| - DB.open db.uri do |db| + DB.open ctx.connection_string do |db| begin db.query("SELECT i, str FROM a WHERE i = ?", 23) do |rs| rs.move_next diff --git a/spec/driver_spec.cr b/spec/driver_spec.cr index fbd64b5..6337279 100644 --- a/spec/driver_spec.cr +++ b/spec/driver_spec.cr @@ -27,7 +27,7 @@ describe Driver do it "should use database option as file to open" do with_db do |db| - db.driver.should be_a(SQLite3::Driver) + db.checkout.should be_a(SQLite3::Connection) File.exists?(DB_FILENAME).should be_true end end diff --git a/src/sqlite3/connection.cr b/src/sqlite3/connection.cr index 3979120..750776f 100644 --- a/src/sqlite3/connection.cr +++ b/src/sqlite3/connection.cr @@ -1,12 +1,56 @@ class SQLite3::Connection < DB::Connection - def initialize(database) - super - filename = self.class.filename(database.uri) - check LibSQLite3.open_v2(filename, out @db, (Flag::READWRITE | Flag::CREATE), nil) + record Options, + filename : String = ":memory:", + # pragmas + busy_timeout : String? = nil, + cache_size : String? = nil, + foreign_keys : String? = nil, + journal_mode : String? = nil, + synchronous : String? = nil, + wal_autocheckpoint : String? = nil do + def self.from_uri(uri : URI, default = Options.new) + params = HTTP::Params.parse(uri.query || "") + + Options.new( + filename: URI.decode_www_form((uri.host || "") + uri.path), + # pragmas + busy_timeout: params.fetch("busy_timeout", default.busy_timeout), + cache_size: params.fetch("cache_size", default.cache_size), + foreign_keys: params.fetch("foreign_keys", default.foreign_keys), + journal_mode: params.fetch("journal_mode", default.journal_mode), + synchronous: params.fetch("synchronous", default.synchronous), + wal_autocheckpoint: params.fetch("wal_autocheckpoint", default.wal_autocheckpoint), + ) + end + + def pragma_statement + res = String.build do |str| + pragma_append(str, "busy_timeout", busy_timeout) + pragma_append(str, "cache_size", cache_size) + pragma_append(str, "foreign_keys", foreign_keys) + pragma_append(str, "journal_mode", journal_mode) + pragma_append(str, "synchronous", synchronous) + pragma_append(str, "wal_autocheckpoint", wal_autocheckpoint) + end + + res.empty? ? nil : res + end + + private def pragma_append(io, key, value) + return unless value + io << "PRAGMA #{key}=#{value};" + end + end + + def initialize(options : ::DB::Connection::Options, sqlite3_options : Options) + super(options) + check LibSQLite3.open_v2(sqlite3_options.filename, out @db, (Flag::READWRITE | Flag::CREATE), nil) # 2 means 2 arguments; 1 is the code for UTF-8 check LibSQLite3.create_function(@db, "regexp", 2, 1, nil, SQLite3::REGEXP_FN, nil, nil) - process_query_params(database.uri) + if pragma_statement = sqlite3_options.pragma_statement + check LibSQLite3.exec(@db, pragma_statement, nil, nil, nil) + end rescue raise DB::ConnectionRefused.new end @@ -89,44 +133,4 @@ class SQLite3::Connection < DB::Connection private def check(code) raise Exception.new(self) unless code == 0 end - - private def process_query_params(uri : URI) - return unless query = uri.query - - detected_pragmas = extract_params(query, - busy_timeout: nil, - cache_size: nil, - foreign_keys: nil, - journal_mode: nil, - synchronous: nil, - wal_autocheckpoint: nil, - ) - - # concatenate all into a single SQL string - sql = String.build do |str| - detected_pragmas.each do |key, value| - next unless value - str << "PRAGMA #{key}=#{value};" - end - end - - check LibSQLite3.exec(@db, sql, nil, nil, nil) - end - - private def extract_params(query : String, **default : **T) forall T - res = default - - URI::Params.parse(query) do |key, value| - {% begin %} - case key - {% for key in T %} - when {{ key.stringify }} - res = res.merge({{key.id}}: value) - {% end %} - end - {% end %} - end - - res - end end diff --git a/src/sqlite3/driver.cr b/src/sqlite3/driver.cr index d43a512..93ae0b0 100644 --- a/src/sqlite3/driver.cr +++ b/src/sqlite3/driver.cr @@ -1,6 +1,16 @@ class SQLite3::Driver < DB::Driver - def build_connection(context : DB::ConnectionContext) : SQLite3::Connection - SQLite3::Connection.new(context) + class ConnectionBuilder < ::DB::ConnectionBuilder + def initialize(@options : ::DB::Connection::Options, @sqlite3_options : SQLite3::Connection::Options) + end + + def build : ::DB::Connection + SQLite3::Connection.new(@options, @sqlite3_options) + end + end + + def connection_builder(uri : URI) : ::DB::ConnectionBuilder + params = HTTP::Params.parse(uri.query || "") + ConnectionBuilder.new(connection_options(params), SQLite3::Connection::Options.from_uri(uri)) end end