Get api source from context

This commit is contained in:
jaina heartles 2022-11-24 03:50:25 -08:00
parent 051217cdaf
commit 503ab62607
3 changed files with 73 additions and 72 deletions

View file

@ -7,16 +7,10 @@ const util = @import("util");
const query_utils = @import("./query.zig"); const query_utils = @import("./query.zig");
const json_utils = @import("./json.zig"); const json_utils = @import("./json.zig");
pub const auth = @import("./controllers/api/auth.zig"); const web_controllers = @import("./controllers/web.zig");
pub const communities = @import("./controllers/api/communities.zig"); const api_controllers = @import("./controllers/api.zig");
pub const invites = @import("./controllers/api/invites.zig");
pub const users = @import("./controllers/api/users.zig");
pub const follows = @import("./controllers/api/users/follows.zig");
pub const notes = @import("./controllers/api/notes.zig");
pub const streaming = @import("./controllers/api/streaming.zig");
pub const timelines = @import("./controllers/api/timelines.zig");
const web = @import("./controllers/web.zig"); const routes = api_controllers ++ web_controllers.routes;
const mdw = http.middleware; const mdw = http.middleware;
@ -40,47 +34,44 @@ const not_found = struct {
const base_handler = mdw.SplitUri(mdw.CatchErrors(not_found, mdw.DefaultErrorHandler)); const base_handler = mdw.SplitUri(mdw.CatchErrors(not_found, mdw.DefaultErrorHandler));
fn InjectApiConn(comptime ApiSource: type) type { const inject_api_conn = struct {
return struct { fn getApiConn(alloc: std.mem.Allocator, api_source: anytype, req: anytype) !@TypeOf(api_source).Conn {
api_source: ApiSource, const host = req.headers.get("Host") orelse return error.NoHost;
fn getApiConn(self: @This(), alloc: std.mem.Allocator, req: anytype) !ApiSource.Conn { const auth_header = req.headers.get("Authorization");
const host = req.headers.get("Host") orelse return error.NoHost; const token = if (auth_header) |header| blk: {
const auth_header = req.headers.get("Authorization"); const prefix = "bearer ";
const token = if (auth_header) |header| blk: { if (header.len < prefix.len) break :blk null;
const prefix = "bearer "; if (!std.ascii.eqlIgnoreCase(prefix, header[0..prefix.len])) break :blk null;
if (header.len < prefix.len) break :blk null; break :blk header[prefix.len..];
if (!std.ascii.eqlIgnoreCase(prefix, header[0..prefix.len])) break :blk null; } else null;
break :blk header[prefix.len..];
} else null;
if (token) |t| return try self.api_source.connectToken(host, t, alloc); if (token) |t| return try api_source.connectToken(host, t, alloc);
if (req.headers.getCookie("active_account") catch return error.BadRequest) |account| { if (req.headers.getCookie("active_account") catch return error.BadRequest) |account| {
if (account.len + ("token.").len <= 64) { if (account.len + ("token.").len <= 64) {
var buf: [64]u8 = undefined; var buf: [64]u8 = undefined;
const cookie_name = std.fmt.bufPrint(&buf, "token.{s}", .{account}) catch unreachable; const cookie_name = std.fmt.bufPrint(&buf, "token.{s}", .{account}) catch unreachable;
if (try req.headers.getCookie(cookie_name)) |token_hdr| { if (try req.headers.getCookie(cookie_name)) |token_hdr| {
return try self.api_source.connectToken(host, token_hdr, alloc); return try api_source.connectToken(host, token_hdr, alloc);
} }
} else return error.InvalidCookie; } else return error.InvalidCookie;
}
return try self.api_source.connectUnauthorized(host, alloc);
} }
fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { return try api_source.connectUnauthorized(host, alloc);
var api_conn = try self.getApiConn(ctx.allocator, req); }
defer api_conn.close();
return next.handle( pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
req, var api_conn = try getApiConn(ctx.allocator, ctx.api_source, req);
res, defer api_conn.close();
mdw.injectContext(.{ .api_conn = &api_conn }),
{}, return next.handle(
); req,
} res,
}; mdw.injectContext(.{ .api_conn = &api_conn }),
} {},
);
}
};
pub fn EndpointRequest(comptime Endpoint: type) type { pub fn EndpointRequest(comptime Endpoint: type) type {
return struct { return struct {
@ -140,7 +131,6 @@ fn CallApiEndpoint(comptime Endpoint: type) type {
pub fn apiEndpoint( pub fn apiEndpoint(
comptime Endpoint: type, comptime Endpoint: type,
api_source: anytype,
) return_type: { ) return_type: {
const RequestType = EndpointRequest(Endpoint); const RequestType = EndpointRequest(Endpoint);
break :return_type mdw.Apply(std.meta.Tuple(.{ break :return_type mdw.Apply(std.meta.Tuple(.{
@ -149,7 +139,7 @@ pub fn apiEndpoint(
@TypeOf(RequestType.query_middleware), @TypeOf(RequestType.query_middleware),
@TypeOf(RequestType.body_middleware), @TypeOf(RequestType.body_middleware),
// TODO: allocation strategy // TODO: allocation strategy
InjectApiConn(@TypeOf(api_source)), @TypeOf(inject_api_conn),
CallApiEndpoint(Endpoint), CallApiEndpoint(Endpoint),
})); }));
} { } {
@ -160,7 +150,7 @@ pub fn apiEndpoint(
RequestType.query_middleware, RequestType.query_middleware,
RequestType.body_middleware, RequestType.body_middleware,
// TODO: allocation strategy // TODO: allocation strategy
InjectApiConn(@TypeOf(api_source)){ .api_source = api_source }, inject_api_conn,
CallApiEndpoint(Endpoint){}, CallApiEndpoint(Endpoint){},
}); });
} }
@ -177,25 +167,6 @@ pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response
//if (!found) response.status(.not_found) catch {}; //if (!found) response.status(.not_found) catch {};
} }
const routes = .{
auth.login,
auth.verify_login,
communities.create,
communities.query,
invites.create,
users.create,
notes.create,
notes.get,
streaming.streaming,
timelines.global,
timelines.local,
timelines.home,
follows.create,
follows.delete,
follows.query_followers,
follows.query_following,
} ++ web.routes;
pub const AllocationStrategy = enum { pub const AllocationStrategy = enum {
arena, arena,
normal, normal,

View file

@ -0,0 +1,29 @@
const controllers = @import("../controllers.zig");
const auth = @import("./api/auth.zig");
const communities = @import("./api/communities.zig");
const invites = @import("./api/invites.zig");
const users = @import("./api/users.zig");
const follows = @import("./api/users/follows.zig");
const notes = @import("./api/notes.zig");
const streaming = @import("./api/streaming.zig");
const timelines = @import("./api/timelines.zig");
pub const routes = .{
controllers.apiEndpoint(auth.login),
controllers.apiEndpoint(auth.verify_login),
controllers.apiEndpoint(communities.create),
controllers.apiEndpoint(communities.query),
controllers.apiEndpoint(invites.create),
controllers.apiEndpoint(users.create),
controllers.apiEndpoint(notes.create),
controllers.apiEndpoint(notes.get),
controllers.apiEndpoint(streaming.streaming),
controllers.apiEndpoint(timelines.global),
controllers.apiEndpoint(timelines.local),
controllers.apiEndpoint(timelines.home),
controllers.apiEndpoint(follows.create),
controllers.apiEndpoint(follows.delete),
controllers.apiEndpoint(follows.query_followers),
controllers.apiEndpoint(follows.query_following),
};

View file

@ -1,11 +1,12 @@
const std = @import("std"); const std = @import("std");
const controllers = @import("../controllers.zig");
pub const routes = .{ pub const routes = .{
index, controllers.apiEndpoint(index),
about, controllers.apiEndpoint(about),
login, controllers.apiEndpoint(login),
global_timeline, controllers.apiEndpoint(global_timeline),
cluster.overview, controllers.apiEndpoint(cluster.overview),
}; };
const index = struct { const index = struct {