diff --git a/src/main/controllers.zig b/src/main/controllers.zig index f162fc0..4470d75 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -7,18 +7,14 @@ const util = @import("util"); const query_utils = @import("./query.zig"); const json_utils = @import("./json.zig"); -pub const auth = @import("./controllers/api/auth.zig"); -pub const communities = @import("./controllers/api/communities.zig"); -pub const invites = @import("./controllers/api/invites.zig"); -pub const users = @import("./controllers/api/users.zig"); -pub const follows = @import("./controllers/api/users/follows.zig"); -pub const notes = @import("./controllers/api/notes.zig"); -pub const streaming = @import("./controllers/api/streaming.zig"); -pub const timelines = @import("./controllers/api/timelines.zig"); - -const web = struct { - const index = @import("./controllers/web/index.zig"); -}; +pub const auth = @import("./controllers/auth.zig"); +pub const communities = @import("./controllers/communities.zig"); +pub const invites = @import("./controllers/invites.zig"); +pub const users = @import("./controllers/users.zig"); +pub const follows = @import("./controllers/users/follows.zig"); +pub const notes = @import("./controllers/notes.zig"); +pub const streaming = @import("./controllers/streaming.zig"); +pub const timelines = @import("./controllers/timelines.zig"); pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response, alloc: std.mem.Allocator) void { // TODO: hashmaps? @@ -54,75 +50,8 @@ const routes = .{ follows.create, follows.query_followers, follows.query_following, - - web.index, }; -fn parseRouteArgs(comptime route: []const u8, comptime Args: type, path: []const u8) !Args { - var args: Args = undefined; - var path_iter = util.PathIter.from(path); - comptime var route_iter = util.PathIter.from(route); - inline while (comptime route_iter.next()) |route_segment| { - const path_segment = path_iter.next() orelse return error.RouteMismatch; - if (route_segment.len > 0 and route_segment[0] == ':') { - const A = @TypeOf(@field(args, route_segment[1..])); - @field(args, route_segment[1..]) = try parseRouteArg(A, path_segment); - } else { - if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; - } - } - - if (path_iter.next() != null) return error.RouteMismatch; - - return args; -} - -fn parseRouteArg(comptime T: type, segment: []const u8) !T { - if (T == []const u8) return segment; - if (comptime std.meta.trait.isContainer(T) and std.meta.trait.hasFn("parse")(T)) return T.parse(segment); - - @compileError("Unsupported Type " ++ @typeName(T)); -} - -const BaseContentType = enum { - json, - url_encoded, - octet_stream, - - other, -}; - -fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T { - const buf = try reader.readAllAlloc(alloc, 1 << 16); - defer alloc.free(buf); - - switch (content_type) { - .octet_stream, .json => { - const body = try json_utils.parse(T, buf, alloc); - defer json_utils.parseFree(body, alloc); - - return try util.deepClone(alloc, body); - }, - .url_encoded => return query_utils.parseQuery(alloc, T, buf) catch |err| switch (err) { - error.NoQuery => error.NoBody, - else => err, - }, - else => return error.UnsupportedMediaType, - } -} - -fn matchContentType(hdr: ?[]const u8) ?BaseContentType { - if (hdr) |h| { - if (std.ascii.eqlIgnoreCase(h, "application/x-www-form-urlencoded")) return .url_encoded; - if (std.ascii.eqlIgnoreCase(h, "application/json")) return .json; - if (std.ascii.eqlIgnoreCase(h, "application/octet-stream")) return .octet_stream; - - return .other; - } - - return null; -} - pub fn Context(comptime Route: type) type { return struct { const Self = @This(); @@ -152,61 +81,38 @@ pub fn Context(comptime Route: type) type { // TODO body_buf: ?[]const u8 = null, + fn parseArgs(path: []const u8) ?Args { + var args: Args = undefined; + var path_iter = util.PathIter.from(path); + comptime var route_iter = util.PathIter.from(Route.path); + inline while (comptime route_iter.next()) |route_segment| { + const path_segment = path_iter.next() orelse return null; + if (route_segment[0] == ':') { + const A = @TypeOf(@field(args, route_segment[1..])); + @field(args, route_segment[1..]) = parseArg(A, path_segment) catch return null; + } else { + if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return null; + } + } + + if (path_iter.next() != null) return null; + + return args; + } + + fn parseArg(comptime T: type, segment: []const u8) !T { + if (T == []const u8) return segment; + if (comptime std.meta.trait.hasFn("parse")(T)) return T.parse(segment); + + @compileError("Unsupported Type " ++ @typeName(T)); + } + pub fn matchAndHandle(api_source: *api.ApiSource, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool { if (req.method != Route.method) return false; var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?'); - var args = parseRouteArgs(Route.path, Args, path) catch return false; + var args: Args = parseArgs(path) orelse return false; - std.log.debug("Matched route {s}", .{Route.path}); - - handle(api_source, req, res, alloc, args) catch |err| { - std.log.err("{}", .{err}); - if (!res.opened) res.err(.internal_server_error, "", {}) catch {}; - }; - - return true; - } - - fn handle( - api_source: *api.ApiSource, - req: *http.Request, - res: *Response, - alloc: std.mem.Allocator, - args: Args, - ) !void { - const base_content_type = matchContentType(req.headers.get("Content-Type")); - - const body = if (Body != void) blk: { - var stream = req.body orelse return error.NoBody; - break :blk try parseBody(Body, base_content_type orelse .json, stream.reader(), alloc); - } else {}; - defer if (Body != void) util.deepFree(alloc, body); - - const query = if (Query != void) blk: { - const path = std.mem.sliceTo(req.uri, '?'); - const q = req.uri[path.len..]; - - break :blk try query_utils.parseQuery(alloc, Query, q); - }; - defer if (Query != void) util.deepFree(alloc, query); - - var api_conn = conn: { - const host = req.headers.get("Host") orelse return error.NoHost; - const auth_header = req.headers.get("Authorization"); - const token = if (auth_header) |header| blk: { - const prefix = "bearer "; - if (header.len < prefix.len) break :blk null; - if (!std.ascii.eqlIgnoreCase(prefix, header[0..prefix.len])) break :blk null; - break :blk header[prefix.len..]; - } else null; - - if (token) |t| break :conn try api_source.connectToken(host, t, alloc); - - break :conn try api_source.connectUnauthorized(host, alloc); - }; - defer api_conn.close(); - - const self = Self{ + var self = Self{ .allocator = alloc, .base_request = req, @@ -215,11 +121,15 @@ pub fn Context(comptime Route: type) type { .headers = req.headers, .args = args, - .body = body, - .query = query, + .body = undefined, + .query = undefined, }; - try Route.handler(self, res, &api_conn); + std.log.debug("Matched route {s}", .{path}); + + self.prepareAndHandle(api_source, req, res); + + return true; } fn errorHandler(response: *Response, status: http.Status, err: anytype) void { @@ -233,6 +143,68 @@ pub fn Context(comptime Route: type) type { std.log.err("Error printing response: {}", .{err2}); }; } + + fn prepareAndHandle(self: *Self, api_source: anytype, req: *http.Request, response: *Response) void { + self.parseBody(req) catch |err| return errorHandler(response, .bad_request, err); + defer self.freeBody(); + + self.parseQuery() catch |err| return errorHandler(response, .bad_request, err); + + var api_conn = self.getApiConn(api_source) catch |err| return errorHandler(response, .internal_server_error, err); + defer api_conn.close(); + + self.handle(response, &api_conn); + } + + fn parseBody(self: *Self, req: *http.Request) !void { + if (Body != void) { + var stream = req.body orelse return error.NoBody; + const body = try stream.reader().readAllAlloc(self.allocator, 1 << 16); + errdefer self.allocator.free(body); + self.body = try json_utils.parse(Body, body, self.allocator); + self.body_buf = body; + } + } + + fn freeBody(self: *Self) void { + if (Body != void) { + json_utils.parseFree(self.body, self.allocator); + self.allocator.free(self.body_buf.?); + } + } + + fn parseQuery(self: *Self) !void { + if (Query != void) { + const path = std.mem.sliceTo(self.uri, '?'); + const q = std.mem.sliceTo(self.uri[path.len..], '#'); + + self.query = try query_utils.parseQuery(Query, q); + } + } + + fn handle(self: Self, response: *Response, api_conn: anytype) void { + Route.handler(self, response, api_conn) catch |err| switch (err) { + else => { + std.log.err("{}", .{err}); + if (!response.opened) response.err(.internal_server_error, "", {}) catch {}; + }, + }; + } + + fn getApiConn(self: *Self, api_source: anytype) !api.ApiSource.Conn { + const host = self.headers.get("Host") orelse return error.NoHost; + const auth_header = self.headers.get("Authorization"); + const token = if (auth_header) |header| blk: { + const prefix = "bearer "; + if (header.len < prefix.len) break :blk null; + if (!std.ascii.eqlIgnoreCase(prefix, header[0..prefix.len])) break :blk null; + break :blk header[prefix.len..]; + } else null; + + if (token) |t| return try api_source.connectToken(host, t, self.allocator); + + return try api_source.connectUnauthorized(host, self.allocator); + } }; } @@ -244,16 +216,22 @@ pub const Response = struct { /// Write a response with no body, only a given status pub fn status(self: *Self, status_code: http.Status) !void { - var stream = try self.open(status_code); + std.debug.assert(!self.opened); + self.opened = true; + + var stream = try self.res.open(status_code, &self.headers); defer stream.close(); try stream.finish(); } /// Write a request body as json pub fn json(self: *Self, status_code: http.Status, response_body: anytype) !void { + std.debug.assert(!self.opened); + self.opened = true; + try self.headers.put("Content-Type", "application/json"); - var stream = try self.open(status_code); + var stream = try self.res.open(status_code, &self.headers); defer stream.close(); const writer = stream.writer(); @@ -262,13 +240,6 @@ pub const Response = struct { try stream.finish(); } - pub fn open(self: *Self, status_code: http.Status) !http.Response.Stream { - std.debug.assert(!self.opened); - self.opened = true; - - return try self.res.open(status_code, &self.headers); - } - /// Prints the given error as json pub fn err(self: *Self, status_code: http.Status, message: []const u8, details: anytype) !void { return self.json(status_code, .{ @@ -280,7 +251,6 @@ pub const Response = struct { /// Signals that the HTTP connection should be hijacked without writing a /// response beforehand. pub fn hijack(self: *Self) *http.Response { - std.debug.assert(!self.opened); self.opened = true; return self.res; } diff --git a/src/main/controllers/api/auth.zig b/src/main/controllers/auth.zig similarity index 100% rename from src/main/controllers/api/auth.zig rename to src/main/controllers/auth.zig diff --git a/src/main/controllers/api/communities.zig b/src/main/controllers/communities.zig similarity index 77% rename from src/main/controllers/api/communities.zig rename to src/main/controllers/communities.zig index f6f475d..89ed8a5 100644 --- a/src/main/controllers/api/communities.zig +++ b/src/main/controllers/communities.zig @@ -1,7 +1,12 @@ +const std = @import("std"); const api = @import("api"); -const controller_utils = @import("../../controllers.zig").helpers; +const util = @import("util"); +const query_utils = @import("../query.zig"); +const controller_utils = @import("../controllers.zig").helpers; const QueryArgs = api.CommunityQueryArgs; +const Uuid = util.Uuid; +const DateTime = util.DateTime; pub const create = struct { pub const method = .POST; diff --git a/src/main/controllers/api/invites.zig b/src/main/controllers/invites.zig similarity index 100% rename from src/main/controllers/api/invites.zig rename to src/main/controllers/invites.zig diff --git a/src/main/controllers/api/notes.zig b/src/main/controllers/notes.zig similarity index 100% rename from src/main/controllers/api/notes.zig rename to src/main/controllers/notes.zig diff --git a/src/main/controllers/api/streaming.zig b/src/main/controllers/streaming.zig similarity index 100% rename from src/main/controllers/api/streaming.zig rename to src/main/controllers/streaming.zig diff --git a/src/main/controllers/api/timelines.zig b/src/main/controllers/timelines.zig similarity index 91% rename from src/main/controllers/api/timelines.zig rename to src/main/controllers/timelines.zig index 8c30cc1..2d5ace6 100644 --- a/src/main/controllers/api/timelines.zig +++ b/src/main/controllers/timelines.zig @@ -1,6 +1,7 @@ const std = @import("std"); const api = @import("api"); -const controller_utils = @import("../../controllers.zig").helpers; +const query_utils = @import("../query.zig"); +const controller_utils = @import("../controllers.zig").helpers; pub const global = struct { pub const method = .GET; diff --git a/src/main/controllers/api/users.zig b/src/main/controllers/users.zig similarity index 100% rename from src/main/controllers/api/users.zig rename to src/main/controllers/users.zig diff --git a/src/main/controllers/api/users/follows.zig b/src/main/controllers/users/follows.zig similarity index 94% rename from src/main/controllers/api/users/follows.zig rename to src/main/controllers/users/follows.zig index 5ee4699..dcb36b9 100644 --- a/src/main/controllers/api/users/follows.zig +++ b/src/main/controllers/users/follows.zig @@ -1,6 +1,6 @@ const api = @import("api"); const util = @import("util"); -const controller_utils = @import("../../../controllers.zig").helpers; +const controller_utils = @import("../../controllers.zig").helpers; const Uuid = util.Uuid; diff --git a/src/main/controllers/web/index.fmt.html b/src/main/controllers/web/index.fmt.html deleted file mode 100644 index 15b263f..0000000 --- a/src/main/controllers/web/index.fmt.html +++ /dev/null @@ -1,25 +0,0 @@ - - - - - {[community_name]s} - - -
-

{[community_name]s}

- Cluster Admin pseudocommunity -
-
-

Login

- - - -
- - diff --git a/src/main/controllers/web/index.zig b/src/main/controllers/web/index.zig deleted file mode 100644 index 33f09b3..0000000 --- a/src/main/controllers/web/index.zig +++ /dev/null @@ -1,20 +0,0 @@ -const std = @import("std"); - -pub const path = "/"; -pub const method = .GET; - -pub fn handler(_: anytype, res: anytype, srv: anytype) !void { - try res.headers.put("Content-Type", "text/html"); - - var stream = try res.open(.ok); - defer stream.close(); - - try std.fmt.format(stream.writer(), template, .{ - .community_name = srv.community.name, - .community_host = srv.community.host, - }); - - try stream.finish(); -} - -const template = @embedFile("./index.fmt.html"); diff --git a/src/main/query.zig b/src/main/query.zig index 1933429..a51a7ab 100644 --- a/src/main/query.zig +++ b/src/main/query.zig @@ -70,7 +70,7 @@ const QueryIter = @import("util").QueryIter; /// /// TODO: values are currently case-sensitive, and are not url-decoded properly. /// This should be fixed. -pub fn parseQuery(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T { +pub fn parseQuery(comptime T: type, query: []const u8) !T { if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct"); var iter = QueryIter.from(query); @@ -85,54 +85,27 @@ pub fn parseQuery(alloc: std.mem.Allocator, comptime T: type, query: []const u8) } else std.log.debug("unknown param {s}", .{pair.key}); } - return (try parse(alloc, T, "", "", fields)) orelse error.NoQuery; + return (try parse(T, "", "", fields)).?; } -fn decodeString(alloc: std.mem.Allocator, val: []const u8) ![]const u8 { - var list = try std.ArrayList(u8).initCapacity(alloc, val.len); - errdefer list.deinit(); - - var idx: usize = 0; - while (idx < val.len) : (idx += 1) { - if (val[idx] != '%') { - try list.append(val[idx]); - } else { - if (val.len < idx + 2) return error.InvalidEscape; - const buf = [2]u8{ val[idx + 1], val[idx + 2] }; - idx += 2; - - const ch = try std.fmt.parseInt(u8, &buf, 16); - try list.append(ch); - } - } - - return list.toOwnedSlice(); -} - -fn parseScalar(alloc: std.mem.Allocator, comptime T: type, comptime name: []const u8, fields: anytype) !?T { +fn parseScalar(comptime T: type, comptime name: []const u8, fields: anytype) !?T { const param = @field(fields, name); return switch (param) { .not_specified => null, - .no_value => try parseQueryValue(alloc, T, null), - .value => |v| try parseQueryValue(alloc, T, v), + .no_value => try parseQueryValue(T, null), + .value => |v| try parseQueryValue(T, v), }; } -fn parse( - alloc: std.mem.Allocator, - comptime T: type, - comptime prefix: []const u8, - comptime name: []const u8, - fields: anytype, -) !?T { - if (comptime isScalar(T)) return parseScalar(alloc, T, prefix ++ "." ++ name, fields); +fn parse(comptime T: type, comptime prefix: []const u8, comptime name: []const u8, fields: anytype) !?T { + if (comptime isScalar(T)) return parseScalar(T, prefix ++ "." ++ name, fields); switch (@typeInfo(T)) { .Union => |info| { var result: ?T = null; inline for (info.fields) |field| { const F = field.field_type; - const maybe_value = try parse(alloc, F, prefix, field.name, fields); + const maybe_value = try parse(F, prefix, field.name, fields); if (maybe_value) |value| { if (result != null) return error.DuplicateUnionField; @@ -151,7 +124,7 @@ fn parse( const F = field.field_type; var maybe_value: ?F = null; - if (try parse(alloc, F, prefix ++ "." ++ name, field.name, fields)) |v| { + if (try parse(F, prefix ++ "." ++ name, field.name, fields)) |v| { maybe_value = v; } else if (field.default_value) |default| { if (comptime @sizeOf(F) != 0) { @@ -178,7 +151,7 @@ fn parse( }, // Only applies to non-scalar optionals - .Optional => |info| return try parse(alloc, info.child, prefix, name, fields), + .Optional => |info| return try parse(info.child, prefix, name, fields), else => @compileError("tmp"), } @@ -231,7 +204,7 @@ fn Intermediary(comptime T: type) type { } }); } -fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, value: ?[]const u8) !T { +fn parseQueryValue(comptime T: type, value: ?[]const u8) !T { const is_optional = comptime std.meta.trait.is(.Optional)(T); // If param is present, but without an associated value if (value == null) { @@ -243,7 +216,7 @@ fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, value: ?[]const u error.InvalidValue; } - return try parseQueryValueNotNull(alloc, if (is_optional) std.meta.Child(T) else T, value.?); + return try parseQueryValueNotNull(if (is_optional) std.meta.Child(T) else T, value.?); } const bool_map = std.ComptimeStringMap(bool, .{ @@ -260,27 +233,15 @@ const bool_map = std.ComptimeStringMap(bool, .{ .{ "0", false }, }); -fn parseQueryValueNotNull(alloc: std.mem.Allocator, comptime T: type, value: []const u8) !T { - const decoded = try decodeString(alloc, value); - errdefer alloc.free(decoded); +fn parseQueryValueNotNull(comptime T: type, value: []const u8) !T { + if (comptime std.meta.trait.isZigString(T)) return value; + if (comptime std.meta.trait.isIntegral(T)) return try std.fmt.parseInt(T, value, 0); + if (comptime std.meta.trait.isFloat(T)) return try std.fmt.parseFloat(T, value); + if (comptime std.meta.trait.is(.Enum)(T)) return std.meta.stringToEnum(T, value) orelse error.InvalidEnumValue; + if (T == bool) return bool_map.get(value) orelse error.InvalidBool; + if (comptime std.meta.trait.hasFn("parse")(T)) return try T.parse(value); - if (comptime std.meta.trait.isZigString(T)) return decoded; - - const result = if (comptime std.meta.trait.isIntegral(T)) - try std.fmt.parseInt(T, decoded, 0) - else if (comptime std.meta.trait.isFloat(T)) - try std.fmt.parseFloat(T, decoded) - else if (comptime std.meta.trait.is(.Enum)(T)) - std.meta.stringToEnum(T, decoded) orelse return error.InvalidEnumValue - else if (T == bool) - bool_map.get(value) orelse return error.InvalidBool - else if (comptime std.meta.trait.hasFn("parse")(T)) - try T.parse(value) - else - @compileError("Invalid type " ++ @typeName(T)); - - alloc.free(decoded); - return result; + @compileError("Invalid type " ++ @typeName(T)); } fn isScalar(comptime T: type) bool { @@ -300,34 +261,14 @@ pub fn formatQuery(params: anytype, writer: anytype) !void { try format("", "", params, writer); } -fn urlFormatString(writer: anytype, val: []const u8) !void { - for (val) |ch| { - const printable = switch (ch) { - '0'...'9', 'a'...'z', 'A'...'Z' => true, - '-', '.', '_', '~', ':', '@', '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=' => true, - else => false, - }; - - try if (printable) writer.writeByte(ch) else std.fmt.format(writer, "%{x:0>2}", .{ch}); - } -} - fn formatScalar(comptime name: []const u8, val: anytype, writer: anytype) !void { const T = @TypeOf(val); - if (comptime std.meta.trait.is(.Optional)(T)) { - return if (val) |v| formatScalar(name, v, writer) else {}; - } - - try urlFormatString(writer, name); - try writer.writeByte('='); - if (comptime std.meta.trait.isZigString(T)) { - try urlFormatString(writer, val); - } else try switch (@typeInfo(T)) { - .Enum => urlFormatString(writer, @tagName(val)), - else => std.fmt.format(writer, "{}", .{val}), + if (comptime std.meta.trait.isZigString(T)) return std.fmt.format(writer, "{s}={s}&", .{ name, val }); + _ = try switch (@typeInfo(T)) { + .Enum => std.fmt.format(writer, "{s}={s}&", .{ name, @tagName(val) }), + .Optional => if (val) |v| formatScalar(name, v, writer), + else => std.fmt.format(writer, "{s}={}&", .{ name, val }), }; - - try writer.writeByte('&'); } fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytype, writer: anytype) !void { diff --git a/src/util/lib.zig b/src/util/lib.zig index 9a66dc8..ae5b6ce 100644 --- a/src/util/lib.zig +++ b/src/util/lib.zig @@ -99,13 +99,7 @@ pub fn deepFree(alloc: ?std.mem.Allocator, val: anytype) void { }, .Optional => if (val) |v| deepFree(alloc, v) else {}, .Struct => |struct_info| inline for (struct_info.fields) |field| deepFree(alloc, @field(val, field.name)), - .Union => |union_info| inline for (union_info.fields) |field| { - const tag = @field(std.meta.Tag(T), field.name); - if (@as(std.meta.Tag(T), val) == tag) { - deepFree(alloc, @field(val, field.name)); - } - }, - .ErrorUnion => if (val) |v| deepFree(alloc, v) else {}, + .Union, .ErrorUnion => @compileError("TODO: Unions not yet supported by deepFree"), .Array => for (val) |v| deepFree(alloc, v), .Enum, .Int, .Float, .Bool, .Void, .Type => {},