From 55b8399d7e87edf304044a198430684aa41e33ed Mon Sep 17 00:00:00 2001 From: Ryan Westlund Date: Wed, 11 Nov 2020 13:00:40 -0500 Subject: [PATCH] Enable REGEXP by connecting Crystal's stdlib Regex (#62) --- spec/db_spec.cr | 5 +++++ src/sqlite3.cr | 9 +++++++++ src/sqlite3/connection.cr | 2 ++ src/sqlite3/lib_sqlite3.cr | 7 +++++++ 4 files changed, 23 insertions(+) diff --git a/spec/db_spec.cr b/spec/db_spec.cr index 3a2fa1f..42d34b5 100644 --- a/spec/db_spec.cr +++ b/spec/db_spec.cr @@ -128,4 +128,9 @@ DB::DriverSpecs(DB::Any).run do it "handles multi-step pragma statements" do |db| db.exec %(PRAGMA journal_mode = memory) end + + it "handles REGEXP operator" do |db| + (db.scalar "select 'unmatching text' REGEXP '^m'").should eq 0 + (db.scalar "select 'matching text' REGEXP '^m'").should eq 1 + end end diff --git a/src/sqlite3.cr b/src/sqlite3.cr index 77ff98f..d1c973b 100644 --- a/src/sqlite3.cr +++ b/src/sqlite3.cr @@ -6,4 +6,13 @@ module SQLite3 # :nodoc: TIME_ZONE = Time::Location::UTC + + # :nodoc: + REGEXP_FN = ->(context : LibSQLite3::SQLite3Context, argc : Int32, argv : LibSQLite3::SQLite3Value*) do + argv = Slice.new(argv, sizeof(Void*)) + pattern = LibSQLite3.value_text(argv[0]) + text = LibSQLite3.value_text(argv[1]) + LibSQLite3.result_int(context, Regex.new(String.new(pattern)).matches?(String.new(text)).to_unsafe) + nil + end end diff --git a/src/sqlite3/connection.cr b/src/sqlite3/connection.cr index 54d1d96..ca60043 100644 --- a/src/sqlite3/connection.cr +++ b/src/sqlite3/connection.cr @@ -4,6 +4,8 @@ class SQLite3::Connection < DB::Connection filename = self.class.filename(database.uri) # TODO maybe enable Flag::URI to parse query string in the uri as additional flags check LibSQLite3.open_v2(filename, out @db, (Flag::READWRITE | Flag::CREATE), nil) + # 2 means 2 arguments; 1 is the code for UTF-8 + check LibSQLite3.create_function(@db, "regexp", 2, 1, nil, SQLite3::REGEXP_FN, nil, nil) rescue raise DB::ConnectionRefused.new end diff --git a/src/sqlite3/lib_sqlite3.cr b/src/sqlite3/lib_sqlite3.cr index d99f67c..9455f8b 100644 --- a/src/sqlite3/lib_sqlite3.cr +++ b/src/sqlite3/lib_sqlite3.cr @@ -5,6 +5,8 @@ lib LibSQLite3 type SQLite3 = Void* type Statement = Void* type SQLite3Backup = Void* + type SQLite3Context = Void* + type SQLite3Value = Void* enum Code # Successful result @@ -72,6 +74,7 @@ lib LibSQLite3 end alias Callback = (Void*, Int32, UInt8**, UInt8**) -> Int32 + alias FuncCallback = (SQLite3Context, Int32, SQLite3Value*) -> Void fun open_v2 = sqlite3_open_v2(filename : UInt8*, db : SQLite3*, flags : ::SQLite3::Flag, zVfs : UInt8*) : Int32 @@ -108,4 +111,8 @@ lib LibSQLite3 fun finalize = sqlite3_finalize(stmt : Statement) : Int32 fun close_v2 = sqlite3_close_v2(SQLite3) : Int32 fun close = sqlite3_close(SQLite3) : Int32 + + fun create_function = sqlite3_create_function(SQLite3, funcName : UInt8*, nArg : Int32, eTextRep : Int32, pApp : Void*, xFunc : FuncCallback, xStep : Void*, xFinal : Void*) : Int32 + fun value_text = sqlite3_value_text(SQLite3Value) : UInt8* + fun result_int = sqlite3_result_int(SQLite3Context, Int32) : Nil end