Fix middleware compilation???
This commit is contained in:
parent
04a95a280b
commit
9c5e46ec5a
9 changed files with 115 additions and 54 deletions
|
@ -82,7 +82,7 @@ pub fn build(b: *std.build.Builder) !void {
|
|||
|
||||
const pkgs = makePkgs(b, options.getPackage("build_options"));
|
||||
|
||||
const exe = b.addExecutable("apub", "src/main/main.zig");
|
||||
const exe = b.addExecutable("fediglam", "src/main/main.zig");
|
||||
exe.setTarget(target);
|
||||
exe.setBuildMode(mode);
|
||||
|
||||
|
@ -96,6 +96,7 @@ pub fn build(b: *std.build.Builder) !void {
|
|||
if (enable_sqlite) exe.linkSystemLibrary("sqlite3");
|
||||
if (enable_postgres) exe.linkSystemLibrary("pq");
|
||||
exe.linkLibC();
|
||||
exe.addSystemIncludePath("/usr/include/");
|
||||
|
||||
//const util_tests = b.addTest("src/util/lib.zig");
|
||||
const http_tests = b.addTest("src/http/test.zig");
|
||||
|
|
|
@ -10,7 +10,7 @@ const json_utils = @import("./json.zig");
|
|||
|
||||
fn printFields(comptime fields: []const std.builtin.Type.StructField) void {
|
||||
comptime {
|
||||
inline for (fields) |f| @compileLog(f.name.ptr);
|
||||
inline for (fields) |f| @compileLog(f.name.ptr, f.field_type);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -47,6 +47,35 @@ fn addFields(lhs: anytype, rhs: anytype) AddFields(@TypeOf(lhs), @TypeOf(rhs)) {
|
|||
return result;
|
||||
}
|
||||
|
||||
fn AddUniqueField(comptime Lhs: type, comptime N: usize, comptime name: [N]u8, comptime Val: type) type {
|
||||
const Ctx = @Type(.{ .Struct = .{
|
||||
.layout = .Auto,
|
||||
.fields = std.meta.fields(Lhs) ++ &[_]std.builtin.Type.StructField{
|
||||
.{
|
||||
.name = &name,
|
||||
.field_type = Val,
|
||||
.alignment = if (@sizeOf(Val) != 0) @alignOf(Val) else 0,
|
||||
.default_value = null,
|
||||
.is_comptime = false,
|
||||
},
|
||||
},
|
||||
.decls = &.{},
|
||||
.is_tuple = false,
|
||||
} });
|
||||
return Ctx;
|
||||
}
|
||||
|
||||
fn AddField(comptime Lhs: type, comptime name: []const u8, comptime Val: type) type {
|
||||
return AddUniqueField(Lhs, name.len, name[0..].*, Val);
|
||||
}
|
||||
|
||||
fn addField(lhs: anytype, comptime name: []const u8, val: anytype) AddField(@TypeOf(lhs), name, @TypeOf(val)) {
|
||||
var result: AddField(@TypeOf(lhs), name, @TypeOf(val)) = undefined;
|
||||
inline for (std.meta.fields(@TypeOf(lhs))) |f| @field(result, f.name) = @field(lhs, f.name);
|
||||
@field(result, name) = val;
|
||||
return result;
|
||||
}
|
||||
|
||||
test {
|
||||
// apply is a plumbing function that applies a tuple of middlewares in order
|
||||
const base = apply(.{
|
||||
|
@ -94,10 +123,22 @@ pub fn InjectContext(comptime Values: type) type {
|
|||
};
|
||||
}
|
||||
|
||||
pub fn InjectContextValue(comptime name: []const u8, comptime V: type) type {
|
||||
return struct {
|
||||
val: V,
|
||||
pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
||||
return next.handle(req, res, addField(ctx, name, self.val), {});
|
||||
}
|
||||
};
|
||||
}
|
||||
pub fn injectContext(values: anytype) InjectContext(@TypeOf(values)) {
|
||||
return .{ .values = values };
|
||||
}
|
||||
|
||||
pub fn injectContextValue(comptime name: []const u8, val: anytype) InjectContextValue(name, @TypeOf(val)) {
|
||||
return .{ .val = val };
|
||||
}
|
||||
|
||||
pub fn NextHandler(comptime First: type, comptime Next: type) type {
|
||||
return struct {
|
||||
first: First,
|
||||
|
@ -320,10 +361,11 @@ fn parsePathArg(comptime T: type, segment: []const u8) !T {
|
|||
pub fn ParsePathArgs(comptime route: []const u8, comptime Args: type) type {
|
||||
return struct {
|
||||
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
||||
if (Args == void) return next.handle(req, res, addField(ctx, "args", {}), {});
|
||||
return next.handle(
|
||||
req,
|
||||
res,
|
||||
addFields(ctx, .{ .args = try parsePathArgs(route, Args, ctx.path) }),
|
||||
addField(ctx, "args", try parsePathArgs(route, Args, ctx.path)),
|
||||
{},
|
||||
);
|
||||
}
|
||||
|
@ -339,6 +381,7 @@ const BaseContentType = enum {
|
|||
};
|
||||
|
||||
fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T {
|
||||
//@compileLog(T);
|
||||
const buf = try reader.readAllAlloc(alloc, 1 << 16);
|
||||
defer alloc.free(buf);
|
||||
|
||||
|
@ -372,7 +415,15 @@ fn matchContentType(hdr: ?[]const u8) ?BaseContentType {
|
|||
pub fn ParseBody(comptime Body: type) type {
|
||||
return struct {
|
||||
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
||||
const base_content_type = matchContentType(req.headers.get("Content-Type"));
|
||||
const content_type = req.headers.get("Content-Type");
|
||||
if (Body == void) {
|
||||
if (content_type != null) return error.UnexpectedBody;
|
||||
const new_ctx = addField(ctx, "body", {});
|
||||
//if (true) @compileError("bug");
|
||||
return next.handle(req, res, new_ctx, {});
|
||||
}
|
||||
|
||||
const base_content_type = matchContentType(content_type);
|
||||
|
||||
var stream = req.body orelse return error.NoBody;
|
||||
const body = try parseBody(Body, base_content_type orelse .json, stream.reader(), ctx.allocator);
|
||||
|
@ -381,7 +432,7 @@ pub fn ParseBody(comptime Body: type) type {
|
|||
return next.handle(
|
||||
req,
|
||||
res,
|
||||
addFields(ctx, .{ .body = body }),
|
||||
addField(ctx, "body", body),
|
||||
{},
|
||||
);
|
||||
}
|
||||
|
@ -391,13 +442,14 @@ pub fn ParseBody(comptime Body: type) type {
|
|||
pub fn ParseQueryParams(comptime QueryParams: type) type {
|
||||
return struct {
|
||||
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
||||
if (QueryParams == void) return next.handle(req, res, addField(ctx, "query_params", {}), {});
|
||||
const query = try query_utils.parseQuery(ctx.allocator, QueryParams, ctx.query_string);
|
||||
defer util.deepFree(ctx.allocator, query);
|
||||
|
||||
return next.handle(
|
||||
req,
|
||||
res,
|
||||
addFields(ctx, .{ .query_params = query }),
|
||||
addField(ctx, "query_params", query),
|
||||
{},
|
||||
);
|
||||
}
|
||||
|
|
|
@ -10,9 +10,12 @@ pub const Response = struct {
|
|||
alloc: std.mem.Allocator,
|
||||
stream: Stream,
|
||||
should_close: bool = false,
|
||||
was_opened: bool = true,
|
||||
|
||||
pub const ResponseStream = response.ResponseStream(Stream.Writer);
|
||||
pub fn open(self: *Response, status: http.Status, headers: *const http.Fields) !ResponseStream {
|
||||
std.debug.assert(!self.was_opened);
|
||||
self.was_opened = true;
|
||||
if (headers.get("Connection")) |hdr| {
|
||||
if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true;
|
||||
}
|
||||
|
@ -21,6 +24,8 @@ pub const Response = struct {
|
|||
}
|
||||
|
||||
pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Fields) !Stream {
|
||||
std.debug.assert(!self.was_opened);
|
||||
self.was_opened = true;
|
||||
try response.writeRequestHeader(self.stream.writer(), headers, status);
|
||||
return self.stream;
|
||||
}
|
||||
|
@ -92,6 +97,7 @@ pub const Server = struct {
|
|||
pub fn handleLoop(
|
||||
self: *Server,
|
||||
allocator: std.mem.Allocator,
|
||||
initial_context: anytype,
|
||||
handler: anytype,
|
||||
) void {
|
||||
while (true) {
|
||||
|
@ -108,6 +114,7 @@ pub const Server = struct {
|
|||
.stream = Stream{ .kind = .tcp, .socket = conn.stream.handle },
|
||||
.address = conn.address,
|
||||
},
|
||||
initial_context,
|
||||
handler,
|
||||
);
|
||||
}
|
||||
|
@ -116,6 +123,7 @@ pub const Server = struct {
|
|||
fn serveConn(
|
||||
allocator: std.mem.Allocator,
|
||||
conn: Connection,
|
||||
initial_context: anytype,
|
||||
handler: anytype,
|
||||
) void {
|
||||
while (true) {
|
||||
|
@ -127,12 +135,16 @@ pub const Server = struct {
|
|||
error.HttpVersionNotSupported => .http_version_not_supported,
|
||||
|
||||
else => blk: {
|
||||
std.log.err("Unknown error parsing request: {}\n{?s}", .{ err, @errorReturnTrace() });
|
||||
std.log.err("Unknown error parsing request: {}\n{?}", .{ err, @errorReturnTrace() });
|
||||
break :blk .internal_server_error;
|
||||
},
|
||||
};
|
||||
|
||||
try conn.stream.writer().print("HTTP/1.1 {} {?s}\r\nConnection: close\r\n\r\n", .{ @enumToInt(status), status.phrase() });
|
||||
conn.stream.writer().print(
|
||||
"HTTP/1.1 {} {?s}\r\nConnection: close\r\n\r\n",
|
||||
.{ @enumToInt(status), status.phrase() },
|
||||
) catch {};
|
||||
return;
|
||||
};
|
||||
|
||||
var res = Response{
|
||||
|
@ -140,8 +152,8 @@ pub const Server = struct {
|
|||
.stream = conn.stream,
|
||||
};
|
||||
|
||||
handler.handle(&req, &res, .{}, {}) catch |err| {
|
||||
std.log.err("Unhandled error serving request: {}\n{?s}", .{ err, @errorReturnTrace() });
|
||||
handler.handle(&req, &res, initial_context, {}) catch |err| {
|
||||
std.log.err("Unhandled error serving request: {}\n{?}", .{ err, @errorReturnTrace() });
|
||||
return;
|
||||
};
|
||||
|
||||
|
|
|
@ -23,18 +23,18 @@ const Opcode = enum(u4) {
|
|||
}
|
||||
};
|
||||
|
||||
pub fn handshake(alloc: std.mem.Allocator, req: *http.Request, res: *http.Response) !Socket {
|
||||
const upgrade = req.headers.get("Upgrade") orelse return error.BadHandshake;
|
||||
const connection = req.headers.get("Connection") orelse return error.BadHandshake;
|
||||
pub fn handshake(alloc: std.mem.Allocator, req_headers: *const http.Fields, res: *http.Response) !Socket {
|
||||
const upgrade = req_headers.get("Upgrade") orelse return error.BadHandshake;
|
||||
const connection = req_headers.get("Connection") orelse return error.BadHandshake;
|
||||
if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake;
|
||||
if (std.ascii.indexOfIgnoreCase(connection, "Upgrade") == null) return error.BadHandshake;
|
||||
|
||||
const key_hdr = req.headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake;
|
||||
const key_hdr = req_headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake;
|
||||
if ((try std.base64.standard.Decoder.calcSizeForSlice(key_hdr)) != 16) return error.BadHandshake;
|
||||
var key: [16]u8 = undefined;
|
||||
std.base64.standard.Decoder.decode(&key, key_hdr) catch return error.BadHandshake;
|
||||
|
||||
const version = req.headers.get("Sec-WebSocket-version") orelse return error.BadHandshake;
|
||||
const version = req_headers.get("Sec-WebSocket-version") orelse return error.BadHandshake;
|
||||
if (!std.mem.eql(u8, "13", version)) return error.BadHandshake;
|
||||
|
||||
var headers = http.Fields.init(alloc);
|
||||
|
|
|
@ -60,7 +60,7 @@ const inject_api_conn = struct {
|
|||
var api_conn = try getApiConn(ctx.allocator, ctx.api_source, req);
|
||||
defer api_conn.close();
|
||||
|
||||
return mdw.injectContext(.{ .api_conn = &api_conn }).handle(
|
||||
return mdw.injectContextValue("api_conn", &api_conn).handle(
|
||||
req,
|
||||
res,
|
||||
ctx,
|
||||
|
@ -85,19 +85,19 @@ pub fn EndpointRequest(comptime Endpoint: type) type {
|
|||
body: Body,
|
||||
query: Query,
|
||||
|
||||
const args_middleware = if (Args == void)
|
||||
mdw.injectContext(.{ .args = {} })
|
||||
else
|
||||
const args_middleware = //if (Args == void)
|
||||
//mdw.injectContext(.{ .args = {} })
|
||||
//else
|
||||
mdw.ParsePathArgs(Endpoint.path, Args){};
|
||||
|
||||
const body_middleware = if (Body == void)
|
||||
mdw.injectContext(.{ .body = {} })
|
||||
else
|
||||
const body_middleware = //if (Body == void)
|
||||
//mdw.injectContext(.{ .body = {} })
|
||||
//else
|
||||
mdw.ParseBody(Body){};
|
||||
|
||||
const query_middleware = if (Query == void)
|
||||
mdw.injectContext(.{ .query_params = {} })
|
||||
else
|
||||
const query_middleware = //if (Query == void)
|
||||
//mdw.injectContext(.{ .query_params = {} })
|
||||
//else
|
||||
mdw.ParseQueryParams(Query){};
|
||||
};
|
||||
}
|
||||
|
@ -157,9 +157,9 @@ const api_router = mdw.apply(.{
|
|||
|
||||
pub const router = mdw.apply(.{
|
||||
mdw.split_uri,
|
||||
mdw.catchErrors(mdw.default_error_handler),
|
||||
//mdw.router(.{api_router} ++ web_endpoints),
|
||||
mdw.router(web_endpoints),
|
||||
//mdw.catchErrors(mdw.default_error_handler),
|
||||
mdw.router(.{api_router} ++ web_endpoints),
|
||||
//mdw.router(web_endpoints),
|
||||
//api_router,
|
||||
});
|
||||
|
||||
|
|
|
@ -10,20 +10,20 @@ const streaming = @import("./api/streaming.zig");
|
|||
const timelines = @import("./api/timelines.zig");
|
||||
|
||||
pub const routes = .{
|
||||
//controllers.apiEndpoint(auth.login),
|
||||
//controllers.apiEndpoint(auth.verify_login),
|
||||
//controllers.apiEndpoint(communities.create),
|
||||
//controllers.apiEndpoint(communities.query),
|
||||
//controllers.apiEndpoint(invites.create),
|
||||
//controllers.apiEndpoint(users.create),
|
||||
//controllers.apiEndpoint(notes.create),
|
||||
//controllers.apiEndpoint(notes.get),
|
||||
controllers.apiEndpoint(auth.login),
|
||||
controllers.apiEndpoint(auth.verify_login),
|
||||
controllers.apiEndpoint(communities.create),
|
||||
controllers.apiEndpoint(communities.query),
|
||||
controllers.apiEndpoint(invites.create),
|
||||
controllers.apiEndpoint(users.create),
|
||||
controllers.apiEndpoint(notes.create),
|
||||
controllers.apiEndpoint(notes.get),
|
||||
//controllers.apiEndpoint(streaming.streaming),
|
||||
//controllers.apiEndpoint(timelines.global),
|
||||
//controllers.apiEndpoint(timelines.local),
|
||||
//controllers.apiEndpoint(timelines.home),
|
||||
//controllers.apiEndpoint(follows.create),
|
||||
//controllers.apiEndpoint(follows.delete),
|
||||
//controllers.apiEndpoint(follows.query_followers),
|
||||
//controllers.apiEndpoint(follows.query_following),
|
||||
controllers.apiEndpoint(timelines.global),
|
||||
controllers.apiEndpoint(timelines.local),
|
||||
controllers.apiEndpoint(timelines.home),
|
||||
controllers.apiEndpoint(follows.create),
|
||||
controllers.apiEndpoint(follows.delete),
|
||||
controllers.apiEndpoint(follows.query_followers),
|
||||
controllers.apiEndpoint(follows.query_following),
|
||||
};
|
||||
|
|
|
@ -2,11 +2,11 @@ const std = @import("std");
|
|||
const controllers = @import("../controllers.zig");
|
||||
|
||||
pub const routes = .{
|
||||
//controllers.apiEndpoint(index),
|
||||
//controllers.apiEndpoint(about),
|
||||
//controllers.apiEndpoint(login),
|
||||
//controllers.apiEndpoint(global_timeline),
|
||||
//controllers.apiEndpoint(cluster.overview),
|
||||
controllers.apiEndpoint(index),
|
||||
controllers.apiEndpoint(about),
|
||||
controllers.apiEndpoint(login),
|
||||
controllers.apiEndpoint(global_timeline),
|
||||
controllers.apiEndpoint(cluster.overview),
|
||||
};
|
||||
|
||||
const index = struct {
|
||||
|
|
|
@ -70,11 +70,7 @@ fn thread_main(src: *api.ApiSource, srv: *http.Server) void {
|
|||
util.seedThreadPrng() catch unreachable;
|
||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||
defer _ = gpa.deinit();
|
||||
srv.handleLoop(gpa.allocator(), .{ .src = src, .allocator = gpa.allocator() }, handle);
|
||||
}
|
||||
|
||||
fn handle(ctx: anytype, req: *http.Request, res: *http.Response) void {
|
||||
c.routeRequest(ctx.src, req, res, ctx.allocator);
|
||||
srv.handleLoop(gpa.allocator(), .{ .api_source = src, .allocator = gpa.allocator() }, c.router);
|
||||
}
|
||||
|
||||
pub fn main() !void {
|
||||
|
|
|
@ -40,7 +40,7 @@ fn handleUnexpectedError(db: *c.sqlite3, code: c_int, sql_text: ?[]const u8) err
|
|||
std.log.debug("Failed at char ({}:{}) of SQL:\n{s}", .{ pos.row, pos.col, sql });
|
||||
}
|
||||
}
|
||||
std.log.debug("{?s}", .{@errorReturnTrace()});
|
||||
std.log.debug("{?}", .{@errorReturnTrace()});
|
||||
|
||||
return error.Unexpected;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue