diff --git a/build.zig b/build.zig index f68bbd5..2017afc 100644 --- a/build.zig +++ b/build.zig @@ -82,7 +82,7 @@ pub fn build(b: *std.build.Builder) !void { const pkgs = makePkgs(b, options.getPackage("build_options")); - const exe = b.addExecutable("fediglam", "src/main/main.zig"); + const exe = b.addExecutable("apub", "src/main/main.zig"); exe.setTarget(target); exe.setBuildMode(mode); @@ -96,7 +96,6 @@ pub fn build(b: *std.build.Builder) !void { if (enable_sqlite) exe.linkSystemLibrary("sqlite3"); if (enable_postgres) exe.linkSystemLibrary("pq"); exe.linkLibC(); - exe.addSystemIncludePath("/usr/include/"); //const util_tests = b.addTest("src/util/lib.zig"); const http_tests = b.addTest("src/http/test.zig"); diff --git a/src/api/services/files.zig b/src/api/services/files.zig deleted file mode 100644 index 18c0e9d..0000000 --- a/src/api/services/files.zig +++ /dev/null @@ -1,69 +0,0 @@ -const std = @import("std"); -const util = @import("util"); - -const Uuid = util.Uuid; -const DateTime = util.DateTime; - -pub const FileOwner = union(enum) { - user_id: Uuid, - community_id: Uuid, -}; - -pub const DriveFile = struct { - id: Uuid, - filename: []const u8, - owner: FileOwner, - size: usize, - created_at: DateTime, -}; - -pub const files = struct { - pub fn create(db: anytype, owner: FileOwner, filename: []const u8, data: []const u8, alloc: std.mem.Allocator) !void { - const id = Uuid.randV4(util.getThreadPrng()); - const now = DateTime.now(); - - // TODO: assert we're not in a transaction - db.insert("drive_file", .{ - .id = id, - .filename = filename, - .owner = owner, - .created_at = now, - }, alloc) catch return error.DatabaseFailure; - // Assume the previous statement succeeded and is not stuck in a transaction - errdefer { - db.exec("DELETE FROM drive_file WHERE ID = $1", .{id}, alloc) catch |err| { - std.log.err("Unable to remove file record in DB: {}", .{err}); - }; - } - - try saveFile(id, data); - } - - const data_root = "./files"; - fn saveFile(id: Uuid, data: []const u8) !void { - var dir = try std.fs.cwd().openDir(data_root); - defer dir.close(); - - var file = try dir.createFile(id.toCharArray(), .{ .exclusive = true }); - defer file.close(); - - try file.writer().writeAll(data); - try file.sync(); - } - - pub fn deref(alloc: std.mem.Allocator, id: Uuid) ![]const u8 { - var dir = try std.fs.cwd().openDir(data_root); - defer dir.close(); - - return dir.readFileAlloc(alloc, id.toCharArray(), 1 << 32); - } - - pub fn delete(db: anytype, alloc: std.mem.Allocator, id: Uuid) !void { - var dir = try std.fs.cwd().openDir(data_root); - defer dir.close(); - - try dir.deleteFile(id.toCharArray()); - - db.exec("DELETE FROM drive_file WHERE ID = $1", .{id}, alloc) catch return error.DatabaseFailure; - } -}; diff --git a/src/http/json.zig b/src/http/json.zig deleted file mode 100644 index 21474cc..0000000 --- a/src/http/json.zig +++ /dev/null @@ -1,677 +0,0 @@ -const std = @import("std"); -const mem = std.mem; -const Allocator = std.mem.Allocator; -const assert = std.debug.assert; - -// This file is largely a copy of std.json - -const StreamingParser = std.json.StreamingParser; -const Token = std.json.Token; -const unescapeValidString = std.json.unescapeValidString; -const UnescapeValidStringError = std.json.UnescapeValidStringError; - -pub fn parse(comptime T: type, body: []const u8, alloc: std.mem.Allocator) !T { - var tokens = TokenStream.init(body); - - const options = ParseOptions{ .allocator = alloc }; - - const token = (try tokens.next()) orelse return error.UnexpectedEndOfJson; - const r = try parseInternal(T, token, &tokens, options); - errdefer parseFreeInternal(T, r, options); - if (!options.allow_trailing_data) { - if ((try tokens.next()) != null) unreachable; - assert(tokens.i >= tokens.slice.len); - } - return r; -} - -pub fn parseFree(value: anytype, alloc: std.mem.Allocator) void { - parseFreeInternal(@TypeOf(value), value, .{ .allocator = alloc }); -} - -// WARNING: the objects "parse" method must not contain a reference to the original value -fn hasCustomParse(comptime T: type) bool { - if (!std.meta.trait.hasFn("parse")(T)) return false; - if (!@hasDecl(T, "JsonParseAs")) return false; - - return true; -} - -///// The rest is (modified) from std.json - -/// A small wrapper over a StreamingParser for full slices. Returns a stream of json Tokens. -pub const TokenStream = struct { - i: usize, - slice: []const u8, - parser: StreamingParser, - token: ?Token, - - pub const Error = StreamingParser.Error || error{UnexpectedEndOfJson}; - - pub fn init(slice: []const u8) TokenStream { - return TokenStream{ - .i = 0, - .slice = slice, - .parser = StreamingParser.init(), - .token = null, - }; - } - - fn stackUsed(self: *TokenStream) usize { - return self.parser.stack.len + if (self.token != null) @as(usize, 1) else 0; - } - - pub fn next(self: *TokenStream) Error!?Token { - if (self.token) |token| { - self.token = null; - return token; - } - - var t1: ?Token = undefined; - var t2: ?Token = undefined; - - while (self.i < self.slice.len) { - try self.parser.feed(self.slice[self.i], &t1, &t2); - self.i += 1; - - if (t1) |token| { - self.token = t2; - return token; - } - } - - // Without this a bare number fails, the streaming parser doesn't know the input ended - try self.parser.feed(' ', &t1, &t2); - self.i += 1; - - if (t1) |token| { - return token; - } else if (self.parser.complete) { - return null; - } else { - return error.UnexpectedEndOfJson; - } - } -}; - -/// Checks to see if a string matches what it would be as a json-encoded string -/// Assumes that `encoded` is a well-formed json string -fn encodesTo(decoded: []const u8, encoded: []const u8) bool { - var i: usize = 0; - var j: usize = 0; - while (i < decoded.len) { - if (j >= encoded.len) return false; - if (encoded[j] != '\\') { - if (decoded[i] != encoded[j]) return false; - j += 1; - i += 1; - } else { - const escape_type = encoded[j + 1]; - if (escape_type != 'u') { - const t: u8 = switch (escape_type) { - '\\' => '\\', - '/' => '/', - 'n' => '\n', - 'r' => '\r', - 't' => '\t', - 'f' => 12, - 'b' => 8, - '"' => '"', - else => unreachable, - }; - if (decoded[i] != t) return false; - j += 2; - i += 1; - } else { - var codepoint = std.fmt.parseInt(u21, encoded[j + 2 .. j + 6], 16) catch unreachable; - j += 6; - if (codepoint >= 0xD800 and codepoint < 0xDC00) { - // surrogate pair - assert(encoded[j] == '\\'); - assert(encoded[j + 1] == 'u'); - const low_surrogate = std.fmt.parseInt(u21, encoded[j + 2 .. j + 6], 16) catch unreachable; - codepoint = 0x10000 + (((codepoint & 0x03ff) << 10) | (low_surrogate & 0x03ff)); - j += 6; - } - var buf: [4]u8 = undefined; - const len = std.unicode.utf8Encode(codepoint, &buf) catch unreachable; - if (i + len > decoded.len) return false; - if (!mem.eql(u8, decoded[i .. i + len], buf[0..len])) return false; - i += len; - } - } - } - assert(i == decoded.len); - assert(j == encoded.len); - return true; -} - -/// parse tokens from a stream, returning `false` if they do not decode to `value` -fn parsesTo(comptime T: type, value: T, tokens: *TokenStream, options: ParseOptions) !bool { - // TODO: should be able to write this function to not require an allocator - const tmp = try parse(T, tokens, options); - defer parseFree(T, tmp, options); - - return parsedEqual(tmp, value); -} - -/// Returns if a value returned by `parse` is deep-equal to another value -fn parsedEqual(a: anytype, b: @TypeOf(a)) bool { - switch (@typeInfo(@TypeOf(a))) { - .Optional => { - if (a == null and b == null) return true; - if (a == null or b == null) return false; - return parsedEqual(a.?, b.?); - }, - .Union => |info| { - if (info.tag_type) |UnionTag| { - const tag_a = std.meta.activeTag(a); - const tag_b = std.meta.activeTag(b); - if (tag_a != tag_b) return false; - - inline for (info.fields) |field_info| { - if (@field(UnionTag, field_info.name) == tag_a) { - return parsedEqual(@field(a, field_info.name), @field(b, field_info.name)); - } - } - return false; - } else { - unreachable; - } - }, - .Array => { - for (a) |e, i| - if (!parsedEqual(e, b[i])) return false; - return true; - }, - .Struct => |info| { - inline for (info.fields) |field_info| { - if (!parsedEqual(@field(a, field_info.name), @field(b, field_info.name))) return false; - } - return true; - }, - .Pointer => |ptrInfo| switch (ptrInfo.size) { - .One => return parsedEqual(a.*, b.*), - .Slice => { - if (a.len != b.len) return false; - for (a) |e, i| - if (!parsedEqual(e, b[i])) return false; - return true; - }, - .Many, .C => unreachable, - }, - else => return a == b, - } - unreachable; -} - -const ParseOptions = struct { - allocator: ?Allocator = null, - - /// Behaviour when a duplicate field is encountered. - duplicate_field_behavior: enum { - UseFirst, - Error, - UseLast, - } = .Error, - - /// If false, finding an unknown field returns an error. - ignore_unknown_fields: bool = false, - - allow_trailing_data: bool = false, -}; - -const SkipValueError = error{UnexpectedJsonDepth} || TokenStream.Error; - -fn skipValue(tokens: *TokenStream) SkipValueError!void { - const original_depth = tokens.stackUsed(); - - // Return an error if no value is found - _ = try tokens.next(); - if (tokens.stackUsed() < original_depth) return error.UnexpectedJsonDepth; - if (tokens.stackUsed() == original_depth) return; - - while (try tokens.next()) |_| { - if (tokens.stackUsed() == original_depth) return; - } -} - -fn ParseInternalError(comptime T: type) type { - // `inferred_types` is used to avoid infinite recursion for recursive type definitions. - const inferred_types = [_]type{}; - return ParseInternalErrorImpl(T, &inferred_types); -} - -fn ParseInternalErrorImpl(comptime T: type, comptime inferred_types: []const type) type { - if (hasCustomParse(T)) { - return ParseInternalError(T.JsonParseAs) || T.ParseError; - } - for (inferred_types) |ty| { - if (T == ty) return error{}; - } - - switch (@typeInfo(T)) { - .Bool => return error{UnexpectedToken}, - .Float, .ComptimeFloat => return error{UnexpectedToken} || std.fmt.ParseFloatError, - .Int, .ComptimeInt => { - return error{ UnexpectedToken, InvalidNumber, Overflow } || - std.fmt.ParseIntError || std.fmt.ParseFloatError; - }, - .Optional => |optionalInfo| { - return ParseInternalErrorImpl(optionalInfo.child, inferred_types ++ [_]type{T}); - }, - .Enum => return error{ UnexpectedToken, InvalidEnumTag } || std.fmt.ParseIntError || - std.meta.IntToEnumError || std.meta.IntToEnumError, - .Union => |unionInfo| { - if (unionInfo.tag_type) |_| { - var errors = error{NoUnionMembersMatched}; - for (unionInfo.fields) |u_field| { - errors = errors || ParseInternalErrorImpl(u_field.field_type, inferred_types ++ [_]type{T}); - } - return errors; - } else { - @compileError("Unable to parse into untagged union '" ++ @typeName(T) ++ "'"); - } - }, - .Struct => |structInfo| { - var errors = error{ - DuplicateJSONField, - UnexpectedEndOfJson, - UnexpectedToken, - UnexpectedValue, - UnknownField, - MissingField, - } || SkipValueError || TokenStream.Error; - for (structInfo.fields) |field| { - errors = errors || ParseInternalErrorImpl(field.field_type, inferred_types ++ [_]type{T}); - } - return errors; - }, - .Array => |arrayInfo| { - return error{ UnexpectedEndOfJson, UnexpectedToken } || TokenStream.Error || - UnescapeValidStringError || - ParseInternalErrorImpl(arrayInfo.child, inferred_types ++ [_]type{T}); - }, - .Pointer => |ptrInfo| { - var errors = error{AllocatorRequired} || std.mem.Allocator.Error; - switch (ptrInfo.size) { - .One => { - return errors || ParseInternalErrorImpl(ptrInfo.child, inferred_types ++ [_]type{T}); - }, - .Slice => { - return errors || error{ UnexpectedEndOfJson, UnexpectedToken } || - ParseInternalErrorImpl(ptrInfo.child, inferred_types ++ [_]type{T}) || - UnescapeValidStringError || TokenStream.Error; - }, - else => @compileError("Unable to parse into type '" ++ @typeName(T) ++ "'"), - } - }, - else => return error{}, - } - unreachable; -} - -fn parseInternal( - comptime T: type, - token: Token, - tokens: *TokenStream, - options: ParseOptions, -) ParseInternalError(T)!T { - if (comptime hasCustomParse(T)) { - const val = try parseInternal(T.JsonParseAs, token, tokens, options); - defer parseFreeInternal(T.JsonParseAs, val, options); - return try T.parse(val); - } - - switch (@typeInfo(T)) { - .Bool => { - return switch (token) { - .True => true, - .False => false, - else => error.UnexpectedToken, - }; - }, - .Float, .ComptimeFloat => { - switch (token) { - .Number => |numberToken| return try std.fmt.parseFloat(T, numberToken.slice(tokens.slice, tokens.i - 1)), - .String => |stringToken| return try std.fmt.parseFloat(T, stringToken.slice(tokens.slice, tokens.i - 1)), - else => return error.UnexpectedToken, - } - }, - .Int, .ComptimeInt => { - switch (token) { - .Number => |numberToken| { - if (numberToken.is_integer) - return try std.fmt.parseInt(T, numberToken.slice(tokens.slice, tokens.i - 1), 10); - const float = try std.fmt.parseFloat(f128, numberToken.slice(tokens.slice, tokens.i - 1)); - if (@round(float) != float) return error.InvalidNumber; - if (float > std.math.maxInt(T) or float < std.math.minInt(T)) return error.Overflow; - return @floatToInt(T, float); - }, - .String => |stringToken| { - return std.fmt.parseInt(T, stringToken.slice(tokens.slice, tokens.i - 1), 10) catch |err| { - switch (err) { - error.Overflow => return err, - error.InvalidCharacter => { - const float = try std.fmt.parseFloat(f128, stringToken.slice(tokens.slice, tokens.i - 1)); - if (@round(float) != float) return error.InvalidNumber; - if (float > std.math.maxInt(T) or float < std.math.minInt(T)) return error.Overflow; - return @floatToInt(T, float); - }, - } - }; - }, - else => return error.UnexpectedToken, - } - }, - .Optional => |optionalInfo| { - if (token == .Null) { - return null; - } else { - return try parseInternal(optionalInfo.child, token, tokens, options); - } - }, - .Enum => |enumInfo| { - switch (token) { - .Number => |numberToken| { - if (!numberToken.is_integer) return error.UnexpectedToken; - const n = try std.fmt.parseInt(enumInfo.tag_type, numberToken.slice(tokens.slice, tokens.i - 1), 10); - return try std.meta.intToEnum(T, n); - }, - .String => |stringToken| { - const source_slice = stringToken.slice(tokens.slice, tokens.i - 1); - switch (stringToken.escapes) { - .None => return std.meta.stringToEnum(T, source_slice) orelse return error.InvalidEnumTag, - .Some => { - inline for (enumInfo.fields) |field| { - if (field.name.len == stringToken.decodedLength() and encodesTo(field.name, source_slice)) { - return @field(T, field.name); - } - } - return error.InvalidEnumTag; - }, - } - }, - else => return error.UnexpectedToken, - } - }, - .Union => |unionInfo| { - if (unionInfo.tag_type) |_| { - // try each of the union fields until we find one that matches - inline for (unionInfo.fields) |u_field| { - // take a copy of tokens so we can withhold mutations until success - var tokens_copy = tokens.*; - if (parseInternal(u_field.field_type, token, &tokens_copy, options)) |value| { - tokens.* = tokens_copy; - return @unionInit(T, u_field.name, value); - } else |err| { - // Bubble up error.OutOfMemory - // Parsing some types won't have OutOfMemory in their - // error-sets, for the condition to be valid, merge it in. - if (@as(@TypeOf(err) || error{OutOfMemory}, err) == error.OutOfMemory) return err; - // Bubble up AllocatorRequired, as it indicates missing option - if (@as(@TypeOf(err) || error{AllocatorRequired}, err) == error.AllocatorRequired) return err; - // otherwise continue through the `inline for` - } - } - return error.NoUnionMembersMatched; - } else { - @compileError("Unable to parse into untagged union '" ++ @typeName(T) ++ "'"); - } - }, - .Struct => |structInfo| { - switch (token) { - .ObjectBegin => {}, - else => return error.UnexpectedToken, - } - var r: T = undefined; - var fields_seen = [_]bool{false} ** structInfo.fields.len; - errdefer { - inline for (structInfo.fields) |field, i| { - if (fields_seen[i] and !field.is_comptime) { - parseFreeInternal(field.field_type, @field(r, field.name), options); - } - } - } - - while (true) { - switch ((try tokens.next()) orelse return error.UnexpectedEndOfJson) { - .ObjectEnd => break, - .String => |stringToken| { - const key_source_slice = stringToken.slice(tokens.slice, tokens.i - 1); - var child_options = options; - child_options.allow_trailing_data = true; - var found = false; - inline for (structInfo.fields) |field, i| { - // TODO: using switches here segfault the compiler (#2727?) - if ((stringToken.escapes == .None and mem.eql(u8, field.name, key_source_slice)) or (stringToken.escapes == .Some and (field.name.len == stringToken.decodedLength() and encodesTo(field.name, key_source_slice)))) { - // if (switch (stringToken.escapes) { - // .None => mem.eql(u8, field.name, key_source_slice), - // .Some => (field.name.len == stringToken.decodedLength() and encodesTo(field.name, key_source_slice)), - // }) { - if (fields_seen[i]) { - // switch (options.duplicate_field_behavior) { - // .UseFirst => {}, - // .Error => {}, - // .UseLast => {}, - // } - if (options.duplicate_field_behavior == .UseFirst) { - // unconditonally ignore value. for comptime fields, this skips check against default_value - const next_token = (try tokens.next()) orelse return error.UnexpectedEndOfJson; - parseFreeInternal(field.field_type, try parseInternal(field.field_type, next_token, tokens, child_options), child_options); - found = true; - break; - } else if (options.duplicate_field_behavior == .Error) { - return error.DuplicateJSONField; - } else if (options.duplicate_field_behavior == .UseLast) { - if (!field.is_comptime) { - parseFreeInternal(field.field_type, @field(r, field.name), child_options); - } - fields_seen[i] = false; - } - } - if (field.is_comptime) { - if (!try parsesTo(field.field_type, @ptrCast(*const field.field_type, field.default_value.?).*, tokens, child_options)) { - return error.UnexpectedValue; - } - } else { - const next_token = (try tokens.next()) orelse return error.UnexpectedEndOfJson; - @field(r, field.name) = try parseInternal(field.field_type, next_token, tokens, child_options); - } - fields_seen[i] = true; - found = true; - break; - } - } - if (!found) { - if (options.ignore_unknown_fields) { - try skipValue(tokens); - continue; - } else { - return error.UnknownField; - } - } - }, - else => return error.UnexpectedToken, - } - } - inline for (structInfo.fields) |field, i| { - if (!fields_seen[i]) { - if (field.default_value) |default_ptr| { - if (!field.is_comptime) { - const default = @ptrCast(*align(1) const field.field_type, default_ptr).*; - @field(r, field.name) = default; - } - } else { - return error.MissingField; - } - } - } - return r; - }, - .Array => |arrayInfo| { - switch (token) { - .ArrayBegin => { - var r: T = undefined; - var i: usize = 0; - var child_options = options; - child_options.allow_trailing_data = true; - errdefer { - // Without the r.len check `r[i]` is not allowed - if (r.len > 0) while (true) : (i -= 1) { - parseFreeInternal(arrayInfo.child, r[i], options); - if (i == 0) break; - }; - } - while (i < r.len) : (i += 1) { - const next_token = (try tokens.next()) orelse return error.UnexpectedEndOfJson; - r[i] = try parseInternal(arrayInfo.child, next_token, tokens, child_options); - } - const tok = (try tokens.next()) orelse return error.UnexpectedEndOfJson; - switch (tok) { - .ArrayEnd => {}, - else => return error.UnexpectedToken, - } - return r; - }, - .String => |stringToken| { - if (arrayInfo.child != u8) return error.UnexpectedToken; - var r: T = undefined; - const source_slice = stringToken.slice(tokens.slice, tokens.i - 1); - switch (stringToken.escapes) { - .None => mem.copy(u8, &r, source_slice), - .Some => try unescapeValidString(&r, source_slice), - } - return r; - }, - else => return error.UnexpectedToken, - } - }, - .Pointer => |ptrInfo| { - const allocator = options.allocator orelse return error.AllocatorRequired; - switch (ptrInfo.size) { - .One => { - const r: T = try allocator.create(ptrInfo.child); - errdefer allocator.destroy(r); - r.* = try parseInternal(ptrInfo.child, token, tokens, options); - return r; - }, - .Slice => { - switch (token) { - .ArrayBegin => { - var arraylist = std.ArrayList(ptrInfo.child).init(allocator); - errdefer { - while (arraylist.popOrNull()) |v| { - parseFreeInternal(ptrInfo.child, v, options); - } - arraylist.deinit(); - } - - while (true) { - const tok = (try tokens.next()) orelse return error.UnexpectedEndOfJson; - switch (tok) { - .ArrayEnd => break, - else => {}, - } - - try arraylist.ensureUnusedCapacity(1); - const v = try parseInternal(ptrInfo.child, tok, tokens, options); - arraylist.appendAssumeCapacity(v); - } - - if (ptrInfo.sentinel) |some| { - const sentinel_value = @ptrCast(*const ptrInfo.child, some).*; - try arraylist.append(sentinel_value); - const output = arraylist.toOwnedSlice(); - return output[0 .. output.len - 1 :sentinel_value]; - } - - return arraylist.toOwnedSlice(); - }, - .String => |stringToken| { - if (ptrInfo.child != u8) return error.UnexpectedToken; - const source_slice = stringToken.slice(tokens.slice, tokens.i - 1); - const len = stringToken.decodedLength(); - const output = try allocator.alloc(u8, len + @boolToInt(ptrInfo.sentinel != null)); - errdefer allocator.free(output); - switch (stringToken.escapes) { - .None => mem.copy(u8, output, source_slice), - .Some => try unescapeValidString(output, source_slice), - } - - if (ptrInfo.sentinel) |some| { - const char = @ptrCast(*const u8, some).*; - output[len] = char; - return output[0..len :char]; - } - - return output; - }, - else => return error.UnexpectedToken, - } - }, - else => @compileError("Unable to parse into type '" ++ @typeName(T) ++ "'"), - } - }, - else => @compileError("Unable to parse into type '" ++ @typeName(T) ++ "'"), - } - unreachable; -} - -fn ParseError(comptime T: type) type { - return ParseInternalError(T) || error{UnexpectedEndOfJson} || TokenStream.Error; -} - -/// Releases resources created by `parse`. -/// Should be called with the same type and `ParseOptions` that were passed to `parse` -fn parseFreeInternal(comptime T: type, value: T, options: ParseOptions) void { - switch (@typeInfo(T)) { - .Bool, .Float, .ComptimeFloat, .Int, .ComptimeInt, .Enum => {}, - .Optional => { - if (value) |v| { - return parseFreeInternal(@TypeOf(v), v, options); - } - }, - .Union => |unionInfo| { - if (unionInfo.tag_type) |UnionTagType| { - inline for (unionInfo.fields) |u_field| { - if (value == @field(UnionTagType, u_field.name)) { - parseFreeInternal(u_field.field_type, @field(value, u_field.name), options); - break; - } - } - } else { - unreachable; - } - }, - .Struct => |structInfo| { - inline for (structInfo.fields) |field| { - if (!field.is_comptime) { - parseFreeInternal(field.field_type, @field(value, field.name), options); - } - } - }, - .Array => |arrayInfo| { - for (value) |v| { - parseFreeInternal(arrayInfo.child, v, options); - } - }, - .Pointer => |ptrInfo| { - const allocator = options.allocator orelse unreachable; - switch (ptrInfo.size) { - .One => { - parseFreeInternal(ptrInfo.child, value.*, options); - allocator.destroy(value); - }, - .Slice => { - for (value) |v| { - parseFreeInternal(ptrInfo.child, v, options); - } - allocator.free(value); - }, - else => unreachable, - } - }, - else => unreachable, - } -} diff --git a/src/http/lib.zig b/src/http/lib.zig index 9a4d8c9..26f7756 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -15,8 +15,6 @@ pub const Response = server.Response; pub const Handler = server.Handler; pub const Server = server.Server; -pub const middleware = @import("./middleware.zig"); - pub const Fields = @import("./headers.zig").Fields; pub const Protocol = enum { diff --git a/src/http/middleware.zig b/src/http/middleware.zig deleted file mode 100644 index 12ddcf3..0000000 --- a/src/http/middleware.zig +++ /dev/null @@ -1,408 +0,0 @@ -const std = @import("std"); -const root = @import("root"); -const builtin = @import("builtin"); -const http = @import("./lib.zig"); -const util = @import("util"); -const query_utils = @import("./query.zig"); -const json_utils = @import("./json.zig"); -fn AddUniqueField(comptime Lhs: type, comptime N: usize, comptime name: [N]u8, comptime Val: type) type { - const Ctx = @Type(.{ .Struct = .{ - .layout = .Auto, - .fields = std.meta.fields(Lhs) ++ &[_]std.builtin.Type.StructField{ - .{ - .name = &name, - .field_type = Val, - .alignment = if (@sizeOf(Val) != 0) @alignOf(Val) else 0, - .default_value = null, - .is_comptime = false, - }, - }, - .decls = &.{}, - .is_tuple = false, - } }); - return Ctx; -} - -fn AddField(comptime Lhs: type, comptime name: []const u8, comptime Val: type) type { - return AddUniqueField(Lhs, name.len, name[0..].*, Val); -} - -fn addField(lhs: anytype, comptime name: []const u8, val: anytype) AddField(@TypeOf(lhs), name, @TypeOf(val)) { - var result: AddField(@TypeOf(lhs), name, @TypeOf(val)) = undefined; - inline for (std.meta.fields(@TypeOf(lhs))) |f| @field(result, f.name) = @field(lhs, f.name); - @field(result, name) = val; - return result; -} - -test { - // apply is a plumbing function that applies a tuple of middlewares in order - const base = apply(.{ - split_uri, - mount("/abc"), - }); - - const request = .{ .uri = "/abc/defg/hijkl?some_query=true#section" }; - const response = .{}; - const initial_context = .{}; - try base.handle(request, response, initial_context, {}); -} - -fn ApplyInternal(comptime fields: []const std.builtin.Type.StructField) type { - if (fields.len == 0) return void; - - return NextHandler( - fields[0].field_type, - ApplyInternal(fields[1..]), - ); -} - -fn applyInternal(middlewares: anytype, comptime fields: []const std.builtin.Type.StructField) ApplyInternal(fields) { - if (fields.len == 0) return {}; - return .{ - .first = @field(middlewares, fields[0].name), - .next = applyInternal(middlewares, fields[1..]), - }; -} - -pub fn apply(middlewares: anytype) Apply(@TypeOf(middlewares)) { - return applyInternal(middlewares, std.meta.fields(@TypeOf(middlewares))); -} - -pub fn Apply(comptime Middlewares: type) type { - return ApplyInternal(std.meta.fields(Middlewares)); -} - -pub fn InjectContextValue(comptime name: []const u8, comptime V: type) type { - return struct { - val: V, - pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - return next.handle(req, res, addField(ctx, name, self.val), {}); - } - }; -} - -pub fn injectContextValue(comptime name: []const u8, val: anytype) InjectContextValue(name, @TypeOf(val)) { - return .{ .val = val }; -} - -pub fn NextHandler(comptime First: type, comptime Next: type) type { - return struct { - first: First, - next: Next, - - pub fn handle( - self: @This(), - req: anytype, - res: anytype, - ctx: anytype, - next: void, - ) !void { - _ = next; - return self.first.handle(req, res, ctx, self.next); - } - }; -} - -pub fn CatchErrors(comptime ErrorHandler: type) type { - return struct { - error_handler: ErrorHandler, - pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - return next.handle(req, res, ctx, {}) catch |err| { - return self.error_handler.handle( - req, - res, - addField(ctx, "err", err), - next, - ); - }; - } - }; -} -pub fn catchErrors(error_handler: anytype) CatchErrors(@TypeOf(error_handler)) { - return .{ .error_handler = error_handler }; -} - -pub const default_error_handler = struct { - fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - _ = next; - std.log.err("Error {} on uri {s}", .{ ctx.err, req.uri }); - - // Tell the server to close the connection after this request - res.should_close = true; - - var buf: [1024]u8 = undefined; - var fba = std.heap.FixedBufferAllocator.init(&buf); - var headers = http.Fields.init(fba.allocator()); - if (!res.was_opened) { - var stream = res.open(.internal_server_error, &headers) catch return; - defer stream.close(); - stream.finish() catch {}; - } - } -}{}; - -pub const split_uri = struct { - pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - var frag_split = std.mem.split(u8, req.uri, "#"); - const without_fragment = frag_split.first(); - const fragment = frag_split.rest(); - - var query_split = std.mem.split(u8, without_fragment, "?"); - const path = query_split.first(); - const query = query_split.rest(); - - const new_ctx = addField( - addField( - addField(ctx, "path", path), - "query_string", - query, - ), - "fragment_string", - fragment, - ); - - return next.handle( - req, - res, - new_ctx, - {}, - ); - } -}{}; - -// routes a request to the correct handler based on declared HTTP method and path -pub fn Router(comptime Routes: type) type { - return struct { - routes: Routes, - - pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: void) !void { - _ = next; - - inline for (self.routes) |r| { - if (r.handle(req, res, ctx, {})) |_| - // success - return - else |err| switch (err) { - error.RouteMismatch => {}, - else => return err, - } - } - - return error.RouteMismatch; - } - }; -} -pub fn router(routes: anytype) Router(@TypeOf(routes)) { - return Router(@TypeOf(routes)){ .routes = routes }; -} - -// helper function for doing route analysis -fn pathMatches(route: []const u8, path: []const u8) bool { - var path_iter = util.PathIter.from(path); - var route_iter = util.PathIter.from(route); - while (route_iter.next()) |route_segment| { - const path_segment = path_iter.next() orelse return false; - if (route_segment.len > 0 and route_segment[0] == ':') { - // Route Argument - } else { - if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return false; - } - } - if (path_iter.next() != null) return false; - - return true; -} -pub const Route = struct { - pub const Desc = struct { - path: []const u8, - method: http.Method, - }; - - desc: Desc, - - fn applies(self: @This(), req: anytype, ctx: anytype) bool { - if (self.desc.method != req.method) return false; - - const eff_path = if (@hasField(@TypeOf(ctx), "path")) - ctx.path - else - std.mem.sliceTo(req.uri, '?'); - - return pathMatches(self.desc.path, eff_path); - } - - pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - std.log.debug("Testing path {s} against {s}", .{ ctx.path, self.desc.path }); - return if (self.applies(req, ctx)) - next.handle(req, res, ctx, {}) - else - error.RouteMismatch; - } -}; - -pub fn Mount(comptime route: []const u8) type { - return struct { - pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - var path_iter = util.PathIter.from(ctx.path); - comptime var route_iter = util.PathIter.from(route); - var path_unused = ctx.path; - - inline while (comptime route_iter.next()) |route_segment| { - if (comptime route_segment.len == 0) continue; - const path_segment = path_iter.next() orelse return error.RouteMismatch; - path_unused = path_iter.rest(); - if (comptime route_segment[0] == ':') { - @compileLog("Argument segments cannot be mounted"); - // Route Argument - } else { - if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; - } - } - - var new_ctx = ctx; - new_ctx.path = path_unused; - return next.handle(req, res, new_ctx, {}); - } - }; -} -pub fn mount(comptime route: []const u8) Mount(route) { - return .{}; -} - -pub fn HandleNotFound(comptime NotFoundHandler: type) type { - return struct { - not_found: NotFoundHandler, - - pub fn handler(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - return next.handler(req, res, ctx, {}) catch |err| switch (err) { - error.RouteMismatch => return self.not_found.handler(req, res, ctx, {}), - else => return err, - }; - } - }; -} - -fn parsePathArgs(comptime route: []const u8, comptime Args: type, path: []const u8) !Args { - var args: Args = undefined; - var path_iter = util.PathIter.from(path); - comptime var route_iter = util.PathIter.from(route); - inline while (comptime route_iter.next()) |route_segment| { - const path_segment = path_iter.next() orelse return error.RouteMismatch; - if (route_segment.len > 0 and route_segment[0] == ':') { - // route segment is an argument segment - const A = @TypeOf(@field(args, route_segment[1..])); - @field(args, route_segment[1..]) = try parsePathArg(A, path_segment); - } else { - if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; - } - } - - if (path_iter.next() != null) return error.RouteMismatch; - - return args; -} - -fn parsePathArg(comptime T: type, segment: []const u8) !T { - if (T == []const u8) return segment; - if (comptime std.meta.trait.isContainer(T) and std.meta.trait.hasFn("parse")(T)) return T.parse(segment); - - @compileError("Unsupported Type " ++ @typeName(T)); -} - -pub fn ParsePathArgs(comptime route: []const u8, comptime Args: type) type { - return struct { - pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - if (Args == void) return next.handle(req, res, addField(ctx, "args", {}), {}); - return next.handle( - req, - res, - addField(ctx, "args", try parsePathArgs(route, Args, ctx.path)), - {}, - ); - } - }; -} - -const BaseContentType = enum { - json, - url_encoded, - octet_stream, - - other, -}; - -fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T { - //@compileLog(T); - const buf = try reader.readAllAlloc(alloc, 1 << 16); - defer alloc.free(buf); - - switch (content_type) { - .octet_stream, .json => { - const body = try json_utils.parse(T, buf, alloc); - defer json_utils.parseFree(body, alloc); - - return try util.deepClone(alloc, body); - }, - .url_encoded => return query_utils.parseQuery(alloc, T, buf) catch |err| switch (err) { - error.NoQuery => error.NoBody, - else => err, - }, - else => return error.UnsupportedMediaType, - } -} - -fn matchContentType(hdr: ?[]const u8) ?BaseContentType { - if (hdr) |h| { - if (std.ascii.eqlIgnoreCase(h, "application/x-www-form-urlencoded")) return .url_encoded; - if (std.ascii.eqlIgnoreCase(h, "application/json")) return .json; - if (std.ascii.eqlIgnoreCase(h, "application/octet-stream")) return .octet_stream; - - return .other; - } - - return null; -} - -pub fn ParseBody(comptime Body: type) type { - return struct { - pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - const content_type = req.headers.get("Content-Type"); - if (Body == void) { - if (content_type != null) return error.UnexpectedBody; - const new_ctx = addField(ctx, "body", {}); - //if (true) @compileError("bug"); - return next.handle(req, res, new_ctx, {}); - } - - const base_content_type = matchContentType(content_type); - - var stream = req.body orelse return error.NoBody; - const body = try parseBody(Body, base_content_type orelse .json, stream.reader(), ctx.allocator); - defer util.deepFree(ctx.allocator, body); - - return next.handle( - req, - res, - addField(ctx, "body", body), - {}, - ); - } - }; -} - -pub fn ParseQueryParams(comptime QueryParams: type) type { - return struct { - pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - if (QueryParams == void) return next.handle(req, res, addField(ctx, "query_params", {}), {}); - const query = try query_utils.parseQuery(ctx.allocator, QueryParams, ctx.query_string); - defer util.deepFree(ctx.allocator, query); - - return next.handle( - req, - res, - addField(ctx, "query_params", query), - {}, - ); - } - }; -} diff --git a/src/http/query.zig b/src/http/query.zig deleted file mode 100644 index 1933429..0000000 --- a/src/http/query.zig +++ /dev/null @@ -1,380 +0,0 @@ -const std = @import("std"); - -const QueryIter = @import("util").QueryIter; - -/// Parses a set of query parameters described by the struct `T`. -/// -/// To specify query parameters, provide a struct similar to the following: -/// ``` -/// struct { -/// foo: bool = false, -/// bar: ?[]const u8 = null, -/// baz: usize = 10, -/// qux: enum { quux, snap } = .quux, -/// } -/// ``` -/// -/// This will allow it to parse a query string like the following: -/// `?foo&bar=abc&qux=snap` -/// -/// Every parameter must have a default value that will be used when the -/// parameter is not provided, and parameter keys. -/// Numbers are parsed from their string representations, and a parameter -/// provided in the query string without a value is parsed either as a bool -/// `true` flag or as `null` depending on the type of its param. -/// -/// Parameter types supported: -/// - []const u8 -/// - numbers (both integer and float) -/// + Numbers are parsed in base 10 -/// - bool -/// + See below for detals -/// - exhaustive enums -/// + Enums are treated as strings with values equal to the enum fields -/// - ?F (where isScalar(F) and F != bool) -/// - Any type that implements: -/// + pub fn parse([]const u8) !F -/// -/// Boolean Parameters: -/// The following query strings will all parse a `true` value for the -/// parameter `foo: bool = false`: -/// - `?foo` -/// - `?foo=true` -/// - `?foo=t` -/// - `?foo=yes` -/// - `?foo=y` -/// - `?foo=1` -/// And the following query strings all parse a `false` value: -/// - `?` -/// - `?foo=false` -/// - `?foo=f` -/// - `?foo=no` -/// - `?foo=n` -/// - `?foo=0` -/// -/// Compound Types: -/// Compound (struct) types are also supported, with the parameter key -/// for its parameters consisting of the struct's field + '.' + parameter -/// field. For example: -/// ``` -/// struct { -/// foo: struct { -/// baz: usize = 0, -/// } = .{}, -/// } -/// ``` -/// Would be used to parse a query string like -/// `?foo.baz=12345` -/// -/// Compound types cannot currently be nullable, and must be structs. -/// -/// TODO: values are currently case-sensitive, and are not url-decoded properly. -/// This should be fixed. -pub fn parseQuery(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T { - if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct"); - var iter = QueryIter.from(query); - - var fields = Intermediary(T){}; - while (iter.next()) |pair| { - // TODO: Hash map - inline for (std.meta.fields(Intermediary(T))) |field| { - if (std.ascii.eqlIgnoreCase(field.name[2..], pair.key)) { - @field(fields, field.name) = if (pair.value) |v| .{ .value = v } else .{ .no_value = {} }; - break; - } - } else std.log.debug("unknown param {s}", .{pair.key}); - } - - return (try parse(alloc, T, "", "", fields)) orelse error.NoQuery; -} - -fn decodeString(alloc: std.mem.Allocator, val: []const u8) ![]const u8 { - var list = try std.ArrayList(u8).initCapacity(alloc, val.len); - errdefer list.deinit(); - - var idx: usize = 0; - while (idx < val.len) : (idx += 1) { - if (val[idx] != '%') { - try list.append(val[idx]); - } else { - if (val.len < idx + 2) return error.InvalidEscape; - const buf = [2]u8{ val[idx + 1], val[idx + 2] }; - idx += 2; - - const ch = try std.fmt.parseInt(u8, &buf, 16); - try list.append(ch); - } - } - - return list.toOwnedSlice(); -} - -fn parseScalar(alloc: std.mem.Allocator, comptime T: type, comptime name: []const u8, fields: anytype) !?T { - const param = @field(fields, name); - return switch (param) { - .not_specified => null, - .no_value => try parseQueryValue(alloc, T, null), - .value => |v| try parseQueryValue(alloc, T, v), - }; -} - -fn parse( - alloc: std.mem.Allocator, - comptime T: type, - comptime prefix: []const u8, - comptime name: []const u8, - fields: anytype, -) !?T { - if (comptime isScalar(T)) return parseScalar(alloc, T, prefix ++ "." ++ name, fields); - switch (@typeInfo(T)) { - .Union => |info| { - var result: ?T = null; - inline for (info.fields) |field| { - const F = field.field_type; - - const maybe_value = try parse(alloc, F, prefix, field.name, fields); - if (maybe_value) |value| { - if (result != null) return error.DuplicateUnionField; - - result = @unionInit(T, field.name, value); - } - } - std.log.debug("{any}", .{result}); - return result; - }, - - .Struct => |info| { - var result: T = undefined; - var fields_specified: usize = 0; - - inline for (info.fields) |field| { - const F = field.field_type; - - var maybe_value: ?F = null; - if (try parse(alloc, F, prefix ++ "." ++ name, field.name, fields)) |v| { - maybe_value = v; - } else if (field.default_value) |default| { - if (comptime @sizeOf(F) != 0) { - maybe_value = @ptrCast(*const F, @alignCast(@alignOf(F), default)).*; - } else { - maybe_value = std.mem.zeroes(F); - } - } - - if (maybe_value) |v| { - fields_specified += 1; - @field(result, field.name) = v; - } - } - - if (fields_specified == 0) { - return null; - } else if (fields_specified != info.fields.len) { - std.log.debug("{} {s} {s}", .{ T, prefix, name }); - return error.PartiallySpecifiedStruct; - } else { - return result; - } - }, - - // Only applies to non-scalar optionals - .Optional => |info| return try parse(alloc, info.child, prefix, name, fields), - - else => @compileError("tmp"), - } -} - -fn recursiveFieldPaths(comptime T: type, comptime prefix: []const u8) []const []const u8 { - comptime { - if (std.meta.trait.is(.Optional)(T)) return recursiveFieldPaths(std.meta.Child(T), prefix); - - var fields: []const []const u8 = &.{}; - - for (std.meta.fields(T)) |f| { - const full_name = prefix ++ f.name; - - if (isScalar(f.field_type)) { - fields = fields ++ @as([]const []const u8, &.{full_name}); - } else { - const field_prefix = if (@typeInfo(f.field_type) == .Union) prefix else full_name ++ "."; - fields = fields ++ recursiveFieldPaths(f.field_type, field_prefix); - } - } - - return fields; - } -} - -const QueryParam = union(enum) { - not_specified: void, - no_value: void, - value: []const u8, -}; - -fn Intermediary(comptime T: type) type { - const field_names = recursiveFieldPaths(T, ".."); - - var fields: [field_names.len]std.builtin.Type.StructField = undefined; - for (field_names) |name, i| fields[i] = .{ - .name = name, - .field_type = QueryParam, - .default_value = &QueryParam{ .not_specified = {} }, - .is_comptime = false, - .alignment = @alignOf(QueryParam), - }; - - return @Type(.{ .Struct = .{ - .layout = .Auto, - .fields = &fields, - .decls = &.{}, - .is_tuple = false, - } }); -} - -fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, value: ?[]const u8) !T { - const is_optional = comptime std.meta.trait.is(.Optional)(T); - // If param is present, but without an associated value - if (value == null) { - return if (is_optional) - null - else if (T == bool) - true - else - error.InvalidValue; - } - - return try parseQueryValueNotNull(alloc, if (is_optional) std.meta.Child(T) else T, value.?); -} - -const bool_map = std.ComptimeStringMap(bool, .{ - .{ "true", true }, - .{ "t", true }, - .{ "yes", true }, - .{ "y", true }, - .{ "1", true }, - - .{ "false", false }, - .{ "f", false }, - .{ "no", false }, - .{ "n", false }, - .{ "0", false }, -}); - -fn parseQueryValueNotNull(alloc: std.mem.Allocator, comptime T: type, value: []const u8) !T { - const decoded = try decodeString(alloc, value); - errdefer alloc.free(decoded); - - if (comptime std.meta.trait.isZigString(T)) return decoded; - - const result = if (comptime std.meta.trait.isIntegral(T)) - try std.fmt.parseInt(T, decoded, 0) - else if (comptime std.meta.trait.isFloat(T)) - try std.fmt.parseFloat(T, decoded) - else if (comptime std.meta.trait.is(.Enum)(T)) - std.meta.stringToEnum(T, decoded) orelse return error.InvalidEnumValue - else if (T == bool) - bool_map.get(value) orelse return error.InvalidBool - else if (comptime std.meta.trait.hasFn("parse")(T)) - try T.parse(value) - else - @compileError("Invalid type " ++ @typeName(T)); - - alloc.free(decoded); - return result; -} - -fn isScalar(comptime T: type) bool { - if (comptime std.meta.trait.isZigString(T)) return true; - if (comptime std.meta.trait.isIntegral(T)) return true; - if (comptime std.meta.trait.isFloat(T)) return true; - if (comptime std.meta.trait.is(.Enum)(T)) return true; - if (T == bool) return true; - if (comptime std.meta.trait.hasFn("parse")(T)) return true; - - if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true; - - return false; -} - -pub fn formatQuery(params: anytype, writer: anytype) !void { - try format("", "", params, writer); -} - -fn urlFormatString(writer: anytype, val: []const u8) !void { - for (val) |ch| { - const printable = switch (ch) { - '0'...'9', 'a'...'z', 'A'...'Z' => true, - '-', '.', '_', '~', ':', '@', '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=' => true, - else => false, - }; - - try if (printable) writer.writeByte(ch) else std.fmt.format(writer, "%{x:0>2}", .{ch}); - } -} - -fn formatScalar(comptime name: []const u8, val: anytype, writer: anytype) !void { - const T = @TypeOf(val); - if (comptime std.meta.trait.is(.Optional)(T)) { - return if (val) |v| formatScalar(name, v, writer) else {}; - } - - try urlFormatString(writer, name); - try writer.writeByte('='); - if (comptime std.meta.trait.isZigString(T)) { - try urlFormatString(writer, val); - } else try switch (@typeInfo(T)) { - .Enum => urlFormatString(writer, @tagName(val)), - else => std.fmt.format(writer, "{}", .{val}), - }; - - try writer.writeByte('&'); -} - -fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytype, writer: anytype) !void { - const T = @TypeOf(params); - const eff_prefix = if (prefix.len == 0) "" else prefix ++ "."; - if (comptime isScalar(T)) return formatScalar(eff_prefix ++ name, params, writer); - - switch (@typeInfo(T)) { - .Struct => { - inline for (std.meta.fields(T)) |field| { - const val = @field(params, field.name); - try format(eff_prefix ++ name, field.name, val, writer); - } - }, - .Union => { - inline for (std.meta.fields(T)) |field| { - const tag = @field(std.meta.Tag(T), field.name); - const tag_name = field.name; - if (@as(std.meta.Tag(T), params) == tag) { - const val = @field(params, tag_name); - try format(prefix, tag_name, val, writer); - } - } - }, - .Optional => { - if (params) |p| try format(prefix, name, p, writer); - }, - else => @compileError("Unsupported query type"), - } -} - -test { - const TestQuery = struct { - int: usize = 3, - boolean: bool = false, - str_enum: ?enum { foo, bar } = null, - }; - - try std.testing.expectEqual(TestQuery{ - .int = 3, - .boolean = false, - .str_enum = null, - }, try parseQuery(TestQuery, "")); - - try std.testing.expectEqual(TestQuery{ - .int = 5, - .boolean = true, - .str_enum = .foo, - }, try parseQuery(TestQuery, "?int=5&boolean=yes&str_enum=foo")); -} diff --git a/src/http/server.zig b/src/http/server.zig index ae38c0a..b24ae75 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -10,12 +10,9 @@ pub const Response = struct { alloc: std.mem.Allocator, stream: Stream, should_close: bool = false, - was_opened: bool = false, pub const ResponseStream = response.ResponseStream(Stream.Writer); pub fn open(self: *Response, status: http.Status, headers: *const http.Fields) !ResponseStream { - std.debug.assert(!self.was_opened); - self.was_opened = true; if (headers.get("Connection")) |hdr| { if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true; } @@ -24,8 +21,6 @@ pub const Response = struct { } pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Fields) !Stream { - std.debug.assert(!self.was_opened); - self.was_opened = true; try response.writeRequestHeader(self.stream.writer(), headers, status); return self.stream; } @@ -97,7 +92,7 @@ pub const Server = struct { pub fn handleLoop( self: *Server, allocator: std.mem.Allocator, - initial_context: anytype, + ctx: anytype, handler: anytype, ) void { while (true) { @@ -114,7 +109,7 @@ pub const Server = struct { .stream = Stream{ .kind = .tcp, .socket = conn.stream.handle }, .address = conn.address, }, - initial_context, + ctx, handler, ); } @@ -123,29 +118,12 @@ pub const Server = struct { fn serveConn( allocator: std.mem.Allocator, conn: Connection, - initial_context: anytype, + ctx: anytype, handler: anytype, ) void { - defer conn.stream.close(); while (true) { var req = request.parse(allocator, conn.stream.reader()) catch |err| { - const status: http.Status = switch (err) { - error.EndOfStream => return, // Do nothing, the client closed the connection - error.BadRequest => .bad_request, - error.UnsupportedMediaType => .unsupported_media_type, - error.HttpVersionNotSupported => .http_version_not_supported, - - else => blk: { - std.log.err("Unknown error parsing request: {}\n{?}", .{ err, @errorReturnTrace() }); - break :blk .internal_server_error; - }, - }; - - conn.stream.writer().print( - "HTTP/1.1 {} {?s}\r\nConnection: close\r\n\r\n", - .{ @enumToInt(status), status.phrase() }, - ) catch {}; - return; + return handleError(conn.stream.writer(), err) catch {}; }; var res = Response{ @@ -153,10 +131,7 @@ pub const Server = struct { .stream = conn.stream, }; - handler.handle(&req, &res, initial_context, {}) catch |err| { - std.log.err("Unhandled error serving request: {}\n{?}", .{ err, @errorReturnTrace() }); - return; - }; + handler(ctx, &req, &res); if (req.headers.get("Connection")) |hdr| { if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| return; @@ -168,3 +143,17 @@ pub const Server = struct { } } }; + +/// Writes an error response message and requests closure of the connection +fn handleError(writer: anytype, err: anyerror) !void { + const status: http.Status = switch (err) { + error.EndOfStream => return, // Do nothing, the client closed the connection + error.BadRequest => .bad_request, + error.UnsupportedMediaType => .unsupported_media_type, + error.HttpVersionNotSupported => .http_version_not_supported, + + else => .internal_server_error, + }; + + try writer.print("HTTP/1.1 {} {?s}\r\nConnection: close\r\n\r\n", .{ @enumToInt(status), status.phrase() }); +} diff --git a/src/http/socket.zig b/src/http/socket.zig index f481fff..ab885bb 100644 --- a/src/http/socket.zig +++ b/src/http/socket.zig @@ -23,18 +23,18 @@ const Opcode = enum(u4) { } }; -pub fn handshake(alloc: std.mem.Allocator, req_headers: *const http.Fields, res: *http.Response) !Socket { - const upgrade = req_headers.get("Upgrade") orelse return error.BadHandshake; - const connection = req_headers.get("Connection") orelse return error.BadHandshake; +pub fn handshake(alloc: std.mem.Allocator, req: *http.Request, res: *http.Response) !Socket { + const upgrade = req.headers.get("Upgrade") orelse return error.BadHandshake; + const connection = req.headers.get("Connection") orelse return error.BadHandshake; if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake; if (std.ascii.indexOfIgnoreCase(connection, "Upgrade") == null) return error.BadHandshake; - const key_hdr = req_headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake; + const key_hdr = req.headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake; if ((try std.base64.standard.Decoder.calcSizeForSlice(key_hdr)) != 16) return error.BadHandshake; var key: [16]u8 = undefined; std.base64.standard.Decoder.decode(&key, key_hdr) catch return error.BadHandshake; - const version = req_headers.get("Sec-WebSocket-version") orelse return error.BadHandshake; + const version = req.headers.get("Sec-WebSocket-version") orelse return error.BadHandshake; if (!std.mem.eql(u8, "13", version)) return error.BadHandshake; var headers = http.Fields.init(alloc); diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 3263cfa..007e909 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -7,73 +7,144 @@ const util = @import("util"); const query_utils = @import("./query.zig"); const json_utils = @import("./json.zig"); -const web_endpoints = @import("./controllers/web.zig").routes; -const api_endpoints = @import("./controllers/api.zig").routes; +pub const auth = @import("./controllers/api/auth.zig"); +pub const communities = @import("./controllers/api/communities.zig"); +pub const invites = @import("./controllers/api/invites.zig"); +pub const users = @import("./controllers/api/users.zig"); +pub const follows = @import("./controllers/api/users/follows.zig"); +pub const notes = @import("./controllers/api/notes.zig"); +pub const streaming = @import("./controllers/api/streaming.zig"); +pub const timelines = @import("./controllers/api/timelines.zig"); -const mdw = http.middleware; +const web = @import("./controllers/web.zig"); -const not_found = struct { - pub fn handler( - _: @This(), - _: anytype, - res: anytype, - ctx: anytype, - ) !void { - var headers = http.Fields.init(ctx.allocator); - defer headers.deinit(); +pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response, alloc: std.mem.Allocator) void { + // TODO: hashmaps? + var response = Response{ .headers = http.Fields.init(alloc), .res = res }; + defer response.headers.deinit(); - var stream = try res.open(.not_found, &headers); - defer stream.close(); - try stream.finish(); + const found = routeRequestInternal(api_source, req, &response, alloc); + + if (!found) response.status(.not_found) catch {}; +} + +fn routeRequestInternal(api_source: anytype, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool { + inline for (routes) |route| { + if (Context(route).matchAndHandle(api_source, req, res, alloc)) return true; } + + return false; +} + +const routes = .{ + auth.login, + auth.verify_login, + communities.create, + communities.query, + invites.create, + users.create, + notes.create, + notes.get, + streaming.streaming, + timelines.global, + timelines.local, + timelines.home, + follows.create, + follows.delete, + follows.query_followers, + follows.query_following, +} ++ web.routes; + +fn parseRouteArgs(comptime route: []const u8, comptime Args: type, path: []const u8) !Args { + var args: Args = undefined; + var path_iter = util.PathIter.from(path); + comptime var route_iter = util.PathIter.from(route); + inline while (comptime route_iter.next()) |route_segment| { + const path_segment = path_iter.next() orelse return error.RouteMismatch; + if (route_segment.len > 0 and route_segment[0] == ':') { + const A = @TypeOf(@field(args, route_segment[1..])); + @field(args, route_segment[1..]) = try parseRouteArg(A, path_segment); + } else { + if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; + } + } + + if (path_iter.next() != null) return error.RouteMismatch; + + return args; +} + +fn parseRouteArg(comptime T: type, segment: []const u8) !T { + if (T == []const u8) return segment; + if (comptime std.meta.trait.isContainer(T) and std.meta.trait.hasFn("parse")(T)) return T.parse(segment); + + @compileError("Unsupported Type " ++ @typeName(T)); +} + +const BaseContentType = enum { + json, + url_encoded, + octet_stream, + + other, }; -const base_handler = mdw.SplitUri(mdw.CatchErrors(not_found, mdw.DefaultErrorHandler)); +fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T { + const buf = try reader.readAllAlloc(alloc, 1 << 16); + defer alloc.free(buf); -const inject_api_conn = struct { - fn getApiConn(alloc: std.mem.Allocator, api_source: anytype, req: anytype) !@TypeOf(api_source.*).Conn { - const host = req.headers.get("Host") orelse return error.NoHost; - const auth_header = req.headers.get("Authorization"); - const token = if (auth_header) |header| blk: { - const prefix = "bearer "; - if (header.len < prefix.len) break :blk null; - if (!std.ascii.eqlIgnoreCase(prefix, header[0..prefix.len])) break :blk null; - break :blk header[prefix.len..]; - } else null; + switch (content_type) { + .octet_stream, .json => { + const body = try json_utils.parse(T, buf, alloc); + defer json_utils.parseFree(body, alloc); - if (token) |t| return try api_source.connectToken(host, t, alloc); + return try util.deepClone(alloc, body); + }, + .url_encoded => return query_utils.parseQuery(alloc, T, buf) catch |err| switch (err) { + error.NoQuery => error.NoBody, + else => err, + }, + else => return error.UnsupportedMediaType, + } +} - if (req.headers.getCookie("active_account") catch return error.BadRequest) |account| { - if (account.len + ("token.").len <= 64) { - var buf: [64]u8 = undefined; - const cookie_name = std.fmt.bufPrint(&buf, "token.{s}", .{account}) catch unreachable; - if (try req.headers.getCookie(cookie_name)) |token_hdr| { - return try api_source.connectToken(host, token_hdr, alloc); - } - } else return error.InvalidCookie; - } +fn matchContentType(hdr: ?[]const u8) ?BaseContentType { + if (hdr) |h| { + if (std.ascii.eqlIgnoreCase(h, "application/x-www-form-urlencoded")) return .url_encoded; + if (std.ascii.eqlIgnoreCase(h, "application/json")) return .json; + if (std.ascii.eqlIgnoreCase(h, "application/octet-stream")) return .octet_stream; - return try api_source.connectUnauthorized(host, alloc); + return .other; } - pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - var api_conn = try getApiConn(ctx.allocator, ctx.api_source, req); - defer api_conn.close(); + return null; +} - return mdw.injectContextValue("api_conn", &api_conn).handle( - req, - res, - ctx, - next, - ); - } -}{}; +pub const AllocationStrategy = enum { + arena, + normal, +}; -pub fn EndpointRequest(comptime Endpoint: type) type { +pub fn Context(comptime Route: type) type { return struct { - const Args = if (@hasDecl(Endpoint, "Args")) Endpoint.Args else void; - const Body = if (@hasDecl(Endpoint, "Body")) Endpoint.Body else void; - const Query = if (@hasDecl(Endpoint, "Query")) Endpoint.Query else void; + const Self = @This(); + + pub const Args = if (@hasDecl(Route, "Args")) Route.Args else void; + + // TODO: if controller does not provide a body type, maybe we should + // leave it as a simple reader instead of void + pub const Body = if (@hasDecl(Route, "Body")) Route.Body else void; + + // TODO: if controller does not provide a query type, maybe we should + // leave it as a simple string instead of void + pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void; + + const allocation_strategy: AllocationStrategy = if (@hasDecl(Route, "allocation_strategy")) + Route.AllocationStrategy + else + .arena; + + base_request: *http.Request, allocator: std.mem.Allocator, @@ -85,87 +156,108 @@ pub fn EndpointRequest(comptime Endpoint: type) type { body: Body, query: Query, - const args_middleware = //if (Args == void) - //mdw.injectContext(.{ .args = {} }) - //else - mdw.ParsePathArgs(Endpoint.path, Args){}; + // TODO + body_buf: ?[]const u8 = null, - const body_middleware = //if (Body == void) - //mdw.injectContext(.{ .body = {} }) - //else - mdw.ParseBody(Body){}; + pub fn matchAndHandle(api_source: *api.ApiSource, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool { + if (req.method != Route.method) return false; + var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?'); + var args = parseRouteArgs(Route.path, Args, path) catch return false; - const query_middleware = //if (Query == void) - //mdw.injectContext(.{ .query_params = {} }) - //else - mdw.ParseQueryParams(Query){}; - }; -} + std.log.debug("Matched route {s}", .{Route.path}); -fn CallApiEndpoint(comptime Endpoint: type) type { - return struct { - pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, _: void) !void { - const request = EndpointRequest(Endpoint){ - .allocator = ctx.allocator, + handle(api_source, req, res, alloc, args) catch |err| { + std.log.err("{}", .{err}); + if (!res.opened) res.err(.internal_server_error, "", {}) catch {}; + }; + + return true; + } + + fn handle( + api_source: *api.ApiSource, + req: *http.Request, + res: *Response, + base_allocator: std.mem.Allocator, + args: Args, + ) !void { + const base_content_type = matchContentType(req.headers.get("Content-Type")); + + var arena = if (allocation_strategy == .arena) + std.heap.ArenaAllocator.init(base_allocator) + else {}; + const alloc = if (allocation_strategy == .arena) arena.allocator() else base_allocator; + + const body = if (Body != void) blk: { + var stream = req.body orelse return error.NoBody; + break :blk try parseBody(Body, base_content_type orelse .json, stream.reader(), alloc); + } else {}; + defer if (Body != void) util.deepFree(alloc, body); + + const query = if (Query != void) blk: { + const path = std.mem.sliceTo(req.uri, '?'); + const q = req.uri[path.len..]; + + break :blk try query_utils.parseQuery(alloc, Query, q); + }; + defer if (Query != void) util.deepFree(alloc, query); + + var api_conn = conn: { + const host = req.headers.get("Host") orelse return error.NoHost; + const auth_header = req.headers.get("Authorization"); + const token = if (auth_header) |header| blk: { + const prefix = "bearer "; + if (header.len < prefix.len) break :blk null; + if (!std.ascii.eqlIgnoreCase(prefix, header[0..prefix.len])) break :blk null; + break :blk header[prefix.len..]; + } else null; + + if (token) |t| break :conn try api_source.connectToken(host, t, alloc); + + if (req.headers.getCookie("active_account") catch return error.BadRequest) |account| { + if (account.len + ("token.").len <= 64) { + var buf: [64]u8 = undefined; + const cookie_name = std.fmt.bufPrint(&buf, "token.{s}", .{account}) catch unreachable; + if (try req.headers.getCookie(cookie_name)) |token_hdr| { + break :conn try api_source.connectToken(host, token_hdr, alloc); + } + } else return error.InvalidCookie; + } + + break :conn try api_source.connectUnauthorized(host, alloc); + }; + defer api_conn.close(); + + const self = Self{ + .allocator = alloc, + .base_request = req, .method = req.method, .uri = req.uri, .headers = req.headers, - .args = ctx.args, - .body = ctx.body, - .query = ctx.query_params, + .args = args, + .body = body, + .query = query, }; - var response = Response{ .headers = http.Fields.init(ctx.allocator), .res = res }; - defer response.headers.deinit(); - return Endpoint.handler(request, &response, ctx.api_conn); + try Route.handler(self, res, &api_conn); + } + + fn errorHandler(response: *Response, status: http.Status, err: anytype) void { + std.log.err("Error occured on handler {s} {s}", .{ @tagName(Route.method), Route.path }); + std.log.err("{}", .{err}); + const result = if (builtin.mode == .Debug) + response.err(status, @errorName(err), {}) + else + response.status(status); + _ = result catch |err2| { + std.log.err("Error printing response: {}", .{err2}); + }; } }; } -pub fn apiEndpoint( - comptime Endpoint: type, -) return_type: { - const RequestType = EndpointRequest(Endpoint); - break :return_type mdw.Apply(std.meta.Tuple(&.{ - mdw.Route, - @TypeOf(RequestType.args_middleware), - @TypeOf(RequestType.query_middleware), - @TypeOf(RequestType.body_middleware), - // TODO: allocation strategy - @TypeOf(inject_api_conn), - CallApiEndpoint(Endpoint), - })); -} { - const RequestType = EndpointRequest(Endpoint); - return mdw.apply(.{ - mdw.Route{ .desc = .{ .path = Endpoint.path, .method = Endpoint.method } }, - RequestType.args_middleware, - RequestType.query_middleware, - RequestType.body_middleware, - // TODO: allocation strategy - inject_api_conn, - CallApiEndpoint(Endpoint){}, - }); -} - -const api_router = mdw.apply(.{ - mdw.mount("/api/v0/"), - mdw.router(api_endpoints), -}); - -pub const router = mdw.apply(.{ - mdw.split_uri, - mdw.catchErrors(mdw.default_error_handler), - mdw.router(.{api_router} ++ web_endpoints), -}); - -pub const AllocationStrategy = enum { - arena, - normal, -}; - pub const Response = struct { const Self = @This(); headers: http.Fields, diff --git a/src/main/controllers/api.zig b/src/main/controllers/api.zig deleted file mode 100644 index 12f3a1f..0000000 --- a/src/main/controllers/api.zig +++ /dev/null @@ -1,29 +0,0 @@ -const controllers = @import("../controllers.zig"); - -const auth = @import("./api/auth.zig"); -const communities = @import("./api/communities.zig"); -const invites = @import("./api/invites.zig"); -const users = @import("./api/users.zig"); -const follows = @import("./api/users/follows.zig"); -const notes = @import("./api/notes.zig"); -const streaming = @import("./api/streaming.zig"); -const timelines = @import("./api/timelines.zig"); - -pub const routes = .{ - controllers.apiEndpoint(auth.login), - controllers.apiEndpoint(auth.verify_login), - controllers.apiEndpoint(communities.create), - controllers.apiEndpoint(communities.query), - controllers.apiEndpoint(invites.create), - controllers.apiEndpoint(users.create), - controllers.apiEndpoint(notes.create), - controllers.apiEndpoint(notes.get), - //controllers.apiEndpoint(streaming.streaming), - controllers.apiEndpoint(timelines.global), - controllers.apiEndpoint(timelines.local), - controllers.apiEndpoint(timelines.home), - controllers.apiEndpoint(follows.create), - controllers.apiEndpoint(follows.delete), - controllers.apiEndpoint(follows.query_followers), - controllers.apiEndpoint(follows.query_following), -}; diff --git a/src/main/controllers/api/auth.zig b/src/main/controllers/api/auth.zig index 1b6e652..eb719d8 100644 --- a/src/main/controllers/api/auth.zig +++ b/src/main/controllers/api/auth.zig @@ -3,7 +3,7 @@ const std = @import("std"); pub const login = struct { pub const method = .POST; - pub const path = "/auth/login"; + pub const path = "/api/v0/auth/login"; pub const Body = struct { username: []const u8, @@ -21,7 +21,7 @@ pub const login = struct { pub const verify_login = struct { pub const method = .GET; - pub const path = "/auth/login"; + pub const path = "/api/v0/auth/login"; pub fn handler(_: anytype, res: anytype, srv: anytype) !void { const info = try srv.verifyAuthorization(); diff --git a/src/main/controllers/api/communities.zig b/src/main/controllers/api/communities.zig index f6f475d..87f744c 100644 --- a/src/main/controllers/api/communities.zig +++ b/src/main/controllers/api/communities.zig @@ -5,7 +5,7 @@ const QueryArgs = api.CommunityQueryArgs; pub const create = struct { pub const method = .POST; - pub const path = "/communities"; + pub const path = "/api/v0/communities"; pub const Body = struct { origin: []const u8, @@ -20,7 +20,7 @@ pub const create = struct { pub const query = struct { pub const method = .GET; - pub const path = "/communities"; + pub const path = "/api/v0/communities"; pub const Query = QueryArgs; diff --git a/src/main/controllers/api/invites.zig b/src/main/controllers/api/invites.zig index 0355bcf..7be7d5d 100644 --- a/src/main/controllers/api/invites.zig +++ b/src/main/controllers/api/invites.zig @@ -2,7 +2,7 @@ const api = @import("api"); pub const create = struct { pub const method = .POST; - pub const path = "/invites"; + pub const path = "/api/v0/invites"; pub const Body = api.InviteOptions; diff --git a/src/main/controllers/api/notes.zig b/src/main/controllers/api/notes.zig index 11da65c..2622067 100644 --- a/src/main/controllers/api/notes.zig +++ b/src/main/controllers/api/notes.zig @@ -3,7 +3,7 @@ const util = @import("util"); pub const create = struct { pub const method = .POST; - pub const path = "/notes"; + pub const path = "/api/v0/notes"; pub const Body = struct { content: []const u8, @@ -18,7 +18,7 @@ pub const create = struct { pub const get = struct { pub const method = .GET; - pub const path = "/notes/:id"; + pub const path = "/api/v0/notes/:id"; pub const Args = struct { id: util.Uuid, diff --git a/src/main/controllers/api/streaming.zig b/src/main/controllers/api/streaming.zig index 4b8745b..263fabb 100644 --- a/src/main/controllers/api/streaming.zig +++ b/src/main/controllers/api/streaming.zig @@ -3,7 +3,7 @@ const std = @import("std"); pub const streaming = struct { pub const method = .GET; - pub const path = "/streaming"; + pub const path = "/api/v0/streaming"; pub fn handler(req: anytype, response: anytype, _: anytype) !void { var iter = req.headers.iterator(); diff --git a/src/main/controllers/api/timelines.zig b/src/main/controllers/api/timelines.zig index 8c30cc1..00da7ff 100644 --- a/src/main/controllers/api/timelines.zig +++ b/src/main/controllers/api/timelines.zig @@ -4,7 +4,7 @@ const controller_utils = @import("../../controllers.zig").helpers; pub const global = struct { pub const method = .GET; - pub const path = "/timelines/global"; + pub const path = "/api/v0/timelines/global"; pub const Query = api.TimelineArgs; @@ -16,7 +16,7 @@ pub const global = struct { pub const local = struct { pub const method = .GET; - pub const path = "/timelines/local"; + pub const path = "/api/v0/timelines/local"; pub const Query = api.TimelineArgs; @@ -28,7 +28,7 @@ pub const local = struct { pub const home = struct { pub const method = .GET; - pub const path = "/timelines/home"; + pub const path = "/api/v0/timelines/home"; pub const Query = api.TimelineArgs; diff --git a/src/main/controllers/api/users.zig b/src/main/controllers/api/users.zig index fdc683e..5fdc40c 100644 --- a/src/main/controllers/api/users.zig +++ b/src/main/controllers/api/users.zig @@ -2,7 +2,7 @@ const api = @import("api"); pub const create = struct { pub const method = .POST; - pub const path = "/users"; + pub const path = "/api/v0/users"; pub const Body = struct { username: []const u8, diff --git a/src/main/controllers/api/users/follows.zig b/src/main/controllers/api/users/follows.zig index 765e7a0..6e3e6e7 100644 --- a/src/main/controllers/api/users/follows.zig +++ b/src/main/controllers/api/users/follows.zig @@ -6,7 +6,7 @@ const Uuid = util.Uuid; pub const create = struct { pub const method = .POST; - pub const path = "/users/:id/follow"; + pub const path = "/api/v0/users/:id/follow"; pub const Args = struct { id: Uuid, @@ -21,7 +21,7 @@ pub const create = struct { pub const delete = struct { pub const method = .DELETE; - pub const path = "/users/:id/follow"; + pub const path = "/api/v0/users/:id/follow"; pub const Args = struct { id: Uuid, @@ -36,7 +36,7 @@ pub const delete = struct { pub const query_followers = struct { pub const method = .GET; - pub const path = "/users/:id/followers"; + pub const path = "/api/v0/users/:id/followers"; pub const Args = struct { id: Uuid, @@ -53,7 +53,7 @@ pub const query_followers = struct { pub const query_following = struct { pub const method = .GET; - pub const path = "/users/:id/following"; + pub const path = "/api/v0/users/:id/following"; pub const Args = struct { id: Uuid, diff --git a/src/main/controllers/web.zig b/src/main/controllers/web.zig index 430e405..9ee8ac4 100644 --- a/src/main/controllers/web.zig +++ b/src/main/controllers/web.zig @@ -1,12 +1,11 @@ const std = @import("std"); -const controllers = @import("../controllers.zig"); pub const routes = .{ - controllers.apiEndpoint(index), - controllers.apiEndpoint(about), - controllers.apiEndpoint(login), - controllers.apiEndpoint(global_timeline), - controllers.apiEndpoint(cluster.overview), + index, + about, + login, + global_timeline, + cluster.overview, }; const index = struct { diff --git a/src/main/main.zig b/src/main/main.zig index 18f644a..c512d3c 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -70,7 +70,11 @@ fn thread_main(src: *api.ApiSource, srv: *http.Server) void { util.seedThreadPrng() catch unreachable; var gpa = std.heap.GeneralPurposeAllocator(.{}){}; defer _ = gpa.deinit(); - srv.handleLoop(gpa.allocator(), .{ .api_source = src, .allocator = gpa.allocator() }, c.router); + srv.handleLoop(gpa.allocator(), .{ .src = src, .allocator = gpa.allocator() }, handle); +} + +fn handle(ctx: anytype, req: *http.Request, res: *http.Response) void { + c.routeRequest(ctx.src, req, res, ctx.allocator); } pub fn main() !void { diff --git a/src/main/migrations.zig b/src/main/migrations.zig index dc97b44..89a63d3 100644 --- a/src/main/migrations.zig +++ b/src/main/migrations.zig @@ -205,25 +205,4 @@ const migrations: []const Migration = &.{ , .down = "DROP TABLE follow", }, - .{ - .name = "files", - .up = - \\CREATE TABLE drive_file( - \\ id UUID NOT NULL PRIMARY KEY, - \\ - \\ filename TEXT NOT NULL, - \\ account_owner_id UUID REFERENCES account(id), - \\ community_owner_id UUID REFERENCES community(id), - \\ size INTEGER NOT NULL, - \\ - \\ created_at TIMESTAMPTZ NOT NULL, - \\ - \\ CHECK( - \\ (account_owner_id IS NULL AND community_owner_id IS NOT NULL) - \\ OR (account_owner_id IS NOT NULL AND community_owner_id IS NULL) - \\ ) - \\); - , - .down = "DROP TABLE drive_file", - }, }; diff --git a/src/sql/engines/sqlite.zig b/src/sql/engines/sqlite.zig index 5b59910..e2bb697 100644 --- a/src/sql/engines/sqlite.zig +++ b/src/sql/engines/sqlite.zig @@ -40,7 +40,7 @@ fn handleUnexpectedError(db: *c.sqlite3, code: c_int, sql_text: ?[]const u8) err std.log.debug("Failed at char ({}:{}) of SQL:\n{s}", .{ pos.row, pos.col, sql }); } } - std.log.debug("{?}", .{@errorReturnTrace()}); + std.log.debug("{?s}", .{@errorReturnTrace()}); return error.Unexpected; } diff --git a/src/util/iters.zig b/src/util/iters.zig index 5ad2258..b19c5bd 100644 --- a/src/util/iters.zig +++ b/src/util/iters.zig @@ -49,30 +49,19 @@ pub const QueryIter = struct { pub const PathIter = struct { is_first: bool, - iter: std.mem.SplitIterator(u8), + iter: Separator('/'), pub fn from(path: []const u8) PathIter { - return .{ .is_first = true, .iter = std.mem.split(u8, path, "/") }; + return .{ .is_first = true, .iter = Separator('/').from(path) }; } pub fn next(self: *PathIter) ?[]const u8 { - defer self.is_first = false; - while (self.iter.next()) |it| if (it.len != 0) { - return it; - }; + if (self.is_first) { + self.is_first = false; + return self.iter.next() orelse ""; + } - if (self.is_first) return self.iter.rest(); - - return null; - } - - pub fn first(self: *PathIter) []const u8 { - std.debug.assert(self.is_first); - return self.next().?; - } - - pub fn rest(self: *PathIter) []const u8 { - return self.iter.rest(); + return self.iter.next(); } };