HeadRequestHandler: run GET handler and don't return the body (#655)

This commit is contained in:
Mike Robbins 2023-02-21 23:34:47 -05:00 committed by GitHub
parent 84ea6627ac
commit 8ebe171279
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 115 additions and 8 deletions

View file

@ -29,7 +29,7 @@ describe "Config" do
config = Kemal.config config = Kemal.config
config.add_handler CustomTestHandler.new config.add_handler CustomTestHandler.new
Kemal.config.setup Kemal.config.setup
config.handlers.size.should eq(7) config.handlers.size.should eq(8)
end end
it "toggles the shutdown message" do it "toggles the shutdown message" do

View file

@ -0,0 +1,37 @@
require "./spec_helper"
describe "Kemal::HeadRequestHandler" do
it "implicitly handles GET endpoints, with Content-Length header" do
get "/" do
"hello"
end
request = HTTP::Request.new("HEAD", "/")
client_response = call_request_on_app(request)
client_response.body.should eq("")
client_response.headers["Content-Length"].should eq("5")
end
it "prefers explicit HEAD endpoint if specified" do
Kemal::RouteHandler::INSTANCE.add_route("HEAD", "/") { "hello" }
get "/" do
raise "shouldn't be called!"
end
request = HTTP::Request.new("HEAD", "/")
client_response = call_request_on_app(request)
client_response.body.should eq("")
client_response.headers["Content-Length"].should eq("5")
end
it "gives compressed Content-Length when gzip enabled" do
gzip true
get "/" do
"hello"
end
headers = HTTP::Headers{"Accept-Encoding" => "gzip"}
request = HTTP::Request.new("HEAD", "/", headers)
client_response = call_request_on_app(request)
client_response.body.should eq("")
client_response.headers["Content-Encoding"].should eq("gzip")
client_response.headers["Content-Length"].should eq("25")
end
end

View file

@ -13,7 +13,7 @@ describe "Macros" do
it "adds a custom handler" do it "adds a custom handler" do
add_handler CustomTestHandler.new add_handler CustomTestHandler.new
Kemal.config.setup Kemal.config.setup
Kemal.config.handlers.size.should eq 7 Kemal.config.handlers.size.should eq 8
end end
end end
@ -150,7 +150,7 @@ describe "Macros" do
it "adds HTTP::CompressHandler to handlers" do it "adds HTTP::CompressHandler to handlers" do
gzip true gzip true
Kemal.config.setup Kemal.config.setup
Kemal.config.handlers[4].should be_a(HTTP::CompressHandler) Kemal.config.handlers[5].should be_a(HTTP::CompressHandler)
end end
end end

View file

@ -103,6 +103,7 @@ module Kemal
unless @default_handlers_setup && @router_included unless @default_handlers_setup && @router_included
setup_init_handler setup_init_handler
setup_log_handler setup_log_handler
setup_head_request_handler
setup_error_handler setup_error_handler
setup_static_file_handler setup_static_file_handler
setup_custom_handlers setup_custom_handlers
@ -129,6 +130,11 @@ module Kemal
@handler_position += 1 @handler_position += 1
end end
private def setup_head_request_handler
HANDLERS.insert(@handler_position, Kemal::HeadRequestHandler::INSTANCE)
@handler_position += 1
end
private def setup_error_handler private def setup_error_handler
if @always_rescue if @always_rescue
@error_handler ||= Kemal::ExceptionHandler.new @error_handler ||= Kemal::ExceptionHandler.new

View file

@ -0,0 +1,60 @@
require "http/server/handler"
module Kemal
class HeadRequestHandler
include HTTP::Handler
INSTANCE = new
private class NullIO < IO
@original_output : IO
@out_count : Int32
@response : HTTP::Server::Response
def initialize(@response)
@closed = false
@original_output = @response.output
@out_count = 0
end
def read(slice : Bytes)
raise NotImplementedError.new("read")
end
def write(slice : Bytes) : Nil
@out_count += slice.bytesize
end
def close : Nil
return if @closed
@closed = true
# Matching HTTP::Server::Response#close behavior:
# Conditionally determine based on status if the `content-length` header should be added automatically.
# See https://tools.ietf.org/html/rfc7230#section-3.3.2.
status = @response.status
set_content_length = !(status.not_modified? || status.no_content? || status.informational?)
if !@response.headers.has_key?("Content-Length") && set_content_length
@response.content_length = @out_count
end
@original_output.close
end
def closed? : Bool
@closed
end
end
def call(context) : Nil
if context.request.method == "HEAD"
# Capture and count bytes of response body generated on HEAD requests without actually sending the body back.
capture_io = NullIO.new(context.response)
context.response.output = capture_io
end
call_next(context)
end
end
end

View file

@ -5,11 +5,12 @@
require "mime" require "mime"
# Adds given `Kemal::Handler` to handlers chain. # Adds given `Kemal::Handler` to handlers chain.
# There are 5 handlers by default and all the custom handlers # There are 6 handlers by default and all the custom handlers
# goes between the first 4 and the last `Kemal::RouteHandler`. # goes between the first 5 and the last `Kemal::RouteHandler`.
# #
# - `Kemal::InitHandler` # - `Kemal::InitHandler`
# - `Kemal::LogHandler` # - `Kemal::LogHandler`
# - `Kemal::HeadRequestHandler`
# - `Kemal::ExceptionHandler` # - `Kemal::ExceptionHandler`
# - `Kemal::StaticFileHandler` # - `Kemal::StaticFileHandler`
# - Here goes custom handlers # - Here goes custom handlers

View file

@ -17,11 +17,9 @@ module Kemal
process_request(context) process_request(context)
end end
# Adds a given route to routing tree. As an exception each `GET` route additionaly defines # Adds a given route to routing tree.
# a corresponding `HEAD` route.
def add_route(method : String, path : String, &handler : HTTP::Server::Context -> _) def add_route(method : String, path : String, &handler : HTTP::Server::Context -> _)
add_to_radix_tree method, path, Route.new(method, path, &handler) add_to_radix_tree method, path, Route.new(method, path, &handler)
add_to_radix_tree("HEAD", path, Route.new("HEAD", path) { }) if method == "GET"
end end
# Looks up the route from the Radix::Tree for the first time and caches to improve performance. # Looks up the route from the Radix::Tree for the first time and caches to improve performance.
@ -34,6 +32,11 @@ module Kemal
route = @routes.find(lookup_path) route = @routes.find(lookup_path)
if verb == "HEAD" && !route.found?
# On HEAD requests, implicitly fallback to running the GET handler.
route = @routes.find(radix_path("GET", path))
end
if route.found? if route.found?
@cached_routes.clear if @cached_routes.size == CACHED_ROUTES_LIMIT @cached_routes.clear if @cached_routes.size == CACHED_ROUTES_LIMIT
@cached_routes[lookup_path] = route @cached_routes[lookup_path] = route