diff --git a/spec/dummy_driver.cr b/spec/dummy_driver.cr index 1143880..1add530 100644 --- a/spec/dummy_driver.cr +++ b/spec/dummy_driver.cr @@ -49,6 +49,7 @@ class DummyDriver < DB::Driver protected def perform_exec(args : Enumerable) set_params args + raise "forced exception due to query" if @query == "raise" DB::ExecResult.new 0i64, 0_i64 end diff --git a/spec/dummy_driver_spec.cr b/spec/dummy_driver_spec.cr index 97a47da..e24af78 100644 --- a/spec/dummy_driver_spec.cr +++ b/spec/dummy_driver_spec.cr @@ -237,6 +237,14 @@ describe DummyDriver do end end + it "should raise executing raise query" do + with_dummy do |db| + expect_raises do + db.exec "raise" + end + end + end + {% for value in [1, 1_i64, "hello", 1.5, 1.5_f32] %} it "should set positional arguments for {{value.id}}" do with_dummy do |db| diff --git a/spec/result_set_spec.cr b/spec/result_set_spec.cr index 21d8424..bb633fd 100644 --- a/spec/result_set_spec.cr +++ b/spec/result_set_spec.cr @@ -56,4 +56,12 @@ describe DB::ResultSet do cols.should eq(["c0", "c1"]) end + + it "gets all column names" do + with_dummy do |db| + db.query "1,2" do |rs| + rs.column_names.should eq(%w(c0 c1)) + end + end + end end diff --git a/spec/statement_spec.cr b/spec/statement_spec.cr index 9621ed5..196c2ed 100644 --- a/spec/statement_spec.cr +++ b/spec/statement_spec.cr @@ -118,4 +118,13 @@ describe DB::Statement do rs.statement.should be(stmt) end end + + it "connection should be released if error occurs during exec" do + with_dummy do |db| + expect_raises do + db.exec "raise" + end + db.@in_pool.should be_true + end + end end diff --git a/src/db/mapping.cr b/src/db/mapping.cr index b7b0093..c84b4b7 100644 --- a/src/db/mapping.cr +++ b/src/db/mapping.cr @@ -138,4 +138,7 @@ module DB end end + macro mapping(**properties) + ::DB.mapping({{properties}}) + end end diff --git a/src/db/query_methods.cr b/src/db/query_methods.cr index 572ff20..f41596b 100644 --- a/src/db/query_methods.cr +++ b/src/db/query_methods.cr @@ -35,7 +35,7 @@ module DB # end # ``` def query(query, *args) - prepare query, &.query(*args) + prepare(query).query(*args) end # Executes a *query* and yields a `ResultSet` with the results. @@ -200,23 +200,13 @@ module DB # Performs the `query` and returns an `ExecResult` def exec(query, *args) - prepare query, &.exec(*args) + prepare(query).exec(*args) end # Performs the `query` and returns a single scalar value # puts db.scalar("SELECT MAX(name)").as(String) # => (a String) def scalar(query, *args) - prepare query, &.scalar(*args) - end - - private def prepare(query) - stm = prepare(query) - begin - yield stm - rescue ex - stm.release_connection - raise ex - end + prepare(query).scalar(*args) end end end diff --git a/src/db/result_set.cr b/src/db/result_set.cr index a3a4e6d..99d9c54 100644 --- a/src/db/result_set.cr +++ b/src/db/result_set.cr @@ -61,6 +61,11 @@ module DB # Returns the name of the column in `index` 0-based position. abstract def column_name(index : Int32) : String + # Returns the name of the columns. + def column_names + Array(String).new(column_count) { |i| column_name(i) } + end + # Reads the next column value abstract def read diff --git a/src/db/statement.cr b/src/db/statement.cr index 55120b4..45a6997 100644 --- a/src/db/statement.cr +++ b/src/db/statement.cr @@ -74,9 +74,9 @@ module DB end private def perform_exec_and_release(args : Enumerable) : ExecResult - res = perform_exec(args) + return perform_exec(args) + ensure release_connection - res end protected abstract def perform_query(args : Enumerable) : ResultSet