From 3bd1bfc9722fc307d54ea53ed275277b05caf2c1 Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Sat, 9 Jul 2022 22:05:01 -0700 Subject: [PATCH] Messing with routes --- src/main/db2.zig | 93 +++++++ src/main/http.zig | 578 ------------------------------------------- src/main/main.zig | 67 ++++- src/main/routing.zig | 429 -------------------------------- 4 files changed, 157 insertions(+), 1010 deletions(-) create mode 100644 src/main/db2.zig delete mode 100644 src/main/http.zig delete mode 100644 src/main/routing.zig diff --git a/src/main/db2.zig b/src/main/db2.zig new file mode 100644 index 0000000..fdefaf5 --- /dev/null +++ b/src/main/db2.zig @@ -0,0 +1,93 @@ +const std = @import("std"); +const util = @import("util"); + +pub const Stream = std.io.StreamSource; + +pub const DocDb = struct { + const Keys = struct { + public_id: ?[]const u8, + local_id: u64, + }; + + const Entry = struct { + keys: Keys, + document: []const u8, + }; + + alloc: std.mem.Allocator, + docs: std.ArrayListUnmanaged(Entry) = .{}, + id_counter: u64 = 0, + + // TODO: hash shit + + pub fn init(alloc: std.mem.Allocator) DocDb { + return DocDb{ + .alloc = alloc, + }; + } + + fn getId(self: *DocDb, doc: []const u8) ?[]const u8 { + var ts = std.json.TokenStream.init(doc); + const HasId = struct { id: []const u8 }; + const parsed = std.json.parse(HasId, &ts, .{ .allocator = self.alloc, .ignore_unknown_fields = true }) catch return null; + return parsed.id; + } + + fn genLocalId(self: *DocDb) u64 { + self.id_counter += 1; + return self.id_counter; + } + + pub fn store(self: *DocDb, doc: []const u8) !u64 { + if (!std.json.validate(doc)) return error.InvalidJson; + + // todo: check for collisions + const clone = try dupe(self.alloc, doc); + errdefer self.alloc.free(clone); + + const local_id = self.genLocalId(); + const public_id = self.getId(clone); + + try self.docs.append(self.alloc, Entry{ + .keys = .{ + .public_id = public_id, + .local_id = local_id, + }, + .document = clone, + }); + + return local_id; + } + + pub fn getByLocalId(self: *DocDb, id: u64) !?Stream { + for (self.docs.items) |doc| { + if (id == doc.keys.local_id) { + return prepareDoc(doc); + } + } + + return null; + } + + pub fn getByPublicId(self: *DocDb, id: []const u8) !?Stream { + for (self.docs.items) |doc| { + if (doc.keys.public_id) |pub_id| { + if (util.ciutf8.eql(id, pub_id)) { + return prepareDoc(doc); + } + } + } + + return null; + } + + fn dupe(alloc: std.mem.Allocator, doc: []const u8) ![]u8 { + var clone = try alloc.alloc(u8, doc.len); + std.mem.copy(u8, clone, doc); + return clone; + } + + fn prepareDoc(e: Entry) Stream { + return std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(e.document) }; + } +}; diff --git a/src/main/http.zig b/src/main/http.zig deleted file mode 100644 index a2729bf..0000000 --- a/src/main/http.zig +++ /dev/null @@ -1,578 +0,0 @@ -const std = @import("std"); -const root = @import("root"); - -const ciutf8 = @import("./util.zig").ciutf8; -const Reader = std.net.Stream.Reader; -//const Writer = std.net.Stream.Writer; - -const Status = std.http.Status; -const Method = std.http.Method; -const Connection = std.net.StreamServer.Connection; -pub const Handler = fn (*Context) anyerror!Response; - -const HeaderMap = std.HashMap([]const u8, []const u8, struct { - pub fn eql(_: @This(), a: []const u8, b: []const u8) bool { - return ciutf8.eql(a, b); - } - - pub fn hash(_: @This(), str: []const u8) u64 { - return ciutf8.hash(str); - } -}, std.hash_map.default_max_load_percentage); -pub const Headers = HeaderMap; - -fn handleBadRequest(writer: std.net.Stream.Writer) !void { - std.log.info("400 Bad Request", .{}); - try writer.writeAll("HTTP/1.1 400 Bad Request"); -} - -fn handleNotImplemented(writer: std.net.Stream.Writer) !void { - std.log.info("501", .{}); - try writer.writeAll("HTTP/1.1 501 Not Implemented"); -} - -fn handleInternalError(writer: std.net.Stream.Writer) !void { - std.log.info("500", .{}); - try writer.writeAll("HTTP/1.1 500 Internal Server Error"); -} - -fn parseHttpMethod(reader: Reader) !Method { - var buf: [8]u8 = undefined; - const str = reader.readUntilDelimiter(&buf, ' ') catch |err| switch (err) { - error.StreamTooLong => return error.MethodNotImplemented, - else => return err, - }; - - inline for (@typeInfo(Method).Enum.fields) |method| { - if (std.mem.eql(u8, method.name, str)) { - return @intToEnum(Method, method.value); - } - } - - return error.MethodNotImplemented; -} - -fn checkProto(reader: Reader) !void { - var buf: [8]u8 = undefined; - const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { - error.StreamTooLong => return error.UnknownProtocol, - else => return err, - }; - - if (!std.mem.eql(u8, proto, "HTTP")) { - return error.UnknownProtocol; - } - - const count = try reader.read(buf[0..3]); - if (count != 3 or buf[1] != '.') { - return error.BadRequest; - } - - if (buf[0] != '1' or buf[2] != '1') { - return error.HttpVersionNotSupported; - } -} - -fn extractHeaderName(line: []const u8) ?[]const u8 { - var index: usize = 0; - - // TODO: handle whitespace - while (index < line.len) : (index += 1) { - if (line[index] == ':') { - if (index == 0) return null; - return line[0..index]; - } - } - - return null; -} - -fn parseHeaders(allocator: std.mem.Allocator, reader: Reader) !HeaderMap { - var map = HeaderMap.init(allocator); - errdefer map.deinit(); - // TODO: free map keys/values - - var buf: [1024]u8 = undefined; - - while (true) { - const line = try reader.readUntilDelimiter(&buf, '\n'); - if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break; - - // TODO: handle multi-line headers - const name = extractHeaderName(line) orelse continue; - const value_end = if (line[line.len - 1] == '\r') line.len - 1 else line.len; - const value = line[name.len + 1 + 1 .. value_end]; - - if (name.len == 0 or value.len == 0) return error.BadRequest; - - const name_alloc = try allocator.alloc(u8, name.len); - errdefer allocator.free(name_alloc); - const value_alloc = try allocator.alloc(u8, value.len); - errdefer allocator.free(value_alloc); - - @memcpy(name_alloc.ptr, name.ptr, name.len); - @memcpy(value_alloc.ptr, value.ptr, value.len); - - try map.put(name_alloc, value_alloc); - } - - return map; -} - -const ConnectionId = u64; -var next_connection_id = std.atomic.Atomic(ConnectionId).init(1); -pub fn handleConnection( - base_alloc: std.mem.Allocator, - conn: std.net.StreamServer.Connection, - handler: Handler, -) void { - const conn_id = next_connection_id.fetchAdd(1, .SeqCst); - defer conn.stream.close(); - std.log.debug("New connection conn={}", .{conn_id}); - - _ = base_alloc; - handleRequest(conn.stream.reader(), conn.stream.writer(), handler) catch |err| std.log.err("conn={}; Unhandled error processing connection {}. Closing", .{ conn_id, err }); - std.log.debug("Terminating connection conn={}", .{conn_id}); -} - -fn handleRequest(reader: Reader, writer: std.net.Stream.Writer, handler: Handler) anyerror!void { - const method = try parseHttpMethod(reader); - std.log.debug("Request recieved", .{}); - - var header_buf: [1 << 16]u8 = undefined; - var fba = std.heap.FixedBufferAllocator.init(&header_buf); - const allocator = fba.allocator(); - const path = reader.readUntilDelimiterAlloc(allocator, ' ', header_buf.len) catch |err| switch (err) { - error.StreamTooLong => return error.URITooLong, - else => return err, - }; - - try checkProto(reader); - _ = try reader.readByte(); - _ = try reader.readByte(); - - const headers = try parseHeaders(allocator, reader); - - const has_body = method.requestHasBody() and headers.get("Content-Length") != null; - - const tfer_encoding = headers.get("Transfer-Encoding"); - if (tfer_encoding != null and !std.mem.eql(u8, tfer_encoding.?, "identity")) { - return error.UnsupportedMediaType; - } - - const encoding = headers.get("Content-Encoding"); - if (encoding != null and !std.mem.eql(u8, encoding.?, "identity")) { - return error.UnsupportedMediaType; - } - - var ctx = Context{ - .request = Request2{ - .method = method, - .path = path, - .headers = headers, - }, - .allocator = allocator, - }; - - if (has_body) { - const body_len = std.fmt.parseInt(usize, headers.get("Content-Length").?, 10) catch return error.BadRequest; - - const body = try allocator.alloc(u8, body_len); - errdefer allocator.free(body); - if ((try reader.read(body)) != body_len) return error.BadRequest; - ctx.request.body = body; - } - defer if (has_body) allocator.free(ctx.request.body.?); - - _ = ctx; - - const response = try handler(&ctx); - const status_text = response.status.phrase() orelse ""; - try writer.print("HTTP/1.1 {} {s}\r\n", .{ @enumToInt(response.status), status_text }); - - var iter = response.headers.iterator(); - while (iter.next()) |it| { - try writer.print("{s}: {s}\r\n", .{ it.key_ptr.*, it.value_ptr.* }); - } else try writer.writeAll("\r\n"); - - if (response.body) |body| { - try writer.writeAll(body); - } - _ = handler; - _ = writer; - _ = path; -} - -pub const Context = struct {}; - -const chunk_size = 16 * 1024; -pub fn openResponse( - alloc: std.mem.Allocator, - writer: anytype, - headers: *const HeaderMap, - status: Status, -) !ResponseStream(@TypeOf(writer)) { - const buf = try alloc.alloc(u8, chunk_size); - errdefer alloc.free(buf); - - try writeStatusLine(writer, status); - try writeHeaders(writer, headers); - - return ResponseStream(@TypeOf(writer)){ - .allocator = alloc, - .base_writer = writer, - .headers = headers, - .buffer = buf, - }; -} - -fn writeStatusLine(writer: anytype, status: Status) !void { - const status_text = status.phrase() orelse ""; - try writer.print("HTTP/1.1 {} {s}\r\n", .{ @enumToInt(status), status_text }); -} - -fn writeHeaders(writer: anytype, headers: *const HeaderMap) !void { - var iter = headers.iterator(); - while (iter.next()) |header| { - for (header.value_ptr.*) |ch| { - if (ch == '\r' or ch == '\n') @panic("newlines not yet supported in headers"); - } - - try writer.print("{s}: {s}\r\n", .{ header.key_ptr.*, header.value_ptr.* }); - } -} - -fn writeChunk(writer: anytype, contents: []const u8) @TypeOf(writer).Error!void { - try writer.print("{x}\r\n", .{contents.len}); - try writer.writeAll(contents); - try writer.writeAll("\r\n"); -} - -fn writeLastChunk(writer: anytype) writer.Error!void { - try writer.writeAll("0\r\n"); -} - -fn ResponseStream(comptime BaseWriter: type) type { - return struct { - const Self = @This(); - const Error = BaseWriter.Error; - const Writer = std.io.Writer(*Self, Error, write); - - allocator: std.mem.Allocator, - base_writer: BaseWriter, - headers: *const HeaderMap, - buffer: []u8, - buffer_pos: usize = 0, - chunked: bool = false, - - fn writeToBuffer(self: *Self, bytes: []const u8) void { - std.mem.copy(u8, self.buffer[self.buffer_pos..], bytes); - self.buffer_pos += bytes.len; - } - - fn startChunking(self: *Self) Error!void { - try self.base_writer.writeAll("Transfer-Encoding: chunked\r\n\r\n"); - self.chunked = true; - } - - fn flushChunk(self: *Self) Error!void { - try writeChunk(self.base_writer, self.buffer[0..self.buffer_pos]); - self.buffer_pos = 0; - } - - fn writeToChunks(self: *Self, bytes: []const u8) Error!void { - var cursor: usize = 0; - while (true) { - const remaining_in_chunk = self.buffer.len - self.buffer_pos; - const remaining_to_write = bytes.len - cursor; - if (remaining_to_write <= remaining_in_chunk) { - self.writeToBuffer(bytes[cursor..]); - return; - } - - std.debug.print("{}\n", .{cursor}); - self.writeToBuffer(bytes[cursor .. cursor + remaining_in_chunk]); - cursor += remaining_in_chunk; - try self.flushChunk(); - } - } - - fn write(self: *Self, bytes: []const u8) Error!usize { - if (!self.chunked and bytes.len > self.buffer.len - self.buffer_pos) { - try self.startChunking(); - } - - if (self.chunked) { - try self.writeToChunks(bytes); - } else { - self.writeToBuffer(bytes); - } - - return bytes.len; - } - - fn flushBodyUnchunked(self: *Self) Error!void { - if (self.buffer_pos != 0) { - try self.base_writer.print("Content-Length: {}\r\n", .{self.buffer_pos}); - } - - try self.base_writer.writeAll("\r\n"); - - if (self.buffer_pos != 0) { - try self.base_writer.writeAll(self.buffer[0..self.buffer_pos]); - } - self.buffer_pos = 0; - } - - pub fn writer(self: *Self) Writer { - return Writer{ .context = self }; - } - - pub fn finish(self: *Self) Error!void { - if (!self.chunked) { - try self.flushBodyUnchunked(); - } else { - if (self.buffer_pos != 0) { - try self.flushChunk(); - } - try self.base_writer.writeAll("0\r\n"); - } - } - - pub fn close(self: *const Self) void { - self.allocator.free(self.buffer); - } - }; -} - -test { - _ = _tests; -} -const _tests = struct { - fn toCrlf(comptime str: []const u8) []const u8 { - comptime { - var buf: [str.len * 2]u8 = undefined; - @setEvalBranchQuota(@as(u32, str.len * 2)); - - var len: usize = 0; - for (str) |ch| { - if (ch == '\n') { - buf[len] = '\r'; - len += 1; - } - - buf[len] = ch; - len += 1; - } - - return buf[0..len]; - } - } - - const test_buffer_size = chunk_size * 4; - test "ResponseStream no headers empty body" { - var buffer: [test_buffer_size]u8 = undefined; - var test_stream = std.io.fixedBufferStream(&buffer); - var headers = HeaderMap.init(std.testing.allocator); - defer headers.deinit(); - - { - var stream = try openResponse( - std.testing.allocator, - test_stream.writer(), - &headers, - .ok, - ); - defer stream.close(); - - try stream.finish(); - } - - try std.testing.expectEqualStrings( - toCrlf( - \\HTTP/1.1 200 OK - \\ - \\ - ), - buffer[0..(try test_stream.getPos())], - ); - } - - test "ResponseStream empty body" { - var buffer: [test_buffer_size]u8 = undefined; - var test_stream = std.io.fixedBufferStream(&buffer); - var headers = HeaderMap.init(std.testing.allocator); - defer headers.deinit(); - - try headers.put("Content-Type", "text/plain"); - - { - var stream = try openResponse( - std.testing.allocator, - test_stream.writer(), - &headers, - .ok, - ); - defer stream.close(); - - try stream.finish(); - } - - try std.testing.expectEqualStrings( - toCrlf( - \\HTTP/1.1 200 OK - \\Content-Type: text/plain - \\ - \\ - ), - buffer[0..(try test_stream.getPos())], - ); - } - - test "ResponseStream not 200 OK" { - var buffer: [test_buffer_size]u8 = undefined; - var test_stream = std.io.fixedBufferStream(&buffer); - var headers = HeaderMap.init(std.testing.allocator); - defer headers.deinit(); - - try headers.put("Content-Type", "text/plain"); - - { - var stream = try openResponse( - std.testing.allocator, - test_stream.writer(), - &headers, - .not_found, - ); - defer stream.close(); - - try stream.finish(); - } - - try std.testing.expectEqualStrings( - toCrlf( - \\HTTP/1.1 404 Not Found - \\Content-Type: text/plain - \\ - \\ - ), - buffer[0..(try test_stream.getPos())], - ); - } - test "ResponseStream small body" { - var buffer: [test_buffer_size]u8 = undefined; - var test_stream = std.io.fixedBufferStream(&buffer); - var headers = HeaderMap.init(std.testing.allocator); - defer headers.deinit(); - - try headers.put("Content-Type", "text/plain"); - - { - var stream = try openResponse( - std.testing.allocator, - test_stream.writer(), - &headers, - .ok, - ); - defer stream.close(); - - try stream.writer().writeAll("Index Page"); - - try stream.finish(); - } - - try std.testing.expectEqualStrings( - toCrlf( - \\HTTP/1.1 200 OK - \\Content-Type: text/plain - \\Content-Length: 10 - \\ - \\Index Page - ), - buffer[0..(try test_stream.getPos())], - ); - } - - test "ResponseStream large body" { - var buffer: [test_buffer_size]u8 = undefined; - var test_stream = std.io.fixedBufferStream(&buffer); - var headers = HeaderMap.init(std.testing.allocator); - defer headers.deinit(); - - try headers.put("Content-Type", "text/plain"); - - { - var stream = try openResponse( - std.testing.allocator, - test_stream.writer(), - &headers, - .ok, - ); - defer stream.close(); - - try stream.writer().writeAll("quuz" ** 6000); - - try stream.finish(); - } - - // zig fmt: off - try std.testing.expectEqualStrings( - toCrlf( - \\HTTP/1.1 200 OK - \\Content-Type: text/plain - \\Transfer-Encoding: chunked - \\ - \\ - ++ "4000\n" - ++ "quuz" ** 4096 ++ "\n" - ++ "1dc0\n" - ++ "quuz" ** (1904) ++ "\n" - ++ "0\n" - ), - buffer[0..(try test_stream.getPos())], - ); - // zig fmt: on - } - - test "ResponseStream large body ending on chunk boundary" { - var buffer: [test_buffer_size]u8 = undefined; - var test_stream = std.io.fixedBufferStream(&buffer); - var headers = HeaderMap.init(std.testing.allocator); - defer headers.deinit(); - - try headers.put("Content-Type", "text/plain"); - - { - var stream = try openResponse( - std.testing.allocator, - test_stream.writer(), - &headers, - .ok, - ); - defer stream.close(); - - try stream.writer().writeAll("quuz" ** (chunk_size / 2)); - - try stream.finish(); - } - - // zig fmt: off - try std.testing.expectEqualStrings( - toCrlf( - \\HTTP/1.1 200 OK - \\Content-Type: text/plain - \\Transfer-Encoding: chunked - \\ - \\ - ++ "4000\n" - ++ "quuz" ** 4096 ++ "\n" - ++ "4000\n" - ++ "quuz" ** 4096 ++ "\n" - ++ "0\n" - ), - buffer[0..(try test_stream.getPos())], - ); - // zig fmt: on - } -}; - diff --git a/src/main/main.zig b/src/main/main.zig index 7888d9b..19b5dd5 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -1,8 +1,66 @@ const std = @import("std"); const http = @import("http"); +const db = @import("./db2.zig"); + +// this thing is overcomplicated and weird. stop this +const Router = http.Router(*RequestServer); +const Route = Router.Route; +const RouteArgs = http.RouteArgs; +const router = Router{ + .routes = &[_]Route{ + Route.new(.POST, "/object", postObject), + Route.new(.GET, "/object/:id", getObject), + }, +}; + +fn postObject(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { + const alloc = srv.alloc; + const headers = http.Headers.init(alloc); + + var stream = try ctx.openResponse(&headers, .ok); + const writer = stream.writer(); + defer stream.close(); + + try writer.print("Page for {s}\n", .{ctx.request.path}); + + const id = srv.doc_db.store(ctx.request.body.?); + try writer.print("{}", .{id}); + + try stream.finish(); +} + +fn getObject(srv: *RequestServer, ctx: *http.server.Context, args: RouteArgs) !void { + const alloc = srv.alloc; + const headers = http.Headers.init(alloc); + + var stream = try ctx.openResponse(&headers, .ok); + const writer = stream.writer(); + defer stream.close(); + + //try writer.print("id: {s}\n", .{args.get("id").?}); + const id = try std.fmt.parseInt(u64, args.get("id").?, 10); + var doc_stream = (try srv.doc_db.getByLocalId(id)).?; + const doc = doc_stream.reader(); + + var buf: [1 << 8]u8 = undefined; + const count = try doc.readAll(&buf); + try writer.writeAll(buf[0..count]); + + try stream.finish(); +} const RequestServer = struct { - fn listenAndRun(_: *RequestServer, addr: std.net.Address) noreturn { + alloc: std.mem.Allocator, + doc_db: db.DocDb, + + fn init(alloc: std.mem.Allocator) RequestServer { + return RequestServer{ + .alloc = alloc, + .doc_db = db.DocDb.init(alloc), + }; + } + + fn listenAndRun(self: *RequestServer, addr: std.net.Address) noreturn { var srv = http.Server.listen(addr) catch unreachable; defer srv.shutdown(); @@ -14,7 +72,9 @@ const RequestServer = struct { var ctx = srv.accept(alloc) catch unreachable; defer ctx.close(); - handleRequest(alloc, &ctx) catch unreachable; + router.dispatch(self, &ctx, ctx.request.method, ctx.request.path) catch unreachable; + + //handleRequest(alloc, &ctx) catch unreachable; } } @@ -35,6 +95,7 @@ const RequestServer = struct { }; pub fn main() anyerror!void { - var srv = RequestServer{}; + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + var srv = RequestServer.init(gpa.allocator()); srv.listenAndRun(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); } diff --git a/src/main/routing.zig b/src/main/routing.zig deleted file mode 100644 index f8559a3..0000000 --- a/src/main/routing.zig +++ /dev/null @@ -1,429 +0,0 @@ -const std = @import("std"); -const http = @import("./http.zig"); -const ciutf8 = @import("./util.zig").ciutf8; - -const PathIter = struct { - is_first: bool, - path: []const u8, - - pub fn from(path: []const u8) PathIter { - return .{ .path = path, .is_first = true }; - } - - pub fn next(self: *PathIter) ?[]const u8 { - if (self.path.len == 0) { - if (self.is_first) { - self.is_first = false; - return self.path; - } else { - return null; - } - } - - var start: usize = 0; - var end: usize = start; - while (end < self.path.len) : (end += 1) { - // skip leading slash - if (end == start and self.path[start] == '/') { - start += 1; - continue; - } else if (self.path[end] == '/') { - break; - } - } - - if (start == end) { - self.path = self.path[end..end]; - self.is_first = false; - return null; - } - - const result = self.path[start..end]; - self.path = self.path[end..]; - self.is_first = false; - return result; - } -}; - -fn splitRoutePath(comptime path: []const u8) []const RouteSegment { - comptime { - var segments: [path.len]RouteSegment = undefined; - - var iter = PathIter.from(path); - var count = 0; - while (iter.next()) |it| { - if (it[0] == ':') { - segments[count] = .{ .param = it[1..] }; - } else { - segments[count] = .{ .literal = it }; - } - - count += 1; - } - - return segments[0..count]; - } -} - -const RouteSegment = union(enum) { - literal: []const u8, - param: []const u8, -}; - -pub fn RouteFn(comptime Context: type) type { - return fn (Context, std.http.Method, []const u8) anyerror!http.Response; -} - -/// `makeRoute` takes a route definition and a handler of the form `fn(, ) anyerror!http.Response` -/// where `Params` is a struct containing one field of type `[]const u8` for each path parameter -/// -/// Arguments: -/// method: The HTTP method to match -/// path: The path spec to match against. Path segments beginning with a `:` will cause the rest of -/// the segment to be treated as the name of a path parameter -/// handler: The code to execute on route match. This must be a function of form `fn(, ) anyerror!http.Response` -/// -/// Implicit Arguments: -/// Context: the type of a user-supplied Context that is passed through the route. typically `std.http.Context` but -/// made generic for ease of testing. There are no restrictions on this type -/// Params: the type of a struct representing the path parameters expressed in ``. This must be -/// a struct, with a one-one map between fields and path parameters. Each field must be of type -/// `[]const u8` and it must have the same name as a single path parameter. -/// -/// Returns: -/// A new route function of type `fn(, std.http.Method, []const u8) anyerror!http.Response`. When called, -/// this function will test the provided values against its specification. If they match, then -/// this function will parse path parameters and will be called with the supplied -/// context and params. If they do not match, this function will return error.Http404 -/// -/// Example: -/// route(.GET, "/user/:id/followers", struct{ -/// fn getFollowers(ctx: std.http.Context, params: struct{ id: []const u8 } anyerror { ... } -/// ).getFollowers) -/// -pub fn makeRoute( - comptime method: std.http.Method, - comptime path: []const u8, - comptime handler: anytype, -) return_type: { - const handler_info = @typeInfo(@TypeOf(handler)); - if (handler_info != .Fn) @compileError("Route expects a function"); - break :return_type RouteFn(@typeInfo(@TypeOf(handler)).Fn.args[0].arg_type.?); -} { - const handler_args = @typeInfo(@TypeOf(handler)).Fn.args; - if (handler_args.len != 2) @compileError("handler function must have signature fn(Context, Params) anyerror"); - if (@typeInfo(handler_args[1].arg_type.?) != .Struct) @compileError("Params in handler(Context, Params) must be struct"); - - const Context = handler_args[0].arg_type.?; - const Params = handler_args[1].arg_type.?; - - const params_fields = std.meta.fields(Params); - var params_field_used = [_]bool{false} ** std.meta.fields(Params).len; - const segments = splitRoutePath(path); - for (segments) |seg| { - if (seg == .param) { - const found = for (params_fields) |f, i| { - if (std.mem.eql(u8, seg.param, f.name)) { - params_field_used[i] = true; - break true; - } - } else false; - - if (!found) @compileError("Params does not contain " ++ seg.param ++ " field"); - } - } - - for (params_fields) |f, i| { - if (f.field_type != []const u8) @compileError("Params fields must be []const u8"); - if (!params_field_used[i]) @compileError("Params field " ++ f.name ++ " not found in path"); - } - - return struct { - fn func(ctx: Context, req_method: std.http.Method, req_path: []const u8) anyerror!http.Response { - if (req_method != method) return error.Http404; - - var params: Params = undefined; - var req_segments = PathIter.from(req_path); - inline for (segments) |seg| { - const req_seg = req_segments.next() orelse return error.Http404; - var match = switch (seg) { - .literal => |literal| ciutf8.eql(literal, req_seg), - .param => |param| blk: { - @field(params, param) = req_seg; - break :blk true; - }, - }; - - if (!match) return error.Http404; - } - - if (req_segments.next() != null) return error.Http404; - - return handler(ctx, params); - } - }.func; -} - -pub fn RouterFn(comptime Context: type) type { - return fn (std.http.Method, path: []const u8, Context) anyerror!http.Response; -} - -pub fn makeRouter( - comptime Context: type, - comptime routes: []const RouteFn(Context), -) RouterFn(Context) { - return struct { - fn dispatch(method: std.http.Method, path: []const u8, ctx: Context) anyerror!http.Response { - for (routes) |r| { - return r(ctx, method, path) catch |err| switch (err) { - error.Http404 => continue, - else => err, - }; - } - - return error.Http404; - } - }.dispatch; -} - -test { - _ = _tests; -} -const _tests = struct { - test "PathIter" { - const path = "/ab/cd/"; - var it = PathIter.from(path); - try std.testing.expectEqualStrings("ab", it.next().?); - try std.testing.expectEqualStrings("cd", it.next().?); - try std.testing.expectEqual(@as(?[]const u8, null), it.next()); - } - - test "PathIter empty" { - const path = ""; - var it = PathIter.from(path); - try std.testing.expectEqualStrings("", it.next().?); - try std.testing.expectEqual(@as(?[]const u8, null), it.next()); - } - - test "PathIter complex" { - const path = "ab/c//defg/"; - var it = PathIter.from(path); - try std.testing.expectEqualStrings("ab", it.next().?); - try std.testing.expectEqualStrings("c", it.next().?); - try std.testing.expectEqualStrings("defg", it.next().?); - try std.testing.expectEqual(@as(?[]const u8, null), it.next()); - } - - fn expectEqualSegments(expected: []const RouteSegment, actual: []const RouteSegment) !void { - try std.testing.expectEqual(expected.len, actual.len); - for (expected) |_, i| { - try std.testing.expectEqual( - std.meta.activeTag(expected[i]), - std.meta.activeTag(actual[i]), - ); - switch (expected[i]) { - .literal => |exp| try std.testing.expectEqualStrings(exp, actual[i].literal), - .param => |exp| try std.testing.expectEqualStrings(exp, actual[i].param), - } - } - } - - test "splitRoutePath" { - const path = "//ab/c/:de/"; - const segments = splitRoutePath(path); - - try expectEqualSegments(&[_]RouteSegment{ - .{ .literal = "ab" }, - .{ .literal = "c" }, - .{ .param = "de" }, - }, segments); - } - - fn CallTracker(comptime _uniq: anytype, comptime next: anytype) type { - _ = _uniq; - - var ctx_type: type = undefined; - var args_type: type = undefined; - switch (@typeInfo(@TypeOf(next))) { - .Fn => |func| { - if (func.args.len != 2) @compileError("next() must take 2 arguments"); - - ctx_type = func.args[0].arg_type.?; - args_type = func.args[1].arg_type.?; - //if (@typeInfo(Args) != .Struct) @compileError("second argument to next() must be struct"); - }, - else => @compileError("next must be function"), - } - - const Context = ctx_type; - const Args = args_type; - - return struct { - var calls: u32 = 0; - - var last_ctx: ?Context = null; - var last_args: ?Args = null; - - fn func(ctx: Context, args: Args) !void { - calls += 1; - last_ctx = ctx; - last_args = args; - return next(ctx, args); - } - - fn expectCalledOnceWith(exp_ctx: Context, exp_args: Args) !void { - try std.testing.expectEqual(@as(u32, 1), calls); - try std.testing.expectEqual(exp_ctx, last_ctx.?); - inline for (std.meta.fields(Args)) |f| { - try std.testing.expectEqualStrings( - @field(exp_args, f.name), - @field(last_args.?, f.name), - ); - } - } - - fn expectNotCalled() !void { - try std.testing.expectEqual(@as(u32, 0), calls); - } - - fn reset() void { - calls = 0; - last_ctx = null; - last_args = null; - } - }; - } - - const TestContext = u32; - const DummyArgs = struct {}; - fn dummyHandler(comptime Args: type) type { - comptime { - return struct { - fn func(_: TestContext, _: Args) anyerror!http.Response {} - }; - } - } - - test "makeRoute(T, ...)" { - const mock_a = CallTracker(.{}, dummyHandler(DummyArgs).func); - const mock_b = CallTracker(.{}, dummyHandler(DummyArgs).func); - - const routes = comptime [_]RouteFn(TestContext){ - makeRoute(.GET, "/a", mock_a.func), - makeRoute(.GET, "/b", mock_b.func), - }; - - const router = makeRouter(TestContext, &routes); - _ = try router(.GET, "/a", 10); - try mock_a.expectCalledOnceWith(10, .{}); - try mock_b.expectNotCalled(); - mock_a.reset(); - - _ = try router(.GET, "/b", 0); - try mock_a.expectNotCalled(); - try mock_b.expectCalledOnceWith(0, .{}); - mock_b.reset(); - - try std.testing.expectError(error.Http404, router(.GET, "/c", 0)); - } - - test "makeRoute(T, ...) same path different methods" { - const mock_get = CallTracker(.{}, dummyHandler(DummyArgs).func); - const mock_post = CallTracker(.{}, dummyHandler(DummyArgs).func); - - const routes = comptime [_]RouteFn(TestContext){ - makeRoute(.GET, "/a", mock_get.func), - makeRoute(.POST, "/a", mock_post.func), - }; - - const router = makeRouter(TestContext, &routes); - _ = try router(.GET, "/a", 10); - try mock_get.expectCalledOnceWith(10, .{}); - try mock_post.expectNotCalled(); - mock_get.reset(); - - _ = try router(.POST, "/a", 10); - try mock_get.expectNotCalled(); - try mock_post.expectCalledOnceWith(10, .{}); - } - - test "makeRoute(T, ...) route under subpath" { - const mock_a = CallTracker(.{}, dummyHandler(DummyArgs).func); - const mock_b = CallTracker(.{}, dummyHandler(DummyArgs).func); - - const routes = comptime [_]RouteFn(TestContext){ - makeRoute(.GET, "/a", mock_a.func), - makeRoute(.GET, "/a/b", mock_b.func), - }; - - const router = makeRouter(TestContext, &routes); - _ = try router(.GET, "/a", 10); - try mock_a.expectCalledOnceWith(10, .{}); - try mock_b.expectNotCalled(); - mock_a.reset(); - - _ = try router(.GET, "/a/b", 11); - try mock_a.expectNotCalled(); - try mock_b.expectCalledOnceWith(11, .{}); - } - - test "makeRoute(T, ...) case-insensitive route" { - const mock_a = CallTracker(.{}, dummyHandler(DummyArgs).func); - - const routes = comptime [_]RouteFn(TestContext){ - makeRoute(.GET, "/test/a", mock_a.func), - }; - - const router = makeRouter(TestContext, &routes); - _ = try router(.GET, "/TEST/A", 10); - try mock_a.expectCalledOnceWith(10, .{}); - mock_a.reset(); - - _ = try router(.GET, "/TesT/a", 11); - try mock_a.expectCalledOnceWith(11, .{}); - } - - test "makeRoute(T, ...) redundant /" { - const mock_a = CallTracker(.{}, dummyHandler(DummyArgs).func); - - const routes = comptime [_]RouteFn(TestContext){ - makeRoute(.GET, "/test/a", mock_a.func), - }; - - const router = makeRouter(TestContext, &routes); - _ = try router(.GET, "/test//a", 10); - try mock_a.expectCalledOnceWith(10, .{}); - mock_a.reset(); - - _ = try router(.GET, "//test///////////a////", 11); - try mock_a.expectCalledOnceWith(11, .{}); - mock_a.reset(); - - _ = try router(.GET, "test/a", 12); - try mock_a.expectCalledOnceWith(12, .{}); - } - - test "makeRoute(T, ...) with variables" { - const mock_a = CallTracker(.{}, dummyHandler(struct { id: []const u8 }).func); - const mock_b = CallTracker(.{}, dummyHandler(struct { a_id: []const u8, b_id: []const u8 }).func); - - const routes = comptime [_]RouteFn(TestContext){ - makeRoute(.GET, "/test/:id/abcd", mock_a.func), - makeRoute(.GET, "/test/:a_id/abcd/:b_id", mock_b.func), - }; - - const router = makeRouter(TestContext, &routes); - _ = try router(.GET, "/test/xyz/abcd", 10); - try mock_a.expectCalledOnceWith(10, .{ .id = "xyz" }); - try mock_b.expectNotCalled(); - mock_a.reset(); - - try std.testing.expectError(error.Http404, router(.GET, "/test//abcd", 10)); - try mock_a.expectNotCalled(); - try mock_b.expectNotCalled(); - - _ = try router(.GET, "/test/xyz/abcd/zyx", 10); - try mock_a.expectNotCalled(); - try mock_b.expectCalledOnceWith(10, .{ .a_id = "xyz", .b_id = "zyx" }); - } -};