771 lines
29 KiB
Zig
771 lines
29 KiB
Zig
/// Middlewares are types with a method of type:
|
|
/// fn handle(
|
|
/// self: @This(),
|
|
/// request: *http.Request(< some type >),
|
|
/// response: *http.Response(< some type >),
|
|
/// context: anytype,
|
|
/// next_handler: anytype,
|
|
/// ) !void
|
|
///
|
|
/// If a middleware returns error.RouteMismatch, then it is assumed that the handler
|
|
/// did not apply to the request, and this is used by routing implementations to
|
|
/// determine when to stop attempting to match a route.
|
|
///
|
|
/// Terminal middlewares that are not implemented using other middlewares should
|
|
/// only accept a `void` value for `next_handler`.
|
|
const std = @import("std");
|
|
const util = @import("util");
|
|
const http = @import("./lib.zig");
|
|
const urlencode = @import("./urlencode.zig");
|
|
const json_utils = @import("./json.zig");
|
|
const fields = @import("./fields.zig");
|
|
|
|
const PathIter = util.PathIter;
|
|
|
|
/// Takes an iterable of middlewares and chains them together.
|
|
pub fn apply(middlewares: anytype) Apply(@TypeOf(middlewares)) {
|
|
return applyInternal(middlewares, std.meta.fields(@TypeOf(middlewares)));
|
|
}
|
|
|
|
/// Helper function for the return type of `apply()`
|
|
pub fn Apply(comptime Middlewares: type) type {
|
|
return ApplyInternal(std.meta.fields(Middlewares));
|
|
}
|
|
|
|
fn ApplyInternal(comptime which: []const std.builtin.Type.StructField) type {
|
|
if (which.len == 0) return void;
|
|
|
|
return HandlerList(
|
|
which[0].field_type,
|
|
ApplyInternal(which[1..]),
|
|
);
|
|
}
|
|
|
|
fn applyInternal(middlewares: anytype, comptime which: []const std.builtin.Type.StructField) ApplyInternal(which) {
|
|
if (which.len == 0) return {};
|
|
return .{
|
|
.first = @field(middlewares, which[0].name),
|
|
.next = applyInternal(middlewares, which[1..]),
|
|
};
|
|
}
|
|
|
|
pub fn HandlerList(comptime First: type, comptime Next: type) type {
|
|
return struct {
|
|
first: First,
|
|
next: Next,
|
|
|
|
pub fn handle(
|
|
self: @This(),
|
|
req: anytype,
|
|
res: anytype,
|
|
ctx: anytype,
|
|
next: void,
|
|
) !void {
|
|
_ = next;
|
|
return self.first.handle(req, res, ctx, self.next);
|
|
}
|
|
};
|
|
}
|
|
|
|
test "apply" {
|
|
var count: usize = 0;
|
|
const NoOp = struct {
|
|
ptr: *usize,
|
|
fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
|
self.ptr.* += 1;
|
|
if (@TypeOf(next) != void) return next.handle(req, res, ctx, {});
|
|
}
|
|
};
|
|
|
|
const middlewares = .{
|
|
NoOp{ .ptr = &count },
|
|
NoOp{ .ptr = &count },
|
|
NoOp{ .ptr = &count },
|
|
NoOp{ .ptr = &count },
|
|
};
|
|
try std.testing.expectEqual(
|
|
Apply(@TypeOf(middlewares)),
|
|
HandlerList(NoOp, HandlerList(NoOp, HandlerList(NoOp, HandlerList(NoOp, void)))),
|
|
);
|
|
|
|
try apply(middlewares).handle(.{}, .{}, .{}, {});
|
|
try std.testing.expectEqual(count, 4);
|
|
}
|
|
|
|
test "injectContextValue - chained" {
|
|
try apply(.{
|
|
injectContextValue("abcd", @as(usize, 5)),
|
|
injectContextValue("efgh", @as(usize, 10)),
|
|
injectContextValue("ijkl", @as(usize, 15)),
|
|
ExpectContext(.{ .abcd = 5, .efgh = 10, .ijkl = 15 }){},
|
|
}).handle(.{}, .{}, .{}, {});
|
|
}
|
|
|
|
fn AddUniqueField(comptime Lhs: type, comptime N: usize, comptime name: [N]u8, comptime Val: type) type {
|
|
const Ctx = @Type(.{ .Struct = .{
|
|
.layout = .Auto,
|
|
.fields = std.meta.fields(Lhs) ++ &[_]std.builtin.Type.StructField{
|
|
.{
|
|
.name = &name,
|
|
.field_type = Val,
|
|
.alignment = if (@sizeOf(Val) != 0) @alignOf(Val) else 0,
|
|
.default_value = null,
|
|
.is_comptime = false,
|
|
},
|
|
},
|
|
.decls = &.{},
|
|
.is_tuple = false,
|
|
} });
|
|
return Ctx;
|
|
}
|
|
|
|
fn AddField(comptime Lhs: type, comptime name: []const u8, comptime Val: type) type {
|
|
return AddUniqueField(Lhs, name.len, name[0..].*, Val);
|
|
}
|
|
|
|
fn addField(lhs: anytype, comptime name: []const u8, val: anytype) AddField(@TypeOf(lhs), name, @TypeOf(val)) {
|
|
var result: AddField(@TypeOf(lhs), name, @TypeOf(val)) = undefined;
|
|
inline for (std.meta.fields(@TypeOf(lhs))) |f| @field(result, f.name) = @field(lhs, f.name);
|
|
@field(result, name) = val;
|
|
return result;
|
|
}
|
|
|
|
test "addField" {
|
|
const expect = std.testing.expect;
|
|
const eql = std.meta.eql;
|
|
|
|
try expect(eql(addField(.{}, "abcd", 5), .{ .abcd = 5 }));
|
|
try expect(eql(addField(.{ .abcd = 5 }, "efgh", 10), .{ .abcd = 5, .efgh = 10 }));
|
|
try expect(eql(
|
|
addField(addField(.{}, "abcd", 5), "efgh", 10),
|
|
.{ .abcd = 5, .efgh = 10 },
|
|
));
|
|
}
|
|
|
|
/// Adds a single value to the context object
|
|
pub fn InjectContextValue(comptime name: []const u8, comptime V: type) type {
|
|
return struct {
|
|
val: V,
|
|
pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
|
return next.handle(req, res, addField(ctx, name, self.val), {});
|
|
}
|
|
};
|
|
}
|
|
|
|
pub fn injectContextValue(comptime name: []const u8, val: anytype) InjectContextValue(name, @TypeOf(val)) {
|
|
return .{ .val = val };
|
|
}
|
|
|
|
test "InjectContextValue" {
|
|
try injectContextValue("abcd", @as(usize, 5))
|
|
.handle(.{}, .{}, .{}, ExpectContext(.{ .abcd = 5 }){});
|
|
try injectContextValue("abcd", @as(usize, 5))
|
|
.handle(.{}, .{}, .{ .efgh = @as(usize, 10) }, ExpectContext(.{ .abcd = 5, .efgh = 10 }){});
|
|
}
|
|
|
|
fn expectDeepEquals(expected: anytype, actual: anytype) !void {
|
|
const E = @TypeOf(expected);
|
|
const A = @TypeOf(actual);
|
|
if (E == void) return std.testing.expect(A == void);
|
|
try std.testing.expect(std.meta.fields(E).len == std.meta.fields(A).len);
|
|
inline for (std.meta.fields(E)) |f| {
|
|
const e = @field(expected, f.name);
|
|
const a = @field(actual, f.name);
|
|
if (comptime std.meta.trait.isZigString(f.field_type)) {
|
|
try std.testing.expectEqualStrings(a, e);
|
|
} else {
|
|
try std.testing.expectEqual(a, e);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Helper for testing purposes
|
|
fn ExpectContext(comptime val: anytype) type {
|
|
return struct {
|
|
pub fn handle(_: @This(), _: anytype, _: anytype, ctx: anytype, _: void) !void {
|
|
try expectDeepEquals(val, ctx);
|
|
}
|
|
};
|
|
}
|
|
fn expectContext(comptime val: anytype) ExpectContext(val) {
|
|
return .{};
|
|
}
|
|
|
|
/// Catches any errors returned by the `next` chain, and passes them via context
|
|
/// to an error handler if one occurs
|
|
pub fn CatchErrors(comptime ErrorHandler: type) type {
|
|
return struct {
|
|
error_handler: ErrorHandler,
|
|
pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
|
return next.handle(req, res, ctx, {}) catch |err| {
|
|
return self.error_handler.handle(
|
|
req,
|
|
res,
|
|
addField(ctx, "err", err),
|
|
next,
|
|
);
|
|
};
|
|
}
|
|
};
|
|
}
|
|
|
|
pub fn catchErrors(error_handler: anytype) CatchErrors(@TypeOf(error_handler)) {
|
|
return .{ .error_handler = error_handler };
|
|
}
|
|
|
|
/// Default error handler for CatchErrors, logs the error and outputs responds with a 500 if
|
|
/// a response has not been written yet
|
|
pub const default_error_handler = struct {
|
|
fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, _: anytype) !void {
|
|
const should_log = !@import("builtin").is_test;
|
|
if (should_log) std.log.err("Error {} on uri {s}", .{ ctx.err, req.uri });
|
|
|
|
// Tell the server to close the connection after this request
|
|
res.should_close = true;
|
|
|
|
var buf: [1024]u8 = undefined;
|
|
var fba = std.heap.FixedBufferAllocator.init(&buf);
|
|
var headers = http.Fields.init(fba.allocator());
|
|
if (!res.was_opened) {
|
|
var stream = res.open(.internal_server_error, &headers) catch return;
|
|
defer stream.close();
|
|
stream.finish() catch {};
|
|
}
|
|
}
|
|
}{};
|
|
|
|
test "CatchErrors" {
|
|
const TestResponse = struct {
|
|
should_close: bool = false,
|
|
was_opened: bool = false,
|
|
|
|
test_should_open: bool,
|
|
const TestStream = struct {
|
|
fn close(_: *@This()) void {}
|
|
fn finish(_: *@This()) !void {}
|
|
};
|
|
|
|
fn open(self: *@This(), status: http.Status, _: *http.Fields) !TestStream {
|
|
self.was_opened = true;
|
|
if (!self.test_should_open) return error.ResponseOpenedTwice;
|
|
try std.testing.expectEqual(status, .internal_server_error);
|
|
return .{};
|
|
}
|
|
};
|
|
|
|
const middleware_list = apply(.{
|
|
catchErrors(default_error_handler),
|
|
struct {
|
|
fn handle(_: @This(), _: anytype, _: anytype, _: anytype, _: anytype) !void {
|
|
return error.SomeError;
|
|
}
|
|
}{},
|
|
});
|
|
|
|
var response = TestResponse{ .test_should_open = true };
|
|
try middleware_list.handle(.{ .uri = "abcd" }, &response, .{}, {});
|
|
try std.testing.expect(response.should_close);
|
|
|
|
// Test that it doesn't open a response if one was already opened
|
|
response = TestResponse{ .test_should_open = false, .was_opened = true };
|
|
try middleware_list.handle(.{ .uri = "abcd" }, &response, .{}, {});
|
|
try std.testing.expect(response.should_close);
|
|
}
|
|
|
|
/// Takes the request uri provided and splits it into "path", "query_string", and "fragment_string"
|
|
/// parts, which are placed into context.
|
|
const SplitUri = struct {
|
|
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
|
var frag_split = std.mem.split(u8, req.uri, "#");
|
|
const without_fragment = frag_split.first();
|
|
const fragment = frag_split.rest();
|
|
|
|
var query_split = std.mem.split(u8, without_fragment, "?");
|
|
const path = query_split.first();
|
|
const query = query_split.rest();
|
|
|
|
const new_ctx = addField(
|
|
addField(
|
|
addField(ctx, "path", path),
|
|
"query_string",
|
|
query,
|
|
),
|
|
"fragment_string",
|
|
fragment,
|
|
);
|
|
|
|
return next.handle(
|
|
req,
|
|
res,
|
|
new_ctx,
|
|
{},
|
|
);
|
|
}
|
|
};
|
|
pub const split_uri = SplitUri{};
|
|
|
|
test "split_uri" {
|
|
const testCase = struct {
|
|
fn func(uri: []const u8, ctx: anytype, expected: anytype) !void {
|
|
const v = apply(.{
|
|
split_uri,
|
|
expectContext(expected),
|
|
});
|
|
try v.handle(.{ .uri = uri }, .{}, ctx, {});
|
|
}
|
|
}.func;
|
|
|
|
try testCase("/", .{}, .{ .path = "/", .query_string = "", .fragment_string = "" });
|
|
try testCase("", .{}, .{ .path = "", .query_string = "", .fragment_string = "" });
|
|
try testCase("/path", .{}, .{ .path = "/path", .query_string = "", .fragment_string = "" });
|
|
try testCase("?abcd=1234", .{}, .{ .path = "", .query_string = "abcd=1234", .fragment_string = "" });
|
|
try testCase("#abcd", .{}, .{ .path = "", .query_string = "", .fragment_string = "abcd" });
|
|
try testCase("/abcd/efgh?query=no#frag", .{}, .{ .path = "/abcd/efgh", .query_string = "query=no", .fragment_string = "frag" });
|
|
}
|
|
|
|
/// Routes a request between the provided routes.
|
|
///
|
|
/// CURRENTLY: Does not do this intelligently, all routing is handled by the routes themselves.
|
|
/// TODO: Consider implementing this with a hashmap?
|
|
pub fn Router(comptime Routes: type) type {
|
|
return struct {
|
|
routes: Routes,
|
|
|
|
pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: void) !void {
|
|
_ = next;
|
|
|
|
inline for (self.routes) |r| {
|
|
if (r.handle(req, res, ctx, {})) |_|
|
|
// success
|
|
return
|
|
else |err| switch (err) {
|
|
error.RouteMismatch => {},
|
|
else => return err,
|
|
}
|
|
}
|
|
|
|
return error.RouteMismatch;
|
|
}
|
|
};
|
|
}
|
|
pub fn router(routes: anytype) Router(@TypeOf(routes)) {
|
|
return Router(@TypeOf(routes)){ .routes = routes };
|
|
}
|
|
|
|
// helper function for doing route analysis
|
|
fn pathMatches(route: []const u8, path: []const u8) bool {
|
|
var path_iter = PathIter.from(path);
|
|
var route_iter = PathIter.from(route);
|
|
while (route_iter.next()) |route_segment| {
|
|
const path_segment = path_iter.next() orelse "";
|
|
|
|
if (route_segment.len > 0 and route_segment[0] == ':') {
|
|
// Route Argument
|
|
if (route_segment[route_segment.len - 1] == '*') {
|
|
// consume rest of path segments
|
|
while (path_iter.next()) |_| {}
|
|
} else if (path_segment.len == 0) return false;
|
|
} else {
|
|
if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return false;
|
|
}
|
|
}
|
|
if (path_iter.next() != null) return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
/// Handler that either calls its next middleware parameter or returns error.RouteMismatch
|
|
/// depending on if the request matches the described route.
|
|
/// Must be below `split_uri` on the middleware list.
|
|
///
|
|
/// Format:
|
|
/// Each route segment can be either a literal string or an argument. Literal strings
|
|
/// must match exactly in order to constitute a matching route. Arguments must begin with
|
|
/// the character ':', with the remainer of the segment referring to the name of the argument.
|
|
/// Argument values must be nonempty.
|
|
///
|
|
/// For example, the route "/abc/:foo/def" would match "/abc/x/def" or "/abc/blahblah/def" but
|
|
/// not "/abc//def".
|
|
pub const Route = struct {
|
|
pub const Desc = struct {
|
|
path: []const u8,
|
|
method: http.Method,
|
|
};
|
|
|
|
desc: Desc,
|
|
|
|
fn applies(self: @This(), req: anytype, ctx: anytype) bool {
|
|
if (self.desc.method != req.method) return false;
|
|
|
|
const eff_path = if (@hasField(@TypeOf(ctx), "path"))
|
|
ctx.path
|
|
else
|
|
std.mem.sliceTo(req.uri, '?');
|
|
|
|
return pathMatches(self.desc.path, eff_path);
|
|
}
|
|
|
|
pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
|
return if (self.applies(req, ctx))
|
|
next.handle(req, res, ctx, {})
|
|
else
|
|
error.RouteMismatch;
|
|
}
|
|
};
|
|
|
|
test "route" {
|
|
const testCase = struct {
|
|
fn func(should_match: bool, route: Route.Desc, method: http.Method, path: []const u8) !void {
|
|
const no_op = struct {
|
|
fn handle(_: anytype, _: anytype, _: anytype, _: anytype, _: anytype) !void {}
|
|
}{};
|
|
const result = (Route{ .desc = route }).handle(.{ .method = method }, .{}, .{ .path = path }, no_op);
|
|
try if (should_match) result else std.testing.expectError(error.RouteMismatch, result);
|
|
}
|
|
}.func;
|
|
|
|
try testCase(true, .{ .method = .GET, .path = "/" }, .GET, "/");
|
|
try testCase(true, .{ .method = .GET, .path = "/" }, .GET, "");
|
|
try testCase(true, .{ .method = .GET, .path = "/abcd" }, .GET, "/abcd");
|
|
try testCase(true, .{ .method = .GET, .path = "/abcd" }, .GET, "abcd");
|
|
try testCase(true, .{ .method = .POST, .path = "/" }, .POST, "/");
|
|
try testCase(true, .{ .method = .POST, .path = "/" }, .POST, "");
|
|
try testCase(true, .{ .method = .POST, .path = "/abcd" }, .POST, "/abcd");
|
|
try testCase(true, .{ .method = .POST, .path = "/abcd" }, .POST, "abcd");
|
|
try testCase(true, .{ .method = .POST, .path = "/abcd/efgh" }, .POST, "abcd/efgh");
|
|
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "abcd/efgh");
|
|
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz");
|
|
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/efgh/xyz");
|
|
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/efgh");
|
|
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/");
|
|
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd");
|
|
|
|
try testCase(false, .{ .method = .POST, .path = "/" }, .GET, "/");
|
|
try testCase(false, .{ .method = .GET, .path = "/abcd" }, .GET, "");
|
|
try testCase(false, .{ .method = .GET, .path = "/" }, .GET, "/abcd");
|
|
try testCase(false, .{ .method = .POST, .path = "/abcd/efgh" }, .POST, "efgh");
|
|
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "/abcd/");
|
|
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/");
|
|
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz/foo");
|
|
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "defg/abcd");
|
|
}
|
|
|
|
/// Mounts a router subtree under a given path. Middlewares further down on the list
|
|
/// are called with the path prefix specified by `route` removed from the path.
|
|
/// Must be below `split_uri` on the middleware list.
|
|
pub fn Mount(comptime route: []const u8) type {
|
|
if (std.mem.indexOfScalar(u8, route, ':') != null) @compileError("Route args cannot be mounted");
|
|
return struct {
|
|
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
|
const args = try parseArgsFromPath(route ++ "/:path*", struct { path: []const u8 }, ctx.path);
|
|
|
|
var new_ctx = ctx;
|
|
new_ctx.path = args.path;
|
|
|
|
return next.handle(req, res, new_ctx, {});
|
|
}
|
|
};
|
|
}
|
|
pub fn mount(comptime route: []const u8) Mount(route) {
|
|
return .{};
|
|
}
|
|
|
|
test "mount" {
|
|
const testCase = struct {
|
|
fn func(comptime base: []const u8, request: []const u8, comptime expected: ?[]const u8) !void {
|
|
const result = mount(base).handle(.{}, .{}, addField(.{}, "path", request), expectContext(.{ .path = expected orelse "" }));
|
|
try if (expected != null) result else std.testing.expectError(error.RouteMismatch, result);
|
|
}
|
|
}.func;
|
|
try testCase("/api/", "/api/", "");
|
|
try testCase("/api/", "/api/abcd", "abcd");
|
|
try testCase("/api/", "/api/abcd/efgh", "abcd/efgh");
|
|
try testCase("/api/", "/api/abcd/efgh/", "abcd/efgh/");
|
|
try testCase("/api/v0", "/api/v0/call", "call");
|
|
|
|
try testCase("/api/", "/web/abcd/efgh/", null);
|
|
try testCase("/api/", "/", null);
|
|
try testCase("/api/", "/ap", null);
|
|
try testCase("/api/v0", "/api/v1/", null);
|
|
}
|
|
|
|
fn parseArgsFromPath(comptime route: []const u8, comptime Args: type, path: []const u8) !Args {
|
|
var args: Args = undefined;
|
|
var path_iter = PathIter.from(path);
|
|
comptime var route_iter = PathIter.from(route);
|
|
var path_unused: []const u8 = path;
|
|
|
|
inline while (comptime route_iter.next()) |route_segment| {
|
|
const path_segment = path_iter.next() orelse "";
|
|
if (route_segment[0] == ':') {
|
|
comptime var name: []const u8 = route_segment[1..];
|
|
var value: []const u8 = path_segment;
|
|
|
|
// route segment is an argument segment
|
|
if (comptime route_segment[route_segment.len - 1] == '*') {
|
|
// waste remaining args
|
|
while (path_iter.next()) |_| {}
|
|
name = route_segment[1 .. route_segment.len - 1];
|
|
value = path_unused;
|
|
} else {
|
|
if (path_segment.len == 0) return error.RouteMismatch;
|
|
}
|
|
|
|
const A = @TypeOf(@field(args, name));
|
|
@field(args, name) = try parseArgFromPath(A, value);
|
|
} else {
|
|
// route segment is a literal segment
|
|
if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch;
|
|
}
|
|
path_unused = path_iter.rest();
|
|
}
|
|
|
|
if (path_iter.next() != null) return error.RouteMismatch;
|
|
|
|
return args;
|
|
}
|
|
|
|
fn parseArgFromPath(comptime T: type, segment: []const u8) !T {
|
|
if (T == []const u8) return segment;
|
|
if (comptime std.meta.trait.isContainer(T) and std.meta.trait.hasFn("parse")(T)) return T.parse(segment);
|
|
if (comptime std.meta.trait.is(.Int)(T)) return std.fmt.parseInt(T, segment, 0);
|
|
|
|
@compileError("Unsupported Type " ++ @typeName(T));
|
|
}
|
|
|
|
/// Parse arguments directly the request path.
|
|
/// Must be placed after a `split_uri` middleware in order to get `path` from context.
|
|
///
|
|
/// Route arguments are specified in the same format as for Route. The name of the argument
|
|
/// refers to the field name in Args that the argument will be parsed to.
|
|
///
|
|
/// This currently works with arguments of 3 different types:
|
|
/// - integers
|
|
/// - []const u8,
|
|
/// - anything with a function of the form:
|
|
/// * T.parse([]const u8) Error!T
|
|
/// * This function cannot hold a reference to the passed string once it appears
|
|
///
|
|
/// Example:
|
|
/// ParsePathArgs("/:id/foo/:name/byrank/:rank", struct {
|
|
/// id: util.Uuid,
|
|
/// name: []const u8,
|
|
/// rank: u32,
|
|
/// })
|
|
/// Would parse a path of "/00000000-0000-0000-0000-000000000000/foo/jaina/byrank/3" into
|
|
/// .{ .id = try Uuid.parse("00000000-0000-0000-0000-000000000000"), .name = "jaina", .rank = 3 }
|
|
pub fn ParsePathArgs(comptime route: []const u8, comptime Args: type) type {
|
|
return struct {
|
|
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
|
if (Args == void) return next.handle(req, res, addField(ctx, "args", {}), {});
|
|
return next.handle(
|
|
req,
|
|
res,
|
|
addField(ctx, "args", try parseArgsFromPath(route, Args, ctx.path)),
|
|
{},
|
|
);
|
|
}
|
|
};
|
|
}
|
|
pub fn parsePathArgs(comptime route: []const u8, comptime Args: type) ParsePathArgs(route, Args) {
|
|
return .{};
|
|
}
|
|
|
|
test "ParsePathArgs" {
|
|
const testCase = struct {
|
|
fn func(comptime route: []const u8, comptime Args: type, path: []const u8, expected: anytype) !void {
|
|
const check = struct {
|
|
expected: @TypeOf(expected),
|
|
path: []const u8,
|
|
fn handle(self: @This(), _: anytype, _: anytype, ctx: anytype, _: void) !void {
|
|
try expectDeepEquals(self.expected, ctx.args);
|
|
try std.testing.expectEqualStrings(self.path, ctx.path);
|
|
}
|
|
}{ .expected = expected, .path = path };
|
|
try parsePathArgs(route, Args).handle(.{}, .{}, .{ .path = path }, check);
|
|
}
|
|
}.func;
|
|
|
|
try testCase("/", void, "/", {});
|
|
try testCase("/:id", struct { id: usize }, "/3", .{ .id = 3 });
|
|
try testCase("/:str", struct { str: []const u8 }, "/abcd", .{ .str = "abcd" });
|
|
try testCase("/:id/xyz/:str", struct { id: usize, str: []const u8 }, "/3/xyz/abcd", .{ .id = 3, .str = "abcd" });
|
|
try testCase("/:id", struct { id: util.Uuid }, "/" ++ util.Uuid.nil.toCharArray(), .{ .id = util.Uuid.nil });
|
|
|
|
try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/abc", .{ .arg = "abc" });
|
|
try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/abc/def", .{ .arg = "abc/def" });
|
|
try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/", .{ .arg = "" });
|
|
|
|
// Compiler crashes if i keep the args named the same as above.
|
|
// TODO: Debug this and try to fix it
|
|
try testCase("/xyz/:bar*", struct { bar: []const u8 }, "/xyz", .{ .bar = "" });
|
|
|
|
// It's a quirk that the initial / is left in for these cases. However, it results in a path
|
|
// that's semantically equivalent so i didn't bother fixing it
|
|
try testCase("/:foo*", struct { foo: []const u8 }, "/abc", .{ .foo = "/abc" });
|
|
try testCase("/:foo*", struct { foo: []const u8 }, "/abc/def", .{ .foo = "/abc/def" });
|
|
try testCase("/:foo*", struct { foo: []const u8 }, "/", .{ .foo = "/" });
|
|
try testCase("/:foo*", struct { foo: []const u8 }, "", .{ .foo = "" });
|
|
|
|
try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/", .{}));
|
|
try std.testing.expectError(error.RouteMismatch, testCase("/abcd/:id", struct { id: usize }, "/123", .{}));
|
|
try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/3/id/blahblah", .{ .id = 3 }));
|
|
try std.testing.expectError(error.InvalidCharacter, testCase("/:id", struct { id: usize }, "/xyz", .{}));
|
|
}
|
|
|
|
const BaseContentType = enum {
|
|
json,
|
|
url_encoded,
|
|
octet_stream,
|
|
multipart_formdata,
|
|
|
|
other,
|
|
};
|
|
|
|
fn parseBodyFromRequest(
|
|
comptime T: type,
|
|
comptime options: ParseBodyOptions,
|
|
content_type: ?[]const u8,
|
|
reader: anytype,
|
|
alloc: std.mem.Allocator,
|
|
) !T {
|
|
// Use json by default for now for testing purposes
|
|
const eff_type = content_type orelse "application/json";
|
|
const parser_type = matchContentType(eff_type);
|
|
|
|
switch (parser_type) {
|
|
.octet_stream, .json => {
|
|
const buf = try reader.readAllAlloc(alloc, 1 << 16);
|
|
defer alloc.free(buf);
|
|
const body = try json_utils.parse(T, options.allow_unknown_fields, buf, alloc);
|
|
defer json_utils.parseFree(body, alloc);
|
|
|
|
return try util.deepClone(alloc, body);
|
|
},
|
|
.url_encoded => {
|
|
const buf = try reader.readAllAlloc(alloc, 1 << 16);
|
|
defer alloc.free(buf);
|
|
return urlencode.parse(alloc, options.allow_unknown_fields, T, buf) catch |err| switch (err) {
|
|
//error.NoQuery => error.NoBody,
|
|
else => err,
|
|
};
|
|
},
|
|
.multipart_formdata => {
|
|
const boundary = fields.getParam(eff_type, "boundary") orelse return error.MissingBoundary;
|
|
return try @import("./multipart.zig").parseFormData(T, options.allow_unknown_fields, boundary, reader, alloc);
|
|
},
|
|
else => return error.UnsupportedMediaType,
|
|
}
|
|
}
|
|
|
|
// figure out what base parser to use
|
|
fn matchContentType(hdr: []const u8) BaseContentType {
|
|
const trimmed = std.mem.sliceTo(hdr, ';');
|
|
if (std.ascii.eqlIgnoreCase(trimmed, "application/x-www-form-urlencoded")) return .url_encoded;
|
|
if (std.ascii.eqlIgnoreCase(trimmed, "application/json")) return .json;
|
|
if (std.ascii.endsWithIgnoreCase(trimmed, "+json")) return .json;
|
|
if (std.ascii.eqlIgnoreCase(trimmed, "application/octet-stream")) return .octet_stream;
|
|
if (std.ascii.eqlIgnoreCase(trimmed, "multipart/form-data")) return .multipart_formdata;
|
|
|
|
return .other;
|
|
}
|
|
|
|
pub const ParseBodyOptions = struct {
|
|
allow_unknown_fields: bool = false,
|
|
};
|
|
|
|
/// Parses a set of body arguments from the request body based on the request's Content-Type
|
|
/// header.
|
|
///
|
|
/// The exact method for parsing depends partially on the Content-Type. json types are preferred
|
|
/// TODO: Need tests for this, including various Content-Type values
|
|
pub fn ParseBody(comptime Body: type, comptime options: ParseBodyOptions) type {
|
|
return struct {
|
|
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
|
const content_type = req.headers.get("Content-Type");
|
|
if (Body == void) {
|
|
if (content_type != null) return error.UnexpectedBody;
|
|
const new_ctx = addField(ctx, "body", {});
|
|
//if (true) @compileError("bug");
|
|
return next.handle(req, res, new_ctx, {});
|
|
}
|
|
|
|
var stream = req.body orelse return error.NoBody;
|
|
const body = try parseBodyFromRequest(Body, options, content_type, stream.reader(), ctx.allocator);
|
|
defer util.deepFree(ctx.allocator, body);
|
|
|
|
return next.handle(
|
|
req,
|
|
res,
|
|
addField(ctx, "body", body),
|
|
{},
|
|
);
|
|
}
|
|
};
|
|
}
|
|
pub fn parseBody(comptime Body: type) ParseBody(Body) {
|
|
return .{};
|
|
}
|
|
|
|
test "parseBodyFromRequest" {
|
|
const testCase = struct {
|
|
fn case(content_type: []const u8, body: []const u8, expected: anytype) !void {
|
|
var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) };
|
|
const result = try parseBodyFromRequest(@TypeOf(expected), .{}, content_type, stream.reader(), std.testing.allocator);
|
|
defer util.deepFree(std.testing.allocator, result);
|
|
|
|
try util.testing.expectDeepEqual(expected, result);
|
|
}
|
|
}.case;
|
|
|
|
const Struct = struct {
|
|
id: usize,
|
|
};
|
|
try testCase("application/json", "{\"id\": 3}", Struct{ .id = 3 });
|
|
try testCase("application/x-www-form-urlencoded", "id=3", Struct{ .id = 3 });
|
|
|
|
//try testCase("multipart/form-data; ",
|
|
//\\
|
|
//, Struct{ .id = 3 });
|
|
}
|
|
|
|
test "parseBody" {
|
|
const Struct = struct {
|
|
foo: []const u8,
|
|
};
|
|
const body =
|
|
\\{"foo": "bar"}
|
|
;
|
|
var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) };
|
|
var headers = http.Fields.init(std.testing.allocator);
|
|
defer headers.deinit();
|
|
|
|
try parseBody(Struct).handle(
|
|
.{ .body = @as(?std.io.StreamSource, stream), .headers = headers },
|
|
.{},
|
|
.{ .allocator = std.testing.allocator },
|
|
struct {
|
|
fn handle(_: anytype, _: anytype, _: anytype, ctx: anytype, _: void) !void {
|
|
try util.testing.expectDeepEqual(Struct{ .foo = "bar" }, ctx.body);
|
|
}
|
|
}{},
|
|
);
|
|
}
|
|
|
|
/// Parses query parameters as defined in query.zig
|
|
pub fn ParseQueryParams(comptime QueryParams: type) type {
|
|
return struct {
|
|
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
|
if (QueryParams == void) return next.handle(req, res, addField(ctx, "query_params", {}), {});
|
|
const query = try urlencode.parse(ctx.allocator, true, QueryParams, ctx.query_string);
|
|
defer util.deepFree(ctx.allocator, query);
|
|
|
|
return next.handle(
|
|
req,
|
|
res,
|
|
addField(ctx, "query_params", query),
|
|
{},
|
|
);
|
|
}
|
|
};
|
|
}
|