Improve websocket match logic. Fixes #401

This commit is contained in:
Serdar Dogruyol 2017-09-14 19:59:22 +03:00
parent e07be72dcf
commit 00981bcf44
2 changed files with 24 additions and 1 deletions

View file

@ -47,4 +47,20 @@ describe "Kemal::WebSocketHandler" do
io_with_context = create_ws_request_and_return_io(handler, request) io_with_context = create_ws_request_and_return_io(handler, request)
io_with_context.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-Websocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n") io_with_context.to_s.should eq("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-Websocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n")
end end
it "matches correct verb" do
handler = Kemal::WebSocketHandler::INSTANCE
handler.next = Kemal::RouteHandler::INSTANCE
ws "/" { }
get "/" { "get" }
request = HTTP::Request.new("GET", "/")
io = IO::Memory.new
response = HTTP::Server::Response.new(io)
context = HTTP::Server::Context.new(request, response)
handler.call(context)
response.close
io.rewind
client_response = HTTP::Client::Response.from_io(io, decompress: false)
client_response.body.should eq("get")
end
end end

View file

@ -11,7 +11,7 @@ module Kemal
end end
def call(context : HTTP::Server::Context) def call(context : HTTP::Server::Context)
return call_next(context) unless context.ws_route_defined? return call_next(context) unless context.ws_route_defined? && websocket_upgrade_request?(context)
context.request.url_params ||= context.ws_route_lookup.params context.request.url_params ||= context.ws_route_lookup.params
content = context.websocket.call(context) content = context.websocket.call(context)
context.response.print(content) context.response.print(content)
@ -34,5 +34,12 @@ module Kemal
private def radix_path(method, path) private def radix_path(method, path)
"/#{method.downcase}#{path}" "/#{method.downcase}#{path}"
end end
private def websocket_upgrade_request?(context)
return false unless upgrade = context.request.headers["Upgrade"]?
return false unless upgrade.compare("websocket", case_insensitive: true) == 0
context.request.headers.includes_word?("Connection", "Upgrade")
end
end end
end end