diff --git a/src/spectator/mocks/anonymous_double.cr b/src/spectator/mocks/anonymous_double.cr index d2e67c9..f04d194 100644 --- a/src/spectator/mocks/anonymous_double.cr +++ b/src/spectator/mocks/anonymous_double.cr @@ -12,7 +12,11 @@ module Spectator::Mocks call = ::Spectator::Mocks::GenericMethodCall.new({{call.name.symbolize}}, args) ::Spectator::Harness.current.mocks.record_call(self, call) if (stub = ::Spectator::Harness.current.mocks.find_stub(self, call)) - stub.call!(args, typeof(@values.fetch({{call.name.symbolize}}) { raise })) + {% if call.block.is_a?(Nop) %} + stub.call!(args, typeof(@values.fetch({{call.name.symbolize}}) { raise })) + {% else %} + stub.call!(args, typeof(@values.fetch({{call.name.symbolize}}) { raise })) { |*ya| yield *ya } + {% end %} else @values.fetch({{call.name.symbolize}}) do return nil if ::Spectator::Harness.current.mocks.expected?(self, {{call.name.symbolize}}) diff --git a/src/spectator/mocks/anonymous_null_double.cr b/src/spectator/mocks/anonymous_null_double.cr index 3da8122..960962e 100644 --- a/src/spectator/mocks/anonymous_null_double.cr +++ b/src/spectator/mocks/anonymous_null_double.cr @@ -8,7 +8,11 @@ module Spectator::Mocks call = ::Spectator::Mocks::GenericMethodCall.new({{call.name.symbolize}}, args) ::Spectator::Harness.current.mocks.record_call(self, call) if (stub = ::Spectator::Harness.current.mocks.find_stub(self, call)) - stub.call!(args, typeof(@values.fetch({{call.name.symbolize}}) { self })) + {% if call.block.is_a?(Nop) %} + stub.call!(args, typeof(@values.fetch({{call.name.symbolize}}) { self })) + {% else %} + stub.call!(args, typeof(@values.fetch({{call.name.symbolize}}) { self })) { |*ya| yield *ya } + {% end %} else @values.fetch({{call.name.symbolize}}) { self } end diff --git a/src/spectator/mocks/double.cr b/src/spectator/mocks/double.cr index f6e3c67..20d5f6f 100644 --- a/src/spectator/mocks/double.cr +++ b/src/spectator/mocks/double.cr @@ -45,12 +45,12 @@ module Spectator::Mocks end end - def {{name}}({{params.splat}}){% if definition.is_a?(TypeDeclaration) %} : {{definition.type}}{% end %} + def {{name}}({{params.splat}} &block){% if definition.is_a?(TypeDeclaration) %} : {{definition.type}}{% end %} %args = ::Spectator::Mocks::GenericArguments.create({{args.splat}}) %call = ::Spectator::Mocks::GenericMethodCall.new({{name.symbolize}}, %args) ::Spectator::Harness.current.mocks.record_call(self, %call) if (%stub = ::Spectator::Harness.current.mocks.find_stub(self, %call)) - %stub.call!(%args, typeof(%method({{args.splat}}) { |*%ya| yield *%ya })) + %stub.call!(%args, typeof(%method({{args.splat}}) { |*%ya| yield *%ya })) { |*%ya| yield *%ya } else %method({{args.splat}}) do |*%yield_args| yield *%yield_args diff --git a/src/spectator/mocks/method_stub.cr b/src/spectator/mocks/method_stub.cr index a297d14..80050b7 100644 --- a/src/spectator/mocks/method_stub.cr +++ b/src/spectator/mocks/method_stub.cr @@ -16,6 +16,10 @@ module Spectator::Mocks abstract def call(args : GenericArguments(T, NT), rt : RT.class) forall T, NT, RT + def call(args : GenericArguments(T, NT), rt : RT.class, &) forall T, NT, RT + call(args, rt) + end + def call!(args : GenericArguments(T, NT), rt : RT.class) : RT forall T, NT, RT value = call(args, rt) if value.is_a?(RT) @@ -25,6 +29,15 @@ module Spectator::Mocks end end + def call!(args : GenericArguments(T, NT), rt : RT.class) : RT forall T, NT, RT + value = call(args, rt) { |*ya| yield *ya } + if value.is_a?(RT) + value.as(RT) + else + raise TypeCastError.new("The return type of stub #{self} doesn't match the expected type #{RT}") + end + end + def to_s(io) io << '#' io << @name diff --git a/src/spectator/mocks/nil_method_stub.cr b/src/spectator/mocks/nil_method_stub.cr index 43a6f82..410ca56 100644 --- a/src/spectator/mocks/nil_method_stub.cr +++ b/src/spectator/mocks/nil_method_stub.cr @@ -36,6 +36,10 @@ module Spectator::Mocks ExceptionMethodStub.new(@name, @source, exception_type.new(*args), @args) end + def and_yield(*yield_args) + YieldMethodStub.new(@name, @source, yield_args, @args) + end + def with(*args : *T, **opts : **NT) forall T, NT args = GenericArguments.new(args, opts) NilMethodStub.new(@name, @source, args) diff --git a/src/spectator/mocks/yield_method_stub.cr b/src/spectator/mocks/yield_method_stub.cr new file mode 100644 index 0000000..8adf254 --- /dev/null +++ b/src/spectator/mocks/yield_method_stub.cr @@ -0,0 +1,19 @@ +require "./generic_arguments" +require "./generic_method_stub" + +module Spectator::Mocks + class YieldMethodStub(YieldArgs) < GenericMethodStub(Nil) + def initialize(name, source, @yield_args : YieldArgs, args = nil) + super(name, source, args) + end + + def call(_args : GenericArguments(T2, NT2), rt : RT.class) forall T2, NT2, RT + raise "Asked to yield |#{@yield_args}| but no block was passed" + end + + def call(_args : GenericArguments(T2, NT2), rt : RT.class) forall T2, NT2, RT + yield *@yield_args + nil + end + end +end