fediglam/src/http/middleware.zig

409 lines
13 KiB
Zig
Raw Normal View History

2022-11-24 04:51:30 +00:00
const std = @import("std");
const root = @import("root");
const builtin = @import("builtin");
const http = @import("./lib.zig");
const util = @import("util");
2022-11-26 01:43:16 +00:00
const query_utils = @import("./query.zig");
const json_utils = @import("./json.zig");
2022-11-27 01:33:46 +00:00
fn AddUniqueField(comptime Lhs: type, comptime N: usize, comptime name: [N]u8, comptime Val: type) type {
const Ctx = @Type(.{ .Struct = .{
.layout = .Auto,
.fields = std.meta.fields(Lhs) ++ &[_]std.builtin.Type.StructField{
.{
.name = &name,
.field_type = Val,
.alignment = if (@sizeOf(Val) != 0) @alignOf(Val) else 0,
.default_value = null,
.is_comptime = false,
},
},
.decls = &.{},
.is_tuple = false,
} });
return Ctx;
}
fn AddField(comptime Lhs: type, comptime name: []const u8, comptime Val: type) type {
return AddUniqueField(Lhs, name.len, name[0..].*, Val);
}
fn addField(lhs: anytype, comptime name: []const u8, val: anytype) AddField(@TypeOf(lhs), name, @TypeOf(val)) {
var result: AddField(@TypeOf(lhs), name, @TypeOf(val)) = undefined;
inline for (std.meta.fields(@TypeOf(lhs))) |f| @field(result, f.name) = @field(lhs, f.name);
@field(result, name) = val;
return result;
}
2022-11-24 04:51:30 +00:00
test {
// apply is a plumbing function that applies a tuple of middlewares in order
const base = apply(.{
split_uri,
mount("/abc"),
});
const request = .{ .uri = "/abc/defg/hijkl?some_query=true#section" };
const response = .{};
const initial_context = .{};
try base.handle(request, response, initial_context, {});
}
fn ApplyInternal(comptime fields: []const std.builtin.Type.StructField) type {
if (fields.len == 0) return void;
return NextHandler(
fields[0].field_type,
ApplyInternal(fields[1..]),
);
}
fn applyInternal(middlewares: anytype, comptime fields: []const std.builtin.Type.StructField) ApplyInternal(fields) {
if (fields.len == 0) return {};
return .{
.first = @field(middlewares, fields[0].name),
.next = applyInternal(middlewares, fields[1..]),
};
}
2022-11-24 11:30:49 +00:00
pub fn apply(middlewares: anytype) Apply(@TypeOf(middlewares)) {
2022-11-24 04:51:30 +00:00
return applyInternal(middlewares, std.meta.fields(@TypeOf(middlewares)));
}
2022-11-24 11:30:49 +00:00
pub fn Apply(comptime Middlewares: type) type {
return ApplyInternal(std.meta.fields(Middlewares));
}
2022-11-27 01:33:46 +00:00
pub fn InjectContextValue(comptime name: []const u8, comptime V: type) type {
return struct {
val: V,
pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
return next.handle(req, res, addField(ctx, name, self.val), {});
}
};
}
2022-11-24 11:30:49 +00:00
2022-11-27 01:33:46 +00:00
pub fn injectContextValue(comptime name: []const u8, val: anytype) InjectContextValue(name, @TypeOf(val)) {
return .{ .val = val };
}
2022-11-24 04:51:30 +00:00
pub fn NextHandler(comptime First: type, comptime Next: type) type {
return struct {
first: First,
next: Next,
pub fn handle(
self: @This(),
req: anytype,
res: anytype,
ctx: anytype,
next: void,
) !void {
_ = next;
return self.first.handle(req, res, ctx, self.next);
}
};
}
pub fn CatchErrors(comptime ErrorHandler: type) type {
return struct {
error_handler: ErrorHandler,
pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
return next.handle(req, res, ctx, {}) catch |err| {
return self.error_handler.handle(
req,
res,
2022-11-27 01:52:30 +00:00
addField(ctx, "err", err),
2022-11-24 04:51:30 +00:00
next,
);
};
}
};
}
pub fn catchErrors(error_handler: anytype) CatchErrors(@TypeOf(error_handler)) {
return .{ .error_handler = error_handler };
}
pub const default_error_handler = struct {
fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
_ = next;
std.log.err("Error {} on uri {s}", .{ ctx.err, req.uri });
// Tell the server to close the connection after this request
res.should_close = true;
2022-11-26 01:43:16 +00:00
var buf: [1024]u8 = undefined;
var fba = std.heap.FixedBufferAllocator.init(&buf);
var headers = http.Fields.init(fba.allocator());
if (!res.was_opened) {
var stream = res.open(.internal_server_error, &headers) catch return;
defer stream.close();
stream.finish() catch {};
}
2022-11-24 04:51:30 +00:00
}
}{};
pub const split_uri = struct {
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
var frag_split = std.mem.split(u8, req.uri, "#");
const without_fragment = frag_split.first();
const fragment = frag_split.rest();
var query_split = std.mem.split(u8, without_fragment, "?");
const path = query_split.first();
const query = query_split.rest();
2022-11-27 01:52:30 +00:00
const new_ctx = addField(
addField(
addField(ctx, "path", path),
"query_string",
query,
),
"fragment_string",
fragment,
);
2022-11-24 04:51:30 +00:00
return next.handle(
req,
res,
2022-11-27 01:52:30 +00:00
new_ctx,
2022-11-24 04:51:30 +00:00
{},
);
}
}{};
2022-11-24 11:30:49 +00:00
// routes a request to the correct handler based on declared HTTP method and path
2022-11-26 01:43:16 +00:00
pub fn Router(comptime Routes: type) type {
2022-11-24 11:30:49 +00:00
return struct {
2022-11-26 01:43:16 +00:00
routes: Routes,
2022-11-24 04:51:30 +00:00
2022-11-24 11:30:49 +00:00
pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: void) !void {
_ = next;
inline for (self.routes) |r| {
if (r.handle(req, res, ctx, {})) |_|
// success
return
else |err| switch (err) {
error.RouteMismatch => {},
else => return err,
}
}
return error.RouteMismatch;
}
};
}
2022-11-26 01:43:16 +00:00
pub fn router(routes: anytype) Router(@TypeOf(routes)) {
return Router(@TypeOf(routes)){ .routes = routes };
2022-11-24 11:55:47 +00:00
}
2022-11-24 11:30:49 +00:00
// helper function for doing route analysis
fn pathMatches(route: []const u8, path: []const u8) bool {
var path_iter = util.PathIter.from(path);
var route_iter = util.PathIter.from(route);
while (route_iter.next()) |route_segment| {
2022-11-24 04:51:30 +00:00
const path_segment = path_iter.next() orelse return false;
if (route_segment.len > 0 and route_segment[0] == ':') {
// Route Argument
} else {
if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return false;
}
}
if (path_iter.next() != null) return false;
return true;
}
2022-11-24 11:30:49 +00:00
pub const Route = struct {
pub const Desc = struct {
path: []const u8,
method: http.Method,
};
2022-11-24 04:51:30 +00:00
2022-11-24 11:30:49 +00:00
desc: Desc,
2022-11-24 04:51:30 +00:00
2022-11-24 11:30:49 +00:00
fn applies(self: @This(), req: anytype, ctx: anytype) bool {
if (self.desc.method != req.method) return false;
2022-11-24 04:51:30 +00:00
2022-11-26 01:43:16 +00:00
const eff_path = if (@hasField(@TypeOf(ctx), "path"))
2022-11-24 11:30:49 +00:00
ctx.path
else
std.mem.sliceTo(req.uri, '?');
2022-11-24 04:51:30 +00:00
2022-11-24 11:30:49 +00:00
return pathMatches(self.desc.path, eff_path);
}
pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
2022-11-27 01:52:30 +00:00
std.log.debug("Testing path {s} against {s}", .{ ctx.path, self.desc.path });
2022-11-24 11:30:49 +00:00
return if (self.applies(req, ctx))
next.handle(req, res, ctx, {})
else
error.RouteMismatch;
}
};
2022-11-24 04:51:30 +00:00
pub fn Mount(comptime route: []const u8) type {
return struct {
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
var path_iter = util.PathIter.from(ctx.path);
comptime var route_iter = util.PathIter.from(route);
var path_unused = ctx.path;
inline while (comptime route_iter.next()) |route_segment| {
if (comptime route_segment.len == 0) continue;
const path_segment = path_iter.next() orelse return error.RouteMismatch;
path_unused = path_iter.rest();
if (comptime route_segment[0] == ':') {
@compileLog("Argument segments cannot be mounted");
// Route Argument
} else {
if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch;
}
}
var new_ctx = ctx;
new_ctx.path = path_unused;
return next.handle(req, res, new_ctx, {});
}
};
}
pub fn mount(comptime route: []const u8) Mount(route) {
return .{};
}
pub fn HandleNotFound(comptime NotFoundHandler: type) type {
return struct {
not_found: NotFoundHandler,
pub fn handler(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
return next.handler(req, res, ctx, {}) catch |err| switch (err) {
error.RouteMismatch => return self.not_found.handler(req, res, ctx, {}),
else => return err,
};
}
};
}
fn parsePathArgs(comptime route: []const u8, comptime Args: type, path: []const u8) !Args {
var args: Args = undefined;
var path_iter = util.PathIter.from(path);
comptime var route_iter = util.PathIter.from(route);
inline while (comptime route_iter.next()) |route_segment| {
const path_segment = path_iter.next() orelse return error.RouteMismatch;
if (route_segment.len > 0 and route_segment[0] == ':') {
// route segment is an argument segment
const A = @TypeOf(@field(args, route_segment[1..]));
@field(args, route_segment[1..]) = try parsePathArg(A, path_segment);
} else {
if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch;
}
}
if (path_iter.next() != null) return error.RouteMismatch;
return args;
}
fn parsePathArg(comptime T: type, segment: []const u8) !T {
if (T == []const u8) return segment;
if (comptime std.meta.trait.isContainer(T) and std.meta.trait.hasFn("parse")(T)) return T.parse(segment);
@compileError("Unsupported Type " ++ @typeName(T));
}
pub fn ParsePathArgs(comptime route: []const u8, comptime Args: type) type {
return struct {
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
2022-11-27 01:33:46 +00:00
if (Args == void) return next.handle(req, res, addField(ctx, "args", {}), {});
2022-11-24 04:51:30 +00:00
return next.handle(
req,
res,
2022-11-27 01:33:46 +00:00
addField(ctx, "args", try parsePathArgs(route, Args, ctx.path)),
2022-11-24 04:51:30 +00:00
{},
);
}
};
}
const BaseContentType = enum {
json,
url_encoded,
octet_stream,
other,
};
fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T {
2022-11-27 01:33:46 +00:00
//@compileLog(T);
2022-11-24 04:51:30 +00:00
const buf = try reader.readAllAlloc(alloc, 1 << 16);
defer alloc.free(buf);
switch (content_type) {
.octet_stream, .json => {
const body = try json_utils.parse(T, buf, alloc);
defer json_utils.parseFree(body, alloc);
return try util.deepClone(alloc, body);
},
.url_encoded => return query_utils.parseQuery(alloc, T, buf) catch |err| switch (err) {
error.NoQuery => error.NoBody,
else => err,
},
else => return error.UnsupportedMediaType,
}
}
fn matchContentType(hdr: ?[]const u8) ?BaseContentType {
if (hdr) |h| {
if (std.ascii.eqlIgnoreCase(h, "application/x-www-form-urlencoded")) return .url_encoded;
if (std.ascii.eqlIgnoreCase(h, "application/json")) return .json;
if (std.ascii.eqlIgnoreCase(h, "application/octet-stream")) return .octet_stream;
return .other;
}
return null;
}
pub fn ParseBody(comptime Body: type) type {
return struct {
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
2022-11-27 01:33:46 +00:00
const content_type = req.headers.get("Content-Type");
if (Body == void) {
if (content_type != null) return error.UnexpectedBody;
const new_ctx = addField(ctx, "body", {});
//if (true) @compileError("bug");
return next.handle(req, res, new_ctx, {});
}
const base_content_type = matchContentType(content_type);
2022-11-24 04:51:30 +00:00
var stream = req.body orelse return error.NoBody;
const body = try parseBody(Body, base_content_type orelse .json, stream.reader(), ctx.allocator);
2022-11-26 01:43:16 +00:00
defer util.deepFree(ctx.allocator, body);
2022-11-24 04:51:30 +00:00
2022-11-26 01:43:16 +00:00
return next.handle(
2022-11-24 04:51:30 +00:00
req,
res,
2022-11-27 01:33:46 +00:00
addField(ctx, "body", body),
2022-11-26 01:43:16 +00:00
{},
2022-11-24 04:51:30 +00:00
);
}
};
}
2022-11-24 11:30:49 +00:00
pub fn ParseQueryParams(comptime QueryParams: type) type {
2022-11-24 04:51:30 +00:00
return struct {
2022-11-24 11:30:49 +00:00
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
2022-11-27 01:33:46 +00:00
if (QueryParams == void) return next.handle(req, res, addField(ctx, "query_params", {}), {});
2022-11-24 04:51:30 +00:00
const query = try query_utils.parseQuery(ctx.allocator, QueryParams, ctx.query_string);
2022-11-26 01:43:16 +00:00
defer util.deepFree(ctx.allocator, query);
2022-11-24 04:51:30 +00:00
2022-11-24 11:30:49 +00:00
return next.handle(
2022-11-24 04:51:30 +00:00
req,
res,
2022-11-27 01:33:46 +00:00
addField(ctx, "query_params", query),
2022-11-24 11:30:49 +00:00
{},
2022-11-24 04:51:30 +00:00
);
}
};
}