diff --git a/spec/transaction_spec.cr b/spec/transaction_spec.cr index 50ca671..8089d98 100644 --- a/spec/transaction_spec.cr +++ b/spec/transaction_spec.cr @@ -175,4 +175,39 @@ describe DB::Transaction do db.pool.is_available?(cnn).should be_true end end + + it "returns block value when sucess" do + with_dummy_connection do |cnn| + res = cnn.transaction do |tx| + 42 + end + + res.should eq(42) + typeof(res).should eq(Int32 | Nil) + end + end + + it "returns value on rollback via method" do + with_dummy_connection do |cnn| + res = cnn.transaction do |tx| + tx.rollback + 42 + end + + res.should eq(42) + typeof(res).should eq(Int32 | Nil) + end + end + + it "returns nil on rollback via exception" do + with_dummy_connection do |cnn| + res = cnn.transaction do |tx| + raise DB::Rollback.new + 42 + end + + res.should be_nil + typeof(res).should eq(Int32 | Nil) + end + end end diff --git a/src/db/begin_transaction.cr b/src/db/begin_transaction.cr index 5fbe2d1..47eede8 100644 --- a/src/db/begin_transaction.cr +++ b/src/db/begin_transaction.cr @@ -11,12 +11,14 @@ module DB # The exception thrown is bubbled unless it is a `DB::Rollback`. # From the yielded object `Transaction#commit` or `Transaction#rollback` # can be called explicitly. - def transaction + # Returns the value of the block. + def transaction(& : Transaction -> T) : T? forall T tx = begin_transaction begin - yield tx + res = yield tx rescue DB::Rollback tx.rollback unless tx.closed? + res rescue e unless tx.closed? # Ignore error in rollback. @@ -27,6 +29,7 @@ module DB raise e else tx.commit unless tx.closed? + res end end end