diff --git a/spec/database_spec.cr b/spec/database_spec.cr index 31db8bd..8371d57 100644 --- a/spec/database_spec.cr +++ b/spec/database_spec.cr @@ -8,7 +8,7 @@ describe DB::Database do db.setup_connection do |cnn| cnn_setup += 1 - cnn.scalar("1").should eq "1" + cnn.scalar("a").should eq "a" end cnn_setup.should eq(2) diff --git a/spec/dummy_driver.cr b/spec/dummy_driver.cr index aa40790..994eeaf 100644 --- a/spec/dummy_driver.cr +++ b/spec/dummy_driver.cr @@ -187,7 +187,7 @@ class DummyDriver < DB::Driver return (@statement.as(DummyStatement)).params[0] end - return n + n.to_i64? || n end def next_column_index : Int32 diff --git a/spec/serializable_spec.cr b/spec/serializable_spec.cr index 60c25f8..4ad2cf9 100644 --- a/spec/serializable_spec.cr +++ b/spec/serializable_spec.cr @@ -81,6 +81,29 @@ class ModelWithJSON property c1 : String end +struct ModelWithEnum + include DB::Serializable + + getter c0 : Int32 + getter c1 : MyEnum + # Ensure multiple enum types work together + getter c2 : MyOtherEnum + + enum MyEnum + Foo = 0 + Bar = 1 + Baz = 2 + Quux = 3 + end + + enum MyOtherEnum + OMG + LOL + WTF + BBQ + end +end + macro from_dummy(query, type) with_dummy do |db| rs = db.query({{ query }}) @@ -172,6 +195,17 @@ describe "DB::Serializable" do expect_model("1,a", ModelWithJSON, {c0: 1, c1: "a"}) end + it "should initialize a model with an enum property" do + expect_model("1,2,LOL", ModelWithEnum, { + c0: 1, + c1: ModelWithEnum::MyEnum::Baz, + c2: ModelWithEnum::MyOtherEnum::LOL, + }) + expect_raises DB::MappingException, "Unknown enum ModelWithEnum::MyEnum value: adsf" do + from_dummy("1,adsf,BBQ", ModelWithEnum) + end + end + it "should initialize multiple instances from a single resultset" do with_dummy do |db| db.query("1,a 2,b") do |rs| diff --git a/src/db/result_set.cr b/src/db/result_set.cr index 7520251..ccc58b4 100644 --- a/src/db/result_set.cr +++ b/src/db/result_set.cr @@ -96,6 +96,23 @@ module DB end end + # Read the value based on the given `enum` type, supporting both string and + # numeric column types. + # + # ``` + # enum Status + # Pending + # Complete + # end + # + # db.query "SELECT 'complete'" do |rs| + # rs.read Status # => Status::Complete + # end + # ``` + def read(type : Enum.class) + type.new(self) + end + # Reads the next columns and returns a tuple of the values. def read(*types : Class) internal_read(*types) @@ -135,3 +152,24 @@ module DB # end end end + +struct Enum + def self.new(rs : DB::ResultSet) : self + index = rs.next_column_index + + case value = rs.read + when String + parse value + when Int + from_value value + else + raise DB::ColumnTypeMismatchError.new( + context: "#{self}.new(rs : DB::ResultSet)", + column_index: index, + column_name: rs.column_name(index), + column_type: value.class.to_s, + expected_type: "String | Int", + ) + end + end +end