diff --git a/src/spectator/expectations/expectation_partial.cr b/src/spectator/expectations/expectation_partial.cr index 696d54c..cf40aaf 100644 --- a/src/spectator/expectations/expectation_partial.cr +++ b/src/spectator/expectations/expectation_partial.cr @@ -25,6 +25,7 @@ module Spectator::Expectations end def to(stub : Mocks::MethodStub) : Nil + Harness.current.mocks.expect(@actual.value, stub.name) value = TestValue.new(stub.name, stub.to_s) matcher = if (arguments = stub.arguments?) Matchers::ReceiveArgumentsMatcher.new(value, arguments) diff --git a/src/spectator/mocks/double.cr b/src/spectator/mocks/double.cr index 4af8cf1..f6e3c67 100644 --- a/src/spectator/mocks/double.cr +++ b/src/spectator/mocks/double.cr @@ -62,7 +62,10 @@ module Spectator::Mocks {% if body && !body.is_a?(Nop) %} {{body.body}} {% else %} - raise ::Spectator::Mocks::UnexpectedMessageError.new("#{self} received unexpected message {{name}}") + unless ::Spectator::Harness.current.mocks.expected?(self, {{name.symbolize}}) + raise ::Spectator::Mocks::UnexpectedMessageError.new("#{self} received unexpected message {{name}}") + end + # This code shouldn't be reached, but makes the compiler happy to have a matching return type. {% if definition.is_a?(TypeDeclaration) %} %x = uninitialized {{definition.type}} @@ -75,6 +78,7 @@ module Spectator::Mocks macro method_missing(call) return self if @null + return self if ::Spectator::Harness.current.mocks.expected?(self, {{call.name.symbolize}}) raise ::Spectator::Mocks::UnexpectedMessageError.new("#{self} received unexpected message {{call.name}}") end diff --git a/src/spectator/mocks/registry.cr b/src/spectator/mocks/registry.cr index c27fb7c..d094052 100644 --- a/src/spectator/mocks/registry.cr +++ b/src/spectator/mocks/registry.cr @@ -5,6 +5,7 @@ module Spectator::Mocks private struct Entry getter stubs = Deque(MethodStub).new getter calls = Deque(MethodCall).new + getter expected = Set(Symbol).new end @all_instances = {} of String => Entry @@ -64,6 +65,14 @@ module Spectator::Mocks fetch_type(type).calls.select { |call| call.name == method_name } end + def expected?(object, method_name : Symbol) : Bool + fetch_instance(object).expected.includes?(method_name) + end + + def expect(object, method_name : Symbol) : Nil + fetch_instance(object).expected.add(method_name) + end + private def fetch_instance(object) key = unique_key(object) if @entries.has_key?(key)