Compare commits

...

9 commits

12 changed files with 587 additions and 1219 deletions

View file

@ -98,16 +98,24 @@ pub fn build(b: *std.build.Builder) !void {
exe.linkLibC();
exe.addSystemIncludePath("/usr/include/");
const unittest_http_cmd = b.step("unit:http", "Run tests for http package");
const unittest_http = b.addTest("src/http/test.zig");
unittest_http_cmd.dependOn(&unittest_http.step);
unittest_http.addPackage(pkgs.util);
//const unittest_util_cmd = b.step("unit:util", "Run tests for util package");
//const unittest_util = b.addTest("src/util/Uuid.zig");
//unittest_util_cmd.dependOn(&unittest_util.step);
//const util_tests = b.addTest("src/util/lib.zig");
const http_tests = b.addTest("src/http/test.zig");
//const sql_tests = b.addTest("src/sql/lib.zig");
http_tests.addPackage(pkgs.util);
//http_tests.addPackage(pkgs.util);
//sql_tests.addPackage(pkgs.util);
const unit_tests = b.step("unit-tests", "Run tests");
//unit_tests.dependOn(&util_tests.step);
unit_tests.dependOn(&http_tests.step);
//unit_tests.dependOn(&sql_tests.step);
//const unit_tests = b.step("unit-tests", "Run tests");
const unittest_all = b.step("unit", "Run unit tests");
unittest_all.dependOn(unittest_http_cmd);
//unittest_all.dependOn(unittest_util_cmd);
const api_integration = b.addTest("./tests/api_integration/lib.zig");
api_integration.addPackage(pkgs.opts);

View file

@ -16,6 +16,7 @@ pub const Handler = server.Handler;
pub const Server = server.Server;
pub const middleware = @import("./middleware.zig");
pub const queryStringify = @import("./query.zig").queryStringify;
pub const Fields = @import("./headers.zig").Fields;

View file

@ -1,10 +1,103 @@
/// Middlewares are types with a method of type:
/// fn handle(
/// self: @This(),
/// request: *http.Request(< some type >),
/// response: *http.Response(< some type >),
/// context: anytype,
/// next_handler: anytype,
/// ) !void
///
/// If a middleware returns error.RouteMismatch, then it is assumed that the handler
/// did not apply to the request, and this is used by routing implementations to
/// determine when to stop attempting to match a route.
///
/// Terminal middlewares that are not implemented using other middlewares should
/// only accept a `void` value for `next_handler`.
const std = @import("std");
const root = @import("root");
const builtin = @import("builtin");
const http = @import("./lib.zig");
const util = @import("util");
const query_utils = @import("./query.zig");
const json_utils = @import("./json.zig");
/// Takes an iterable of middlewares and chains them together.
pub fn apply(middlewares: anytype) Apply(@TypeOf(middlewares)) {
return applyInternal(middlewares, std.meta.fields(@TypeOf(middlewares)));
}
/// Helper function for the return type of `apply()`
pub fn Apply(comptime Middlewares: type) type {
return ApplyInternal(std.meta.fields(Middlewares));
}
fn ApplyInternal(comptime fields: []const std.builtin.Type.StructField) type {
if (fields.len == 0) return void;
return HandlerList(
fields[0].field_type,
ApplyInternal(fields[1..]),
);
}
fn applyInternal(middlewares: anytype, comptime fields: []const std.builtin.Type.StructField) ApplyInternal(fields) {
if (fields.len == 0) return {};
return .{
.first = @field(middlewares, fields[0].name),
.next = applyInternal(middlewares, fields[1..]),
};
}
pub fn HandlerList(comptime First: type, comptime Next: type) type {
return struct {
first: First,
next: Next,
pub fn handle(
self: @This(),
req: anytype,
res: anytype,
ctx: anytype,
next: void,
) !void {
_ = next;
return self.first.handle(req, res, ctx, self.next);
}
};
}
test "apply" {
var count: usize = 0;
const NoOp = struct {
ptr: *usize,
fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
self.ptr.* += 1;
if (@TypeOf(next) != void) return next.handle(req, res, ctx, {});
}
};
const middlewares = .{
NoOp{ .ptr = &count },
NoOp{ .ptr = &count },
NoOp{ .ptr = &count },
NoOp{ .ptr = &count },
};
try std.testing.expectEqual(
Apply(@TypeOf(middlewares)),
HandlerList(NoOp, HandlerList(NoOp, HandlerList(NoOp, HandlerList(NoOp, void)))),
);
try apply(middlewares).handle(.{}, .{}, .{}, {});
try std.testing.expectEqual(count, 4);
}
test "injectContextValue - chained" {
try apply(.{
injectContextValue("abcd", @as(usize, 5)),
injectContextValue("efgh", @as(usize, 10)),
injectContextValue("ijkl", @as(usize, 15)),
ExpectContext(.{ .abcd = 5, .efgh = 10, .ijkl = 15 }){},
}).handle(.{}, .{}, .{}, {});
}
fn AddUniqueField(comptime Lhs: type, comptime N: usize, comptime name: [N]u8, comptime Val: type) type {
const Ctx = @Type(.{ .Struct = .{
.layout = .Auto,
@ -34,44 +127,19 @@ fn addField(lhs: anytype, comptime name: []const u8, val: anytype) AddField(@Typ
return result;
}
test {
// apply is a plumbing function that applies a tuple of middlewares in order
const base = apply(.{
split_uri,
mount("/abc"),
});
test "addField" {
const expect = std.testing.expect;
const eql = std.meta.eql;
const request = .{ .uri = "/abc/defg/hijkl?some_query=true#section" };
const response = .{};
const initial_context = .{};
try base.handle(request, response, initial_context, {});
}
fn ApplyInternal(comptime fields: []const std.builtin.Type.StructField) type {
if (fields.len == 0) return void;
return NextHandler(
fields[0].field_type,
ApplyInternal(fields[1..]),
);
}
fn applyInternal(middlewares: anytype, comptime fields: []const std.builtin.Type.StructField) ApplyInternal(fields) {
if (fields.len == 0) return {};
return .{
.first = @field(middlewares, fields[0].name),
.next = applyInternal(middlewares, fields[1..]),
};
}
pub fn apply(middlewares: anytype) Apply(@TypeOf(middlewares)) {
return applyInternal(middlewares, std.meta.fields(@TypeOf(middlewares)));
}
pub fn Apply(comptime Middlewares: type) type {
return ApplyInternal(std.meta.fields(Middlewares));
try expect(eql(addField(.{}, "abcd", 5), .{ .abcd = 5 }));
try expect(eql(addField(.{ .abcd = 5 }, "efgh", 10), .{ .abcd = 5, .efgh = 10 }));
try expect(eql(
addField(addField(.{}, "abcd", 5), "efgh", 10),
.{ .abcd = 5, .efgh = 10 },
));
}
/// Adds a single value to the context object
pub fn InjectContextValue(comptime name: []const u8, comptime V: type) type {
return struct {
val: V,
@ -85,24 +153,43 @@ pub fn injectContextValue(comptime name: []const u8, val: anytype) InjectContext
return .{ .val = val };
}
pub fn NextHandler(comptime First: type, comptime Next: type) type {
return struct {
first: First,
next: Next,
test "InjectContextValue" {
try injectContextValue("abcd", @as(usize, 5))
.handle(.{}, .{}, .{}, ExpectContext(.{ .abcd = 5 }){});
try injectContextValue("abcd", @as(usize, 5))
.handle(.{}, .{}, .{ .efgh = @as(usize, 10) }, ExpectContext(.{ .abcd = 5, .efgh = 10 }){});
}
pub fn handle(
self: @This(),
req: anytype,
res: anytype,
ctx: anytype,
next: void,
) !void {
_ = next;
return self.first.handle(req, res, ctx, self.next);
fn expectDeepEquals(expected: anytype, actual: anytype) !void {
const E = @TypeOf(expected);
const A = @TypeOf(actual);
if (E == void) return std.testing.expect(A == void);
try std.testing.expect(std.meta.fields(E).len == std.meta.fields(A).len);
inline for (std.meta.fields(E)) |f| {
const e = @field(expected, f.name);
const a = @field(actual, f.name);
if (comptime std.meta.trait.isZigString(f.field_type)) {
try std.testing.expectEqualStrings(a, e);
} else {
try std.testing.expectEqual(a, e);
}
}
}
// Helper for testing purposes
fn ExpectContext(comptime val: anytype) type {
return struct {
pub fn handle(_: @This(), _: anytype, _: anytype, ctx: anytype, _: void) !void {
try expectDeepEquals(val, ctx);
}
};
}
fn expectContext(comptime val: anytype) ExpectContext(val) {
return .{};
}
/// Catches any errors returned by the `next` chain, and passes them via context
/// to an error handler if one occurs
pub fn CatchErrors(comptime ErrorHandler: type) type {
return struct {
error_handler: ErrorHandler,
@ -118,14 +205,17 @@ pub fn CatchErrors(comptime ErrorHandler: type) type {
}
};
}
pub fn catchErrors(error_handler: anytype) CatchErrors(@TypeOf(error_handler)) {
return .{ .error_handler = error_handler };
}
/// Default error handler for CatchErrors, logs the error and outputs responds with a 500 if
/// a response has not been written yet
pub const default_error_handler = struct {
fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
_ = next;
std.log.err("Error {} on uri {s}", .{ ctx.err, req.uri });
fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, _: anytype) !void {
const should_log = !@import("builtin").is_test;
if (should_log) std.log.err("Error {} on uri {s}", .{ ctx.err, req.uri });
// Tell the server to close the connection after this request
res.should_close = true;
@ -141,7 +231,47 @@ pub const default_error_handler = struct {
}
}{};
pub const split_uri = struct {
test "CatchErrors" {
const TestResponse = struct {
should_close: bool = false,
was_opened: bool = false,
test_should_open: bool,
const TestStream = struct {
fn close(_: *@This()) void {}
fn finish(_: *@This()) !void {}
};
fn open(self: *@This(), status: http.Status, _: *http.Fields) !TestStream {
self.was_opened = true;
if (!self.test_should_open) return error.ResponseOpenedTwice;
try std.testing.expectEqual(status, .internal_server_error);
return .{};
}
};
const middleware_list = apply(.{
catchErrors(default_error_handler),
struct {
fn handle(_: @This(), _: anytype, _: anytype, _: anytype, _: anytype) !void {
return error.SomeError;
}
}{},
});
var response = TestResponse{ .test_should_open = true };
try middleware_list.handle(.{ .uri = "abcd" }, &response, .{}, {});
try std.testing.expect(response.should_close);
// Test that it doesn't open a response if one was already opened
response = TestResponse{ .test_should_open = false, .was_opened = true };
try middleware_list.handle(.{ .uri = "abcd" }, &response, .{}, {});
try std.testing.expect(response.should_close);
}
/// Takes the request uri provided and splits it into "path", "query_string", and "fragment_string"
/// parts, which are placed into context.
const SplitUri = struct {
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
var frag_split = std.mem.split(u8, req.uri, "#");
const without_fragment = frag_split.first();
@ -168,9 +298,32 @@ pub const split_uri = struct {
{},
);
}
}{};
};
pub const split_uri = SplitUri{};
// routes a request to the correct handler based on declared HTTP method and path
test "split_uri" {
const testCase = struct {
fn func(uri: []const u8, ctx: anytype, expected: anytype) !void {
const v = apply(.{
split_uri,
expectContext(expected),
});
try v.handle(.{ .uri = uri }, .{}, ctx, {});
}
}.func;
try testCase("/", .{}, .{ .path = "/", .query_string = "", .fragment_string = "" });
try testCase("", .{}, .{ .path = "", .query_string = "", .fragment_string = "" });
try testCase("/path", .{}, .{ .path = "/path", .query_string = "", .fragment_string = "" });
try testCase("?abcd=1234", .{}, .{ .path = "", .query_string = "abcd=1234", .fragment_string = "" });
try testCase("#abcd", .{}, .{ .path = "", .query_string = "", .fragment_string = "abcd" });
try testCase("/abcd/efgh?query=no#frag", .{}, .{ .path = "/abcd/efgh", .query_string = "query=no", .fragment_string = "frag" });
}
/// Routes a request between the provided routes.
///
/// CURRENTLY: Does not do this intelligently, all routing is handled by the routes themselves.
/// TODO: Consider implementing this with a hashmap?
pub fn Router(comptime Routes: type) type {
return struct {
routes: Routes,
@ -204,6 +357,7 @@ fn pathMatches(route: []const u8, path: []const u8) bool {
const path_segment = path_iter.next() orelse return false;
if (route_segment.len > 0 and route_segment[0] == ':') {
// Route Argument
if (path_segment.len == 0) return false;
} else {
if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return false;
}
@ -212,6 +366,19 @@ fn pathMatches(route: []const u8, path: []const u8) bool {
return true;
}
/// Handler that either calls its next middleware parameter or returns error.RouteMismatch
/// depending on if the request matches the described route.
/// Must be below `split_uri` on the middleware list.
///
/// Format:
/// Each route segment can be either a literal string or an argument. Literal strings
/// must match exactly in order to constitute a matching route. Arguments must begin with
/// the character ':', with the remainer of the segment referring to the name of the argument.
/// Argument values must be nonempty.
///
/// For example, the route "/abc/:foo/def" would match "/abc/x/def" or "/abc/blahblah/def" but
/// not "/abc//def".
pub const Route = struct {
pub const Desc = struct {
path: []const u8,
@ -232,7 +399,6 @@ pub const Route = struct {
}
pub fn handle(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
std.log.debug("Testing path {s} against {s}", .{ ctx.path, self.desc.path });
return if (self.applies(req, ctx))
next.handle(req, res, ctx, {})
else
@ -240,12 +406,47 @@ pub const Route = struct {
}
};
test "route" {
const testCase = struct {
fn func(should_match: bool, route: Route.Desc, method: http.Method, path: []const u8) !void {
const no_op = struct {
fn handle(_: anytype, _: anytype, _: anytype, _: anytype, _: anytype) !void {}
}{};
const result = (Route{ .desc = route }).handle(.{ .method = method }, .{}, .{ .path = path }, no_op);
try if (should_match) result else std.testing.expectError(error.RouteMismatch, result);
}
}.func;
try testCase(true, .{ .method = .GET, .path = "/" }, .GET, "/");
try testCase(true, .{ .method = .GET, .path = "/" }, .GET, "");
try testCase(true, .{ .method = .GET, .path = "/abcd" }, .GET, "/abcd");
try testCase(true, .{ .method = .GET, .path = "/abcd" }, .GET, "abcd");
try testCase(true, .{ .method = .POST, .path = "/" }, .POST, "/");
try testCase(true, .{ .method = .POST, .path = "/" }, .POST, "");
try testCase(true, .{ .method = .POST, .path = "/abcd" }, .POST, "/abcd");
try testCase(true, .{ .method = .POST, .path = "/abcd" }, .POST, "abcd");
try testCase(true, .{ .method = .POST, .path = "/abcd/efgh" }, .POST, "abcd/efgh");
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "abcd/efgh");
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz");
try testCase(false, .{ .method = .POST, .path = "/" }, .GET, "/");
try testCase(false, .{ .method = .GET, .path = "/abcd" }, .GET, "");
try testCase(false, .{ .method = .GET, .path = "/" }, .GET, "/abcd");
try testCase(false, .{ .method = .POST, .path = "/abcd/efgh" }, .POST, "efgh");
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "/abcd/");
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/");
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz/foo");
}
/// Mounts a router subtree under a given path. Middlewares further down on the list
/// are called with the path prefix specified by `route` removed from the path.
/// Must be below `split_uri` on the middleware list.
pub fn Mount(comptime route: []const u8) type {
return struct {
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
var path_iter = util.PathIter.from(ctx.path);
comptime var route_iter = util.PathIter.from(route);
var path_unused = ctx.path;
var path_unused: []const u8 = ctx.path;
inline while (comptime route_iter.next()) |route_segment| {
if (comptime route_segment.len == 0) continue;
@ -269,20 +470,26 @@ pub fn mount(comptime route: []const u8) Mount(route) {
return .{};
}
pub fn HandleNotFound(comptime NotFoundHandler: type) type {
return struct {
not_found: NotFoundHandler,
pub fn handler(self: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
return next.handler(req, res, ctx, {}) catch |err| switch (err) {
error.RouteMismatch => return self.not_found.handler(req, res, ctx, {}),
else => return err,
};
test "mount" {
const testCase = struct {
fn func(comptime base: []const u8, request: []const u8, comptime expected: ?[]const u8) !void {
const result = mount(base).handle(.{}, .{}, addField(.{}, "path", request), expectContext(.{ .path = expected orelse "" }));
try if (expected != null) result else std.testing.expectError(error.RouteMismatch, result);
}
};
}.func;
try testCase("/api/", "/api/", "");
try testCase("/api/", "/api/abcd", "abcd");
try testCase("/api/", "/api/abcd/efgh", "abcd/efgh");
try testCase("/api/", "/api/abcd/efgh/", "abcd/efgh/");
try testCase("/api/v0", "/api/v0/call", "call");
try testCase("/api/", "/web/abcd/efgh/", null);
try testCase("/api/", "/", null);
try testCase("/api/", "/ap", null);
try testCase("/api/v0", "/api/v1/", null);
}
fn parsePathArgs(comptime route: []const u8, comptime Args: type, path: []const u8) !Args {
fn parseArgsFromPath(comptime route: []const u8, comptime Args: type, path: []const u8) !Args {
var args: Args = undefined;
var path_iter = util.PathIter.from(path);
comptime var route_iter = util.PathIter.from(route);
@ -290,8 +497,9 @@ fn parsePathArgs(comptime route: []const u8, comptime Args: type, path: []const
const path_segment = path_iter.next() orelse return error.RouteMismatch;
if (route_segment.len > 0 and route_segment[0] == ':') {
// route segment is an argument segment
if (path_segment.len == 0) return error.RouteMismatch;
const A = @TypeOf(@field(args, route_segment[1..]));
@field(args, route_segment[1..]) = try parsePathArg(A, path_segment);
@field(args, route_segment[1..]) = try parseArgFromPath(A, path_segment);
} else {
if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch;
}
@ -302,13 +510,35 @@ fn parsePathArgs(comptime route: []const u8, comptime Args: type, path: []const
return args;
}
fn parsePathArg(comptime T: type, segment: []const u8) !T {
fn parseArgFromPath(comptime T: type, segment: []const u8) !T {
if (T == []const u8) return segment;
if (comptime std.meta.trait.isContainer(T) and std.meta.trait.hasFn("parse")(T)) return T.parse(segment);
if (comptime std.meta.trait.is(.Int)(T)) return std.fmt.parseInt(T, segment, 0);
@compileError("Unsupported Type " ++ @typeName(T));
}
/// Parse arguments directly the request path.
/// Must be placed after a `split_uri` middleware in order to get `path` from context.
///
/// Route arguments are specified in the same format as for Route. The name of the argument
/// refers to the field name in Args that the argument will be parsed to.
///
/// This currently works with arguments of 3 different types:
/// - integers
/// - []const u8,
/// - anything with a function of the form:
/// * T.parse([]const u8) Error!T
/// * This function cannot hold a reference to the passed string once it appears
///
/// Example:
/// ParsePathArgs("/:id/foo/:name/byrank/:rank", struct {
/// id: util.Uuid,
/// name: []const u8,
/// rank: u32,
/// })
/// Would parse a path of "/00000000-0000-0000-0000-000000000000/foo/jaina/byrank/3" into
/// .{ .id = try Uuid.parse("00000000-0000-0000-0000-000000000000"), .name = "jaina", .rank = 3 }
pub fn ParsePathArgs(comptime route: []const u8, comptime Args: type) type {
return struct {
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
@ -316,12 +546,42 @@ pub fn ParsePathArgs(comptime route: []const u8, comptime Args: type) type {
return next.handle(
req,
res,
addField(ctx, "args", try parsePathArgs(route, Args, ctx.path)),
addField(ctx, "args", try parseArgsFromPath(route, Args, ctx.path)),
{},
);
}
};
}
pub fn parsePathArgs(comptime route: []const u8, comptime Args: type) ParsePathArgs(route, Args) {
return .{};
}
test "ParsePathArgs" {
const testCase = struct {
fn func(comptime route: []const u8, comptime Args: type, path: []const u8, expected: anytype) !void {
const check = struct {
expected: @TypeOf(expected),
path: []const u8,
fn handle(self: @This(), _: anytype, _: anytype, ctx: anytype, _: void) !void {
try expectDeepEquals(self.expected, ctx.args);
try std.testing.expectEqualStrings(self.path, ctx.path);
}
}{ .expected = expected, .path = path };
try parsePathArgs(route, Args).handle(.{}, .{}, .{ .path = path }, check);
}
}.func;
try testCase("/", void, "/", {});
try testCase("/:id", struct { id: usize }, "/3", .{ .id = 3 });
try testCase("/:str", struct { str: []const u8 }, "/abcd", .{ .str = "abcd" });
try testCase("/:id/xyz/:str", struct { id: usize, str: []const u8 }, "/3/xyz/abcd", .{ .id = 3, .str = "abcd" });
try testCase("/:id", struct { id: util.Uuid }, "/" ++ util.Uuid.nil.toCharArray(), .{ .id = util.Uuid.nil });
try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/", .{}));
try std.testing.expectError(error.RouteMismatch, testCase("/abcd/:id", struct { id: usize }, "/123", .{}));
try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/3/id/blahblah", .{ .id = 3 }));
try std.testing.expectError(error.InvalidCharacter, testCase("/:id", struct { id: usize }, "/xyz", .{}));
}
const BaseContentType = enum {
json,
@ -331,7 +591,7 @@ const BaseContentType = enum {
other,
};
fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T {
fn parseBodyFromRequest(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T {
//@compileLog(T);
const buf = try reader.readAllAlloc(alloc, 1 << 16);
defer alloc.free(buf);
@ -351,6 +611,7 @@ fn parseBody(comptime T: type, content_type: BaseContentType, reader: anytype, a
}
}
// figure out what base parser to use
fn matchContentType(hdr: ?[]const u8) ?BaseContentType {
if (hdr) |h| {
if (std.ascii.eqlIgnoreCase(h, "application/x-www-form-urlencoded")) return .url_encoded;
@ -363,6 +624,11 @@ fn matchContentType(hdr: ?[]const u8) ?BaseContentType {
return null;
}
/// Parses a set of body arguments from the request body based on the request's Content-Type
/// header.
///
/// The exact method for parsing depends partially on the Content-Type. json types are preferred
/// TODO: Need tests for this, including various Content-Type values
pub fn ParseBody(comptime Body: type) type {
return struct {
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
@ -377,7 +643,7 @@ pub fn ParseBody(comptime Body: type) type {
const base_content_type = matchContentType(content_type);
var stream = req.body orelse return error.NoBody;
const body = try parseBody(Body, base_content_type orelse .json, stream.reader(), ctx.allocator);
const body = try parseBodyFromRequest(Body, base_content_type orelse .json, stream.reader(), ctx.allocator);
defer util.deepFree(ctx.allocator, body);
return next.handle(
@ -389,7 +655,11 @@ pub fn ParseBody(comptime Body: type) type {
}
};
}
pub fn parseBody(comptime Body: type) ParseBody(Body) {
return .{};
}
/// Parses query parameters as defined in query.zig
pub fn ParseQueryParams(comptime QueryParams: type) type {
return struct {
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {

View file

@ -1,6 +1,7 @@
const std = @import("std");
const util = @import("util");
const QueryIter = @import("util").QueryIter;
const QueryIter = util.QueryIter;
/// Parses a set of query parameters described by the struct `T`.
///
@ -66,10 +67,6 @@ const QueryIter = @import("util").QueryIter;
/// Would be used to parse a query string like
/// `?foo.baz=12345`
///
/// Compound types cannot currently be nullable, and must be structs.
///
/// TODO: values are currently case-sensitive, and are not url-decoded properly.
/// This should be fixed.
pub fn parseQuery(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T {
if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct");
var iter = QueryIter.from(query);
@ -88,7 +85,11 @@ pub fn parseQuery(alloc: std.mem.Allocator, comptime T: type, query: []const u8)
return (try parse(alloc, T, "", "", fields)) orelse error.NoQuery;
}
fn decodeString(alloc: std.mem.Allocator, val: []const u8) ![]const u8 {
pub fn parseQueryFree(alloc: std.mem.Allocator, val: anytype) void {
util.deepFree(alloc, val);
}
fn decodeString(alloc: std.mem.Allocator, val: []const u8) ![]u8 {
var list = try std.ArrayList(u8).initCapacity(alloc, val.len);
errdefer list.deinit();
@ -146,6 +147,9 @@ fn parse(
.Struct => |info| {
var result: T = undefined;
var fields_specified: usize = 0;
errdefer inline for (info.fields) |field, i| {
if (fields_specified < i) util.deepFree(alloc, @field(result, field.name));
};
inline for (info.fields) |field| {
const F = field.field_type;
@ -155,7 +159,7 @@ fn parse(
maybe_value = v;
} else if (field.default_value) |default| {
if (comptime @sizeOf(F) != 0) {
maybe_value = @ptrCast(*const F, @alignCast(@alignOf(F), default)).*;
maybe_value = try util.deepClone(alloc, @ptrCast(*const F, @alignCast(@alignOf(F), default)).*);
} else {
maybe_value = std.mem.zeroes(F);
}
@ -231,10 +235,38 @@ fn Intermediary(comptime T: type) type {
} });
}
fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, value: ?[]const u8) !T {
fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, maybe_value: ?[]const u8) !T {
const is_optional = comptime std.meta.trait.is(.Optional)(T);
// If param is present, but without an associated value
if (value == null) {
if (maybe_value) |value| {
const Eff = if (is_optional) std.meta.Child(T) else T;
if (value.len == 0 and is_optional) return null;
const decoded = try decodeString(alloc, value);
errdefer alloc.free(decoded);
if (comptime std.meta.trait.isZigString(Eff)) return decoded;
defer alloc.free(decoded);
const result = if (comptime std.meta.trait.isIntegral(Eff))
try std.fmt.parseInt(Eff, decoded, 0)
else if (comptime std.meta.trait.isFloat(Eff))
try std.fmt.parseFloat(Eff, decoded)
else if (comptime std.meta.trait.is(.Enum)(Eff)) blk: {
_ = std.ascii.lowerString(decoded, decoded);
break :blk std.meta.stringToEnum(Eff, decoded) orelse return error.InvalidEnumValue;
} else if (Eff == bool) blk: {
_ = std.ascii.lowerString(decoded, decoded);
break :blk bool_map.get(decoded) orelse return error.InvalidBool;
} else if (comptime std.meta.trait.hasFn("parse")(Eff))
try Eff.parse(value)
else
@compileError("Invalid type " ++ @typeName(T));
return result;
} else {
// If param is present, but without an associated value
return if (is_optional)
null
else if (T == bool)
@ -242,8 +274,6 @@ fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, value: ?[]const u
else
error.InvalidValue;
}
return try parseQueryValueNotNull(alloc, if (is_optional) std.meta.Child(T) else T, value.?);
}
const bool_map = std.ComptimeStringMap(bool, .{
@ -260,34 +290,12 @@ const bool_map = std.ComptimeStringMap(bool, .{
.{ "0", false },
});
fn parseQueryValueNotNull(alloc: std.mem.Allocator, comptime T: type, value: []const u8) !T {
const decoded = try decodeString(alloc, value);
errdefer alloc.free(decoded);
if (comptime std.meta.trait.isZigString(T)) return decoded;
const result = if (comptime std.meta.trait.isIntegral(T))
try std.fmt.parseInt(T, decoded, 0)
else if (comptime std.meta.trait.isFloat(T))
try std.fmt.parseFloat(T, decoded)
else if (comptime std.meta.trait.is(.Enum)(T))
std.meta.stringToEnum(T, decoded) orelse return error.InvalidEnumValue
else if (T == bool)
bool_map.get(value) orelse return error.InvalidBool
else if (comptime std.meta.trait.hasFn("parse")(T))
try T.parse(value)
else
@compileError("Invalid type " ++ @typeName(T));
alloc.free(decoded);
return result;
}
fn isScalar(comptime T: type) bool {
if (comptime std.meta.trait.isZigString(T)) return true;
if (comptime std.meta.trait.isIntegral(T)) return true;
if (comptime std.meta.trait.isFloat(T)) return true;
if (comptime std.meta.trait.is(.Enum)(T)) return true;
if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true;
if (T == bool) return true;
if (comptime std.meta.trait.hasFn("parse")(T)) return true;
@ -296,8 +304,16 @@ fn isScalar(comptime T: type) bool {
return false;
}
pub fn formatQuery(params: anytype, writer: anytype) !void {
try format("", "", params, writer);
pub fn QueryStringify(comptime Params: type) type {
return struct {
params: Params,
pub fn format(v: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
try formatQuery("", "", v.params, writer);
}
};
}
pub fn queryStringify(val: anytype) QueryStringify(@TypeOf(val)) {
return QueryStringify(@TypeOf(val)){ .params = val };
}
fn urlFormatString(writer: anytype, val: []const u8) !void {
@ -323,14 +339,14 @@ fn formatScalar(comptime name: []const u8, val: anytype, writer: anytype) !void
if (comptime std.meta.trait.isZigString(T)) {
try urlFormatString(writer, val);
} else try switch (@typeInfo(T)) {
.Enum => urlFormatString(writer, @tagName(val)),
.EnumLiteral, .Enum => urlFormatString(writer, @tagName(val)),
else => std.fmt.format(writer, "{}", .{val}),
};
try writer.writeByte('&');
}
fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytype, writer: anytype) !void {
fn formatQuery(comptime prefix: []const u8, comptime name: []const u8, params: anytype, writer: anytype) !void {
const T = @TypeOf(params);
const eff_prefix = if (prefix.len == 0) "" else prefix ++ ".";
if (comptime isScalar(T)) return formatScalar(eff_prefix ++ name, params, writer);
@ -339,7 +355,7 @@ fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytyp
.Struct => {
inline for (std.meta.fields(T)) |field| {
const val = @field(params, field.name);
try format(eff_prefix ++ name, field.name, val, writer);
try formatQuery(eff_prefix ++ name, field.name, val, writer);
}
},
.Union => {
@ -348,33 +364,115 @@ fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytyp
const tag_name = field.name;
if (@as(std.meta.Tag(T), params) == tag) {
const val = @field(params, tag_name);
try format(prefix, tag_name, val, writer);
try formatQuery(prefix, tag_name, val, writer);
}
}
},
.Optional => {
if (params) |p| try format(prefix, name, p, writer);
if (params) |p| try formatQuery(prefix, name, p, writer);
},
else => @compileError("Unsupported query type"),
}
}
test {
const TestQuery = struct {
int: usize = 3,
boolean: bool = false,
str_enum: ?enum { foo, bar } = null,
test "parseQuery" {
const testCase = struct {
fn case(comptime T: type, expected: T, query_string: []const u8) !void {
const result = try parseQuery(std.testing.allocator, T, query_string);
defer parseQueryFree(std.testing.allocator, result);
try util.testing.expectDeepEqual(expected, result);
}
}.case;
try testCase(struct { int: usize = 3 }, .{ .int = 3 }, "");
try testCase(struct { int: usize = 3 }, .{ .int = 2 }, "int=2");
try testCase(struct { int: usize = 3 }, .{ .int = 2 }, "int=2&");
try testCase(struct { boolean: bool = false }, .{ .boolean = false }, "");
try testCase(struct { boolean: bool = false }, .{ .boolean = true }, "boolean");
try testCase(struct { boolean: bool = false }, .{ .boolean = true }, "boolean=true");
try testCase(struct { boolean: bool = false }, .{ .boolean = true }, "boolean=y");
try testCase(struct { boolean: bool = false }, .{ .boolean = false }, "boolean=f");
try testCase(struct { boolean: bool = false }, .{ .boolean = false }, "boolean=no");
try testCase(struct { str_enum: ?enum { foo, bar } = null }, .{ .str_enum = null }, "");
try testCase(struct { str_enum: ?enum { foo, bar } = null }, .{ .str_enum = .foo }, "str_enum=foo");
try testCase(struct { str_enum: ?enum { foo, bar } = null }, .{ .str_enum = .bar }, "str_enum=bar");
try testCase(struct { str_enum: ?enum { foo, bar } = .foo }, .{ .str_enum = .foo }, "");
try testCase(struct { str_enum: ?enum { foo, bar } = .foo }, .{ .str_enum = null }, "str_enum");
try testCase(struct { n1: usize = 5, n2: usize = 5 }, .{ .n1 = 1, .n2 = 2 }, "n1=1&n2=2");
try testCase(struct { n1: usize = 5, n2: usize = 5 }, .{ .n1 = 1, .n2 = 2 }, "n1=1&n2=2&");
try testCase(struct { n1: usize = 5, n2: usize = 5 }, .{ .n1 = 1, .n2 = 2 }, "n1=1&&n2=2&");
try testCase(struct { str: ?[]const u8 = null }, .{ .str = null }, "");
try testCase(struct { str: ?[]const u8 = null }, .{ .str = null }, "str");
try testCase(struct { str: ?[]const u8 = null }, .{ .str = null }, "str=");
try testCase(struct { str: ?[]const u8 = null }, .{ .str = "foo" }, "str=foo");
try testCase(struct { str: ?[]const u8 = "foo" }, .{ .str = "foo" }, "str=foo");
try testCase(struct { str: ?[]const u8 = "foo" }, .{ .str = "foo" }, "");
try testCase(struct { str: ?[]const u8 = "foo" }, .{ .str = null }, "str");
try testCase(struct { str: ?[]const u8 = "foo" }, .{ .str = null }, "str=");
const rand_uuid = comptime util.Uuid.parse("c1fb6578-4d0c-4eb9-9f67-d56da3ae6f5d") catch unreachable;
try testCase(struct { id: ?util.Uuid = null }, .{ .id = null }, "");
try testCase(struct { id: ?util.Uuid = null }, .{ .id = null }, "id=");
try testCase(struct { id: ?util.Uuid = null }, .{ .id = null }, "id");
try testCase(struct { id: ?util.Uuid = null }, .{ .id = rand_uuid }, "id=" ++ rand_uuid.toCharArray());
try testCase(struct { id: ?util.Uuid = rand_uuid }, .{ .id = rand_uuid }, "");
try testCase(struct { id: ?util.Uuid = rand_uuid }, .{ .id = null }, "id=");
try testCase(struct { id: ?util.Uuid = rand_uuid }, .{ .id = null }, "id");
try testCase(struct { id: ?util.Uuid = rand_uuid }, .{ .id = rand_uuid }, "id=" ++ rand_uuid.toCharArray());
const SubStruct = struct {
sub: struct {
foo: usize = 1,
bar: usize = 2,
} = .{},
};
try testCase(SubStruct, .{ .sub = .{ .foo = 1, .bar = 2 } }, "");
try testCase(SubStruct, .{ .sub = .{ .foo = 3, .bar = 3 } }, "sub.foo=3&sub.bar=3");
try testCase(SubStruct, .{ .sub = .{ .foo = 3, .bar = 2 } }, "sub.foo=3");
try std.testing.expectEqual(TestQuery{
.int = 3,
.boolean = false,
.str_enum = null,
}, try parseQuery(TestQuery, ""));
// TODO: Semantics are ill-defined here. What happens if the substruct doesn't have
// default values?
// const SubStruct2 = struct {
// sub: ?struct {
// foo: usize = 1,
// } = null,
// };
// try testCase(SubStruct2, .{ .sub = null }, "");
// try testCase(SubStruct2, .{ .sub = null }, "sub=");
try std.testing.expectEqual(TestQuery{
.int = 5,
.boolean = true,
.str_enum = .foo,
}, try parseQuery(TestQuery, "?int=5&boolean=yes&str_enum=foo"));
// TODO: also here (semantics are well defined it just breaks tests)
// const SubUnion = struct {
// sub: ?union(enum) {
// foo: usize,
// bar: usize,
// } = null,
// };
// try testCase(SubUnion, .{ .sub = null }, "");
// try testCase(SubUnion, .{ .sub = null }, "sub=");
const SubUnion2 = struct {
sub: ?struct {
foo: usize,
val: union(enum) {
bar: []const u8,
baz: []const u8,
},
} = null,
};
try testCase(SubUnion2, .{ .sub = null }, "");
try testCase(SubUnion2, .{ .sub = .{ .foo = 1, .val = .{ .bar = "abc" } } }, "sub.foo=1&sub.bar=abc");
try testCase(SubUnion2, .{ .sub = .{ .foo = 1, .val = .{ .baz = "abc" } } }, "sub.foo=1&sub.baz=abc");
}
test "formatQuery" {
try std.testing.expectFmt("", "{}", .{queryStringify(.{})});
try std.testing.expectFmt("id=3&", "{}", .{queryStringify(.{ .id = 3 })});
try std.testing.expectFmt("id=3&id2=4&", "{}", .{queryStringify(.{ .id = 3, .id2 = 4 })});
try std.testing.expectFmt("str=foo&", "{}", .{queryStringify(.{ .str = "foo" })});
try std.testing.expectFmt("enum_str=foo&", "{}", .{queryStringify(.{ .enum_str = .foo })});
try std.testing.expectFmt("boolean=false&", "{}", .{queryStringify(.{ .boolean = false })});
try std.testing.expectFmt("boolean=true&", "{}", .{queryStringify(.{ .boolean = true })});
}

View file

@ -1,3 +1,5 @@
test {
_ = @import("./request/test_parser.zig");
_ = @import("./middleware.zig");
_ = @import("./query.zig");
}

View file

@ -4,8 +4,6 @@ const builtin = @import("builtin");
const http = @import("http");
const api = @import("api");
const util = @import("util");
const query_utils = @import("./query.zig");
const json_utils = @import("./json.zig");
const web_endpoints = @import("./controllers/web.zig").routes;
const api_endpoints = @import("./controllers/api.zig").routes;
@ -244,14 +242,14 @@ const json_options = if (builtin.mode == .Debug)
};
pub const helpers = struct {
pub fn paginate(community: api.Community, path: []const u8, results: anytype, res: *Response, alloc: std.mem.Allocator) !void {
pub fn paginate(results: anytype, res: *Response, alloc: std.mem.Allocator) !void {
var link = std.ArrayList(u8).init(alloc);
const link_writer = link.writer();
defer link.deinit();
try writeLink(link_writer, community, path, results.next_page, "next");
try writeLink(link_writer, null, "", results.next_page, "next");
try link_writer.writeByte(',');
try writeLink(link_writer, community, path, results.prev_page, "prev");
try writeLink(link_writer, null, "", results.prev_page, "prev");
try res.headers.put("Link", link.items);
@ -260,24 +258,24 @@ pub const helpers = struct {
fn writeLink(
writer: anytype,
community: api.Community,
community: ?api.Community,
path: []const u8,
params: anytype,
rel: []const u8,
) !void {
if (community) |c| {
try std.fmt.format(
writer,
"<{s}://{s}/{s}?{}>; rel=\"{s}\"",
.{ @tagName(c.scheme), c.host, path, http.queryStringify(params), rel },
);
} else {
try std.fmt.format(
writer,
"<{s}?{}>; rel=\"{s}\"",
.{ path, http.queryStringify(params), rel },
);
}
// TODO: percent-encode
try std.fmt.format(
writer,
"<{s}://{s}/{s}?",
.{ @tagName(community.scheme), community.host, path },
);
try query_utils.formatQuery(params, writer);
try std.fmt.format(
writer,
">; rel=\"{s}\"",
.{rel},
);
}
};

View file

@ -27,6 +27,6 @@ pub const query = struct {
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
const results = try srv.queryCommunities(req.query);
try controller_utils.paginate(srv.community, path, results, res, req.allocator);
try controller_utils.paginate(results, res, req.allocator);
}
};

View file

@ -10,7 +10,7 @@ pub const global = struct {
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
const results = try srv.globalTimeline(req.query);
try controller_utils.paginate(srv.community, path, results, res, req.allocator);
try controller_utils.paginate(results, res, req.allocator);
}
};
@ -22,7 +22,7 @@ pub const local = struct {
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
const results = try srv.localTimeline(req.query);
try controller_utils.paginate(srv.community, path, results, res, req.allocator);
try controller_utils.paginate(results, res, req.allocator);
}
};
@ -34,6 +34,6 @@ pub const home = struct {
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
const results = try srv.homeTimeline(req.query);
try controller_utils.paginate(srv.community, path, results, res, req.allocator);
try controller_utils.paginate(results, res, req.allocator);
}
};

View file

@ -47,7 +47,7 @@ pub const query_followers = struct {
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
const results = try srv.queryFollowers(req.args.id, req.query);
try controller_utils.paginate(srv.community, path, results, res, req.allocator);
try controller_utils.paginate(results, res, req.allocator);
}
};
@ -64,6 +64,6 @@ pub const query_following = struct {
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
const results = try srv.queryFollowing(req.args.id, req.query);
try controller_utils.paginate(srv.community, path, results, res, req.allocator);
try controller_utils.paginate(results, res, req.allocator);
}
};

View file

@ -1,677 +0,0 @@
const std = @import("std");
const mem = std.mem;
const Allocator = std.mem.Allocator;
const assert = std.debug.assert;
// This file is largely a copy of std.json
const StreamingParser = std.json.StreamingParser;
const Token = std.json.Token;
const unescapeValidString = std.json.unescapeValidString;
const UnescapeValidStringError = std.json.UnescapeValidStringError;
pub fn parse(comptime T: type, body: []const u8, alloc: std.mem.Allocator) !T {
var tokens = TokenStream.init(body);
const options = ParseOptions{ .allocator = alloc };
const token = (try tokens.next()) orelse return error.UnexpectedEndOfJson;
const r = try parseInternal(T, token, &tokens, options);
errdefer parseFreeInternal(T, r, options);
if (!options.allow_trailing_data) {
if ((try tokens.next()) != null) unreachable;
assert(tokens.i >= tokens.slice.len);
}
return r;
}
pub fn parseFree(value: anytype, alloc: std.mem.Allocator) void {
parseFreeInternal(@TypeOf(value), value, .{ .allocator = alloc });
}
// WARNING: the objects "parse" method must not contain a reference to the original value
fn hasCustomParse(comptime T: type) bool {
if (!std.meta.trait.hasFn("parse")(T)) return false;
if (!@hasDecl(T, "JsonParseAs")) return false;
return true;
}
///// The rest is (modified) from std.json
/// A small wrapper over a StreamingParser for full slices. Returns a stream of json Tokens.
pub const TokenStream = struct {
i: usize,
slice: []const u8,
parser: StreamingParser,
token: ?Token,
pub const Error = StreamingParser.Error || error{UnexpectedEndOfJson};
pub fn init(slice: []const u8) TokenStream {
return TokenStream{
.i = 0,
.slice = slice,
.parser = StreamingParser.init(),
.token = null,
};
}
fn stackUsed(self: *TokenStream) usize {
return self.parser.stack.len + if (self.token != null) @as(usize, 1) else 0;
}
pub fn next(self: *TokenStream) Error!?Token {
if (self.token) |token| {
self.token = null;
return token;
}
var t1: ?Token = undefined;
var t2: ?Token = undefined;
while (self.i < self.slice.len) {
try self.parser.feed(self.slice[self.i], &t1, &t2);
self.i += 1;
if (t1) |token| {
self.token = t2;
return token;
}
}
// Without this a bare number fails, the streaming parser doesn't know the input ended
try self.parser.feed(' ', &t1, &t2);
self.i += 1;
if (t1) |token| {
return token;
} else if (self.parser.complete) {
return null;
} else {
return error.UnexpectedEndOfJson;
}
}
};
/// Checks to see if a string matches what it would be as a json-encoded string
/// Assumes that `encoded` is a well-formed json string
fn encodesTo(decoded: []const u8, encoded: []const u8) bool {
var i: usize = 0;
var j: usize = 0;
while (i < decoded.len) {
if (j >= encoded.len) return false;
if (encoded[j] != '\\') {
if (decoded[i] != encoded[j]) return false;
j += 1;
i += 1;
} else {
const escape_type = encoded[j + 1];
if (escape_type != 'u') {
const t: u8 = switch (escape_type) {
'\\' => '\\',
'/' => '/',
'n' => '\n',
'r' => '\r',
't' => '\t',
'f' => 12,
'b' => 8,
'"' => '"',
else => unreachable,
};
if (decoded[i] != t) return false;
j += 2;
i += 1;
} else {
var codepoint = std.fmt.parseInt(u21, encoded[j + 2 .. j + 6], 16) catch unreachable;
j += 6;
if (codepoint >= 0xD800 and codepoint < 0xDC00) {
// surrogate pair
assert(encoded[j] == '\\');
assert(encoded[j + 1] == 'u');
const low_surrogate = std.fmt.parseInt(u21, encoded[j + 2 .. j + 6], 16) catch unreachable;
codepoint = 0x10000 + (((codepoint & 0x03ff) << 10) | (low_surrogate & 0x03ff));
j += 6;
}
var buf: [4]u8 = undefined;
const len = std.unicode.utf8Encode(codepoint, &buf) catch unreachable;
if (i + len > decoded.len) return false;
if (!mem.eql(u8, decoded[i .. i + len], buf[0..len])) return false;
i += len;
}
}
}
assert(i == decoded.len);
assert(j == encoded.len);
return true;
}
/// parse tokens from a stream, returning `false` if they do not decode to `value`
fn parsesTo(comptime T: type, value: T, tokens: *TokenStream, options: ParseOptions) !bool {
// TODO: should be able to write this function to not require an allocator
const tmp = try parse(T, tokens, options);
defer parseFree(T, tmp, options);
return parsedEqual(tmp, value);
}
/// Returns if a value returned by `parse` is deep-equal to another value
fn parsedEqual(a: anytype, b: @TypeOf(a)) bool {
switch (@typeInfo(@TypeOf(a))) {
.Optional => {
if (a == null and b == null) return true;
if (a == null or b == null) return false;
return parsedEqual(a.?, b.?);
},
.Union => |info| {
if (info.tag_type) |UnionTag| {
const tag_a = std.meta.activeTag(a);
const tag_b = std.meta.activeTag(b);
if (tag_a != tag_b) return false;
inline for (info.fields) |field_info| {
if (@field(UnionTag, field_info.name) == tag_a) {
return parsedEqual(@field(a, field_info.name), @field(b, field_info.name));
}
}
return false;
} else {
unreachable;
}
},
.Array => {
for (a) |e, i|
if (!parsedEqual(e, b[i])) return false;
return true;
},
.Struct => |info| {
inline for (info.fields) |field_info| {
if (!parsedEqual(@field(a, field_info.name), @field(b, field_info.name))) return false;
}
return true;
},
.Pointer => |ptrInfo| switch (ptrInfo.size) {
.One => return parsedEqual(a.*, b.*),
.Slice => {
if (a.len != b.len) return false;
for (a) |e, i|
if (!parsedEqual(e, b[i])) return false;
return true;
},
.Many, .C => unreachable,
},
else => return a == b,
}
unreachable;
}
const ParseOptions = struct {
allocator: ?Allocator = null,
/// Behaviour when a duplicate field is encountered.
duplicate_field_behavior: enum {
UseFirst,
Error,
UseLast,
} = .Error,
/// If false, finding an unknown field returns an error.
ignore_unknown_fields: bool = false,
allow_trailing_data: bool = false,
};
const SkipValueError = error{UnexpectedJsonDepth} || TokenStream.Error;
fn skipValue(tokens: *TokenStream) SkipValueError!void {
const original_depth = tokens.stackUsed();
// Return an error if no value is found
_ = try tokens.next();
if (tokens.stackUsed() < original_depth) return error.UnexpectedJsonDepth;
if (tokens.stackUsed() == original_depth) return;
while (try tokens.next()) |_| {
if (tokens.stackUsed() == original_depth) return;
}
}
fn ParseInternalError(comptime T: type) type {
// `inferred_types` is used to avoid infinite recursion for recursive type definitions.
const inferred_types = [_]type{};
return ParseInternalErrorImpl(T, &inferred_types);
}
fn ParseInternalErrorImpl(comptime T: type, comptime inferred_types: []const type) type {
if (hasCustomParse(T)) {
return ParseInternalError(T.JsonParseAs) || T.ParseError;
}
for (inferred_types) |ty| {
if (T == ty) return error{};
}
switch (@typeInfo(T)) {
.Bool => return error{UnexpectedToken},
.Float, .ComptimeFloat => return error{UnexpectedToken} || std.fmt.ParseFloatError,
.Int, .ComptimeInt => {
return error{ UnexpectedToken, InvalidNumber, Overflow } ||
std.fmt.ParseIntError || std.fmt.ParseFloatError;
},
.Optional => |optionalInfo| {
return ParseInternalErrorImpl(optionalInfo.child, inferred_types ++ [_]type{T});
},
.Enum => return error{ UnexpectedToken, InvalidEnumTag } || std.fmt.ParseIntError ||
std.meta.IntToEnumError || std.meta.IntToEnumError,
.Union => |unionInfo| {
if (unionInfo.tag_type) |_| {
var errors = error{NoUnionMembersMatched};
for (unionInfo.fields) |u_field| {
errors = errors || ParseInternalErrorImpl(u_field.field_type, inferred_types ++ [_]type{T});
}
return errors;
} else {
@compileError("Unable to parse into untagged union '" ++ @typeName(T) ++ "'");
}
},
.Struct => |structInfo| {
var errors = error{
DuplicateJSONField,
UnexpectedEndOfJson,
UnexpectedToken,
UnexpectedValue,
UnknownField,
MissingField,
} || SkipValueError || TokenStream.Error;
for (structInfo.fields) |field| {
errors = errors || ParseInternalErrorImpl(field.field_type, inferred_types ++ [_]type{T});
}
return errors;
},
.Array => |arrayInfo| {
return error{ UnexpectedEndOfJson, UnexpectedToken } || TokenStream.Error ||
UnescapeValidStringError ||
ParseInternalErrorImpl(arrayInfo.child, inferred_types ++ [_]type{T});
},
.Pointer => |ptrInfo| {
var errors = error{AllocatorRequired} || std.mem.Allocator.Error;
switch (ptrInfo.size) {
.One => {
return errors || ParseInternalErrorImpl(ptrInfo.child, inferred_types ++ [_]type{T});
},
.Slice => {
return errors || error{ UnexpectedEndOfJson, UnexpectedToken } ||
ParseInternalErrorImpl(ptrInfo.child, inferred_types ++ [_]type{T}) ||
UnescapeValidStringError || TokenStream.Error;
},
else => @compileError("Unable to parse into type '" ++ @typeName(T) ++ "'"),
}
},
else => return error{},
}
unreachable;
}
fn parseInternal(
comptime T: type,
token: Token,
tokens: *TokenStream,
options: ParseOptions,
) ParseInternalError(T)!T {
if (comptime hasCustomParse(T)) {
const val = try parseInternal(T.JsonParseAs, token, tokens, options);
defer parseFreeInternal(T.JsonParseAs, val, options);
return try T.parse(val);
}
switch (@typeInfo(T)) {
.Bool => {
return switch (token) {
.True => true,
.False => false,
else => error.UnexpectedToken,
};
},
.Float, .ComptimeFloat => {
switch (token) {
.Number => |numberToken| return try std.fmt.parseFloat(T, numberToken.slice(tokens.slice, tokens.i - 1)),
.String => |stringToken| return try std.fmt.parseFloat(T, stringToken.slice(tokens.slice, tokens.i - 1)),
else => return error.UnexpectedToken,
}
},
.Int, .ComptimeInt => {
switch (token) {
.Number => |numberToken| {
if (numberToken.is_integer)
return try std.fmt.parseInt(T, numberToken.slice(tokens.slice, tokens.i - 1), 10);
const float = try std.fmt.parseFloat(f128, numberToken.slice(tokens.slice, tokens.i - 1));
if (@round(float) != float) return error.InvalidNumber;
if (float > std.math.maxInt(T) or float < std.math.minInt(T)) return error.Overflow;
return @floatToInt(T, float);
},
.String => |stringToken| {
return std.fmt.parseInt(T, stringToken.slice(tokens.slice, tokens.i - 1), 10) catch |err| {
switch (err) {
error.Overflow => return err,
error.InvalidCharacter => {
const float = try std.fmt.parseFloat(f128, stringToken.slice(tokens.slice, tokens.i - 1));
if (@round(float) != float) return error.InvalidNumber;
if (float > std.math.maxInt(T) or float < std.math.minInt(T)) return error.Overflow;
return @floatToInt(T, float);
},
}
};
},
else => return error.UnexpectedToken,
}
},
.Optional => |optionalInfo| {
if (token == .Null) {
return null;
} else {
return try parseInternal(optionalInfo.child, token, tokens, options);
}
},
.Enum => |enumInfo| {
switch (token) {
.Number => |numberToken| {
if (!numberToken.is_integer) return error.UnexpectedToken;
const n = try std.fmt.parseInt(enumInfo.tag_type, numberToken.slice(tokens.slice, tokens.i - 1), 10);
return try std.meta.intToEnum(T, n);
},
.String => |stringToken| {
const source_slice = stringToken.slice(tokens.slice, tokens.i - 1);
switch (stringToken.escapes) {
.None => return std.meta.stringToEnum(T, source_slice) orelse return error.InvalidEnumTag,
.Some => {
inline for (enumInfo.fields) |field| {
if (field.name.len == stringToken.decodedLength() and encodesTo(field.name, source_slice)) {
return @field(T, field.name);
}
}
return error.InvalidEnumTag;
},
}
},
else => return error.UnexpectedToken,
}
},
.Union => |unionInfo| {
if (unionInfo.tag_type) |_| {
// try each of the union fields until we find one that matches
inline for (unionInfo.fields) |u_field| {
// take a copy of tokens so we can withhold mutations until success
var tokens_copy = tokens.*;
if (parseInternal(u_field.field_type, token, &tokens_copy, options)) |value| {
tokens.* = tokens_copy;
return @unionInit(T, u_field.name, value);
} else |err| {
// Bubble up error.OutOfMemory
// Parsing some types won't have OutOfMemory in their
// error-sets, for the condition to be valid, merge it in.
if (@as(@TypeOf(err) || error{OutOfMemory}, err) == error.OutOfMemory) return err;
// Bubble up AllocatorRequired, as it indicates missing option
if (@as(@TypeOf(err) || error{AllocatorRequired}, err) == error.AllocatorRequired) return err;
// otherwise continue through the `inline for`
}
}
return error.NoUnionMembersMatched;
} else {
@compileError("Unable to parse into untagged union '" ++ @typeName(T) ++ "'");
}
},
.Struct => |structInfo| {
switch (token) {
.ObjectBegin => {},
else => return error.UnexpectedToken,
}
var r: T = undefined;
var fields_seen = [_]bool{false} ** structInfo.fields.len;
errdefer {
inline for (structInfo.fields) |field, i| {
if (fields_seen[i] and !field.is_comptime) {
parseFreeInternal(field.field_type, @field(r, field.name), options);
}
}
}
while (true) {
switch ((try tokens.next()) orelse return error.UnexpectedEndOfJson) {
.ObjectEnd => break,
.String => |stringToken| {
const key_source_slice = stringToken.slice(tokens.slice, tokens.i - 1);
var child_options = options;
child_options.allow_trailing_data = true;
var found = false;
inline for (structInfo.fields) |field, i| {
// TODO: using switches here segfault the compiler (#2727?)
if ((stringToken.escapes == .None and mem.eql(u8, field.name, key_source_slice)) or (stringToken.escapes == .Some and (field.name.len == stringToken.decodedLength() and encodesTo(field.name, key_source_slice)))) {
// if (switch (stringToken.escapes) {
// .None => mem.eql(u8, field.name, key_source_slice),
// .Some => (field.name.len == stringToken.decodedLength() and encodesTo(field.name, key_source_slice)),
// }) {
if (fields_seen[i]) {
// switch (options.duplicate_field_behavior) {
// .UseFirst => {},
// .Error => {},
// .UseLast => {},
// }
if (options.duplicate_field_behavior == .UseFirst) {
// unconditonally ignore value. for comptime fields, this skips check against default_value
const next_token = (try tokens.next()) orelse return error.UnexpectedEndOfJson;
parseFreeInternal(field.field_type, try parseInternal(field.field_type, next_token, tokens, child_options), child_options);
found = true;
break;
} else if (options.duplicate_field_behavior == .Error) {
return error.DuplicateJSONField;
} else if (options.duplicate_field_behavior == .UseLast) {
if (!field.is_comptime) {
parseFreeInternal(field.field_type, @field(r, field.name), child_options);
}
fields_seen[i] = false;
}
}
if (field.is_comptime) {
if (!try parsesTo(field.field_type, @ptrCast(*const field.field_type, field.default_value.?).*, tokens, child_options)) {
return error.UnexpectedValue;
}
} else {
const next_token = (try tokens.next()) orelse return error.UnexpectedEndOfJson;
@field(r, field.name) = try parseInternal(field.field_type, next_token, tokens, child_options);
}
fields_seen[i] = true;
found = true;
break;
}
}
if (!found) {
if (options.ignore_unknown_fields) {
try skipValue(tokens);
continue;
} else {
return error.UnknownField;
}
}
},
else => return error.UnexpectedToken,
}
}
inline for (structInfo.fields) |field, i| {
if (!fields_seen[i]) {
if (field.default_value) |default_ptr| {
if (!field.is_comptime) {
const default = @ptrCast(*align(1) const field.field_type, default_ptr).*;
@field(r, field.name) = default;
}
} else {
return error.MissingField;
}
}
}
return r;
},
.Array => |arrayInfo| {
switch (token) {
.ArrayBegin => {
var r: T = undefined;
var i: usize = 0;
var child_options = options;
child_options.allow_trailing_data = true;
errdefer {
// Without the r.len check `r[i]` is not allowed
if (r.len > 0) while (true) : (i -= 1) {
parseFreeInternal(arrayInfo.child, r[i], options);
if (i == 0) break;
};
}
while (i < r.len) : (i += 1) {
const next_token = (try tokens.next()) orelse return error.UnexpectedEndOfJson;
r[i] = try parseInternal(arrayInfo.child, next_token, tokens, child_options);
}
const tok = (try tokens.next()) orelse return error.UnexpectedEndOfJson;
switch (tok) {
.ArrayEnd => {},
else => return error.UnexpectedToken,
}
return r;
},
.String => |stringToken| {
if (arrayInfo.child != u8) return error.UnexpectedToken;
var r: T = undefined;
const source_slice = stringToken.slice(tokens.slice, tokens.i - 1);
switch (stringToken.escapes) {
.None => mem.copy(u8, &r, source_slice),
.Some => try unescapeValidString(&r, source_slice),
}
return r;
},
else => return error.UnexpectedToken,
}
},
.Pointer => |ptrInfo| {
const allocator = options.allocator orelse return error.AllocatorRequired;
switch (ptrInfo.size) {
.One => {
const r: T = try allocator.create(ptrInfo.child);
errdefer allocator.destroy(r);
r.* = try parseInternal(ptrInfo.child, token, tokens, options);
return r;
},
.Slice => {
switch (token) {
.ArrayBegin => {
var arraylist = std.ArrayList(ptrInfo.child).init(allocator);
errdefer {
while (arraylist.popOrNull()) |v| {
parseFreeInternal(ptrInfo.child, v, options);
}
arraylist.deinit();
}
while (true) {
const tok = (try tokens.next()) orelse return error.UnexpectedEndOfJson;
switch (tok) {
.ArrayEnd => break,
else => {},
}
try arraylist.ensureUnusedCapacity(1);
const v = try parseInternal(ptrInfo.child, tok, tokens, options);
arraylist.appendAssumeCapacity(v);
}
if (ptrInfo.sentinel) |some| {
const sentinel_value = @ptrCast(*const ptrInfo.child, some).*;
try arraylist.append(sentinel_value);
const output = arraylist.toOwnedSlice();
return output[0 .. output.len - 1 :sentinel_value];
}
return arraylist.toOwnedSlice();
},
.String => |stringToken| {
if (ptrInfo.child != u8) return error.UnexpectedToken;
const source_slice = stringToken.slice(tokens.slice, tokens.i - 1);
const len = stringToken.decodedLength();
const output = try allocator.alloc(u8, len + @boolToInt(ptrInfo.sentinel != null));
errdefer allocator.free(output);
switch (stringToken.escapes) {
.None => mem.copy(u8, output, source_slice),
.Some => try unescapeValidString(output, source_slice),
}
if (ptrInfo.sentinel) |some| {
const char = @ptrCast(*const u8, some).*;
output[len] = char;
return output[0..len :char];
}
return output;
},
else => return error.UnexpectedToken,
}
},
else => @compileError("Unable to parse into type '" ++ @typeName(T) ++ "'"),
}
},
else => @compileError("Unable to parse into type '" ++ @typeName(T) ++ "'"),
}
unreachable;
}
fn ParseError(comptime T: type) type {
return ParseInternalError(T) || error{UnexpectedEndOfJson} || TokenStream.Error;
}
/// Releases resources created by `parse`.
/// Should be called with the same type and `ParseOptions` that were passed to `parse`
fn parseFreeInternal(comptime T: type, value: T, options: ParseOptions) void {
switch (@typeInfo(T)) {
.Bool, .Float, .ComptimeFloat, .Int, .ComptimeInt, .Enum => {},
.Optional => {
if (value) |v| {
return parseFreeInternal(@TypeOf(v), v, options);
}
},
.Union => |unionInfo| {
if (unionInfo.tag_type) |UnionTagType| {
inline for (unionInfo.fields) |u_field| {
if (value == @field(UnionTagType, u_field.name)) {
parseFreeInternal(u_field.field_type, @field(value, u_field.name), options);
break;
}
}
} else {
unreachable;
}
},
.Struct => |structInfo| {
inline for (structInfo.fields) |field| {
if (!field.is_comptime) {
parseFreeInternal(field.field_type, @field(value, field.name), options);
}
}
},
.Array => |arrayInfo| {
for (value) |v| {
parseFreeInternal(arrayInfo.child, v, options);
}
},
.Pointer => |ptrInfo| {
const allocator = options.allocator orelse unreachable;
switch (ptrInfo.size) {
.One => {
parseFreeInternal(ptrInfo.child, value.*, options);
allocator.destroy(value);
},
.Slice => {
for (value) |v| {
parseFreeInternal(ptrInfo.child, v, options);
}
allocator.free(value);
},
else => unreachable,
}
},
else => unreachable,
}
}

View file

@ -1,380 +0,0 @@
const std = @import("std");
const QueryIter = @import("util").QueryIter;
/// Parses a set of query parameters described by the struct `T`.
///
/// To specify query parameters, provide a struct similar to the following:
/// ```
/// struct {
/// foo: bool = false,
/// bar: ?[]const u8 = null,
/// baz: usize = 10,
/// qux: enum { quux, snap } = .quux,
/// }
/// ```
///
/// This will allow it to parse a query string like the following:
/// `?foo&bar=abc&qux=snap`
///
/// Every parameter must have a default value that will be used when the
/// parameter is not provided, and parameter keys.
/// Numbers are parsed from their string representations, and a parameter
/// provided in the query string without a value is parsed either as a bool
/// `true` flag or as `null` depending on the type of its param.
///
/// Parameter types supported:
/// - []const u8
/// - numbers (both integer and float)
/// + Numbers are parsed in base 10
/// - bool
/// + See below for detals
/// - exhaustive enums
/// + Enums are treated as strings with values equal to the enum fields
/// - ?F (where isScalar(F) and F != bool)
/// - Any type that implements:
/// + pub fn parse([]const u8) !F
///
/// Boolean Parameters:
/// The following query strings will all parse a `true` value for the
/// parameter `foo: bool = false`:
/// - `?foo`
/// - `?foo=true`
/// - `?foo=t`
/// - `?foo=yes`
/// - `?foo=y`
/// - `?foo=1`
/// And the following query strings all parse a `false` value:
/// - `?`
/// - `?foo=false`
/// - `?foo=f`
/// - `?foo=no`
/// - `?foo=n`
/// - `?foo=0`
///
/// Compound Types:
/// Compound (struct) types are also supported, with the parameter key
/// for its parameters consisting of the struct's field + '.' + parameter
/// field. For example:
/// ```
/// struct {
/// foo: struct {
/// baz: usize = 0,
/// } = .{},
/// }
/// ```
/// Would be used to parse a query string like
/// `?foo.baz=12345`
///
/// Compound types cannot currently be nullable, and must be structs.
///
/// TODO: values are currently case-sensitive, and are not url-decoded properly.
/// This should be fixed.
pub fn parseQuery(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T {
if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct");
var iter = QueryIter.from(query);
var fields = Intermediary(T){};
while (iter.next()) |pair| {
// TODO: Hash map
inline for (std.meta.fields(Intermediary(T))) |field| {
if (std.ascii.eqlIgnoreCase(field.name[2..], pair.key)) {
@field(fields, field.name) = if (pair.value) |v| .{ .value = v } else .{ .no_value = {} };
break;
}
} else std.log.debug("unknown param {s}", .{pair.key});
}
return (try parse(alloc, T, "", "", fields)) orelse error.NoQuery;
}
fn decodeString(alloc: std.mem.Allocator, val: []const u8) ![]const u8 {
var list = try std.ArrayList(u8).initCapacity(alloc, val.len);
errdefer list.deinit();
var idx: usize = 0;
while (idx < val.len) : (idx += 1) {
if (val[idx] != '%') {
try list.append(val[idx]);
} else {
if (val.len < idx + 2) return error.InvalidEscape;
const buf = [2]u8{ val[idx + 1], val[idx + 2] };
idx += 2;
const ch = try std.fmt.parseInt(u8, &buf, 16);
try list.append(ch);
}
}
return list.toOwnedSlice();
}
fn parseScalar(alloc: std.mem.Allocator, comptime T: type, comptime name: []const u8, fields: anytype) !?T {
const param = @field(fields, name);
return switch (param) {
.not_specified => null,
.no_value => try parseQueryValue(alloc, T, null),
.value => |v| try parseQueryValue(alloc, T, v),
};
}
fn parse(
alloc: std.mem.Allocator,
comptime T: type,
comptime prefix: []const u8,
comptime name: []const u8,
fields: anytype,
) !?T {
if (comptime isScalar(T)) return parseScalar(alloc, T, prefix ++ "." ++ name, fields);
switch (@typeInfo(T)) {
.Union => |info| {
var result: ?T = null;
inline for (info.fields) |field| {
const F = field.field_type;
const maybe_value = try parse(alloc, F, prefix, field.name, fields);
if (maybe_value) |value| {
if (result != null) return error.DuplicateUnionField;
result = @unionInit(T, field.name, value);
}
}
std.log.debug("{any}", .{result});
return result;
},
.Struct => |info| {
var result: T = undefined;
var fields_specified: usize = 0;
inline for (info.fields) |field| {
const F = field.field_type;
var maybe_value: ?F = null;
if (try parse(alloc, F, prefix ++ "." ++ name, field.name, fields)) |v| {
maybe_value = v;
} else if (field.default_value) |default| {
if (comptime @sizeOf(F) != 0) {
maybe_value = @ptrCast(*const F, @alignCast(@alignOf(F), default)).*;
} else {
maybe_value = std.mem.zeroes(F);
}
}
if (maybe_value) |v| {
fields_specified += 1;
@field(result, field.name) = v;
}
}
if (fields_specified == 0) {
return null;
} else if (fields_specified != info.fields.len) {
std.log.debug("{} {s} {s}", .{ T, prefix, name });
return error.PartiallySpecifiedStruct;
} else {
return result;
}
},
// Only applies to non-scalar optionals
.Optional => |info| return try parse(alloc, info.child, prefix, name, fields),
else => @compileError("tmp"),
}
}
fn recursiveFieldPaths(comptime T: type, comptime prefix: []const u8) []const []const u8 {
comptime {
if (std.meta.trait.is(.Optional)(T)) return recursiveFieldPaths(std.meta.Child(T), prefix);
var fields: []const []const u8 = &.{};
for (std.meta.fields(T)) |f| {
const full_name = prefix ++ f.name;
if (isScalar(f.field_type)) {
fields = fields ++ @as([]const []const u8, &.{full_name});
} else {
const field_prefix = if (@typeInfo(f.field_type) == .Union) prefix else full_name ++ ".";
fields = fields ++ recursiveFieldPaths(f.field_type, field_prefix);
}
}
return fields;
}
}
const QueryParam = union(enum) {
not_specified: void,
no_value: void,
value: []const u8,
};
fn Intermediary(comptime T: type) type {
const field_names = recursiveFieldPaths(T, "..");
var fields: [field_names.len]std.builtin.Type.StructField = undefined;
for (field_names) |name, i| fields[i] = .{
.name = name,
.field_type = QueryParam,
.default_value = &QueryParam{ .not_specified = {} },
.is_comptime = false,
.alignment = @alignOf(QueryParam),
};
return @Type(.{ .Struct = .{
.layout = .Auto,
.fields = &fields,
.decls = &.{},
.is_tuple = false,
} });
}
fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, value: ?[]const u8) !T {
const is_optional = comptime std.meta.trait.is(.Optional)(T);
// If param is present, but without an associated value
if (value == null) {
return if (is_optional)
null
else if (T == bool)
true
else
error.InvalidValue;
}
return try parseQueryValueNotNull(alloc, if (is_optional) std.meta.Child(T) else T, value.?);
}
const bool_map = std.ComptimeStringMap(bool, .{
.{ "true", true },
.{ "t", true },
.{ "yes", true },
.{ "y", true },
.{ "1", true },
.{ "false", false },
.{ "f", false },
.{ "no", false },
.{ "n", false },
.{ "0", false },
});
fn parseQueryValueNotNull(alloc: std.mem.Allocator, comptime T: type, value: []const u8) !T {
const decoded = try decodeString(alloc, value);
errdefer alloc.free(decoded);
if (comptime std.meta.trait.isZigString(T)) return decoded;
const result = if (comptime std.meta.trait.isIntegral(T))
try std.fmt.parseInt(T, decoded, 0)
else if (comptime std.meta.trait.isFloat(T))
try std.fmt.parseFloat(T, decoded)
else if (comptime std.meta.trait.is(.Enum)(T))
std.meta.stringToEnum(T, decoded) orelse return error.InvalidEnumValue
else if (T == bool)
bool_map.get(value) orelse return error.InvalidBool
else if (comptime std.meta.trait.hasFn("parse")(T))
try T.parse(value)
else
@compileError("Invalid type " ++ @typeName(T));
alloc.free(decoded);
return result;
}
fn isScalar(comptime T: type) bool {
if (comptime std.meta.trait.isZigString(T)) return true;
if (comptime std.meta.trait.isIntegral(T)) return true;
if (comptime std.meta.trait.isFloat(T)) return true;
if (comptime std.meta.trait.is(.Enum)(T)) return true;
if (T == bool) return true;
if (comptime std.meta.trait.hasFn("parse")(T)) return true;
if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true;
return false;
}
pub fn formatQuery(params: anytype, writer: anytype) !void {
try format("", "", params, writer);
}
fn urlFormatString(writer: anytype, val: []const u8) !void {
for (val) |ch| {
const printable = switch (ch) {
'0'...'9', 'a'...'z', 'A'...'Z' => true,
'-', '.', '_', '~', ':', '@', '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=' => true,
else => false,
};
try if (printable) writer.writeByte(ch) else std.fmt.format(writer, "%{x:0>2}", .{ch});
}
}
fn formatScalar(comptime name: []const u8, val: anytype, writer: anytype) !void {
const T = @TypeOf(val);
if (comptime std.meta.trait.is(.Optional)(T)) {
return if (val) |v| formatScalar(name, v, writer) else {};
}
try urlFormatString(writer, name);
try writer.writeByte('=');
if (comptime std.meta.trait.isZigString(T)) {
try urlFormatString(writer, val);
} else try switch (@typeInfo(T)) {
.Enum => urlFormatString(writer, @tagName(val)),
else => std.fmt.format(writer, "{}", .{val}),
};
try writer.writeByte('&');
}
fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytype, writer: anytype) !void {
const T = @TypeOf(params);
const eff_prefix = if (prefix.len == 0) "" else prefix ++ ".";
if (comptime isScalar(T)) return formatScalar(eff_prefix ++ name, params, writer);
switch (@typeInfo(T)) {
.Struct => {
inline for (std.meta.fields(T)) |field| {
const val = @field(params, field.name);
try format(eff_prefix ++ name, field.name, val, writer);
}
},
.Union => {
inline for (std.meta.fields(T)) |field| {
const tag = @field(std.meta.Tag(T), field.name);
const tag_name = field.name;
if (@as(std.meta.Tag(T), params) == tag) {
const val = @field(params, tag_name);
try format(prefix, tag_name, val, writer);
}
}
},
.Optional => {
if (params) |p| try format(prefix, name, p, writer);
},
else => @compileError("Unsupported query type"),
}
}
test {
const TestQuery = struct {
int: usize = 3,
boolean: bool = false,
str_enum: ?enum { foo, bar } = null,
};
try std.testing.expectEqual(TestQuery{
.int = 3,
.boolean = false,
.str_enum = null,
}, try parseQuery(TestQuery, ""));
try std.testing.expectEqual(TestQuery{
.int = 5,
.boolean = true,
.str_enum = .foo,
}, try parseQuery(TestQuery, "?int=5&boolean=yes&str_enum=foo"));
}

View file

@ -160,6 +160,13 @@ pub fn deepClone(alloc: std.mem.Allocator, val: anytype) !@TypeOf(val) {
count += 1;
}
},
.Union => {
inline for (comptime std.meta.fieldNames(T)) |f| {
if (std.meta.isTag(val, f)) {
return @unionInit(T, f, try deepClone(alloc, @field(val, f)));
}
} else unreachable;
},
.Array => {
var count: usize = 0;
errdefer for (result[0..count]) |v| deepFree(alloc, v);
@ -194,3 +201,44 @@ pub fn seedThreadPrng() !void {
prng = std.rand.DefaultPrng.init(@bitCast(u64, buf));
}
pub const testing = struct {
pub fn expectDeepEqual(expected: anytype, actual: @TypeOf(expected)) !void {
const T = @TypeOf(expected);
switch (@typeInfo(T)) {
.Null, .Void => return,
.Int, .Float, .Bool, .Enum => try std.testing.expectEqual(expected, actual),
.Struct => {
inline for (comptime std.meta.fieldNames(T)) |f| {
try expectDeepEqual(@field(expected, f), @field(actual, f));
}
},
.Union => {
inline for (comptime std.meta.fieldNames(T)) |f| {
if (std.meta.isTag(expected, f)) {
try std.testing.expect(std.meta.isTag(actual, f));
try expectDeepEqual(@field(expected, f), @field(actual, f));
}
}
},
.Pointer, .Array => {
if (comptime std.meta.trait.isIndexable(T)) {
try std.testing.expectEqual(expected.len, actual.len);
for (expected) |_, i| {
try expectDeepEqual(expected[i], actual[i]);
}
} else if (comptime std.meta.trait.isSingleItemPtr(T)) {
try expectDeepEqual(expected.*, actual.*);
}
},
.Optional => {
if (expected) |e| {
try expectDeepEqual(e, actual orelse return error.TestExpectedEqual);
} else {
try std.testing.expect(actual == null);
}
},
else => @compileError("Unsupported Type " ++ @typeName(T)),
}
}
};