Compare commits

...

25 commits

Author SHA1 Message Date
208007c0f7 Drive - Uploads & dirs 2022-12-03 07:09:29 -08:00
31f676580d Rework fs db schema 2022-12-03 07:09:03 -08:00
6cfd035883 parse content-type form header 2022-12-03 06:36:54 -08:00
e27d0064ee Fix some bugs in sql engine 2022-12-03 06:36:31 -08:00
a97850964e Stub out filesystem apis 2022-12-03 01:00:04 -08:00
a45ccfe0e4 Basic file upload 2022-12-02 23:44:27 -08:00
2bcef49e5e Add star segment support in routes 2022-12-02 23:21:49 -08:00
2206cd6ac9 Form File support 2022-12-02 22:34:12 -08:00
e6f57495c0 Cleaner multipart handling 2022-12-02 22:20:24 -08:00
6e56775d61 Multipart/form-data 2022-12-02 21:49:27 -08:00
0b13f210c7 Refactor 2022-12-02 21:49:17 -08:00
ba4f3a7bf4 Reorganize tests 2022-12-01 21:02:33 -08:00
f7bcafe1b1 Remove dead code 2022-12-01 20:52:51 -08:00
16c574bdd6 Refactoring 2022-12-01 20:41:52 -08:00
b2093128de Remove ciutf8 2022-12-01 19:46:51 -08:00
04c593ffdd add util.comptimeToCrlf 2022-12-01 19:46:07 -08:00
8400cd74fd Use deserialization utils 2022-12-01 01:56:17 -08:00
83af6a40e4 More serialization refactor 2022-11-30 21:11:54 -08:00
c7dcded04a Add tests for deserialization 2022-11-30 20:01:17 -08:00
aa632ace8b Work on deserialization refactor 2022-11-30 19:21:55 -08:00
96a46a98c9 Multipart deserialization 2022-11-27 22:33:05 -08:00
2f78490545 Add rudimentary scalar parsing 2022-11-27 06:24:41 -08:00
4a98b6a9c4 Parse form params 2022-11-27 06:11:01 -08:00
938ee61477 Start work on multipart form parser 2022-11-27 05:43:06 -08:00
b99a0095d4 Rudimentary test cases for ParseBody 2022-11-27 02:21:22 -08:00
26 changed files with 1724 additions and 880 deletions

1
.gitignore vendored
View file

@ -2,3 +2,4 @@
**/zig-cache
**.db
/config.json
/files

View file

@ -99,13 +99,23 @@ pub fn build(b: *std.build.Builder) !void {
exe.addSystemIncludePath("/usr/include/");
const unittest_http_cmd = b.step("unit:http", "Run tests for http package");
const unittest_http = b.addTest("src/http/test.zig");
const unittest_http = b.addTest("src/http/lib.zig");
unittest_http_cmd.dependOn(&unittest_http.step);
unittest_http.addPackage(pkgs.util);
//const unittest_util_cmd = b.step("unit:util", "Run tests for util package");
//const unittest_util = b.addTest("src/util/Uuid.zig");
//unittest_util_cmd.dependOn(&unittest_util.step);
const unittest_util_cmd = b.step("unit:util", "Run tests for util package");
const unittest_util = b.addTest("src/util/lib.zig");
unittest_util_cmd.dependOn(&unittest_util.step);
const unittest_sql_cmd = b.step("unit:sql", "Run tests for sql package");
const unittest_sql = b.addTest("src/sql/lib.zig");
unittest_sql_cmd.dependOn(&unittest_sql.step);
unittest_sql.addPackage(pkgs.util);
const unittest_template_cmd = b.step("unit:template", "Run tests for template package");
const unittest_template = b.addTest("src/template/lib.zig");
unittest_template_cmd.dependOn(&unittest_template.step);
//unittest_template.addPackage(pkgs.util);
//const util_tests = b.addTest("src/util/lib.zig");
//const sql_tests = b.addTest("src/sql/lib.zig");
@ -115,7 +125,9 @@ pub fn build(b: *std.build.Builder) !void {
//const unit_tests = b.step("unit-tests", "Run tests");
const unittest_all = b.step("unit", "Run unit tests");
unittest_all.dependOn(unittest_http_cmd);
//unittest_all.dependOn(unittest_util_cmd);
unittest_all.dependOn(unittest_util_cmd);
unittest_all.dependOn(unittest_sql_cmd);
unittest_all.dependOn(unittest_template_cmd);
const api_integration = b.addTest("./tests/api_integration/lib.zig");
api_integration.addPackage(pkgs.opts);

View file

@ -9,6 +9,7 @@ const services = struct {
const communities = @import("./services/communities.zig");
const actors = @import("./services/actors.zig");
const auth = @import("./services/auth.zig");
const drive = @import("./services/files.zig");
const invites = @import("./services/invites.zig");
const notes = @import("./services/notes.zig");
const follows = @import("./services/follows.zig");
@ -136,6 +137,14 @@ pub const FollowerQueryResult = FollowQueryResult;
pub const FollowingQueryArgs = FollowQueryArgs;
pub const FollowingQueryResult = FollowQueryResult;
pub const UploadFileArgs = struct {
filename: []const u8,
dir: ?[]const u8,
description: ?[]const u8,
content_type: []const u8,
sensitive: bool,
};
pub fn isAdminSetup(db: sql.Db) !bool {
_ = services.communities.adminCommunityId(db) catch |err| switch (err) {
error.NotFound => return false,
@ -509,5 +518,23 @@ fn ApiConn(comptime DbConn: type) type {
self.allocator,
);
}
pub fn uploadFile(self: *Self, meta: UploadFileArgs, body: []const u8) !void {
const user_id = self.user_id orelse return error.NoToken;
return try services.drive.createFile(self.db, .{
.dir = meta.dir orelse "/",
.filename = meta.filename,
.owner = .{ .user_id = user_id },
.created_by = user_id,
.description = meta.description,
.content_type = meta.content_type,
.sensitive = meta.sensitive,
}, body, self.allocator);
}
pub fn driveMkdir(self: *Self, path: []const u8) !void {
const user_id = self.user_id orelse return error.NoToken;
try services.drive.mkdir(self.db, .{ .user_id = user_id }, path, self.allocator);
}
};
}

View file

@ -11,59 +11,224 @@ pub const FileOwner = union(enum) {
pub const DriveFile = struct {
id: Uuid,
path: []const u8,
filename: []const u8,
owner: FileOwner,
size: usize,
description: []const u8,
content_type: []const u8,
sensitive: bool,
created_at: DateTime,
updated_at: DateTime,
};
const EntryType = enum {
dir,
file,
};
pub const CreateFileArgs = struct {
dir: []const u8,
filename: []const u8,
owner: FileOwner,
size: usize,
created_at: DateTime,
created_by: Uuid,
description: ?[]const u8,
content_type: ?[]const u8,
sensitive: bool,
};
pub const files = struct {
pub fn create(db: anytype, owner: FileOwner, filename: []const u8, data: []const u8, alloc: std.mem.Allocator) !void {
const id = Uuid.randV4(util.getThreadPrng());
const now = DateTime.now();
fn lookupDirectory(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !Uuid {
return (try db.queryRow(
std.meta.Tuple(
&.{util.Uuid},
),
\\SELECT id
\\FROM drive_entry_path
\\WHERE
\\ path = (CASE WHEN LENGTH($1) = 0 THEN '/' ELSE '/' || $1 || '/' END)
\\ AND account_owner_id IS NOT DISTINCT FROM $2
\\ AND community_owner_id IS NOT DISTINCT FROM $3
\\ AND kind = 'dir'
\\LIMIT 1
,
.{
std.mem.trim(u8, path, "/"),
if (owner == .user_id) owner.user_id else null,
if (owner == .community_id) owner.community_id else null,
},
alloc,
))[0];
}
// TODO: assert we're not in a transaction
db.insert("drive_file", .{
fn lookup(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !Uuid {
return (try db.queryRow(
std.meta.Tuple(
&.{util.Uuid},
),
\\SELECT id
\\FROM drive_entry_path
\\WHERE
\\ path = (CASE WHEN LENGTH($1) = 0 THEN '/' ELSE '/' || $1 || '/' END)
\\ AND account_owner_id IS NOT DISTINCT FROM $2
\\ AND community_owner_id IS NOT DISTINCT FROM $3
\\LIMIT 1
,
.{
std.mem.trim(u8, path, "/"),
if (owner == .user_id) owner.user_id else null,
if (owner == .community_id) owner.community_id else null,
},
alloc,
))[0];
}
pub fn mkdir(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !void {
var split = std.mem.splitBackwards(u8, std.mem.trim(u8, path, "/"), "/");
const name = split.first();
const dir = split.rest();
std.log.debug("'{s}' / '{s}'", .{ name, dir });
if (name.len == 0) return error.EmptyName;
const id = Uuid.randV4(util.getThreadPrng());
const tx = try db.begin();
errdefer tx.rollback();
const parent = try lookupDirectory(tx, owner, dir, alloc);
try tx.insert("drive_entry", .{
.id = id,
.account_owner_id = if (owner == .user_id) owner.user_id else null,
.community_owner_id = if (owner == .community_id) owner.community_id else null,
.name = name,
.parent_directory_id = parent,
}, alloc);
try tx.commit();
}
pub fn rmdir(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !void {
const tx = try db.begin();
errdefer tx.rollback();
const id = try lookupDirectory(tx, owner, path, alloc);
try tx.exec("DELETE FROM drive_directory WHERE id = $1", .{id}, alloc);
try tx.commit();
}
fn insertFileRow(tx: anytype, id: Uuid, filename: []const u8, owner: FileOwner, dir: Uuid, alloc: std.mem.Allocator) !void {
try tx.insert("drive_entry", .{
.id = id,
.account_owner_id = if (owner == .user_id) owner.user_id else null,
.community_owner_id = if (owner == .community_id) owner.community_id else null,
.parent_directory_id = dir,
.name = filename,
.file_id = id,
}, alloc);
}
pub fn createFile(db: anytype, args: CreateFileArgs, data: []const u8, alloc: std.mem.Allocator) !void {
const id = Uuid.randV4(util.getThreadPrng());
const now = DateTime.now();
{
var tx = try db.begin();
errdefer tx.rollback();
const dir_id = try lookupDirectory(tx, args.owner, args.dir, alloc);
try tx.insert("file_upload", .{
.id = id,
.filename = filename,
.owner = owner,
.filename = args.filename,
.created_by = args.created_by,
.size = data.len,
.description = args.description,
.content_type = args.content_type,
.sensitive = args.sensitive,
.is_deleted = false,
.created_at = now,
}, alloc) catch return error.DatabaseFailure;
// Assume the previous statement succeeded and is not stuck in a transaction
errdefer {
db.exec("DELETE FROM drive_file WHERE ID = $1", .{id}, alloc) catch |err| {
std.log.err("Unable to remove file record in DB: {}", .{err});
};
.updated_at = now,
}, alloc);
var sub_tx = try tx.savepoint();
if (insertFileRow(sub_tx, id, args.filename, args.owner, dir_id, alloc)) |_| {
try sub_tx.release();
} else |err| {
std.log.debug("{}", .{err});
switch (err) {
error.UniqueViolation => {
try sub_tx.rollbackSavepoint();
// Rename the file before trying again
var split = std.mem.split(u8, args.filename, ".");
const name = split.first();
const ext = split.rest();
var buf: [256]u8 = undefined;
const drive_filename = try std.fmt.bufPrint(&buf, "{s}.{}.{s}", .{ name, id, ext });
try insertFileRow(tx, id, drive_filename, args.owner, dir_id, alloc);
},
else => return error.DatabaseFailure,
}
}
try saveFile(id, data);
try tx.commit();
}
const data_root = "./files";
fn saveFile(id: Uuid, data: []const u8) !void {
var dir = try std.fs.cwd().openDir(data_root);
defer dir.close();
var file = try dir.createFile(id.toCharArray(), .{ .exclusive = true });
defer file.close();
try file.writer().writeAll(data);
try file.sync();
errdefer {
db.exec("DELETE FROM file_upload WHERE ID = $1", .{id}, alloc) catch |err| {
std.log.err("Unable to remove file record in DB: {}", .{err});
};
db.exec("DELETE FROM drive_entry WHERE ID = $1", .{id}, alloc) catch |err| {
std.log.err("Unable to remove file record in DB: {}", .{err});
};
}
pub fn deref(alloc: std.mem.Allocator, id: Uuid) ![]const u8 {
var dir = try std.fs.cwd().openDir(data_root);
defer dir.close();
try saveFile(id, data);
}
return dir.readFileAlloc(alloc, id.toCharArray(), 1 << 32);
}
const data_root = "./files";
fn saveFile(id: Uuid, data: []const u8) !void {
var dir = try std.fs.cwd().openDir(data_root, .{});
defer dir.close();
pub fn delete(db: anytype, alloc: std.mem.Allocator, id: Uuid) !void {
var dir = try std.fs.cwd().openDir(data_root);
defer dir.close();
var file = try dir.createFile(&id.toCharArray(), .{ .exclusive = true });
defer file.close();
try dir.deleteFile(id.toCharArray());
try file.writer().writeAll(data);
try file.sync();
}
db.exec("DELETE FROM drive_file WHERE ID = $1", .{id}, alloc) catch return error.DatabaseFailure;
}
};
pub fn deref(alloc: std.mem.Allocator, id: Uuid) ![]const u8 {
var dir = try std.fs.cwd().openDir(data_root, .{});
defer dir.close();
return dir.readFileAlloc(alloc, &id.toCharArray(), 1 << 32);
}
pub fn deleteFile(db: anytype, alloc: std.mem.Allocator, id: Uuid) !void {
var dir = try std.fs.cwd().openDir(data_root, .{});
defer dir.close();
try dir.deleteFile(id.toCharArray());
const tx = try db.beginOrSavepoint();
errdefer tx.rollback();
tx.exec("DELETE FROM drive_entry WHERE ID = $1", .{id}, alloc) catch return error.DatabaseFailure;
tx.exec("DELETE FROM file_upload WHERE ID = $1", .{id}, alloc) catch return error.DatabaseFailure;
try tx.commitOrRelease();
}

View file

@ -1,5 +1,60 @@
const std = @import("std");
pub const ParamIter = struct {
str: []const u8,
index: usize = 0,
const Param = struct {
name: []const u8,
value: []const u8,
};
pub fn from(str: []const u8) ParamIter {
return .{ .str = str, .index = std.mem.indexOfScalar(u8, str, ';') orelse str.len };
}
pub fn fieldValue(self: *ParamIter) []const u8 {
return std.mem.sliceTo(self.str, ';');
}
pub fn next(self: *ParamIter) ?Param {
if (self.index >= self.str.len) return null;
const start = self.index + 1;
const new_start = std.mem.indexOfScalarPos(u8, self.str, start, ';') orelse self.str.len;
self.index = new_start;
const param = std.mem.trim(u8, self.str[start..new_start], " \t");
var split = std.mem.split(u8, param, "=");
const name = split.first();
const value = std.mem.trimLeft(u8, split.rest(), " \t");
// TODO: handle quoted values
// TODO: handle parse errors
return Param{
.name = name,
.value = value,
};
}
};
pub fn getParam(field: []const u8, name: ?[]const u8) ?[]const u8 {
var iter = ParamIter.from(field);
if (name) |param| {
while (iter.next()) |p| {
if (std.ascii.eqlIgnoreCase(param, p.name)) {
const trimmed = std.mem.trim(u8, p.value, " \t");
if (trimmed.len >= 2 and trimmed[0] == '"' and trimmed[trimmed.len - 1] == '"') {
return trimmed[1 .. trimmed.len - 1];
}
return trimmed;
}
}
return null;
} else return iter.fieldValue();
}
pub const Fields = struct {
const HashContext = struct {
const hash_seed = 1;

View file

@ -1,27 +1,55 @@
const std = @import("std");
const ciutf8 = @import("util").ciutf8;
const request = @import("./request.zig");
const server = @import("./server.zig");
pub const urlencode = @import("./urlencode.zig");
pub const socket = @import("./socket.zig");
const json = @import("./json.zig");
const multipart = @import("./multipart.zig");
pub const fields = @import("./fields.zig");
pub const Method = enum {
GET,
HEAD,
POST,
PUT,
DELETE,
CONNECT,
OPTIONS,
TRACE,
PATCH,
// WebDAV methods (we use some of them for the drive system)
MKCOL,
MOVE,
pub fn requestHasBody(self: Method) bool {
return switch (self) {
.POST, .PUT, .PATCH, .MKCOL, .MOVE => true,
else => false,
};
}
};
pub const Method = std.http.Method;
pub const Status = std.http.Status;
pub const Request = request.Request(server.Stream.Reader);
pub const Response = server.Response;
pub const Handler = server.Handler;
//pub const Handler = server.Handler;
pub const Server = server.Server;
pub const middleware = @import("./middleware.zig");
pub const queryStringify = @import("./query.zig").queryStringify;
pub const Fields = @import("./headers.zig").Fields;
pub const Fields = fields.Fields;
pub const FormFile = multipart.FormFile;
pub const Protocol = enum {
http_1_0,
http_1_1,
http_1_x,
};
test {
_ = std.testing.refAllDecls(@This());
}

View file

@ -14,10 +14,11 @@
/// Terminal middlewares that are not implemented using other middlewares should
/// only accept a `void` value for `next_handler`.
const std = @import("std");
const http = @import("./lib.zig");
const util = @import("util");
const query_utils = @import("./query.zig");
const http = @import("./lib.zig");
const urlencode = @import("./urlencode.zig");
const json_utils = @import("./json.zig");
const fields = @import("./fields.zig");
/// Takes an iterable of middlewares and chains them together.
pub fn apply(middlewares: anytype) Apply(@TypeOf(middlewares)) {
@ -29,20 +30,20 @@ pub fn Apply(comptime Middlewares: type) type {
return ApplyInternal(std.meta.fields(Middlewares));
}
fn ApplyInternal(comptime fields: []const std.builtin.Type.StructField) type {
if (fields.len == 0) return void;
fn ApplyInternal(comptime which: []const std.builtin.Type.StructField) type {
if (which.len == 0) return void;
return HandlerList(
fields[0].field_type,
ApplyInternal(fields[1..]),
which[0].field_type,
ApplyInternal(which[1..]),
);
}
fn applyInternal(middlewares: anytype, comptime fields: []const std.builtin.Type.StructField) ApplyInternal(fields) {
if (fields.len == 0) return {};
fn applyInternal(middlewares: anytype, comptime which: []const std.builtin.Type.StructField) ApplyInternal(which) {
if (which.len == 0) return {};
return .{
.first = @field(middlewares, fields[0].name),
.next = applyInternal(middlewares, fields[1..]),
.first = @field(middlewares, which[0].name),
.next = applyInternal(middlewares, which[1..]),
};
}
@ -349,15 +350,71 @@ pub fn router(routes: anytype) Router(@TypeOf(routes)) {
return Router(@TypeOf(routes)){ .routes = routes };
}
pub const PathIter = struct {
is_first: bool,
iter: std.mem.SplitIterator(u8),
pub fn from(path: []const u8) PathIter {
return .{ .is_first = true, .iter = std.mem.split(u8, path, "/") };
}
pub fn next(self: *PathIter) ?[]const u8 {
defer self.is_first = false;
while (self.iter.next()) |it| if (it.len != 0) {
return it;
};
if (self.is_first) return self.iter.rest();
return null;
}
pub fn first(self: *PathIter) []const u8 {
std.debug.assert(self.is_first);
return self.next().?;
}
pub fn rest(self: *PathIter) []const u8 {
return self.iter.rest();
}
};
test "PathIter" {
const testCase = struct {
fn case(path: []const u8, segments: []const []const u8) !void {
var iter = PathIter.from(path);
for (segments) |s| {
try std.testing.expectEqualStrings(s, iter.next() orelse return error.TestExpectedEqual);
}
try std.testing.expect(iter.next() == null);
}
}.case;
try testCase("", &.{""});
try testCase("*", &.{"*"});
try testCase("/", &.{""});
try testCase("/ab/cd", &.{ "ab", "cd" });
try testCase("/ab/cd/", &.{ "ab", "cd" });
try testCase("/ab/cd//", &.{ "ab", "cd" });
try testCase("ab", &.{"ab"});
try testCase("/ab", &.{"ab"});
try testCase("ab/", &.{"ab"});
try testCase("ab//ab//", &.{ "ab", "ab" });
}
// helper function for doing route analysis
fn pathMatches(route: []const u8, path: []const u8) bool {
var path_iter = util.PathIter.from(path);
var route_iter = util.PathIter.from(route);
var path_iter = PathIter.from(path);
var route_iter = PathIter.from(route);
while (route_iter.next()) |route_segment| {
const path_segment = path_iter.next() orelse return false;
const path_segment = path_iter.next() orelse "";
if (route_segment.len > 0 and route_segment[0] == ':') {
// Route Argument
if (path_segment.len == 0) return false;
if (route_segment[route_segment.len - 1] == '*') {
// consume rest of path segments
while (path_iter.next()) |_| {}
} else if (path_segment.len == 0) return false;
} else {
if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return false;
}
@ -428,6 +485,10 @@ test "route" {
try testCase(true, .{ .method = .POST, .path = "/abcd/efgh" }, .POST, "abcd/efgh");
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "abcd/efgh");
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz");
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/efgh/xyz");
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/efgh");
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/");
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd");
try testCase(false, .{ .method = .POST, .path = "/" }, .GET, "/");
try testCase(false, .{ .method = .GET, .path = "/abcd" }, .GET, "");
@ -436,32 +497,21 @@ test "route" {
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "/abcd/");
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/");
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz/foo");
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "defg/abcd");
}
/// Mounts a router subtree under a given path. Middlewares further down on the list
/// are called with the path prefix specified by `route` removed from the path.
/// Must be below `split_uri` on the middleware list.
pub fn Mount(comptime route: []const u8) type {
if (std.mem.indexOfScalar(u8, route, ':') != null) @compileError("Route args cannot be mounted");
return struct {
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
var path_iter = util.PathIter.from(ctx.path);
comptime var route_iter = util.PathIter.from(route);
var path_unused: []const u8 = ctx.path;
inline while (comptime route_iter.next()) |route_segment| {
if (comptime route_segment.len == 0) continue;
const path_segment = path_iter.next() orelse return error.RouteMismatch;
path_unused = path_iter.rest();
if (comptime route_segment[0] == ':') {
@compileLog("Argument segments cannot be mounted");
// Route Argument
} else {
if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch;
}
}
const args = try parseArgsFromPath(route ++ "/:path*", struct { path: []const u8 }, ctx.path);
var new_ctx = ctx;
new_ctx.path = path_unused;
new_ctx.path = args.path;
return next.handle(req, res, new_ctx, {});
}
};
@ -491,18 +541,33 @@ test "mount" {
fn parseArgsFromPath(comptime route: []const u8, comptime Args: type, path: []const u8) !Args {
var args: Args = undefined;
var path_iter = util.PathIter.from(path);
comptime var route_iter = util.PathIter.from(route);
var path_iter = PathIter.from(path);
comptime var route_iter = PathIter.from(route);
var path_unused: []const u8 = path;
inline while (comptime route_iter.next()) |route_segment| {
const path_segment = path_iter.next() orelse return error.RouteMismatch;
if (route_segment.len > 0 and route_segment[0] == ':') {
const path_segment = path_iter.next() orelse "";
if (route_segment[0] == ':') {
comptime var name: []const u8 = route_segment[1..];
var value: []const u8 = path_segment;
// route segment is an argument segment
if (path_segment.len == 0) return error.RouteMismatch;
const A = @TypeOf(@field(args, route_segment[1..]));
@field(args, route_segment[1..]) = try parseArgFromPath(A, path_segment);
if (comptime route_segment[route_segment.len - 1] == '*') {
// waste remaining args
while (path_iter.next()) |_| {}
name = route_segment[1 .. route_segment.len - 1];
value = path_unused;
} else {
if (path_segment.len == 0) return error.RouteMismatch;
}
const A = @TypeOf(@field(args, name));
@field(args, name) = try parseArgFromPath(A, value);
} else {
// route segment is a literal segment
if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch;
}
path_unused = path_iter.rest();
}
if (path_iter.next() != null) return error.RouteMismatch;
@ -577,6 +642,21 @@ test "ParsePathArgs" {
try testCase("/:id/xyz/:str", struct { id: usize, str: []const u8 }, "/3/xyz/abcd", .{ .id = 3, .str = "abcd" });
try testCase("/:id", struct { id: util.Uuid }, "/" ++ util.Uuid.nil.toCharArray(), .{ .id = util.Uuid.nil });
try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/abc", .{ .arg = "abc" });
try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/abc/def", .{ .arg = "abc/def" });
try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/", .{ .arg = "" });
// Compiler crashes if i keep the args named the same as above.
// TODO: Debug this and try to fix it
try testCase("/xyz/:bar*", struct { bar: []const u8 }, "/xyz", .{ .bar = "" });
// It's a quirk that the initial / is left in for these cases. However, it results in a path
// that's semantically equivalent so i didn't bother fixing it
try testCase("/:foo*", struct { foo: []const u8 }, "/abc", .{ .foo = "/abc" });
try testCase("/:foo*", struct { foo: []const u8 }, "/abc/def", .{ .foo = "/abc/def" });
try testCase("/:foo*", struct { foo: []const u8 }, "/", .{ .foo = "/" });
try testCase("/:foo*", struct { foo: []const u8 }, "", .{ .foo = "" });
try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/", .{}));
try std.testing.expectError(error.RouteMismatch, testCase("/abcd/:id", struct { id: usize }, "/123", .{}));
try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/3/id/blahblah", .{ .id = 3 }));
@ -587,41 +667,51 @@ const BaseContentType = enum {
json,
url_encoded,
octet_stream,
multipart_formdata,
other,
};
fn parseBodyFromRequest(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T {
//@compileLog(T);
const buf = try reader.readAllAlloc(alloc, 1 << 16);
defer alloc.free(buf);
fn parseBodyFromRequest(comptime T: type, content_type: ?[]const u8, reader: anytype, alloc: std.mem.Allocator) !T {
// Use json by default for now for testing purposes
const eff_type = content_type orelse "application/json";
const parser_type = matchContentType(eff_type);
switch (content_type) {
switch (parser_type) {
.octet_stream, .json => {
const buf = try reader.readAllAlloc(alloc, 1 << 16);
defer alloc.free(buf);
const body = try json_utils.parse(T, buf, alloc);
defer json_utils.parseFree(body, alloc);
return try util.deepClone(alloc, body);
},
.url_encoded => return query_utils.parseQuery(alloc, T, buf) catch |err| switch (err) {
error.NoQuery => error.NoBody,
else => err,
.url_encoded => {
const buf = try reader.readAllAlloc(alloc, 1 << 16);
defer alloc.free(buf);
return urlencode.parse(alloc, T, buf) catch |err| switch (err) {
//error.NoQuery => error.NoBody,
else => err,
};
},
.multipart_formdata => {
const boundary = fields.getParam(eff_type, "boundary") orelse return error.MissingBoundary;
return try @import("./multipart.zig").parseFormData(T, boundary, reader, alloc);
},
else => return error.UnsupportedMediaType,
}
}
// figure out what base parser to use
fn matchContentType(hdr: ?[]const u8) ?BaseContentType {
if (hdr) |h| {
if (std.ascii.eqlIgnoreCase(h, "application/x-www-form-urlencoded")) return .url_encoded;
if (std.ascii.eqlIgnoreCase(h, "application/json")) return .json;
if (std.ascii.eqlIgnoreCase(h, "application/octet-stream")) return .octet_stream;
fn matchContentType(hdr: []const u8) BaseContentType {
const trimmed = std.mem.sliceTo(hdr, ';');
if (std.ascii.eqlIgnoreCase(trimmed, "application/x-www-form-urlencoded")) return .url_encoded;
if (std.ascii.eqlIgnoreCase(trimmed, "application/json")) return .json;
if (std.ascii.endsWithIgnoreCase(trimmed, "+json")) return .json;
if (std.ascii.eqlIgnoreCase(trimmed, "application/octet-stream")) return .octet_stream;
if (std.ascii.eqlIgnoreCase(trimmed, "multipart/form-data")) return .multipart_formdata;
return .other;
}
return null;
return .other;
}
/// Parses a set of body arguments from the request body based on the request's Content-Type
@ -640,10 +730,8 @@ pub fn ParseBody(comptime Body: type) type {
return next.handle(req, res, new_ctx, {});
}
const base_content_type = matchContentType(content_type);
var stream = req.body orelse return error.NoBody;
const body = try parseBodyFromRequest(Body, base_content_type orelse .json, stream.reader(), ctx.allocator);
const body = try parseBodyFromRequest(Body, content_type, stream.reader(), ctx.allocator);
defer util.deepFree(ctx.allocator, body);
return next.handle(
@ -659,12 +747,57 @@ pub fn parseBody(comptime Body: type) ParseBody(Body) {
return .{};
}
test "parseBodyFromRequest" {
const testCase = struct {
fn case(content_type: []const u8, body: []const u8, expected: anytype) !void {
var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) };
const result = try parseBodyFromRequest(@TypeOf(expected), content_type, stream.reader(), std.testing.allocator);
defer util.deepFree(std.testing.allocator, result);
try util.testing.expectDeepEqual(expected, result);
}
}.case;
const Struct = struct {
id: usize,
};
try testCase("application/json", "{\"id\": 3}", Struct{ .id = 3 });
try testCase("application/x-www-form-urlencoded", "id=3", Struct{ .id = 3 });
//try testCase("multipart/form-data; ",
//\\
//, Struct{ .id = 3 });
}
test "parseBody" {
const Struct = struct {
foo: []const u8,
};
const body =
\\{"foo": "bar"}
;
var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) };
var headers = http.Fields.init(std.testing.allocator);
defer headers.deinit();
try parseBody(Struct).handle(
.{ .body = @as(?std.io.StreamSource, stream), .headers = headers },
.{},
.{ .allocator = std.testing.allocator },
struct {
fn handle(_: anytype, _: anytype, _: anytype, ctx: anytype, _: void) !void {
try util.testing.expectDeepEqual(Struct{ .foo = "bar" }, ctx.body);
}
}{},
);
}
/// Parses query parameters as defined in query.zig
pub fn ParseQueryParams(comptime QueryParams: type) type {
return struct {
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
if (QueryParams == void) return next.handle(req, res, addField(ctx, "query_params", {}), {});
const query = try query_utils.parseQuery(ctx.allocator, QueryParams, ctx.query_string);
const query = try urlencode.parse(ctx.allocator, QueryParams, ctx.query_string);
defer util.deepFree(ctx.allocator, query);
return next.handle(

362
src/http/multipart.zig Normal file
View file

@ -0,0 +1,362 @@
const std = @import("std");
const util = @import("util");
const fields = @import("./fields.zig");
const max_boundary = 70;
const read_ahead = max_boundary + 4;
pub fn MultipartStream(comptime ReaderType: type) type {
return struct {
const Multipart = @This();
pub const BaseReader = ReaderType;
pub const PartReader = std.io.Reader(*Part, ReaderType.Error, Part.read);
stream: std.io.PeekStream(.{ .Static = read_ahead }, ReaderType),
boundary: []const u8,
pub fn next(self: *Multipart, alloc: std.mem.Allocator) !?Part {
const reader = self.stream.reader();
while (true) {
try reader.skipUntilDelimiterOrEof('\r');
var line_buf: [read_ahead]u8 = undefined;
const len = try reader.readAll(line_buf[0 .. self.boundary.len + 3]);
const line = line_buf[0..len];
if (line.len == 0) return null;
if (std.mem.startsWith(u8, line, "\n--") and std.mem.endsWith(u8, line, self.boundary)) {
// match, check for end thing
var more_buf: [2]u8 = undefined;
if (try reader.readAll(&more_buf) != 2) return error.EndOfStream;
const more = !(more_buf[0] == '-' and more_buf[1] == '-');
try self.stream.putBack(&more_buf);
try reader.skipUntilDelimiterOrEof('\n');
if (more) return try Part.open(self, alloc) else return null;
}
}
}
pub const Part = struct {
base: ?*Multipart,
fields: fields.Fields,
pub fn open(base: *Multipart, alloc: std.mem.Allocator) !Part {
var parsed_fields = try @import("./request/parser.zig").parseHeaders(alloc, base.stream.reader());
return .{ .base = base, .fields = parsed_fields };
}
pub fn reader(self: *Part) PartReader {
return .{ .context = self };
}
pub fn close(self: *Part) void {
self.fields.deinit();
}
pub fn read(self: *Part, buf: []u8) ReaderType.Error!usize {
const base = self.base orelse return 0;
const r = base.stream.reader();
var count: usize = 0;
while (count < buf.len) {
const byte = r.readByte() catch |err| switch (err) {
error.EndOfStream => {
self.base = null;
return count;
},
else => |e| return e,
};
buf[count] = byte;
count += 1;
if (byte != '\r') continue;
var line_buf: [read_ahead]u8 = undefined;
const line = line_buf[0..try r.readAll(line_buf[0 .. base.boundary.len + 3])];
if (!std.mem.startsWith(u8, line, "\n--") or !std.mem.endsWith(u8, line, base.boundary)) {
base.stream.putBack(line) catch unreachable;
continue;
} else {
base.stream.putBack(line) catch unreachable;
base.stream.putBackByte('\r') catch unreachable;
self.base = null;
return count - 1;
}
}
return count;
}
};
};
}
pub fn openMultipart(boundary: []const u8, reader: anytype) !MultipartStream(@TypeOf(reader)) {
if (boundary.len > max_boundary) return error.BoundaryTooLarge;
var stream = .{
.stream = std.io.peekStream(read_ahead, reader),
.boundary = boundary,
};
stream.stream.putBack("\r\n") catch unreachable;
return stream;
}
const MultipartFormField = struct {
name: []const u8,
value: []const u8,
filename: ?[]const u8 = null,
content_type: ?[]const u8 = null,
};
pub const FormFile = struct {
data: []const u8,
filename: []const u8,
content_type: []const u8,
};
pub fn MultipartForm(comptime ReaderType: type) type {
return struct {
stream: MultipartStream(ReaderType),
pub fn next(self: *@This(), alloc: std.mem.Allocator) !?MultipartFormField {
var part = (try self.stream.next(alloc)) orelse return null;
defer part.close();
const disposition = part.fields.get("Content-Disposition") orelse return error.MissingDisposition;
if (!std.ascii.eqlIgnoreCase(fields.getParam(disposition, null).?, "form-data")) return error.BadDisposition;
const name = try util.deepClone(alloc, fields.getParam(disposition, "name") orelse return error.BadDisposition);
errdefer util.deepFree(alloc, name);
const filename = try util.deepClone(alloc, fields.getParam(disposition, "filename"));
errdefer util.deepFree(alloc, filename);
const content_type = try util.deepClone(alloc, part.fields.get("Content-Type"));
errdefer util.deepFree(alloc, content_type);
const value = try part.reader().readAllAlloc(alloc, 1 << 32);
return MultipartFormField{
.name = name,
.value = value,
.filename = filename,
.content_type = content_type,
};
}
};
}
pub fn openForm(multipart_stream: anytype) MultipartForm(@TypeOf(multipart_stream).BaseReader) {
return .{ .stream = multipart_stream };
}
fn Deserializer(comptime Result: type) type {
return util.DeserializerContext(Result, MultipartFormField, struct {
pub const options = .{ .isScalar = isScalar, .embed_unions = true };
pub fn isScalar(comptime T: type) bool {
if (T == FormFile or T == ?FormFile) return true;
return util.serialize.defaultIsScalar(T);
}
pub fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, val: MultipartFormField) !T {
if (T == FormFile or T == ?FormFile) return try deserializeFormFile(alloc, val);
if (val.filename != null) return error.FilenameProvidedForNonFile;
return try util.serialize.deserializeString(alloc, T, val.value);
}
fn deserializeFormFile(alloc: std.mem.Allocator, val: MultipartFormField) !FormFile {
const data = try util.deepClone(alloc, val.value);
errdefer util.deepFree(alloc, data);
const filename = try util.deepClone(alloc, val.filename orelse "(untitled)");
errdefer util.deepFree(alloc, filename);
const content_type = try util.deepClone(alloc, val.content_type orelse "application/octet-stream");
return FormFile{
.data = data,
.filename = filename,
.content_type = content_type,
};
}
});
}
pub fn parseFormData(comptime T: type, boundary: []const u8, reader: anytype, alloc: std.mem.Allocator) !T {
var form = openForm(try openMultipart(boundary, reader));
var ds = Deserializer(T){};
defer {
var iter = ds.iterator();
while (iter.next()) |pair| {
util.deepFree(alloc, pair.value);
}
}
while (true) {
var part = (try form.next(alloc)) orelse break;
errdefer util.deepFree(alloc, part);
try ds.setSerializedField(part.name, part);
}
return try ds.finish(alloc);
}
// TODO: Fix these tests
test "MultipartStream" {
const ExpectedPart = struct {
disposition: []const u8,
value: []const u8,
};
const testCase = struct {
fn case(comptime body: []const u8, boundary: []const u8, expected_parts: []const ExpectedPart) !void {
var src = std.io.StreamSource{
.const_buffer = std.io.fixedBufferStream(body),
};
var stream = try openMultipart(boundary, src.reader());
for (expected_parts) |expected| {
var part = try stream.next(std.testing.allocator) orelse return error.TestExpectedEqual;
defer part.close();
const dispo = part.fields.get("Content-Disposition") orelse return error.TestExpectedEqual;
try std.testing.expectEqualStrings(expected.disposition, dispo);
var buf: [128]u8 = undefined;
const count = try part.reader().read(&buf);
try std.testing.expectEqualStrings(expected.value, buf[0..count]);
}
try std.testing.expect(try stream.next(std.testing.allocator) == null);
}
}.case;
try testCase("--abc--\r\n", "abc", &.{});
try testCase(
util.comptimeToCrlf(
\\------abcd
\\Content-Disposition: form-data; name=first; charset=utf8
\\
\\content
\\------abcd
\\content-Disposition: form-data; name=second
\\
\\no content
\\------abcd
\\content-disposition: form-data; name=third
\\
\\
\\------abcd--
\\
),
"----abcd",
&.{
.{ .disposition = "form-data; name=first; charset=utf8", .value = "content" },
.{ .disposition = "form-data; name=second", .value = "no content" },
.{ .disposition = "form-data; name=third", .value = "" },
},
);
try testCase(
util.comptimeToCrlf(
\\--xyz
\\Content-Disposition: uhh
\\
\\xyz
\\--xyz
\\Content-disposition: ok
\\
\\ --xyz
\\--xyz--
\\
),
"xyz",
&.{
.{ .disposition = "uhh", .value = "xyz" },
.{ .disposition = "ok", .value = " --xyz" },
},
);
}
test "MultipartForm" {
const testCase = struct {
fn case(comptime body: []const u8, boundary: []const u8, expected_parts: []const MultipartFormField) !void {
var src = std.io.StreamSource{
.const_buffer = std.io.fixedBufferStream(body),
};
var form = openForm(try openMultipart(boundary, src.reader()));
for (expected_parts) |expected| {
var data = try form.next(std.testing.allocator) orelse return error.TestExpectedEqual;
defer util.deepFree(std.testing.allocator, data);
try util.testing.expectDeepEqual(expected, data);
}
try std.testing.expect(try form.next(std.testing.allocator) == null);
}
}.case;
try testCase(
util.comptimeToCrlf(
\\--abcd
\\Content-Disposition: form-data; name=foo
\\
\\content
\\--abcd--
\\
),
"abcd",
&.{.{ .name = "foo", .value = "content" }},
);
try testCase(
util.comptimeToCrlf(
\\--abcd
\\Content-Disposition: form-data; name=foo
\\
\\content
\\--abcd
\\Content-Disposition: form-data; name=bar
\\Content-Type: blah
\\
\\abcd
\\--abcd
\\Content-Disposition: form-data; name=baz; filename="myfile.txt"
\\Content-Type: text/plain
\\
\\ --abcd
\\
\\--abcd--
\\
),
"abcd",
&.{
.{ .name = "foo", .value = "content" },
.{ .name = "bar", .value = "abcd", .content_type = "blah" },
.{
.name = "baz",
.value = " --abcd\r\n",
.content_type = "text/plain",
.filename = "myfile.txt",
},
},
);
}
test "parseFormData" {
const body = util.comptimeToCrlf(
\\--abcd
\\Content-Disposition: form-data; name=foo
\\
\\content
\\--abcd--
\\
);
var src = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) };
const val = try parseFormData(struct {
foo: []const u8,
}, "abcd", src.reader(), std.testing.allocator);
util.deepFree(std.testing.allocator, val);
}

View file

@ -93,7 +93,7 @@ fn parseProto(reader: anytype) !http.Protocol {
};
}
fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields {
pub fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields {
var headers = Fields.init(allocator);
var buf: [4096]u8 = undefined;

View file

@ -1,4 +1,5 @@
const std = @import("std");
const util = @import("util");
const parser = @import("./parser.zig");
const http = @import("../lib.zig");
const t = std.testing;
@ -30,30 +31,9 @@ const test_case = struct {
}
};
fn toCrlf(comptime str: []const u8) []const u8 {
comptime {
var buf: [str.len * 2]u8 = undefined;
@setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2
var buf_len: usize = 0;
for (str) |ch| {
if (ch == '\n') {
buf[buf_len] = '\r';
buf_len += 1;
}
buf[buf_len] = ch;
buf_len += 1;
}
return buf[0..buf_len];
}
}
test "HTTP/1.x parse - No body" {
try test_case.parse(
toCrlf(
util.comptimeToCrlf(
\\GET / HTTP/1.1
\\
\\
@ -65,7 +45,7 @@ test "HTTP/1.x parse - No body" {
},
);
try test_case.parse(
toCrlf(
util.comptimeToCrlf(
\\POST / HTTP/1.1
\\
\\
@ -77,7 +57,7 @@ test "HTTP/1.x parse - No body" {
},
);
try test_case.parse(
toCrlf(
util.comptimeToCrlf(
\\GET /url/abcd HTTP/1.1
\\
\\
@ -89,7 +69,7 @@ test "HTTP/1.x parse - No body" {
},
);
try test_case.parse(
toCrlf(
util.comptimeToCrlf(
\\GET / HTTP/1.0
\\
\\
@ -101,7 +81,7 @@ test "HTTP/1.x parse - No body" {
},
);
try test_case.parse(
toCrlf(
util.comptimeToCrlf(
\\GET /url/abcd HTTP/1.1
\\Content-Type: application/json
\\
@ -115,7 +95,7 @@ test "HTTP/1.x parse - No body" {
},
);
try test_case.parse(
toCrlf(
util.comptimeToCrlf(
\\GET /url/abcd HTTP/1.1
\\Content-Type: application/json
\\Authorization: bearer <token>
@ -163,7 +143,7 @@ test "HTTP/1.x parse - No body" {
},
);
try test_case.parse(
toCrlf(
util.comptimeToCrlf(
\\GET / HTTP/1.2
\\
\\
@ -265,7 +245,7 @@ test "HTTP/1.x parse - bad requests" {
test "HTTP/1.x parse - Headers" {
try test_case.parse(
toCrlf(
util.comptimeToCrlf(
\\GET /url/abcd HTTP/1.1
\\Content-Type: application/json
\\Content-Type: application/xml

View file

@ -1,4 +1,5 @@
const std = @import("std");
const util = @import("util");
const http = @import("../lib.zig");
const Status = http.Status;
@ -169,25 +170,7 @@ test {
_ = _tests;
}
const _tests = struct {
fn toCrlf(comptime str: []const u8) []const u8 {
comptime {
var buf: [str.len * 2]u8 = undefined;
@setEvalBranchQuota(@as(u32, str.len * 2));
var len: usize = 0;
for (str) |ch| {
if (ch == '\n') {
buf[len] = '\r';
len += 1;
}
buf[len] = ch;
len += 1;
}
return buf[0..len];
}
}
const toCrlf = util.comptimeToCrlf;
const test_buffer_size = chunk_size * 4;
test "ResponseStream no headers empty body" {

View file

@ -1,5 +0,0 @@
test {
_ = @import("./request/test_parser.zig");
_ = @import("./middleware.zig");
_ = @import("./query.zig");
}

View file

@ -1,7 +1,38 @@
const std = @import("std");
const util = @import("util");
const QueryIter = util.QueryIter;
pub const Iter = struct {
const Pair = struct {
key: []const u8,
value: ?[]const u8,
};
iter: std.mem.SplitIterator(u8),
pub fn from(q: []const u8) Iter {
return Iter{
.iter = std.mem.split(u8, std.mem.trimLeft(u8, q, "?"), "&"),
};
}
pub fn next(self: *Iter) ?Pair {
while (true) {
const part = self.iter.next() orelse return null;
if (part.len == 0) continue;
const key = std.mem.sliceTo(part, '=');
if (key.len == part.len) return Pair{
.key = key,
.value = null,
};
return Pair{
.key = key,
.value = part[key.len + 1 ..],
};
}
}
};
/// Parses a set of query parameters described by the struct `T`.
///
@ -67,25 +98,44 @@ const QueryIter = util.QueryIter;
/// Would be used to parse a query string like
/// `?foo.baz=12345`
///
pub fn parseQuery(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T {
if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct");
var iter = QueryIter.from(query);
pub fn parse(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T {
var iter = Iter.from(query);
var deserializer = Deserializer(T){};
var fields = Intermediary(T){};
while (iter.next()) |pair| {
// TODO: Hash map
inline for (std.meta.fields(Intermediary(T))) |field| {
if (std.ascii.eqlIgnoreCase(field.name[2..], pair.key)) {
@field(fields, field.name) = if (pair.value) |v| .{ .value = v } else .{ .no_value = {} };
break;
}
} else std.log.debug("unknown param {s}", .{pair.key});
try deserializer.setSerializedField(pair.key, pair.value);
}
return (try parse(alloc, T, "", "", fields)) orelse error.NoQuery;
return try deserializer.finish(alloc);
}
pub fn parseQueryFree(alloc: std.mem.Allocator, val: anytype) void {
fn Deserializer(comptime Result: type) type {
return util.DeserializerContext(Result, ?[]const u8, struct {
pub const options = util.serialize.default_options;
pub fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, maybe_val: ?[]const u8) !T {
const is_optional = comptime std.meta.trait.is(.Optional)(T);
if (maybe_val) |val| {
if (val.len == 0 and is_optional) return null;
const decoded = try decodeString(alloc, val);
defer alloc.free(decoded);
return try util.serialize.deserializeString(alloc, T, decoded);
} else {
// If param is present, but without an associated value
return if (is_optional)
null
else if (T == bool)
true
else
error.InvalidValue;
}
}
});
}
pub fn parseFree(alloc: std.mem.Allocator, val: anytype) void {
util.deepFree(alloc, val);
}
@ -110,186 +160,6 @@ fn decodeString(alloc: std.mem.Allocator, val: []const u8) ![]u8 {
return list.toOwnedSlice();
}
fn parseScalar(alloc: std.mem.Allocator, comptime T: type, comptime name: []const u8, fields: anytype) !?T {
const param = @field(fields, name);
return switch (param) {
.not_specified => null,
.no_value => try parseQueryValue(alloc, T, null),
.value => |v| try parseQueryValue(alloc, T, v),
};
}
fn parse(
alloc: std.mem.Allocator,
comptime T: type,
comptime prefix: []const u8,
comptime name: []const u8,
fields: anytype,
) !?T {
if (comptime isScalar(T)) return parseScalar(alloc, T, prefix ++ "." ++ name, fields);
switch (@typeInfo(T)) {
.Union => |info| {
var result: ?T = null;
inline for (info.fields) |field| {
const F = field.field_type;
const maybe_value = try parse(alloc, F, prefix, field.name, fields);
if (maybe_value) |value| {
if (result != null) return error.DuplicateUnionField;
result = @unionInit(T, field.name, value);
}
}
std.log.debug("{any}", .{result});
return result;
},
.Struct => |info| {
var result: T = undefined;
var fields_specified: usize = 0;
errdefer inline for (info.fields) |field, i| {
if (fields_specified < i) util.deepFree(alloc, @field(result, field.name));
};
inline for (info.fields) |field| {
const F = field.field_type;
var maybe_value: ?F = null;
if (try parse(alloc, F, prefix ++ "." ++ name, field.name, fields)) |v| {
maybe_value = v;
} else if (field.default_value) |default| {
if (comptime @sizeOf(F) != 0) {
maybe_value = try util.deepClone(alloc, @ptrCast(*const F, @alignCast(@alignOf(F), default)).*);
} else {
maybe_value = std.mem.zeroes(F);
}
}
if (maybe_value) |v| {
fields_specified += 1;
@field(result, field.name) = v;
}
}
if (fields_specified == 0) {
return null;
} else if (fields_specified != info.fields.len) {
std.log.debug("{} {s} {s}", .{ T, prefix, name });
return error.PartiallySpecifiedStruct;
} else {
return result;
}
},
// Only applies to non-scalar optionals
.Optional => |info| return try parse(alloc, info.child, prefix, name, fields),
else => @compileError("tmp"),
}
}
fn recursiveFieldPaths(comptime T: type, comptime prefix: []const u8) []const []const u8 {
comptime {
if (std.meta.trait.is(.Optional)(T)) return recursiveFieldPaths(std.meta.Child(T), prefix);
var fields: []const []const u8 = &.{};
for (std.meta.fields(T)) |f| {
const full_name = prefix ++ f.name;
if (isScalar(f.field_type)) {
fields = fields ++ @as([]const []const u8, &.{full_name});
} else {
const field_prefix = if (@typeInfo(f.field_type) == .Union) prefix else full_name ++ ".";
fields = fields ++ recursiveFieldPaths(f.field_type, field_prefix);
}
}
return fields;
}
}
const QueryParam = union(enum) {
not_specified: void,
no_value: void,
value: []const u8,
};
fn Intermediary(comptime T: type) type {
const field_names = recursiveFieldPaths(T, "..");
var fields: [field_names.len]std.builtin.Type.StructField = undefined;
for (field_names) |name, i| fields[i] = .{
.name = name,
.field_type = QueryParam,
.default_value = &QueryParam{ .not_specified = {} },
.is_comptime = false,
.alignment = @alignOf(QueryParam),
};
return @Type(.{ .Struct = .{
.layout = .Auto,
.fields = &fields,
.decls = &.{},
.is_tuple = false,
} });
}
fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, maybe_value: ?[]const u8) !T {
const is_optional = comptime std.meta.trait.is(.Optional)(T);
if (maybe_value) |value| {
const Eff = if (is_optional) std.meta.Child(T) else T;
if (value.len == 0 and is_optional) return null;
const decoded = try decodeString(alloc, value);
errdefer alloc.free(decoded);
if (comptime std.meta.trait.isZigString(Eff)) return decoded;
defer alloc.free(decoded);
const result = if (comptime std.meta.trait.isIntegral(Eff))
try std.fmt.parseInt(Eff, decoded, 0)
else if (comptime std.meta.trait.isFloat(Eff))
try std.fmt.parseFloat(Eff, decoded)
else if (comptime std.meta.trait.is(.Enum)(Eff)) blk: {
_ = std.ascii.lowerString(decoded, decoded);
break :blk std.meta.stringToEnum(Eff, decoded) orelse return error.InvalidEnumValue;
} else if (Eff == bool) blk: {
_ = std.ascii.lowerString(decoded, decoded);
break :blk bool_map.get(decoded) orelse return error.InvalidBool;
} else if (comptime std.meta.trait.hasFn("parse")(Eff))
try Eff.parse(value)
else
@compileError("Invalid type " ++ @typeName(T));
return result;
} else {
// If param is present, but without an associated value
return if (is_optional)
null
else if (T == bool)
true
else
error.InvalidValue;
}
}
const bool_map = std.ComptimeStringMap(bool, .{
.{ "true", true },
.{ "t", true },
.{ "yes", true },
.{ "y", true },
.{ "1", true },
.{ "false", false },
.{ "f", false },
.{ "no", false },
.{ "n", false },
.{ "0", false },
});
fn isScalar(comptime T: type) bool {
if (comptime std.meta.trait.isZigString(T)) return true;
if (comptime std.meta.trait.isIntegral(T)) return true;
@ -304,7 +174,7 @@ fn isScalar(comptime T: type) bool {
return false;
}
pub fn QueryStringify(comptime Params: type) type {
pub fn EncodeStruct(comptime Params: type) type {
return struct {
params: Params,
pub fn format(v: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
@ -312,8 +182,8 @@ pub fn QueryStringify(comptime Params: type) type {
}
};
}
pub fn queryStringify(val: anytype) QueryStringify(@TypeOf(val)) {
return QueryStringify(@TypeOf(val)){ .params = val };
pub fn encodeStruct(val: anytype) EncodeStruct(@TypeOf(val)) {
return EncodeStruct(@TypeOf(val)){ .params = val };
}
fn urlFormatString(writer: anytype, val: []const u8) !void {
@ -375,11 +245,11 @@ fn formatQuery(comptime prefix: []const u8, comptime name: []const u8, params: a
}
}
test "parseQuery" {
test "parse" {
const testCase = struct {
fn case(comptime T: type, expected: T, query_string: []const u8) !void {
const result = try parseQuery(std.testing.allocator, T, query_string);
defer parseQueryFree(std.testing.allocator, result);
const result = try parse(std.testing.allocator, T, query_string);
defer parseFree(std.testing.allocator, result);
try util.testing.expectDeepEqual(expected, result);
}
}.case;
@ -465,14 +335,46 @@ test "parseQuery" {
try testCase(SubUnion2, .{ .sub = .{ .foo = 1, .val = .{ .baz = "abc" } } }, "sub.foo=1&sub.baz=abc");
}
test "formatQuery" {
try std.testing.expectFmt("", "{}", .{queryStringify(.{})});
try std.testing.expectFmt("id=3&", "{}", .{queryStringify(.{ .id = 3 })});
try std.testing.expectFmt("id=3&id2=4&", "{}", .{queryStringify(.{ .id = 3, .id2 = 4 })});
test "encodeStruct" {
try std.testing.expectFmt("", "{}", .{encodeStruct(.{})});
try std.testing.expectFmt("id=3&", "{}", .{encodeStruct(.{ .id = 3 })});
try std.testing.expectFmt("id=3&id2=4&", "{}", .{encodeStruct(.{ .id = 3, .id2 = 4 })});
try std.testing.expectFmt("str=foo&", "{}", .{queryStringify(.{ .str = "foo" })});
try std.testing.expectFmt("enum_str=foo&", "{}", .{queryStringify(.{ .enum_str = .foo })});
try std.testing.expectFmt("str=foo&", "{}", .{encodeStruct(.{ .str = "foo" })});
try std.testing.expectFmt("enum_str=foo&", "{}", .{encodeStruct(.{ .enum_str = .foo })});
try std.testing.expectFmt("boolean=false&", "{}", .{queryStringify(.{ .boolean = false })});
try std.testing.expectFmt("boolean=true&", "{}", .{queryStringify(.{ .boolean = true })});
try std.testing.expectFmt("boolean=false&", "{}", .{encodeStruct(.{ .boolean = false })});
try std.testing.expectFmt("boolean=true&", "{}", .{encodeStruct(.{ .boolean = true })});
}
test "Iter" {
const testCase = struct {
fn case(str: []const u8, pairs: []const Iter.Pair) !void {
var iter = Iter.from(str);
for (pairs) |pair| {
try util.testing.expectDeepEqual(@as(?Iter.Pair, pair), iter.next());
}
try std.testing.expect(iter.next() == null);
}
}.case;
try testCase("", &.{});
try testCase("abc", &.{.{ .key = "abc", .value = null }});
try testCase("abc=", &.{.{ .key = "abc", .value = "" }});
try testCase("abc=def", &.{.{ .key = "abc", .value = "def" }});
try testCase("abc=def&", &.{.{ .key = "abc", .value = "def" }});
try testCase("?abc=def&", &.{.{ .key = "abc", .value = "def" }});
try testCase("?abc=def&foo&bar=baz&qux=", &.{
.{ .key = "abc", .value = "def" },
.{ .key = "foo", .value = null },
.{ .key = "bar", .value = "baz" },
.{ .key = "qux", .value = "" },
});
try testCase("?abc=def&&foo&bar=baz&&qux=&", &.{
.{ .key = "abc", .value = "def" },
.{ .key = "foo", .value = null },
.{ .key = "bar", .value = "baz" },
.{ .key = "qux", .value = "" },
});
try testCase("&=def&", &.{.{ .key = "", .value = "def" }});
}

View file

@ -267,13 +267,13 @@ pub const helpers = struct {
try std.fmt.format(
writer,
"<{s}://{s}/{s}?{}>; rel=\"{s}\"",
.{ @tagName(c.scheme), c.host, path, http.queryStringify(params), rel },
.{ @tagName(c.scheme), c.host, path, http.urlencode.encodeStruct(params), rel },
);
} else {
try std.fmt.format(
writer,
"<{s}?{}>; rel=\"{s}\"",
.{ path, http.queryStringify(params), rel },
.{ path, http.urlencode.encodeStruct(params), rel },
);
}
// TODO: percent-encode

View file

@ -2,6 +2,7 @@ const controllers = @import("../controllers.zig");
const auth = @import("./api/auth.zig");
const communities = @import("./api/communities.zig");
const drive = @import("./api/drive.zig");
const invites = @import("./api/invites.zig");
const users = @import("./api/users.zig");
const follows = @import("./api/users/follows.zig");
@ -26,4 +27,6 @@ pub const routes = .{
controllers.apiEndpoint(follows.delete),
controllers.apiEndpoint(follows.query_followers),
controllers.apiEndpoint(follows.query_following),
controllers.apiEndpoint(drive.upload),
controllers.apiEndpoint(drive.mkdir),
};

View file

@ -0,0 +1,144 @@
const api = @import("api");
const http = @import("http");
const util = @import("util");
const controller_utils = @import("../../controllers.zig").helpers;
const Uuid = util.Uuid;
const DateTime = util.DateTime;
pub const drive_path = "/drive/:path*";
pub const DriveArgs = struct {
path: []const u8,
};
pub const query = struct {
pub const method = .GET;
pub const path = drive_path;
pub const Args = DriveArgs;
pub const Query = struct {
const OrderBy = enum {
created_at,
filename,
};
max_items: usize = 20,
like: ?[]const u8 = null,
order_by: OrderBy = .created_at,
direction: api.Direction = .descending,
prev: ?struct {
id: Uuid,
order_val: union(OrderBy) {
created_at: DateTime,
filename: []const u8,
},
} = null,
page_direction: api.PageDirection = .forward,
};
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
const result = srv.driveQuery(req.args.path, req.query) catch |err| switch (err) {
error.NotADirectory => {
const meta = try srv.getFile(path);
try res.json(.ok, meta);
return;
},
else => |e| return e,
};
try controller_utils.paginate(result, res, req.allocator);
}
};
pub const upload = struct {
pub const method = .POST;
pub const path = drive_path;
pub const Args = DriveArgs;
pub const Body = struct {
file: http.FormFile,
description: ?[]const u8 = null,
sensitive: bool = false,
};
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
const f = req.body.file;
try srv.uploadFile(.{
.dir = req.args.path,
.filename = f.filename,
.description = req.body.description,
.content_type = f.content_type,
.sensitive = req.body.sensitive,
}, f.data);
// TODO: print meta
try res.json(.created, .{});
}
};
pub const delete = struct {
pub const method = .DELETE;
pub const path = drive_path;
pub const Args = DriveArgs;
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
const info = try srv.driveLookup(req.args.path);
if (info == .dir)
try srv.driveRmdir(req.args.path)
else if (info == .file)
try srv.deleteFile(req.args.path);
return res.json(.ok, .{});
}
};
pub const mkdir = struct {
pub const method = .MKCOL;
pub const path = drive_path;
pub const Args = DriveArgs;
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
try srv.driveMkdir(req.args.path);
return res.json(.created, .{});
}
};
pub const update = struct {
pub const method = .PUT;
pub const path = drive_path;
pub const Args = DriveArgs;
pub const Body = struct {
description: ?[]const u8 = null,
content_type: ?[]const u8 = null,
sensitive: ?bool = null,
};
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
const info = try srv.driveLookup(req.args.path);
if (info != .file) return error.NotFile;
const new_info = try srv.updateFile(path, req.body);
try res.json(.ok, new_info);
}
};
pub const move = struct {
pub const method = .MOVE;
pub const path = drive_path;
pub const Args = DriveArgs;
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
const destination = req.fields.get("Destination") orelse return error.NoDestination;
try srv.driveMove(req.args.path, destination);
try res.fields.put("Location", destination);
try srv.json(.created, .{});
}
};

View file

@ -19,8 +19,9 @@ fn execScript(db: anytype, script: []const u8, alloc: std.mem.Allocator) !void {
const tx = try db.beginOrSavepoint();
errdefer tx.rollback();
var iter = util.SqlStmtIter.from(script);
var iter = std.mem.split(u8, script, ";");
while (iter.next()) |stmt| {
if (stmt.len == 0) continue;
try execStmt(tx, stmt, alloc);
}
@ -208,22 +209,139 @@ const migrations: []const Migration = &.{
.{
.name = "files",
.up =
\\CREATE TABLE drive_file(
\\CREATE TABLE file_upload(
\\ id UUID NOT NULL PRIMARY KEY,
\\
\\ filename TEXT NOT NULL,
\\ account_owner_id UUID REFERENCES account(id),
\\ community_owner_id UUID REFERENCES community(id),
\\ created_by UUID REFERENCES account(id),
\\ size INTEGER NOT NULL,
\\
\\ filename TEXT NOT NULL,
\\ description TEXT,
\\ content_type TEXT,
\\ sensitive BOOLEAN NOT NULL,
\\
\\ is_deleted BOOLEAN NOT NULL DEFAULT FALSE,
\\
\\ created_at TIMESTAMPTZ NOT NULL,
\\ updated_at TIMESTAMPTZ NOT NULL
\\);
\\
\\CREATE TABLE drive_entry(
\\ id UUID NOT NULL PRIMARY KEY,
\\
\\ account_owner_id UUID REFERENCES account(id),
\\ community_owner_id UUID REFERENCES community(id),
\\
\\ name TEXT,
\\ parent_directory_id UUID REFERENCES drive_entry(id),
\\
\\ file_id UUID REFERENCES file_upload(id),
\\
\\ CHECK(
\\ (account_owner_id IS NULL AND community_owner_id IS NOT NULL)
\\ OR (account_owner_id IS NOT NULL AND community_owner_id IS NULL)
\\ ),
\\ CHECK(
\\ (name IS NULL AND parent_directory_id IS NULL AND file_id IS NULL)
\\ OR (name IS NOT NULL AND parent_directory_id IS NOT NULL)
\\ )
\\);
\\CREATE UNIQUE INDEX drive_entry_uniqueness
\\ON drive_entry(
\\ name,
\\ COALESCE(parent_directory_id, ''),
\\ COALESCE(account_owner_id, community_owner_id)
\\);
,
.down = "DROP TABLE drive_file",
.down =
\\DROP INDEX drive_entry_uniqueness;
\\DROP TABLE drive_entry;
\\DROP TABLE file_upload;
,
},
.{
.name = "drive_entry_path",
.up =
\\CREATE VIEW drive_entry_path(
\\ id,
\\ path,
\\ account_owner_id,
\\ community_owner_id,
\\ kind
\\) AS WITH RECURSIVE full_path(
\\ id,
\\ path,
\\ account_owner_id,
\\ community_owner_id,
\\ kind
\\) AS (
\\ SELECT
\\ id,
\\ '' AS path,
\\ account_owner_id,
\\ community_owner_id,
\\ 'dir' AS kind
\\ FROM drive_entry
\\ WHERE parent_directory_id IS NULL
\\ UNION ALL
\\ SELECT
\\ base.id,
\\ (dir.path || '/' || base.name) AS path,
\\ base.account_owner_id,
\\ base.community_owner_id,
\\ (CASE WHEN base.file_id IS NULL THEN 'dir' ELSE 'file' END) as kind
\\ FROM drive_entry AS base
\\ JOIN full_path AS dir ON
\\ base.parent_directory_id = dir.id
\\ AND base.account_owner_id IS NOT DISTINCT FROM dir.account_owner_id
\\ AND base.community_owner_id IS NOT DISTINCT FROM dir.community_owner_id
\\)
\\SELECT
\\ id,
\\ (CASE WHEN kind = 'dir' THEN path || '/' ELSE path END) AS path,
\\ account_owner_id,
\\ community_owner_id,
\\ kind
\\FROM full_path;
,
.down =
\\DROP VIEW drive_entry_path;
,
},
.{
.name = "create drive root directories",
.up =
\\INSERT INTO drive_entry(
\\ id,
\\ account_owner_id,
\\ community_owner_id,
\\ parent_directory_id,
\\ name,
\\ file_id
\\) SELECT
\\ id,
\\ id AS account_owner_id,
\\ NULL AS community_owner_id,
\\ NULL AS parent_directory_id,
\\ NULL AS name,
\\ NULL AS file_id
\\FROM account;
\\INSERT INTO drive_entry(
\\ id,
\\ account_owner_id,
\\ community_owner_id,
\\ parent_directory_id,
\\ name,
\\ file_id
\\) SELECT
\\ id,
\\ NULL AS account_owner_id,
\\ id AS community_owner_id,
\\ NULL AS parent_directory_id,
\\ NULL AS name,
\\ NULL AS file_id
\\FROM community;
,
.down = "",
},
};

View file

@ -88,7 +88,7 @@ pub fn prepareParamText(arena: *std.heap.ArenaAllocator, val: anytype) !?[:0]con
else => |T| switch (@typeInfo(T)) {
.Enum => return @tagName(val),
.Optional => if (val) |v| try prepareParamText(arena, v) else null,
.Int => try std.fmt.allocPrintZ(arena.allocator(), "{}", .{val}),
.Bool, .Int => try std.fmt.allocPrintZ(arena.allocator(), "{}", .{val}),
.Union => loop: inline for (std.meta.fields(T)) |field| {
// Have to do this in a roundabout way to satisfy comptime checker
const Tag = std.meta.Tag(T);

View file

@ -193,6 +193,7 @@ pub const Db = struct {
.Null => return self.bindNull(stmt, idx),
.Int => return self.bindInt(stmt, idx, std.math.cast(i64, val) orelse unreachable),
.Float => return self.bindFloat(stmt, idx, val),
.Bool => return self.bindInt(stmt, idx, if (val) 1 else 0),
else => @compileError("Unable to serialize type " ++ @typeName(T)),
}
}
@ -251,18 +252,20 @@ pub const Results = struct {
db: *c.sqlite3,
pub fn finish(self: Results) void {
switch (c.sqlite3_finalize(self.stmt)) {
c.SQLITE_OK => {},
else => |err| {
handleUnexpectedError(self.db, err, self.getGeneratingSql()) catch {};
},
}
_ = c.sqlite3_finalize(self.stmt);
}
pub fn row(self: Results) common.RowError!?Row {
return switch (c.sqlite3_step(self.stmt)) {
c.SQLITE_ROW => Row{ .stmt = self.stmt, .db = self.db },
c.SQLITE_DONE => null,
c.SQLITE_CONSTRAINT_UNIQUE => return error.UniqueViolation,
c.SQLITE_CONSTRAINT_CHECK => return error.CheckViolation,
c.SQLITE_CONSTRAINT_NOTNULL => return error.NotNullViolation,
c.SQLITE_CONSTRAINT_FOREIGNKEY => return error.ForeignKeyViolation,
c.SQLITE_CONSTRAINT => return error.ConstraintViolation,
else => |err| handleUnexpectedError(self.db, err, self.getGeneratingSql()),
};
}

View file

@ -144,42 +144,16 @@ fn fieldPtr(ptr: anytype, comptime names: []const []const u8) FieldPtr(@TypeOf(p
return fieldPtr(&@field(ptr.*, names[0]), names[1..]);
}
fn isScalar(comptime T: type) bool {
if (comptime std.meta.trait.isZigString(T)) return true;
if (comptime std.meta.trait.isIntegral(T)) return true;
if (comptime std.meta.trait.isFloat(T)) return true;
if (comptime std.meta.trait.is(.Enum)(T)) return true;
if (T == bool) return true;
if (comptime std.meta.trait.hasFn("parse")(T)) return true;
if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true;
return false;
}
fn recursiveFieldPaths(comptime T: type, comptime prefix: []const []const u8) []const []const []const u8 {
comptime {
var fields: []const []const []const u8 = &.{};
for (std.meta.fields(T)) |f| {
const full_name = prefix ++ [_][]const u8{f.name};
if (isScalar(f.field_type)) {
fields = fields ++ [_][]const []const u8{full_name};
} else {
fields = fields ++ recursiveFieldPaths(f.field_type, full_name);
}
}
return fields;
}
}
// Represents a set of results.
// row() must be called until it returns null, or the query may not complete
// Must be deallocated by a call to finish()
pub fn Results(comptime T: type) type {
// would normally make this a declaration of the struct, but it causes the compiler to crash
const fields = if (T == void) .{} else recursiveFieldPaths(T, &.{});
const fields = if (T == void) .{} else util.serialize.getRecursiveFieldList(
T,
&.{},
util.serialize.default_options,
);
return struct {
const Self = @This();
@ -457,6 +431,7 @@ fn Tx(comptime tx_level: u8) type {
pub fn rollback(self: Self) void {
(if (tx_level < 2) self.rollbackTx() else self.rollbackSavepoint()) catch |err| {
std.log.err("Failed to rollback transaction: {}", .{err});
std.log.err("{any}", .{@errorReturnTrace()});
@panic("TODO: more gracefully handle rollback failures");
};
}
@ -654,7 +629,7 @@ fn Tx(comptime tx_level: u8) type {
}
fn rollbackUnchecked(self: Self) !void {
try self.exec("ROLLBACK", {}, null);
try self.execInternal("ROLLBACK", {}, null, false);
}
};
}

View file

@ -601,3 +601,20 @@ const ControlTokenIter = struct {
self.peeked_token = token;
}
};
test "template" {
const testCase = struct {
fn case(comptime tmpl: []const u8, args: anytype, expected: []const u8) !void {
var stream = std.io.changeDetectionStream(expected, std.io.null_writer);
try execute(stream.writer(), tmpl, args);
try std.testing.expect(!stream.changeDetected());
}
}.case;
try testCase("", .{}, "");
try testCase("abcd", .{}, "abcd");
try testCase("{.val}", .{ .val = 3 }, "3");
try testCase("{#if .val}1{/if}", .{ .val = true }, "1");
try testCase("{#for .vals |$v|}{$v}{/for}", .{ .vals = [_]u8{ 1, 2, 3 } }, "123");
try testCase("{#for .vals |$v|=} {$v} {=/for}", .{ .vals = [_]u8{ 1, 2, 3 } }, "123");
}

View file

@ -1,161 +0,0 @@
const Url = @This();
const std = @import("std");
scheme: []const u8,
hostport: []const u8,
path: []const u8,
query: []const u8,
fragment: []const u8,
pub fn parse(url: []const u8) !Url {
const scheme_end = for (url) |ch, i| {
if (ch == ':') break i;
} else return error.InvalidUrl;
if (url.len < scheme_end + 3 or url[scheme_end + 1] != '/' or url[scheme_end + 1] != '/') return error.InvalidUrl;
const hostport_start = scheme_end + 3;
const hostport_end = for (url[hostport_start..]) |ch, i| {
if (ch == '/' or ch == '?' or ch == '#') break i + hostport_start;
} else url.len;
const path_end = for (url[hostport_end..]) |ch, i| {
if (ch == '?' or ch == '#') break i + hostport_end;
} else url.len;
const query_end = if (!(url.len > path_end and url[path_end] == '?'))
path_end
else for (url[path_end..]) |ch, i| {
if (ch == '#') break i + path_end;
} else url.len;
const query = url[path_end..query_end];
const fragment = url[query_end..];
return Url{
.scheme = url[0..scheme_end],
.hostport = url[hostport_start..hostport_end],
.path = url[hostport_end..path_end],
.query = if (query.len > 0) query[1..] else query,
.fragment = if (fragment.len > 0) fragment[1..] else fragment,
};
}
pub fn getQuery(self: Url, param: []const u8) ?[]const u8 {
var key_start: usize = 0;
std.log.debug("query: {s}", .{self.query});
while (key_start < self.query.len) {
const key_end = for (self.query[key_start..]) |ch, i| {
if (ch == '=') break key_start + i;
} else return null;
const val_start = key_end + 1;
const val_end = for (self.query[val_start..]) |ch, i| {
if (ch == '&') break val_start + i;
} else self.query.len;
const key = self.query[key_start..key_end];
if (std.mem.eql(u8, key, param)) return self.query[val_start..val_end];
key_start = val_end + 1;
}
return null;
}
pub fn strDecode(buf: []u8, str: []const u8) ![]u8 {
var str_i: usize = 0;
var buf_i: usize = 0;
while (str_i < str.len) : ({
str_i += 1;
buf_i += 1;
}) {
if (buf_i >= buf.len) return error.NoSpaceLeft;
const ch = str[str_i];
if (ch == '%') {
if (str.len < str_i + 2) return error.BadEscape;
const hi = try std.fmt.charToDigit(str[str_i + 1], 16);
const lo = try std.fmt.charToDIgit(str[str_i + 2], 16);
str_i += 2;
buf[buf_i] = (hi << 4) | lo;
} else {
buf[buf_i] = str[str_i];
}
}
return buf[0..buf_i];
}
fn expectEqualUrl(expected: Url, actual: Url) !void {
const t = @import("std").testing;
try t.expectEqualStrings(expected.scheme, actual.scheme);
try t.expectEqualStrings(expected.hostport, actual.hostport);
try t.expectEqualStrings(expected.path, actual.path);
try t.expectEqualStrings(expected.query, actual.query);
try t.expectEqualStrings(expected.fragment, actual.fragment);
}
test "Url" {
try expectEqualUrl(.{
.scheme = "https",
.hostport = "example.com",
.path = "",
.query = "",
.fragment = "",
}, try Url.parse("https://example.com"));
try expectEqualUrl(.{
.scheme = "https",
.hostport = "example.com:1234",
.path = "",
.query = "",
.fragment = "",
}, try Url.parse("https://example.com:1234"));
try expectEqualUrl(.{
.scheme = "http",
.hostport = "example.com",
.path = "/home",
.query = "",
.fragment = "",
}, try Url.parse("http://example.com/home"));
try expectEqualUrl(.{
.scheme = "https",
.hostport = "example.com",
.path = "",
.query = "query=abc",
.fragment = "",
}, try Url.parse("https://example.com?query=abc"));
try expectEqualUrl(.{
.scheme = "https",
.hostport = "example.com",
.path = "",
.query = "query=abc",
.fragment = "",
}, try Url.parse("https://example.com?query=abc"));
try expectEqualUrl(.{
.scheme = "https",
.hostport = "example.com",
.path = "/path/to/resource",
.query = "query=abc",
.fragment = "123",
}, try Url.parse("https://example.com/path/to/resource?query=abc#123"));
const t = @import("std").testing;
try t.expectError(error.InvalidUrl, Url.parse("https:example.com"));
try t.expectError(error.InvalidUrl, Url.parse("example.com"));
}
test "Url.getQuery" {
const url = try Url.parse("https://example.com?a=xyz&b=jkl");
const t = @import("std").testing;
try t.expectEqualStrings("xyz", url.getQuery("a").?);
try t.expectEqualStrings("jkl", url.getQuery("b").?);
try t.expect(url.getQuery("c") == null);
try t.expect(url.getQuery("xyz") == null);
}

View file

@ -1,106 +0,0 @@
const std = @import("std");
const Hash = std.hash.Wyhash;
const View = std.unicode.Utf8View;
const toLower = std.ascii.toLower;
const isAscii = std.ascii.isASCII;
const hash_seed = 1;
pub fn hash(str: []const u8) u64 {
// fallback to regular hash on invalid utf8
const view = View.init(str) catch return Hash.hash(hash_seed, str);
var iter = view.iterator();
var h = Hash.init(hash_seed);
var it = iter.nextCodepointSlice();
while (it != null) : (it = iter.nextCodepointSlice()) {
if (it.?.len == 1 and isAscii(it.?[0])) {
const ch = [1]u8{toLower(it.?[0])};
h.update(&ch);
} else {
h.update(it.?);
}
}
return h.final();
}
pub fn eql(a: []const u8, b: []const u8) bool {
if (a.len != b.len) return false;
const va = View.init(a) catch return std.mem.eql(u8, a, b);
const vb = View.init(b) catch return false;
var iter_a = va.iterator();
var iter_b = vb.iterator();
var it_a = iter_a.nextCodepointSlice();
var it_b = iter_b.nextCodepointSlice();
while (it_a != null and it_b != null) : ({
it_a = iter_a.nextCodepointSlice();
it_b = iter_b.nextCodepointSlice();
}) {
if (it_a.?.len != it_b.?.len) return false;
if (it_a.?.len == 1) {
if (isAscii(it_a.?[0]) and isAscii(it_b.?[0])) {
const ch_a = toLower(it_a.?[0]);
const ch_b = toLower(it_b.?[0]);
if (ch_a != ch_b) return false;
} else if (it_a.?[0] != it_b.?[0]) return false;
} else if (!std.mem.eql(u8, it_a.?, it_b.?)) return false;
}
return it_a == null and it_b == null;
}
test "case insensitive eql with utf-8 chars" {
const t = std.testing;
try t.expectEqual(true, eql("abc 💯 def", "aBc 💯 DEF"));
try t.expectEqual(false, eql("xyz 💯 ijk", "aBc 💯 DEF"));
try t.expectEqual(false, eql("abc 💯 def", "aBc x DEF"));
try t.expectEqual(true, eql("💯", "💯"));
try t.expectEqual(false, eql("💯", "a"));
try t.expectEqual(false, eql("💯", "💯 continues"));
try t.expectEqual(false, eql("💯 fsdfs", "💯"));
try t.expectEqual(false, eql("💯", ""));
try t.expectEqual(false, eql("", "💯"));
try t.expectEqual(true, eql("abc x def", "aBc x DEF"));
try t.expectEqual(false, eql("xyz x ijk", "aBc x DEF"));
try t.expectEqual(true, eql("x", "x"));
try t.expectEqual(false, eql("x", "a"));
try t.expectEqual(false, eql("x", "x continues"));
try t.expectEqual(false, eql("x fsdfs", "x"));
try t.expectEqual(false, eql("x", ""));
try t.expectEqual(false, eql("", "x"));
try t.expectEqual(true, eql("", ""));
}
test "case insensitive hash with utf-8 chars" {
const t = std.testing;
try t.expect(hash("abc 💯 def") == hash("aBc 💯 DEF"));
try t.expect(hash("xyz 💯 ijk") != hash("aBc 💯 DEF"));
try t.expect(hash("abc 💯 def") != hash("aBc x DEF"));
try t.expect(hash("💯") == hash("💯"));
try t.expect(hash("💯") != hash("a"));
try t.expect(hash("💯") != hash("💯 continues"));
try t.expect(hash("💯 fsdfs") != hash("💯"));
try t.expect(hash("💯") != hash(""));
try t.expect(hash("") != hash("💯"));
try t.expect(hash("abc x def") == hash("aBc x DEF"));
try t.expect(hash("xyz x ijk") != hash("aBc x DEF"));
try t.expect(hash("x") == hash("x"));
try t.expect(hash("x") != hash("a"));
try t.expect(hash("x") != hash("x continues"));
try t.expect(hash("x fsdfs") != hash("x"));
try t.expect(hash("x") != hash(""));
try t.expect(hash("") != hash("x"));
try t.expect(hash("") == hash(""));
}

View file

@ -1,189 +0,0 @@
const std = @import("std");
pub fn Separator(comptime separator: u8) type {
return struct {
const Self = @This();
str: []const u8,
pub fn from(str: []const u8) Self {
return .{ .str = std.mem.trim(u8, str, &.{separator}) };
}
pub fn next(self: *Self) ?[]const u8 {
if (self.str.len == 0) return null;
const part = std.mem.sliceTo(self.str, separator);
self.str = std.mem.trimLeft(u8, self.str[part.len..], &.{separator});
return part;
}
};
}
pub const QueryIter = struct {
const Pair = struct {
key: []const u8,
value: ?[]const u8,
};
iter: Separator('&'),
pub fn from(q: []const u8) QueryIter {
return QueryIter{ .iter = Separator('&').from(std.mem.trimLeft(u8, q, "?")) };
}
pub fn next(self: *QueryIter) ?Pair {
const part = self.iter.next() orelse return null;
const key = std.mem.sliceTo(part, '=');
if (key.len == part.len) return Pair{
.key = key,
.value = null,
};
return Pair{
.key = key,
.value = part[key.len + 1 ..],
};
}
};
pub const PathIter = struct {
is_first: bool,
iter: std.mem.SplitIterator(u8),
pub fn from(path: []const u8) PathIter {
return .{ .is_first = true, .iter = std.mem.split(u8, path, "/") };
}
pub fn next(self: *PathIter) ?[]const u8 {
defer self.is_first = false;
while (self.iter.next()) |it| if (it.len != 0) {
return it;
};
if (self.is_first) return self.iter.rest();
return null;
}
pub fn first(self: *PathIter) []const u8 {
std.debug.assert(self.is_first);
return self.next().?;
}
pub fn rest(self: *PathIter) []const u8 {
return self.iter.rest();
}
};
test "QueryIter" {
const t = @import("std").testing;
if (true) return error.SkipZigTest;
{
var iter = QueryIter.from("");
try t.expect(iter.next() == null);
try t.expect(iter.next() == null);
}
{
var iter = QueryIter.from("?");
try t.expect(iter.next() == null);
try t.expect(iter.next() == null);
}
{
var iter = QueryIter.from("?abc");
try t.expectEqual(QueryIter.Pair{
.key = "abc",
.value = null,
}, iter.next().?);
try t.expect(iter.next() == null);
try t.expect(iter.next() == null);
}
{
var iter = QueryIter.from("?abc=");
try t.expectEqual(QueryIter.Pair{
.key = "abc",
.value = "",
}, iter.next().?);
try t.expect(iter.next() == null);
try t.expect(iter.next() == null);
}
{
var iter = QueryIter.from("?abc=def");
try t.expectEqual(QueryIter.Pair{
.key = "abc",
.value = "def",
}, iter.next().?);
try t.expect(iter.next() == null);
try t.expect(iter.next() == null);
}
{
var iter = QueryIter.from("?abc=def&");
try t.expectEqual(QueryIter.Pair{
.key = "abc",
.value = "def",
}, iter.next().?);
try t.expect(iter.next() == null);
try t.expect(iter.next() == null);
}
{
var iter = QueryIter.from("?abc=def&foo&bar=baz&qux=");
try t.expectEqual(QueryIter.Pair{
.key = "abc",
.value = "def",
}, iter.next().?);
try t.expectEqual(QueryIter.Pair{
.key = "foo",
.value = null,
}, iter.next().?);
try t.expectEqual(QueryIter.Pair{
.key = "bar",
.value = "baz",
}, iter.next().?);
try t.expectEqual(QueryIter.Pair{
.key = "qux",
.value = "",
}, iter.next().?);
try t.expect(iter.next() == null);
try t.expect(iter.next() == null);
}
{
var iter = QueryIter.from("?=def&");
try t.expectEqual(QueryIter.Pair{
.key = "",
.value = "def",
}, iter.next().?);
try t.expect(iter.next() == null);
try t.expect(iter.next() == null);
}
}
test "PathIter /ab/cd/" {
const path = "/ab/cd/";
var it = PathIter.from(path);
try std.testing.expectEqualStrings("ab", it.next().?);
try std.testing.expectEqualStrings("cd", it.next().?);
try std.testing.expectEqual(@as(?[]const u8, null), it.next());
}
test "PathIter ''" {
const path = "";
var it = PathIter.from(path);
try std.testing.expectEqualStrings("", it.next().?);
try std.testing.expectEqual(@as(?[]const u8, null), it.next());
}
test "PathIter ab/c//defg/" {
const path = "ab/c//defg/";
var it = PathIter.from(path);
try std.testing.expectEqualStrings("ab", it.next().?);
try std.testing.expectEqualStrings("c", it.next().?);
try std.testing.expectEqualStrings("defg", it.next().?);
try std.testing.expectEqual(@as(?[]const u8, null), it.next());
}

View file

@ -1,13 +1,10 @@
const std = @import("std");
const iters = @import("./iters.zig");
pub const ciutf8 = @import("./ciutf8.zig");
pub const Uuid = @import("./Uuid.zig");
pub const DateTime = @import("./DateTime.zig");
pub const Url = @import("./Url.zig");
pub const PathIter = iters.PathIter;
pub const QueryIter = iters.QueryIter;
pub const SqlStmtIter = iters.Separator(';');
pub const serialize = @import("./serialize.zig");
pub const Deserializer = serialize.Deserializer;
pub const DeserializerContext = serialize.DeserializerContext;
/// Joins an array of strings, prefixing every entry with `prefix`,
/// and putting `separator` in between each pair
@ -202,6 +199,16 @@ pub fn seedThreadPrng() !void {
prng = std.rand.DefaultPrng.init(@bitCast(u64, buf));
}
pub fn comptimeToCrlf(comptime str: []const u8) []const u8 {
comptime {
@setEvalBranchQuota(str.len * 10);
const size = std.mem.replacementSize(u8, str, "\n", "\r\n");
var buf: [size]u8 = undefined;
_ = std.mem.replace(u8, str, "\n", "\r\n", &buf);
return &buf;
}
}
pub const testing = struct {
pub fn expectDeepEqual(expected: anytype, actual: @TypeOf(expected)) !void {
const T = @TypeOf(expected);
@ -242,3 +249,7 @@ pub const testing = struct {
}
}
};
test {
_ = std.testing.refAllDecls(@This());
}

386
src/util/serialize.zig Normal file
View file

@ -0,0 +1,386 @@
const std = @import("std");
const util = @import("./lib.zig");
pub const FieldRef = []const []const u8;
pub fn defaultIsScalar(comptime T: type) bool {
if (comptime std.meta.trait.is(.Optional)(T) and defaultIsScalar(std.meta.Child(T))) return true;
if (comptime std.meta.trait.isZigString(T)) return true;
if (comptime std.meta.trait.isIntegral(T)) return true;
if (comptime std.meta.trait.isFloat(T)) return true;
if (comptime std.meta.trait.is(.Enum)(T)) return true;
if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true;
if (comptime std.meta.trait.hasFn("parse")(T)) return true;
if (T == bool) return true;
return false;
}
pub fn deserializeString(allocator: std.mem.Allocator, comptime T: type, value: []const u8) !T {
if (comptime std.meta.trait.is(.Optional)(T)) {
if (value.len == 0) return null;
return try deserializeString(allocator, std.meta.Child(T), value);
}
if (T == []u8 or T == []const u8) return try util.deepClone(allocator, value);
if (comptime std.meta.trait.isIntegral(T)) return try std.fmt.parseInt(T, value, 0);
if (comptime std.meta.trait.isFloat(T)) return try std.fmt.parseFloat(T, value);
if (comptime std.meta.trait.hasFn("parse")(T)) return try T.parse(value);
var buf: [64]u8 = undefined;
const lowered = std.ascii.lowerString(&buf, value);
if (T == bool) return bool_map.get(lowered) orelse return error.InvalidBool;
if (comptime std.meta.trait.is(.Enum)(T)) {
return std.meta.stringToEnum(T, lowered) orelse return error.InvalidEnumTag;
}
@compileError("Invalid type " ++ @typeName(T));
}
pub fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: SerializationOptions) []const FieldRef {
comptime {
if (std.meta.trait.is(.Union)(T) and prefix.len == 0 and options.embed_unions) {
@compileError("Cannot embed a union into nothing");
}
if (options.isScalar(T)) return &.{prefix};
if (std.meta.trait.is(.Optional)(T)) return getRecursiveFieldList(std.meta.Child(T), prefix, options);
const eff_prefix: FieldRef = if (std.meta.trait.is(.Union)(T) and options.embed_unions)
prefix[0 .. prefix.len - 1]
else
prefix;
var fields: []const FieldRef = &.{};
for (std.meta.fields(T)) |f| {
const new_prefix = eff_prefix ++ &[_][]const u8{f.name};
const F = f.field_type;
fields = fields ++ getRecursiveFieldList(F, new_prefix, options);
}
return fields;
}
}
pub const SerializationOptions = struct {
embed_unions: bool,
isScalar: fn (type) bool,
};
pub const default_options = SerializationOptions{
.embed_unions = true,
.isScalar = defaultIsScalar,
};
fn Intermediary(comptime Result: type, comptime From: type, comptime options: SerializationOptions) type {
const field_refs = getRecursiveFieldList(Result, &.{}, options);
var fields: [field_refs.len]std.builtin.Type.StructField = undefined;
for (field_refs) |ref, i| {
fields[i] = .{
.name = util.comptimeJoin(".", ref),
.field_type = ?From,
.default_value = &@as(?From, null),
.is_comptime = false,
.alignment = @alignOf(?From),
};
}
return @Type(.{ .Struct = .{
.layout = .Auto,
.fields = &fields,
.decls = &.{},
.is_tuple = false,
} });
}
pub fn Deserializer(comptime Result: type) type {
return DeserializerContext(Result, []const u8, struct {
const options = default_options;
fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, val: []const u8) !T {
return try deserializeString(alloc, T, val);
}
});
}
pub fn DeserializerContext(comptime Result: type, comptime From: type, comptime Context: type) type {
return struct {
const Data = Intermediary(Result, From, Context.options);
data: Data = .{},
context: Context = .{},
pub fn setSerializedField(self: *@This(), key: []const u8, value: From) !void {
const field = std.meta.stringToEnum(std.meta.FieldEnum(Data), key) orelse return error.UnknownField;
inline for (comptime std.meta.fieldNames(Data)) |field_name| {
@setEvalBranchQuota(10000);
const f = comptime std.meta.stringToEnum(std.meta.FieldEnum(Data), field_name);
if (field == f) {
@field(self.data, field_name) = value;
return;
}
}
unreachable;
}
pub const Iter = struct {
data: *const Data,
field_index: usize,
const Item = struct {
key: []const u8,
value: From,
};
pub fn next(self: *Iter) ?Item {
while (self.field_index < std.meta.fields(Data).len) {
const idx = self.field_index;
self.field_index += 1;
inline for (comptime std.meta.fieldNames(Data)) |field, i| {
if (i == idx) {
const maybe_value = @field(self.data.*, field);
if (maybe_value) |value| return Item{ .key = field, .value = value };
}
}
}
return null;
}
};
pub fn iterator(self: *const @This()) Iter {
return .{ .data = &self.data, .field_index = 0 };
}
pub fn finishFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void {
util.deepFree(allocator, val);
}
pub fn finish(self: *@This(), allocator: std.mem.Allocator) !Result {
return (try self.deserialize(allocator, Result, &.{})) orelse error.MissingField;
}
fn getSerializedField(self: *@This(), comptime field_ref: FieldRef) ?From {
//inline for (comptime std.meta.fieldNames(Data)) |f| @compileLog(f.ptr);
return @field(self.data, util.comptimeJoin(".", field_ref));
}
fn deserializeFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void {
util.deepFree(allocator, val);
}
fn deserialize(self: *@This(), allocator: std.mem.Allocator, comptime T: type, comptime field_ref: FieldRef) !?T {
if (comptime Context.options.isScalar(T)) {
return try self.context.deserializeScalar(allocator, T, self.getSerializedField(field_ref) orelse return null);
}
switch (@typeInfo(T)) {
// At most one of any union field can be active at a time, and it is embedded
// in its parent container
.Union => |info| {
var result: ?T = null;
errdefer if (result) |v| self.deserializeFree(allocator, v);
// TODO: errdefer cleanup
const union_ref: FieldRef = if (Context.options.embed_unions) field_ref[0 .. field_ref.len - 1] else field_ref;
inline for (info.fields) |field| {
const F = field.field_type;
const new_field_ref = union_ref ++ &[_][]const u8{field.name};
const maybe_value = try self.deserialize(allocator, F, new_field_ref);
if (maybe_value) |value| {
// TODO: errdefer cleanup
errdefer self.deserializeFree(allocator, value);
if (result != null) return error.DuplicateUnionMember;
result = @unionInit(T, field.name, value);
}
}
return result;
},
.Struct => |info| {
var result: T = undefined;
var any_explicit = false;
var any_missing = false;
var fields_alloced = [1]bool{false} ** info.fields.len;
errdefer inline for (info.fields) |field, i| {
if (fields_alloced[i]) self.deserializeFree(allocator, @field(result, field.name));
};
inline for (info.fields) |field, i| {
const F = field.field_type;
const new_field_ref = field_ref ++ &[_][]const u8{field.name};
const maybe_value = try self.deserialize(allocator, F, new_field_ref);
if (maybe_value) |v| {
@field(result, field.name) = v;
fields_alloced[i] = true;
any_explicit = true;
} else if (field.default_value) |ptr| {
if (@sizeOf(F) != 0) {
const cast_ptr = @ptrCast(*const F, @alignCast(field.alignment, ptr));
@field(result, field.name) = try util.deepClone(allocator, cast_ptr.*);
fields_alloced[i] = true;
}
} else {
any_missing = true;
}
}
if (any_missing) {
return if (any_explicit) error.MissingField else null;
}
return result;
},
// Specifically non-scalar optionals
.Optional => |info| return try self.deserialize(allocator, info.child, field_ref),
else => @compileError("Unsupported type"),
}
}
};
}
const bool_map = std.ComptimeStringMap(bool, .{
.{ "true", true },
.{ "t", true },
.{ "yes", true },
.{ "y", true },
.{ "1", true },
.{ "false", false },
.{ "f", false },
.{ "no", false },
.{ "n", false },
.{ "0", false },
});
test "Deserializer" {
// Happy case - simple
{
const T = struct { foo: []const u8, bar: bool };
var ds = Deserializer(T){};
try ds.setSerializedField("foo", "123");
try ds.setSerializedField("bar", "true");
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val);
}
// Returns error if nonexistent field set
{
const T = struct { foo: []const u8, bar: bool };
var ds = Deserializer(T){};
try std.testing.expectError(error.UnknownField, ds.setSerializedField("baz", "123"));
}
// Substruct dereferencing
{
const T = struct {
foo: struct { bar: bool, baz: bool },
};
var ds = Deserializer(T){};
try ds.setSerializedField("foo.bar", "true");
try ds.setSerializedField("foo.baz", "true");
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true, .baz = true } }, val);
}
// Union embedding
{
const T = struct {
foo: union(enum) { bar: bool, baz: bool },
};
var ds = Deserializer(T){};
try ds.setSerializedField("bar", "true");
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true } }, val);
}
// Returns error if multiple union fields specified
{
const T = struct {
foo: union(enum) { bar: bool, baz: bool },
};
var ds = Deserializer(T){};
try ds.setSerializedField("bar", "true");
try ds.setSerializedField("baz", "true");
try std.testing.expectError(error.DuplicateUnionMember, ds.finish(std.testing.allocator));
}
// Uses default values if fields aren't provided
{
const T = struct { foo: []const u8 = "123", bar: bool = true };
var ds = Deserializer(T){};
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val);
}
// Returns an error if fields aren't provided and no default exists
{
const T = struct { foo: []const u8, bar: bool };
var ds = Deserializer(T){};
try ds.setSerializedField("foo", "123");
try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator));
}
// Handles optional containers
{
const T = struct {
foo: ?struct { bar: usize = 3, baz: usize } = null,
qux: ?union(enum) { quux: usize } = null,
};
var ds = Deserializer(T){};
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = null, .qux = null }, val);
}
{
const T = struct {
foo: ?struct { bar: usize = 3, baz: usize } = null,
qux: ?union(enum) { quux: usize } = null,
};
var ds = Deserializer(T){};
try ds.setSerializedField("foo.baz", "3");
try ds.setSerializedField("quux", "3");
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = 3, .baz = 3 }, .qux = .{ .quux = 3 } }, val);
}
{
const T = struct {
foo: ?struct { bar: usize = 3, baz: usize } = null,
qux: ?union(enum) { quux: usize } = null,
};
var ds = Deserializer(T){};
try ds.setSerializedField("foo.bar", "3");
try ds.setSerializedField("quux", "3");
try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator));
}
}