diff --git a/spec/issues/github_issue_48_spec.cr b/spec/issues/github_issue_48_spec.cr index d5e68fb..21e4664 100644 --- a/spec/issues/github_issue_48_spec.cr +++ b/spec/issues/github_issue_48_spec.cr @@ -29,6 +29,15 @@ Spectator.describe "GitHub Issue #48" do def union : Int32 | String 42.as(Int32 | String) end + + def capture(&block : -> T) forall T + block + end + + def capture(thing : T, &block : T -> T) forall T + block.call(thing) + block + end end mock Test, make_nilable: nil @@ -95,4 +104,22 @@ Spectator.describe "GitHub Issue #48" do allow(fake).to receive(:union).and_return(:test) expect { fake.union }.to raise_error(TypeCastError, /Symbol/) end + + it "handles captured blocks" do + proc = ->{} + allow(fake).to receive(:capture).and_return(proc) + expect(fake.capture { nil }).to be(proc) + end + + it "raises on type cast error with captured blocks" do + proc = ->{ 42 } + allow(fake).to receive(:capture).and_return(proc) + expect { fake.capture { "other" } }.to raise_error(TypeCastError, /Proc\(String\)/) + end + + it "handles captured blocks with arguments" do + proc = ->(x : Int32) { x * 2 } + allow(fake).to receive(:capture).and_return(proc) + expect(fake.capture(5) { 5 }).to be(proc) + end end diff --git a/src/spectator/mocks/stubbable.cr b/src/spectator/mocks/stubbable.cr index 7c022cd..6ded8ee 100644 --- a/src/spectator/mocks/stubbable.cr +++ b/src/spectator/mocks/stubbable.cr @@ -126,7 +126,31 @@ module Spectator {{method.body}} end - {% original = "previous_def#{" { |*_spectator_yargs| yield *_spectator_yargs }".id if method.accepts_block?}".id %} + {% original = "previous_def" + # Workaround for Crystal not propagating block with previous_def/super. + if method.accepts_block? + original += "(" + method.args.each_with_index do |arg, i| + original += '*' if method.splat_index == i + original += arg.name.stringify + original += ", " + end + if method.double_splat + original += method.double_splat.stringify + original += ", " + end + # If the block is captured (i.e. `&block` syntax), it must be passed along as an argument. + # Otherwise, use `yield` to forward the block. + captured_block = if method.block_arg && method.block_arg.name && method.block_arg.name.size > 0 + method.block_arg.name + else + nil + end + original += "&#{captured_block}" if captured_block + original += ")" + original += " { |*_spectator_yargs| yield *_spectator_yargs }" unless captured_block + end + original = original.id %} {% # Reconstruct the method signature. # I wish there was a better way of doing this, but there isn't (at least not that I'm aware of). @@ -241,7 +265,32 @@ module Spectator {{method.body}} end - {% original = "previous_def#{" { |*_spectator_yargs| yield *_spectator_yargs }".id if method.accepts_block?}".id %} + {% original = "previous_def" + # Workaround for Crystal not propagating block with previous_def/super. + if method.accepts_block? + original += "(" + method.args.each_with_index do |arg, i| + original += '*' if method.splat_index == i + original += arg.name.stringify + original += ", " + end + if method.double_splat + original += method.double_splat.stringify + original += ", " + end + # If the block is captured (i.e. `&block` syntax), it must be passed along as an argument. + # Otherwise, use `yield` to forward the block. + captured_block = if method.block_arg && method.block_arg.name && method.block_arg.name.size > 0 + method.block_arg.name + else + nil + end + original += "&#{captured_block}" if captured_block + original += ")" + original += " { |*_spectator_yargs| yield *_spectator_yargs }" unless captured_block + end + original = original.id %} + {% end %} {% # Reconstruct the method signature. @@ -418,7 +467,16 @@ module Spectator {% if method.block_arg %}&{{method.block_arg}}{% elsif method.accepts_block? %}&{% end %} ){% if method.return_type %} : {{method.return_type}}{% end %}{% if !method.free_vars.empty? %} forall {{method.free_vars.splat}}{% end %} {% unless method.abstract? %} - {{scope}}{% if method.accepts_block? %} { |*%yargs| yield *%yargs }{% end %} + {{scope}}{% if method.accepts_block? %}( + {% for arg, i in method.args %}{% if i == method.splat_index %}*{% end %}{{arg.name}}, {% end %} + {% if method.double_splat %}**{{method.double_splat}}, {% end %} + {% captured_block = if method.block_arg && method.block_arg.name && method.block_arg.name.size > 0 + method.block_arg.name + else + nil + end %} + {% if captured_block %}&{{captured_block}}{% end %} + ){% if !captured_block %} { |*%yargs| yield *%yargs }{% end %}{% end %} end {% end %} {% end %}