Fix middleware compilation???

This commit is contained in:
jaina heartles 2022-11-26 17:33:46 -08:00
parent 04a95a280b
commit 9c5e46ec5a
9 changed files with 115 additions and 54 deletions

View file

@ -82,7 +82,7 @@ pub fn build(b: *std.build.Builder) !void {
const pkgs = makePkgs(b, options.getPackage("build_options"));
const exe = b.addExecutable("apub", "src/main/main.zig");
const exe = b.addExecutable("fediglam", "src/main/main.zig");
exe.setTarget(target);
exe.setBuildMode(mode);
@ -96,6 +96,7 @@ pub fn build(b: *std.build.Builder) !void {
if (enable_sqlite) exe.linkSystemLibrary("sqlite3");
if (enable_postgres) exe.linkSystemLibrary("pq");
exe.linkLibC();
exe.addSystemIncludePath("/usr/include/");
//const util_tests = b.addTest("src/util/lib.zig");
const http_tests = b.addTest("src/http/test.zig");

View file

@ -10,7 +10,7 @@ const json_utils = @import("./json.zig");
fn printFields(comptime fields: []const std.builtin.Type.StructField) void {
comptime {
inline for (fields) |f| @compileLog(f.name.ptr);
inline for (fields) |f| @compileLog(f.name.ptr, f.field_type);
}
}
@ -47,6 +47,35 @@ fn addFields(lhs: anytype, rhs: anytype) AddFields(@TypeOf(lhs), @TypeOf(rhs)) {
return result;
}
fn AddUniqueField(comptime Lhs: type, comptime N: usize, comptime name: [N]u8, comptime Val: type) type {
const Ctx = @Type(.{ .Struct = .{
.layout = .Auto,
.fields = std.meta.fields(Lhs) ++ &[_]std.builtin.Type.StructField{
.{
.name = &name,
.field_type = Val,
.alignment = if (@sizeOf(Val) != 0) @alignOf(Val) else 0,
.default_value = null,
.is_comptime = false,
},
},
.decls = &.{},
.is_tuple = false,
} });
return Ctx;
}
fn AddField(comptime Lhs: type, comptime name: []const u8, comptime Val: type) type {
return AddUniqueField(Lhs, name.len, name[0..].*, Val);
}
fn addField(lhs: anytype, comptime name: []const u8, val: anytype) AddField(@TypeOf(lhs), name, @TypeOf(val)) {
var result: AddField(@TypeOf(lhs), name, @TypeOf(val)) = undefined;
inline for (std.meta.fields(@TypeOf(lhs))) |f| @field(result, f.name) = @field(lhs, f.name);
@field(result, name) = val;
return result;
}
test {
// apply is a plumbing function that applies a tuple of middlewares in order
const base = apply(.{
@ -94,10 +123,22 @@ pub fn InjectContext(comptime Values: type) type {
};
}
pub fn InjectContextValue(comptime name: []const u8, comptime V: type) type {
return struct {
val: V,
pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
return next.handle(req, res, addField(ctx, name, self.val), {});
}
};
}
pub fn injectContext(values: anytype) InjectContext(@TypeOf(values)) {
return .{ .values = values };
}
pub fn injectContextValue(comptime name: []const u8, val: anytype) InjectContextValue(name, @TypeOf(val)) {
return .{ .val = val };
}
pub fn NextHandler(comptime First: type, comptime Next: type) type {
return struct {
first: First,
@ -320,10 +361,11 @@ fn parsePathArg(comptime T: type, segment: []const u8) !T {
pub fn ParsePathArgs(comptime route: []const u8, comptime Args: type) type {
return struct {
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
if (Args == void) return next.handle(req, res, addField(ctx, "args", {}), {});
return next.handle(
req,
res,
addFields(ctx, .{ .args = try parsePathArgs(route, Args, ctx.path) }),
addField(ctx, "args", try parsePathArgs(route, Args, ctx.path)),
{},
);
}
@ -339,6 +381,7 @@ const BaseContentType = enum {
};
fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T {
//@compileLog(T);
const buf = try reader.readAllAlloc(alloc, 1 << 16);
defer alloc.free(buf);
@ -372,7 +415,15 @@ fn matchContentType(hdr: ?[]const u8) ?BaseContentType {
pub fn ParseBody(comptime Body: type) type {
return struct {
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
const base_content_type = matchContentType(req.headers.get("Content-Type"));
const content_type = req.headers.get("Content-Type");
if (Body == void) {
if (content_type != null) return error.UnexpectedBody;
const new_ctx = addField(ctx, "body", {});
//if (true) @compileError("bug");
return next.handle(req, res, new_ctx, {});
}
const base_content_type = matchContentType(content_type);
var stream = req.body orelse return error.NoBody;
const body = try parseBody(Body, base_content_type orelse .json, stream.reader(), ctx.allocator);
@ -381,7 +432,7 @@ pub fn ParseBody(comptime Body: type) type {
return next.handle(
req,
res,
addFields(ctx, .{ .body = body }),
addField(ctx, "body", body),
{},
);
}
@ -391,13 +442,14 @@ pub fn ParseBody(comptime Body: type) type {
pub fn ParseQueryParams(comptime QueryParams: type) type {
return struct {
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
if (QueryParams == void) return next.handle(req, res, addField(ctx, "query_params", {}), {});
const query = try query_utils.parseQuery(ctx.allocator, QueryParams, ctx.query_string);
defer util.deepFree(ctx.allocator, query);
return next.handle(
req,
res,
addFields(ctx, .{ .query_params = query }),
addField(ctx, "query_params", query),
{},
);
}

View file

@ -10,9 +10,12 @@ pub const Response = struct {
alloc: std.mem.Allocator,
stream: Stream,
should_close: bool = false,
was_opened: bool = true,
pub const ResponseStream = response.ResponseStream(Stream.Writer);
pub fn open(self: *Response, status: http.Status, headers: *const http.Fields) !ResponseStream {
std.debug.assert(!self.was_opened);
self.was_opened = true;
if (headers.get("Connection")) |hdr| {
if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true;
}
@ -21,6 +24,8 @@ pub const Response = struct {
}
pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Fields) !Stream {
std.debug.assert(!self.was_opened);
self.was_opened = true;
try response.writeRequestHeader(self.stream.writer(), headers, status);
return self.stream;
}
@ -92,6 +97,7 @@ pub const Server = struct {
pub fn handleLoop(
self: *Server,
allocator: std.mem.Allocator,
initial_context: anytype,
handler: anytype,
) void {
while (true) {
@ -108,6 +114,7 @@ pub const Server = struct {
.stream = Stream{ .kind = .tcp, .socket = conn.stream.handle },
.address = conn.address,
},
initial_context,
handler,
);
}
@ -116,6 +123,7 @@ pub const Server = struct {
fn serveConn(
allocator: std.mem.Allocator,
conn: Connection,
initial_context: anytype,
handler: anytype,
) void {
while (true) {
@ -127,12 +135,16 @@ pub const Server = struct {
error.HttpVersionNotSupported => .http_version_not_supported,
else => blk: {
std.log.err("Unknown error parsing request: {}\n{?s}", .{ err, @errorReturnTrace() });
std.log.err("Unknown error parsing request: {}\n{?}", .{ err, @errorReturnTrace() });
break :blk .internal_server_error;
},
};
try conn.stream.writer().print("HTTP/1.1 {} {?s}\r\nConnection: close\r\n\r\n", .{ @enumToInt(status), status.phrase() });
conn.stream.writer().print(
"HTTP/1.1 {} {?s}\r\nConnection: close\r\n\r\n",
.{ @enumToInt(status), status.phrase() },
) catch {};
return;
};
var res = Response{
@ -140,8 +152,8 @@ pub const Server = struct {
.stream = conn.stream,
};
handler.handle(&req, &res, .{}, {}) catch |err| {
std.log.err("Unhandled error serving request: {}\n{?s}", .{ err, @errorReturnTrace() });
handler.handle(&req, &res, initial_context, {}) catch |err| {
std.log.err("Unhandled error serving request: {}\n{?}", .{ err, @errorReturnTrace() });
return;
};

View file

@ -23,18 +23,18 @@ const Opcode = enum(u4) {
}
};
pub fn handshake(alloc: std.mem.Allocator, req: *http.Request, res: *http.Response) !Socket {
const upgrade = req.headers.get("Upgrade") orelse return error.BadHandshake;
const connection = req.headers.get("Connection") orelse return error.BadHandshake;
pub fn handshake(alloc: std.mem.Allocator, req_headers: *const http.Fields, res: *http.Response) !Socket {
const upgrade = req_headers.get("Upgrade") orelse return error.BadHandshake;
const connection = req_headers.get("Connection") orelse return error.BadHandshake;
if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake;
if (std.ascii.indexOfIgnoreCase(connection, "Upgrade") == null) return error.BadHandshake;
const key_hdr = req.headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake;
const key_hdr = req_headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake;
if ((try std.base64.standard.Decoder.calcSizeForSlice(key_hdr)) != 16) return error.BadHandshake;
var key: [16]u8 = undefined;
std.base64.standard.Decoder.decode(&key, key_hdr) catch return error.BadHandshake;
const version = req.headers.get("Sec-WebSocket-version") orelse return error.BadHandshake;
const version = req_headers.get("Sec-WebSocket-version") orelse return error.BadHandshake;
if (!std.mem.eql(u8, "13", version)) return error.BadHandshake;
var headers = http.Fields.init(alloc);

View file

@ -60,7 +60,7 @@ const inject_api_conn = struct {
var api_conn = try getApiConn(ctx.allocator, ctx.api_source, req);
defer api_conn.close();
return mdw.injectContext(.{ .api_conn = &api_conn }).handle(
return mdw.injectContextValue("api_conn", &api_conn).handle(
req,
res,
ctx,
@ -85,19 +85,19 @@ pub fn EndpointRequest(comptime Endpoint: type) type {
body: Body,
query: Query,
const args_middleware = if (Args == void)
mdw.injectContext(.{ .args = {} })
else
const args_middleware = //if (Args == void)
//mdw.injectContext(.{ .args = {} })
//else
mdw.ParsePathArgs(Endpoint.path, Args){};
const body_middleware = if (Body == void)
mdw.injectContext(.{ .body = {} })
else
const body_middleware = //if (Body == void)
//mdw.injectContext(.{ .body = {} })
//else
mdw.ParseBody(Body){};
const query_middleware = if (Query == void)
mdw.injectContext(.{ .query_params = {} })
else
const query_middleware = //if (Query == void)
//mdw.injectContext(.{ .query_params = {} })
//else
mdw.ParseQueryParams(Query){};
};
}
@ -157,9 +157,9 @@ const api_router = mdw.apply(.{
pub const router = mdw.apply(.{
mdw.split_uri,
mdw.catchErrors(mdw.default_error_handler),
//mdw.router(.{api_router} ++ web_endpoints),
mdw.router(web_endpoints),
//mdw.catchErrors(mdw.default_error_handler),
mdw.router(.{api_router} ++ web_endpoints),
//mdw.router(web_endpoints),
//api_router,
});

View file

@ -10,20 +10,20 @@ const streaming = @import("./api/streaming.zig");
const timelines = @import("./api/timelines.zig");
pub const routes = .{
//controllers.apiEndpoint(auth.login),
//controllers.apiEndpoint(auth.verify_login),
//controllers.apiEndpoint(communities.create),
//controllers.apiEndpoint(communities.query),
//controllers.apiEndpoint(invites.create),
//controllers.apiEndpoint(users.create),
//controllers.apiEndpoint(notes.create),
//controllers.apiEndpoint(notes.get),
controllers.apiEndpoint(auth.login),
controllers.apiEndpoint(auth.verify_login),
controllers.apiEndpoint(communities.create),
controllers.apiEndpoint(communities.query),
controllers.apiEndpoint(invites.create),
controllers.apiEndpoint(users.create),
controllers.apiEndpoint(notes.create),
controllers.apiEndpoint(notes.get),
//controllers.apiEndpoint(streaming.streaming),
//controllers.apiEndpoint(timelines.global),
//controllers.apiEndpoint(timelines.local),
//controllers.apiEndpoint(timelines.home),
//controllers.apiEndpoint(follows.create),
//controllers.apiEndpoint(follows.delete),
//controllers.apiEndpoint(follows.query_followers),
//controllers.apiEndpoint(follows.query_following),
controllers.apiEndpoint(timelines.global),
controllers.apiEndpoint(timelines.local),
controllers.apiEndpoint(timelines.home),
controllers.apiEndpoint(follows.create),
controllers.apiEndpoint(follows.delete),
controllers.apiEndpoint(follows.query_followers),
controllers.apiEndpoint(follows.query_following),
};

View file

@ -2,11 +2,11 @@ const std = @import("std");
const controllers = @import("../controllers.zig");
pub const routes = .{
//controllers.apiEndpoint(index),
//controllers.apiEndpoint(about),
//controllers.apiEndpoint(login),
//controllers.apiEndpoint(global_timeline),
//controllers.apiEndpoint(cluster.overview),
controllers.apiEndpoint(index),
controllers.apiEndpoint(about),
controllers.apiEndpoint(login),
controllers.apiEndpoint(global_timeline),
controllers.apiEndpoint(cluster.overview),
};
const index = struct {

View file

@ -70,11 +70,7 @@ fn thread_main(src: *api.ApiSource, srv: *http.Server) void {
util.seedThreadPrng() catch unreachable;
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
srv.handleLoop(gpa.allocator(), .{ .src = src, .allocator = gpa.allocator() }, handle);
}
fn handle(ctx: anytype, req: *http.Request, res: *http.Response) void {
c.routeRequest(ctx.src, req, res, ctx.allocator);
srv.handleLoop(gpa.allocator(), .{ .api_source = src, .allocator = gpa.allocator() }, c.router);
}
pub fn main() !void {

View file

@ -40,7 +40,7 @@ fn handleUnexpectedError(db: *c.sqlite3, code: c_int, sql_text: ?[]const u8) err
std.log.debug("Failed at char ({}:{}) of SQL:\n{s}", .{ pos.row, pos.col, sql });
}
}
std.log.debug("{?s}", .{@errorReturnTrace()});
std.log.debug("{?}", .{@errorReturnTrace()});
return error.Unexpected;
}