From e076a08cd04b92b5446ff542000d3b945fb86361 Mon Sep 17 00:00:00 2001 From: Jamie Gaskins Date: Thu, 27 Oct 2022 12:35:57 -0400 Subject: [PATCH] Close a transaction when `return`ing from within its block (#167) --- spec/transaction_spec.cr | 20 ++++++++++++++++++++ src/db/begin_transaction.cr | 21 +++++++++++---------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/spec/transaction_spec.cr b/spec/transaction_spec.cr index 8089d98..29df7cc 100644 --- a/spec/transaction_spec.cr +++ b/spec/transaction_spec.cr @@ -95,6 +95,20 @@ describe DB::Transaction do t.committed.should be_false end + it "transaction with block from connection should be committed if `return` is called" do + t = uninitialized DummyDriver::DummyTransaction + + with_witness do |w| + with_dummy_connection do |cnn| + t = return_from_txn(cnn).as(DummyDriver::DummyTransaction) + w.check + end + end + + t.rolledback.should be_false + t.committed.should be_true + end + it "transaction can be committed within block" do with_dummy_connection do |cnn| cnn.transaction do |tx| @@ -211,3 +225,9 @@ describe DB::Transaction do end end end + +private def return_from_txn(cnn) + cnn.transaction do |tx| + return tx + end +end diff --git a/src/db/begin_transaction.cr b/src/db/begin_transaction.cr index 58db073..7a8f80c 100644 --- a/src/db/begin_transaction.cr +++ b/src/db/begin_transaction.cr @@ -13,25 +13,26 @@ module DB # can be called explicitly. # Returns the value of the block. def transaction(& : Transaction -> T) : T? forall T + rollback = false # TODO: Cast to workaround crystal-lang/crystal#9483 # begin_transaction returns a Tx where Tx < Transaction tx = begin_transaction.as(Transaction) begin res = yield tx rescue DB::Rollback - tx.rollback unless tx.closed? + rollback = true res rescue e - unless tx.closed? - # Ignore error in rollback. - # It would only be a secondary error to the original one, caused by - # corrupted connection state. - tx.rollback rescue nil - end + rollback = true raise e - else - tx.commit unless tx.closed? - res + ensure + unless tx.closed? + if rollback + tx.rollback + else + tx.commit + end + end end end end