From 00981bcf4424bcaa73d439a45cb005b9a0d8114e Mon Sep 17 00:00:00 2001 From: Serdar Dogruyol Date: Thu, 14 Sep 2017 19:59:22 +0300 Subject: [PATCH] Improve websocket match logic. Fixes #401 --- spec/websocket_handler_spec.cr | 16 ++++++++++++++++ src/kemal/websocket_handler.cr | 9 ++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/spec/websocket_handler_spec.cr b/spec/websocket_handler_spec.cr index 1e454ad..5f2974d 100644 --- a/spec/websocket_handler_spec.cr +++ b/spec/websocket_handler_spec.cr @@ -47,4 +47,20 @@ describe "Kemal::WebSocketHandler" do io_with_context = create_ws_request_and_return_io(handler, request) io_with_context.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-Websocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n") end + + it "matches correct verb" do + handler = Kemal::WebSocketHandler::INSTANCE + handler.next = Kemal::RouteHandler::INSTANCE + ws "/" { } + get "/" { "get" } + request = HTTP::Request.new("GET", "/") + io = IO::Memory.new + response = HTTP::Server::Response.new(io) + context = HTTP::Server::Context.new(request, response) + handler.call(context) + response.close + io.rewind + client_response = HTTP::Client::Response.from_io(io, decompress: false) + client_response.body.should eq("get") + end end diff --git a/src/kemal/websocket_handler.cr b/src/kemal/websocket_handler.cr index 0761809..c370513 100644 --- a/src/kemal/websocket_handler.cr +++ b/src/kemal/websocket_handler.cr @@ -11,7 +11,7 @@ module Kemal end def call(context : HTTP::Server::Context) - return call_next(context) unless context.ws_route_defined? + return call_next(context) unless context.ws_route_defined? && websocket_upgrade_request?(context) context.request.url_params ||= context.ws_route_lookup.params content = context.websocket.call(context) context.response.print(content) @@ -34,5 +34,12 @@ module Kemal private def radix_path(method, path) "/#{method.downcase}#{path}" end + + private def websocket_upgrade_request?(context) + return false unless upgrade = context.request.headers["Upgrade"]? + return false unless upgrade.compare("websocket", case_insensitive: true) == 0 + + context.request.headers.includes_word?("Connection", "Upgrade") + end end end