diff --git a/shard.yml b/shard.yml index c9a66dc..c61beb0 100644 --- a/shard.yml +++ b/shard.yml @@ -1,7 +1,10 @@ name: sqlite3 version: 0.1.0 +dependencies: + db: + github: bcardiff/crystal-db + authors: - Ary Borenszweig - Brian J. Cardiff - diff --git a/spec/driver_spec.cr b/spec/driver_spec.cr index 4f728b5..19da869 100644 --- a/spec/driver_spec.cr +++ b/spec/driver_spec.cr @@ -25,6 +25,7 @@ def sqlite_type_for(v) when String ; "text" when Int32, Int64 ; "int" when Float32, Float64; "float" + when Time ; "text" else raise "not implemented for #{typeof(v)}" end @@ -46,6 +47,9 @@ def assert_filename(uri, filename) SQLite3::Connection.filename(URI.parse(uri)).should eq(filename) end +class NotSupportedType +end + describe Driver do it "should register sqlite3 name" do DB.driver_class("sqlite3").should eq(SQLite3::Driver) @@ -111,7 +115,7 @@ describe Driver do it "executes and selects blob" do with_db do |db| - slice = db.scalar(%(select X'53514C697465')) as Slice(UInt8) + slice = db.scalar(%(select X'53514C697465')).as(Bytes) slice.to_a.should eq([0x53, 0x51, 0x4C, 0x69, 0x74, 0x65]) end end @@ -119,7 +123,7 @@ describe Driver do it "executes with bind blob" do with_db do |db| ary = UInt8[0x53, 0x51, 0x4C, 0x69, 0x74, 0x65] - slice = db.scalar(%(select cast(? as BLOB)), Slice.new(ary.to_unsafe, ary.size)) as Slice(UInt8) + slice = db.scalar(%(select cast(? as BLOB)), Bytes.new(ary.to_unsafe, ary.size)).as(Bytes) slice.to_a.should eq(ary) end end @@ -158,7 +162,7 @@ describe Driver do rs.column_type(0).should eq(String) rs.column_type(1).should eq(Int64) rs.column_type(2).should eq(Float64) - rs.column_type(3).should eq(Slice(UInt8)) + rs.column_type(3).should eq(Bytes) end end end @@ -190,13 +194,46 @@ describe Driver do ary = UInt8[0x53, 0x51, 0x4C, 0x69, 0x74, 0x65] db.exec "create table table1 (col1 blob)" - db.exec %(insert into table1 values (?)), Slice.new(ary.to_unsafe, ary.size) + db.exec %(insert into table1 values (?)), Bytes.new(ary.to_unsafe, ary.size) - slice = db.scalar("select cast(col1 as blob) from table1") as Slice(UInt8) + slice = db.scalar("select cast(col1 as blob) from table1").as(Bytes) slice.to_a.should eq(ary) end end + it "insert/get value date from table" do + with_db do |db| + value = Time.new(2016, 7, 22, 15, 0, 0, 0) + db.exec "create table table1 (col1 #{sqlite_type_for(value)})" + db.exec %(insert into table1 values (?)), value + + db.query "select col1 from table1" do |rs| + rs.move_next + rs.read(Time).should eq(value) + end + + db.query "select col1 from table1" do |rs| + rs.move_next + rs.read?(Time).should eq(value) + end + end + end + + it "raises on unsupported param types" do + with_db do |db| + expect_raises Exception, "SQLite3::Statement does not support NotSupportedType params" do + db.query "select 1", NotSupportedType.new + end + # TODO raising exception does not close the connection and pool is exhausted + end + + with_db do |db| + expect_raises Exception, "SQLite3::Statement does not support NotSupportedType params" do + db.exec "select 1", NotSupportedType.new + end + end + end + it "gets many rows from table" do with_mem_db do |db| db.exec "create table person (name string, age integer)" diff --git a/src/sqlite3.cr b/src/sqlite3.cr index 2fc0340..a5b49c4 100644 --- a/src/sqlite3.cr +++ b/src/sqlite3.cr @@ -1,2 +1,6 @@ require "db" require "./sqlite3/**" + +module SQLite3 + DATE_FORMAT = "%F %H:%M:%S.%L" +end diff --git a/src/sqlite3/result_set.cr b/src/sqlite3/result_set.cr index 53001a5..39e3e0b 100644 --- a/src/sqlite3/result_set.cr +++ b/src/sqlite3/result_set.cr @@ -22,7 +22,7 @@ class SQLite3::ResultSet < DB::ResultSet end end - {% for t in DB::TYPES %} + macro nilable_read_for(t) def read?(t : {{t}}.class) : {{t}}? if read_nil? moving_column { nil } @@ -30,6 +30,10 @@ class SQLite3::ResultSet < DB::ResultSet read(t) end end + end + + {% for t in DB::TYPES %} + nilable_read_for({{t}}) {% end %} def read(t : String.class) : String @@ -52,16 +56,22 @@ class SQLite3::ResultSet < DB::ResultSet moving_column { |col| LibSQLite3.column_double(self, col) } end - def read(t : Slice(UInt8).class) : Slice(UInt8) + def read(t : Bytes.class) : Bytes moving_column do |col| blob = LibSQLite3.column_blob(self, col) bytes = LibSQLite3.column_bytes(self, col) ptr = Pointer(UInt8).malloc(bytes) ptr.copy_from(blob, bytes) - Slice(UInt8).new(ptr, bytes) + Bytes.new(ptr, bytes) end end + def read(t : Time.class) : Time + Time.parse read(String), SQLite3::DATE_FORMAT + end + + nilable_read_for Time + def column_count LibSQLite3.column_count(self) end @@ -74,7 +84,7 @@ class SQLite3::ResultSet < DB::ResultSet case LibSQLite3.column_type(self, index) when Type::INTEGER; Int64 when Type::FLOAT ; Float64 - when Type::BLOB ; Slice(UInt8) + when Type::BLOB ; Bytes when Type::TEXT ; String when Type::NULL ; Nil else diff --git a/src/sqlite3/statement.cr b/src/sqlite3/statement.cr index d0cde14..f7c36f2 100644 --- a/src/sqlite3/statement.cr +++ b/src/sqlite3/statement.cr @@ -4,7 +4,7 @@ class SQLite3::Statement < DB::Statement check LibSQLite3.prepare_v2(@connection, sql, sql.bytesize + 1, out @stmt, nil) end - protected def perform_query(args : Slice(DB::Any)) + protected def perform_query(args : Enumerable) : DB::ResultSet LibSQLite3.reset(self) args.each_with_index(1) do |arg, index| bind_arg(index, arg) @@ -12,12 +12,12 @@ class SQLite3::Statement < DB::Statement ResultSet.new(self) end - protected def perform_exec(args : Slice(DB::Any)) + protected def perform_exec(args : Enumerable) : DB::ExecResult rs = perform_query(args) rs.move_next rs.close - rows_affected = LibSQLite3.changes(connection) + rows_affected = LibSQLite3.changes(connection).to_i64 last_id = LibSQLite3.last_insert_rowid(connection) DB::ExecResult.new rows_affected, last_id @@ -52,10 +52,18 @@ class SQLite3::Statement < DB::Statement check LibSQLite3.bind_text(self, index, value, value.bytesize, nil) end - private def bind_arg(index, value : Slice(UInt8)) + private def bind_arg(index, value : Bytes) check LibSQLite3.bind_blob(self, index, value, value.size, nil) end + private def bind_arg(index, value : Time) + bind_arg(index, value.to_s(SQLite3::DATE_FORMAT)) + end + + private def bind_arg(index, value) + raise "#{self.class} does not support #{value.class} params" + end + private def check(code) raise Exception.new(@connection) unless code == 0 end