diff --git a/.gitignore b/.gitignore index 7e4c4e7..db4419f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ **/zig-cache **.db /config.json +/files diff --git a/build.zig b/build.zig index cf43658..d644d4f 100644 --- a/build.zig +++ b/build.zig @@ -99,13 +99,23 @@ pub fn build(b: *std.build.Builder) !void { exe.addSystemIncludePath("/usr/include/"); const unittest_http_cmd = b.step("unit:http", "Run tests for http package"); - const unittest_http = b.addTest("src/http/test.zig"); + const unittest_http = b.addTest("src/http/lib.zig"); unittest_http_cmd.dependOn(&unittest_http.step); unittest_http.addPackage(pkgs.util); - //const unittest_util_cmd = b.step("unit:util", "Run tests for util package"); - //const unittest_util = b.addTest("src/util/Uuid.zig"); - //unittest_util_cmd.dependOn(&unittest_util.step); + const unittest_util_cmd = b.step("unit:util", "Run tests for util package"); + const unittest_util = b.addTest("src/util/lib.zig"); + unittest_util_cmd.dependOn(&unittest_util.step); + + const unittest_sql_cmd = b.step("unit:sql", "Run tests for sql package"); + const unittest_sql = b.addTest("src/sql/lib.zig"); + unittest_sql_cmd.dependOn(&unittest_sql.step); + unittest_sql.addPackage(pkgs.util); + + const unittest_template_cmd = b.step("unit:template", "Run tests for template package"); + const unittest_template = b.addTest("src/template/lib.zig"); + unittest_template_cmd.dependOn(&unittest_template.step); + //unittest_template.addPackage(pkgs.util); //const util_tests = b.addTest("src/util/lib.zig"); //const sql_tests = b.addTest("src/sql/lib.zig"); @@ -115,7 +125,9 @@ pub fn build(b: *std.build.Builder) !void { //const unit_tests = b.step("unit-tests", "Run tests"); const unittest_all = b.step("unit", "Run unit tests"); unittest_all.dependOn(unittest_http_cmd); - //unittest_all.dependOn(unittest_util_cmd); + unittest_all.dependOn(unittest_util_cmd); + unittest_all.dependOn(unittest_sql_cmd); + unittest_all.dependOn(unittest_template_cmd); const api_integration = b.addTest("./tests/api_integration/lib.zig"); api_integration.addPackage(pkgs.opts); diff --git a/src/api/lib.zig b/src/api/lib.zig index 1a46fa9..065fdc7 100644 --- a/src/api/lib.zig +++ b/src/api/lib.zig @@ -9,6 +9,7 @@ const services = struct { const communities = @import("./services/communities.zig"); const actors = @import("./services/actors.zig"); const auth = @import("./services/auth.zig"); + const drive = @import("./services/files.zig"); const invites = @import("./services/invites.zig"); const notes = @import("./services/notes.zig"); const follows = @import("./services/follows.zig"); @@ -136,6 +137,14 @@ pub const FollowerQueryResult = FollowQueryResult; pub const FollowingQueryArgs = FollowQueryArgs; pub const FollowingQueryResult = FollowQueryResult; +pub const UploadFileArgs = struct { + filename: []const u8, + dir: ?[]const u8, + description: ?[]const u8, + content_type: []const u8, + sensitive: bool, +}; + pub fn isAdminSetup(db: sql.Db) !bool { _ = services.communities.adminCommunityId(db) catch |err| switch (err) { error.NotFound => return false, @@ -509,5 +518,23 @@ fn ApiConn(comptime DbConn: type) type { self.allocator, ); } + + pub fn uploadFile(self: *Self, meta: UploadFileArgs, body: []const u8) !void { + const user_id = self.user_id orelse return error.NoToken; + return try services.drive.createFile(self.db, .{ + .dir = meta.dir orelse "/", + .filename = meta.filename, + .owner = .{ .user_id = user_id }, + .created_by = user_id, + .description = meta.description, + .content_type = meta.content_type, + .sensitive = meta.sensitive, + }, body, self.allocator); + } + + pub fn driveMkdir(self: *Self, path: []const u8) !void { + const user_id = self.user_id orelse return error.NoToken; + try services.drive.mkdir(self.db, .{ .user_id = user_id }, path, self.allocator); + } }; } diff --git a/src/api/services/files.zig b/src/api/services/files.zig index 18c0e9d..147d049 100644 --- a/src/api/services/files.zig +++ b/src/api/services/files.zig @@ -11,59 +11,224 @@ pub const FileOwner = union(enum) { pub const DriveFile = struct { id: Uuid, + + path: []const u8, + filename: []const u8, + + owner: FileOwner, + + size: usize, + + description: []const u8, + content_type: []const u8, + sensitive: bool, + + created_at: DateTime, + updated_at: DateTime, +}; + +const EntryType = enum { + dir, + file, +}; + +pub const CreateFileArgs = struct { + dir: []const u8, filename: []const u8, owner: FileOwner, - size: usize, - created_at: DateTime, + created_by: Uuid, + description: ?[]const u8, + content_type: ?[]const u8, + sensitive: bool, }; -pub const files = struct { - pub fn create(db: anytype, owner: FileOwner, filename: []const u8, data: []const u8, alloc: std.mem.Allocator) !void { - const id = Uuid.randV4(util.getThreadPrng()); - const now = DateTime.now(); +fn lookupDirectory(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !Uuid { + return (try db.queryRow( + std.meta.Tuple( + &.{util.Uuid}, + ), + \\SELECT id + \\FROM drive_entry_path + \\WHERE + \\ path = (CASE WHEN LENGTH($1) = 0 THEN '/' ELSE '/' || $1 || '/' END) + \\ AND account_owner_id IS NOT DISTINCT FROM $2 + \\ AND community_owner_id IS NOT DISTINCT FROM $3 + \\ AND kind = 'dir' + \\LIMIT 1 + , + .{ + std.mem.trim(u8, path, "/"), + if (owner == .user_id) owner.user_id else null, + if (owner == .community_id) owner.community_id else null, + }, + alloc, + ))[0]; +} - // TODO: assert we're not in a transaction - db.insert("drive_file", .{ +fn lookup(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !Uuid { + return (try db.queryRow( + std.meta.Tuple( + &.{util.Uuid}, + ), + \\SELECT id + \\FROM drive_entry_path + \\WHERE + \\ path = (CASE WHEN LENGTH($1) = 0 THEN '/' ELSE '/' || $1 || '/' END) + \\ AND account_owner_id IS NOT DISTINCT FROM $2 + \\ AND community_owner_id IS NOT DISTINCT FROM $3 + \\LIMIT 1 + , + .{ + std.mem.trim(u8, path, "/"), + if (owner == .user_id) owner.user_id else null, + if (owner == .community_id) owner.community_id else null, + }, + alloc, + ))[0]; +} + +pub fn mkdir(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !void { + var split = std.mem.splitBackwards(u8, std.mem.trim(u8, path, "/"), "/"); + const name = split.first(); + const dir = split.rest(); + std.log.debug("'{s}' / '{s}'", .{ name, dir }); + + if (name.len == 0) return error.EmptyName; + + const id = Uuid.randV4(util.getThreadPrng()); + + const tx = try db.begin(); + errdefer tx.rollback(); + + const parent = try lookupDirectory(tx, owner, dir, alloc); + + try tx.insert("drive_entry", .{ + .id = id, + + .account_owner_id = if (owner == .user_id) owner.user_id else null, + .community_owner_id = if (owner == .community_id) owner.community_id else null, + + .name = name, + .parent_directory_id = parent, + }, alloc); + try tx.commit(); +} + +pub fn rmdir(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !void { + const tx = try db.begin(); + errdefer tx.rollback(); + + const id = try lookupDirectory(tx, owner, path, alloc); + try tx.exec("DELETE FROM drive_directory WHERE id = $1", .{id}, alloc); + try tx.commit(); +} + +fn insertFileRow(tx: anytype, id: Uuid, filename: []const u8, owner: FileOwner, dir: Uuid, alloc: std.mem.Allocator) !void { + try tx.insert("drive_entry", .{ + .id = id, + + .account_owner_id = if (owner == .user_id) owner.user_id else null, + .community_owner_id = if (owner == .community_id) owner.community_id else null, + + .parent_directory_id = dir, + .name = filename, + + .file_id = id, + }, alloc); +} + +pub fn createFile(db: anytype, args: CreateFileArgs, data: []const u8, alloc: std.mem.Allocator) !void { + const id = Uuid.randV4(util.getThreadPrng()); + const now = DateTime.now(); + + { + var tx = try db.begin(); + errdefer tx.rollback(); + + const dir_id = try lookupDirectory(tx, args.owner, args.dir, alloc); + + try tx.insert("file_upload", .{ .id = id, - .filename = filename, - .owner = owner, + + .filename = args.filename, + + .created_by = args.created_by, + .size = data.len, + + .description = args.description, + .content_type = args.content_type, + .sensitive = args.sensitive, + + .is_deleted = false, + .created_at = now, - }, alloc) catch return error.DatabaseFailure; - // Assume the previous statement succeeded and is not stuck in a transaction - errdefer { - db.exec("DELETE FROM drive_file WHERE ID = $1", .{id}, alloc) catch |err| { - std.log.err("Unable to remove file record in DB: {}", .{err}); - }; + .updated_at = now, + }, alloc); + + var sub_tx = try tx.savepoint(); + if (insertFileRow(sub_tx, id, args.filename, args.owner, dir_id, alloc)) |_| { + try sub_tx.release(); + } else |err| { + std.log.debug("{}", .{err}); + switch (err) { + error.UniqueViolation => { + try sub_tx.rollbackSavepoint(); + // Rename the file before trying again + var split = std.mem.split(u8, args.filename, "."); + const name = split.first(); + const ext = split.rest(); + var buf: [256]u8 = undefined; + const drive_filename = try std.fmt.bufPrint(&buf, "{s}.{}.{s}", .{ name, id, ext }); + try insertFileRow(tx, id, drive_filename, args.owner, dir_id, alloc); + }, + else => return error.DatabaseFailure, + } } - try saveFile(id, data); + try tx.commit(); } - const data_root = "./files"; - fn saveFile(id: Uuid, data: []const u8) !void { - var dir = try std.fs.cwd().openDir(data_root); - defer dir.close(); - - var file = try dir.createFile(id.toCharArray(), .{ .exclusive = true }); - defer file.close(); - - try file.writer().writeAll(data); - try file.sync(); + errdefer { + db.exec("DELETE FROM file_upload WHERE ID = $1", .{id}, alloc) catch |err| { + std.log.err("Unable to remove file record in DB: {}", .{err}); + }; + db.exec("DELETE FROM drive_entry WHERE ID = $1", .{id}, alloc) catch |err| { + std.log.err("Unable to remove file record in DB: {}", .{err}); + }; } - pub fn deref(alloc: std.mem.Allocator, id: Uuid) ![]const u8 { - var dir = try std.fs.cwd().openDir(data_root); - defer dir.close(); + try saveFile(id, data); +} - return dir.readFileAlloc(alloc, id.toCharArray(), 1 << 32); - } +const data_root = "./files"; +fn saveFile(id: Uuid, data: []const u8) !void { + var dir = try std.fs.cwd().openDir(data_root, .{}); + defer dir.close(); - pub fn delete(db: anytype, alloc: std.mem.Allocator, id: Uuid) !void { - var dir = try std.fs.cwd().openDir(data_root); - defer dir.close(); + var file = try dir.createFile(&id.toCharArray(), .{ .exclusive = true }); + defer file.close(); - try dir.deleteFile(id.toCharArray()); + try file.writer().writeAll(data); + try file.sync(); +} - db.exec("DELETE FROM drive_file WHERE ID = $1", .{id}, alloc) catch return error.DatabaseFailure; - } -}; +pub fn deref(alloc: std.mem.Allocator, id: Uuid) ![]const u8 { + var dir = try std.fs.cwd().openDir(data_root, .{}); + defer dir.close(); + + return dir.readFileAlloc(alloc, &id.toCharArray(), 1 << 32); +} + +pub fn deleteFile(db: anytype, alloc: std.mem.Allocator, id: Uuid) !void { + var dir = try std.fs.cwd().openDir(data_root, .{}); + defer dir.close(); + + try dir.deleteFile(id.toCharArray()); + + const tx = try db.beginOrSavepoint(); + errdefer tx.rollback(); + + tx.exec("DELETE FROM drive_entry WHERE ID = $1", .{id}, alloc) catch return error.DatabaseFailure; + tx.exec("DELETE FROM file_upload WHERE ID = $1", .{id}, alloc) catch return error.DatabaseFailure; + try tx.commitOrRelease(); +} diff --git a/src/http/headers.zig b/src/http/fields.zig similarity index 76% rename from src/http/headers.zig rename to src/http/fields.zig index 1b91865..d6319da 100644 --- a/src/http/headers.zig +++ b/src/http/fields.zig @@ -1,5 +1,60 @@ const std = @import("std"); +pub const ParamIter = struct { + str: []const u8, + index: usize = 0, + + const Param = struct { + name: []const u8, + value: []const u8, + }; + + pub fn from(str: []const u8) ParamIter { + return .{ .str = str, .index = std.mem.indexOfScalar(u8, str, ';') orelse str.len }; + } + + pub fn fieldValue(self: *ParamIter) []const u8 { + return std.mem.sliceTo(self.str, ';'); + } + + pub fn next(self: *ParamIter) ?Param { + if (self.index >= self.str.len) return null; + + const start = self.index + 1; + const new_start = std.mem.indexOfScalarPos(u8, self.str, start, ';') orelse self.str.len; + self.index = new_start; + + const param = std.mem.trim(u8, self.str[start..new_start], " \t"); + var split = std.mem.split(u8, param, "="); + const name = split.first(); + const value = std.mem.trimLeft(u8, split.rest(), " \t"); + // TODO: handle quoted values + // TODO: handle parse errors + + return Param{ + .name = name, + .value = value, + }; + } +}; + +pub fn getParam(field: []const u8, name: ?[]const u8) ?[]const u8 { + var iter = ParamIter.from(field); + + if (name) |param| { + while (iter.next()) |p| { + if (std.ascii.eqlIgnoreCase(param, p.name)) { + const trimmed = std.mem.trim(u8, p.value, " \t"); + if (trimmed.len >= 2 and trimmed[0] == '"' and trimmed[trimmed.len - 1] == '"') { + return trimmed[1 .. trimmed.len - 1]; + } + return trimmed; + } + } + return null; + } else return iter.fieldValue(); +} + pub const Fields = struct { const HashContext = struct { const hash_seed = 1; diff --git a/src/http/lib.zig b/src/http/lib.zig index 9fc5a61..c114bea 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -1,27 +1,55 @@ const std = @import("std"); -const ciutf8 = @import("util").ciutf8; const request = @import("./request.zig"); - const server = @import("./server.zig"); - +pub const urlencode = @import("./urlencode.zig"); pub const socket = @import("./socket.zig"); +const json = @import("./json.zig"); +const multipart = @import("./multipart.zig"); +pub const fields = @import("./fields.zig"); + +pub const Method = enum { + GET, + HEAD, + POST, + PUT, + DELETE, + CONNECT, + OPTIONS, + TRACE, + PATCH, + + // WebDAV methods (we use some of them for the drive system) + MKCOL, + MOVE, + + pub fn requestHasBody(self: Method) bool { + return switch (self) { + .POST, .PUT, .PATCH, .MKCOL, .MOVE => true, + else => false, + }; + } +}; -pub const Method = std.http.Method; pub const Status = std.http.Status; pub const Request = request.Request(server.Stream.Reader); pub const Response = server.Response; -pub const Handler = server.Handler; +//pub const Handler = server.Handler; pub const Server = server.Server; pub const middleware = @import("./middleware.zig"); -pub const queryStringify = @import("./query.zig").queryStringify; -pub const Fields = @import("./headers.zig").Fields; +pub const Fields = fields.Fields; + +pub const FormFile = multipart.FormFile; pub const Protocol = enum { http_1_0, http_1_1, http_1_x, }; + +test { + _ = std.testing.refAllDecls(@This()); +} diff --git a/src/http/middleware.zig b/src/http/middleware.zig index 97855b1..ce4d307 100644 --- a/src/http/middleware.zig +++ b/src/http/middleware.zig @@ -14,10 +14,11 @@ /// Terminal middlewares that are not implemented using other middlewares should /// only accept a `void` value for `next_handler`. const std = @import("std"); -const http = @import("./lib.zig"); const util = @import("util"); -const query_utils = @import("./query.zig"); +const http = @import("./lib.zig"); +const urlencode = @import("./urlencode.zig"); const json_utils = @import("./json.zig"); +const fields = @import("./fields.zig"); /// Takes an iterable of middlewares and chains them together. pub fn apply(middlewares: anytype) Apply(@TypeOf(middlewares)) { @@ -29,20 +30,20 @@ pub fn Apply(comptime Middlewares: type) type { return ApplyInternal(std.meta.fields(Middlewares)); } -fn ApplyInternal(comptime fields: []const std.builtin.Type.StructField) type { - if (fields.len == 0) return void; +fn ApplyInternal(comptime which: []const std.builtin.Type.StructField) type { + if (which.len == 0) return void; return HandlerList( - fields[0].field_type, - ApplyInternal(fields[1..]), + which[0].field_type, + ApplyInternal(which[1..]), ); } -fn applyInternal(middlewares: anytype, comptime fields: []const std.builtin.Type.StructField) ApplyInternal(fields) { - if (fields.len == 0) return {}; +fn applyInternal(middlewares: anytype, comptime which: []const std.builtin.Type.StructField) ApplyInternal(which) { + if (which.len == 0) return {}; return .{ - .first = @field(middlewares, fields[0].name), - .next = applyInternal(middlewares, fields[1..]), + .first = @field(middlewares, which[0].name), + .next = applyInternal(middlewares, which[1..]), }; } @@ -349,15 +350,71 @@ pub fn router(routes: anytype) Router(@TypeOf(routes)) { return Router(@TypeOf(routes)){ .routes = routes }; } +pub const PathIter = struct { + is_first: bool, + iter: std.mem.SplitIterator(u8), + + pub fn from(path: []const u8) PathIter { + return .{ .is_first = true, .iter = std.mem.split(u8, path, "/") }; + } + + pub fn next(self: *PathIter) ?[]const u8 { + defer self.is_first = false; + while (self.iter.next()) |it| if (it.len != 0) { + return it; + }; + + if (self.is_first) return self.iter.rest(); + + return null; + } + + pub fn first(self: *PathIter) []const u8 { + std.debug.assert(self.is_first); + return self.next().?; + } + + pub fn rest(self: *PathIter) []const u8 { + return self.iter.rest(); + } +}; + +test "PathIter" { + const testCase = struct { + fn case(path: []const u8, segments: []const []const u8) !void { + var iter = PathIter.from(path); + for (segments) |s| { + try std.testing.expectEqualStrings(s, iter.next() orelse return error.TestExpectedEqual); + } + try std.testing.expect(iter.next() == null); + } + }.case; + + try testCase("", &.{""}); + try testCase("*", &.{"*"}); + try testCase("/", &.{""}); + try testCase("/ab/cd", &.{ "ab", "cd" }); + try testCase("/ab/cd/", &.{ "ab", "cd" }); + try testCase("/ab/cd//", &.{ "ab", "cd" }); + try testCase("ab", &.{"ab"}); + try testCase("/ab", &.{"ab"}); + try testCase("ab/", &.{"ab"}); + try testCase("ab//ab//", &.{ "ab", "ab" }); +} + // helper function for doing route analysis fn pathMatches(route: []const u8, path: []const u8) bool { - var path_iter = util.PathIter.from(path); - var route_iter = util.PathIter.from(route); + var path_iter = PathIter.from(path); + var route_iter = PathIter.from(route); while (route_iter.next()) |route_segment| { - const path_segment = path_iter.next() orelse return false; + const path_segment = path_iter.next() orelse ""; + if (route_segment.len > 0 and route_segment[0] == ':') { // Route Argument - if (path_segment.len == 0) return false; + if (route_segment[route_segment.len - 1] == '*') { + // consume rest of path segments + while (path_iter.next()) |_| {} + } else if (path_segment.len == 0) return false; } else { if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return false; } @@ -428,6 +485,10 @@ test "route" { try testCase(true, .{ .method = .POST, .path = "/abcd/efgh" }, .POST, "abcd/efgh"); try testCase(true, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "abcd/efgh"); try testCase(true, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz"); + try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/efgh/xyz"); + try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/efgh"); + try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/"); + try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd"); try testCase(false, .{ .method = .POST, .path = "/" }, .GET, "/"); try testCase(false, .{ .method = .GET, .path = "/abcd" }, .GET, ""); @@ -436,32 +497,21 @@ test "route" { try testCase(false, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "/abcd/"); try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/"); try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz/foo"); + try testCase(false, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "defg/abcd"); } /// Mounts a router subtree under a given path. Middlewares further down on the list /// are called with the path prefix specified by `route` removed from the path. /// Must be below `split_uri` on the middleware list. pub fn Mount(comptime route: []const u8) type { + if (std.mem.indexOfScalar(u8, route, ':') != null) @compileError("Route args cannot be mounted"); return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - var path_iter = util.PathIter.from(ctx.path); - comptime var route_iter = util.PathIter.from(route); - var path_unused: []const u8 = ctx.path; - - inline while (comptime route_iter.next()) |route_segment| { - if (comptime route_segment.len == 0) continue; - const path_segment = path_iter.next() orelse return error.RouteMismatch; - path_unused = path_iter.rest(); - if (comptime route_segment[0] == ':') { - @compileLog("Argument segments cannot be mounted"); - // Route Argument - } else { - if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; - } - } + const args = try parseArgsFromPath(route ++ "/:path*", struct { path: []const u8 }, ctx.path); var new_ctx = ctx; - new_ctx.path = path_unused; + new_ctx.path = args.path; + return next.handle(req, res, new_ctx, {}); } }; @@ -491,18 +541,33 @@ test "mount" { fn parseArgsFromPath(comptime route: []const u8, comptime Args: type, path: []const u8) !Args { var args: Args = undefined; - var path_iter = util.PathIter.from(path); - comptime var route_iter = util.PathIter.from(route); + var path_iter = PathIter.from(path); + comptime var route_iter = PathIter.from(route); + var path_unused: []const u8 = path; + inline while (comptime route_iter.next()) |route_segment| { - const path_segment = path_iter.next() orelse return error.RouteMismatch; - if (route_segment.len > 0 and route_segment[0] == ':') { + const path_segment = path_iter.next() orelse ""; + if (route_segment[0] == ':') { + comptime var name: []const u8 = route_segment[1..]; + var value: []const u8 = path_segment; + // route segment is an argument segment - if (path_segment.len == 0) return error.RouteMismatch; - const A = @TypeOf(@field(args, route_segment[1..])); - @field(args, route_segment[1..]) = try parseArgFromPath(A, path_segment); + if (comptime route_segment[route_segment.len - 1] == '*') { + // waste remaining args + while (path_iter.next()) |_| {} + name = route_segment[1 .. route_segment.len - 1]; + value = path_unused; + } else { + if (path_segment.len == 0) return error.RouteMismatch; + } + + const A = @TypeOf(@field(args, name)); + @field(args, name) = try parseArgFromPath(A, value); } else { + // route segment is a literal segment if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; } + path_unused = path_iter.rest(); } if (path_iter.next() != null) return error.RouteMismatch; @@ -577,6 +642,21 @@ test "ParsePathArgs" { try testCase("/:id/xyz/:str", struct { id: usize, str: []const u8 }, "/3/xyz/abcd", .{ .id = 3, .str = "abcd" }); try testCase("/:id", struct { id: util.Uuid }, "/" ++ util.Uuid.nil.toCharArray(), .{ .id = util.Uuid.nil }); + try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/abc", .{ .arg = "abc" }); + try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/abc/def", .{ .arg = "abc/def" }); + try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/", .{ .arg = "" }); + + // Compiler crashes if i keep the args named the same as above. + // TODO: Debug this and try to fix it + try testCase("/xyz/:bar*", struct { bar: []const u8 }, "/xyz", .{ .bar = "" }); + + // It's a quirk that the initial / is left in for these cases. However, it results in a path + // that's semantically equivalent so i didn't bother fixing it + try testCase("/:foo*", struct { foo: []const u8 }, "/abc", .{ .foo = "/abc" }); + try testCase("/:foo*", struct { foo: []const u8 }, "/abc/def", .{ .foo = "/abc/def" }); + try testCase("/:foo*", struct { foo: []const u8 }, "/", .{ .foo = "/" }); + try testCase("/:foo*", struct { foo: []const u8 }, "", .{ .foo = "" }); + try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/", .{})); try std.testing.expectError(error.RouteMismatch, testCase("/abcd/:id", struct { id: usize }, "/123", .{})); try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/3/id/blahblah", .{ .id = 3 })); @@ -587,41 +667,51 @@ const BaseContentType = enum { json, url_encoded, octet_stream, + multipart_formdata, other, }; -fn parseBodyFromRequest(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T { - //@compileLog(T); - const buf = try reader.readAllAlloc(alloc, 1 << 16); - defer alloc.free(buf); +fn parseBodyFromRequest(comptime T: type, content_type: ?[]const u8, reader: anytype, alloc: std.mem.Allocator) !T { + // Use json by default for now for testing purposes + const eff_type = content_type orelse "application/json"; + const parser_type = matchContentType(eff_type); - switch (content_type) { + switch (parser_type) { .octet_stream, .json => { + const buf = try reader.readAllAlloc(alloc, 1 << 16); + defer alloc.free(buf); const body = try json_utils.parse(T, buf, alloc); defer json_utils.parseFree(body, alloc); return try util.deepClone(alloc, body); }, - .url_encoded => return query_utils.parseQuery(alloc, T, buf) catch |err| switch (err) { - error.NoQuery => error.NoBody, - else => err, + .url_encoded => { + const buf = try reader.readAllAlloc(alloc, 1 << 16); + defer alloc.free(buf); + return urlencode.parse(alloc, T, buf) catch |err| switch (err) { + //error.NoQuery => error.NoBody, + else => err, + }; + }, + .multipart_formdata => { + const boundary = fields.getParam(eff_type, "boundary") orelse return error.MissingBoundary; + return try @import("./multipart.zig").parseFormData(T, boundary, reader, alloc); }, else => return error.UnsupportedMediaType, } } // figure out what base parser to use -fn matchContentType(hdr: ?[]const u8) ?BaseContentType { - if (hdr) |h| { - if (std.ascii.eqlIgnoreCase(h, "application/x-www-form-urlencoded")) return .url_encoded; - if (std.ascii.eqlIgnoreCase(h, "application/json")) return .json; - if (std.ascii.eqlIgnoreCase(h, "application/octet-stream")) return .octet_stream; +fn matchContentType(hdr: []const u8) BaseContentType { + const trimmed = std.mem.sliceTo(hdr, ';'); + if (std.ascii.eqlIgnoreCase(trimmed, "application/x-www-form-urlencoded")) return .url_encoded; + if (std.ascii.eqlIgnoreCase(trimmed, "application/json")) return .json; + if (std.ascii.endsWithIgnoreCase(trimmed, "+json")) return .json; + if (std.ascii.eqlIgnoreCase(trimmed, "application/octet-stream")) return .octet_stream; + if (std.ascii.eqlIgnoreCase(trimmed, "multipart/form-data")) return .multipart_formdata; - return .other; - } - - return null; + return .other; } /// Parses a set of body arguments from the request body based on the request's Content-Type @@ -640,10 +730,8 @@ pub fn ParseBody(comptime Body: type) type { return next.handle(req, res, new_ctx, {}); } - const base_content_type = matchContentType(content_type); - var stream = req.body orelse return error.NoBody; - const body = try parseBodyFromRequest(Body, base_content_type orelse .json, stream.reader(), ctx.allocator); + const body = try parseBodyFromRequest(Body, content_type, stream.reader(), ctx.allocator); defer util.deepFree(ctx.allocator, body); return next.handle( @@ -659,12 +747,57 @@ pub fn parseBody(comptime Body: type) ParseBody(Body) { return .{}; } +test "parseBodyFromRequest" { + const testCase = struct { + fn case(content_type: []const u8, body: []const u8, expected: anytype) !void { + var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) }; + const result = try parseBodyFromRequest(@TypeOf(expected), content_type, stream.reader(), std.testing.allocator); + defer util.deepFree(std.testing.allocator, result); + + try util.testing.expectDeepEqual(expected, result); + } + }.case; + + const Struct = struct { + id: usize, + }; + try testCase("application/json", "{\"id\": 3}", Struct{ .id = 3 }); + try testCase("application/x-www-form-urlencoded", "id=3", Struct{ .id = 3 }); + + //try testCase("multipart/form-data; ", + //\\ + //, Struct{ .id = 3 }); +} + +test "parseBody" { + const Struct = struct { + foo: []const u8, + }; + const body = + \\{"foo": "bar"} + ; + var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) }; + var headers = http.Fields.init(std.testing.allocator); + defer headers.deinit(); + + try parseBody(Struct).handle( + .{ .body = @as(?std.io.StreamSource, stream), .headers = headers }, + .{}, + .{ .allocator = std.testing.allocator }, + struct { + fn handle(_: anytype, _: anytype, _: anytype, ctx: anytype, _: void) !void { + try util.testing.expectDeepEqual(Struct{ .foo = "bar" }, ctx.body); + } + }{}, + ); +} + /// Parses query parameters as defined in query.zig pub fn ParseQueryParams(comptime QueryParams: type) type { return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { if (QueryParams == void) return next.handle(req, res, addField(ctx, "query_params", {}), {}); - const query = try query_utils.parseQuery(ctx.allocator, QueryParams, ctx.query_string); + const query = try urlencode.parse(ctx.allocator, QueryParams, ctx.query_string); defer util.deepFree(ctx.allocator, query); return next.handle( diff --git a/src/http/multipart.zig b/src/http/multipart.zig new file mode 100644 index 0000000..815711d --- /dev/null +++ b/src/http/multipart.zig @@ -0,0 +1,362 @@ +const std = @import("std"); +const util = @import("util"); +const fields = @import("./fields.zig"); + +const max_boundary = 70; +const read_ahead = max_boundary + 4; + +pub fn MultipartStream(comptime ReaderType: type) type { + return struct { + const Multipart = @This(); + + pub const BaseReader = ReaderType; + pub const PartReader = std.io.Reader(*Part, ReaderType.Error, Part.read); + + stream: std.io.PeekStream(.{ .Static = read_ahead }, ReaderType), + boundary: []const u8, + + pub fn next(self: *Multipart, alloc: std.mem.Allocator) !?Part { + const reader = self.stream.reader(); + while (true) { + try reader.skipUntilDelimiterOrEof('\r'); + var line_buf: [read_ahead]u8 = undefined; + const len = try reader.readAll(line_buf[0 .. self.boundary.len + 3]); + const line = line_buf[0..len]; + if (line.len == 0) return null; + if (std.mem.startsWith(u8, line, "\n--") and std.mem.endsWith(u8, line, self.boundary)) { + // match, check for end thing + var more_buf: [2]u8 = undefined; + if (try reader.readAll(&more_buf) != 2) return error.EndOfStream; + + const more = !(more_buf[0] == '-' and more_buf[1] == '-'); + try self.stream.putBack(&more_buf); + try reader.skipUntilDelimiterOrEof('\n'); + if (more) return try Part.open(self, alloc) else return null; + } + } + } + + pub const Part = struct { + base: ?*Multipart, + fields: fields.Fields, + + pub fn open(base: *Multipart, alloc: std.mem.Allocator) !Part { + var parsed_fields = try @import("./request/parser.zig").parseHeaders(alloc, base.stream.reader()); + return .{ .base = base, .fields = parsed_fields }; + } + + pub fn reader(self: *Part) PartReader { + return .{ .context = self }; + } + + pub fn close(self: *Part) void { + self.fields.deinit(); + } + + pub fn read(self: *Part, buf: []u8) ReaderType.Error!usize { + const base = self.base orelse return 0; + + const r = base.stream.reader(); + + var count: usize = 0; + while (count < buf.len) { + const byte = r.readByte() catch |err| switch (err) { + error.EndOfStream => { + self.base = null; + return count; + }, + else => |e| return e, + }; + + buf[count] = byte; + count += 1; + if (byte != '\r') continue; + + var line_buf: [read_ahead]u8 = undefined; + const line = line_buf[0..try r.readAll(line_buf[0 .. base.boundary.len + 3])]; + if (!std.mem.startsWith(u8, line, "\n--") or !std.mem.endsWith(u8, line, base.boundary)) { + base.stream.putBack(line) catch unreachable; + continue; + } else { + base.stream.putBack(line) catch unreachable; + base.stream.putBackByte('\r') catch unreachable; + self.base = null; + return count - 1; + } + } + + return count; + } + }; + }; +} + +pub fn openMultipart(boundary: []const u8, reader: anytype) !MultipartStream(@TypeOf(reader)) { + if (boundary.len > max_boundary) return error.BoundaryTooLarge; + var stream = .{ + .stream = std.io.peekStream(read_ahead, reader), + .boundary = boundary, + }; + + stream.stream.putBack("\r\n") catch unreachable; + return stream; +} + +const MultipartFormField = struct { + name: []const u8, + value: []const u8, + + filename: ?[]const u8 = null, + content_type: ?[]const u8 = null, +}; + +pub const FormFile = struct { + data: []const u8, + filename: []const u8, + content_type: []const u8, +}; + +pub fn MultipartForm(comptime ReaderType: type) type { + return struct { + stream: MultipartStream(ReaderType), + + pub fn next(self: *@This(), alloc: std.mem.Allocator) !?MultipartFormField { + var part = (try self.stream.next(alloc)) orelse return null; + defer part.close(); + + const disposition = part.fields.get("Content-Disposition") orelse return error.MissingDisposition; + + if (!std.ascii.eqlIgnoreCase(fields.getParam(disposition, null).?, "form-data")) return error.BadDisposition; + const name = try util.deepClone(alloc, fields.getParam(disposition, "name") orelse return error.BadDisposition); + errdefer util.deepFree(alloc, name); + const filename = try util.deepClone(alloc, fields.getParam(disposition, "filename")); + errdefer util.deepFree(alloc, filename); + const content_type = try util.deepClone(alloc, part.fields.get("Content-Type")); + errdefer util.deepFree(alloc, content_type); + + const value = try part.reader().readAllAlloc(alloc, 1 << 32); + + return MultipartFormField{ + .name = name, + .value = value, + + .filename = filename, + .content_type = content_type, + }; + } + }; +} + +pub fn openForm(multipart_stream: anytype) MultipartForm(@TypeOf(multipart_stream).BaseReader) { + return .{ .stream = multipart_stream }; +} + +fn Deserializer(comptime Result: type) type { + return util.DeserializerContext(Result, MultipartFormField, struct { + pub const options = .{ .isScalar = isScalar, .embed_unions = true }; + + pub fn isScalar(comptime T: type) bool { + if (T == FormFile or T == ?FormFile) return true; + return util.serialize.defaultIsScalar(T); + } + + pub fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, val: MultipartFormField) !T { + if (T == FormFile or T == ?FormFile) return try deserializeFormFile(alloc, val); + + if (val.filename != null) return error.FilenameProvidedForNonFile; + return try util.serialize.deserializeString(alloc, T, val.value); + } + + fn deserializeFormFile(alloc: std.mem.Allocator, val: MultipartFormField) !FormFile { + const data = try util.deepClone(alloc, val.value); + errdefer util.deepFree(alloc, data); + const filename = try util.deepClone(alloc, val.filename orelse "(untitled)"); + errdefer util.deepFree(alloc, filename); + const content_type = try util.deepClone(alloc, val.content_type orelse "application/octet-stream"); + return FormFile{ + .data = data, + .filename = filename, + .content_type = content_type, + }; + } + }); +} + +pub fn parseFormData(comptime T: type, boundary: []const u8, reader: anytype, alloc: std.mem.Allocator) !T { + var form = openForm(try openMultipart(boundary, reader)); + + var ds = Deserializer(T){}; + defer { + var iter = ds.iterator(); + while (iter.next()) |pair| { + util.deepFree(alloc, pair.value); + } + } + while (true) { + var part = (try form.next(alloc)) orelse break; + errdefer util.deepFree(alloc, part); + + try ds.setSerializedField(part.name, part); + } + + return try ds.finish(alloc); +} + +// TODO: Fix these tests +test "MultipartStream" { + const ExpectedPart = struct { + disposition: []const u8, + value: []const u8, + }; + const testCase = struct { + fn case(comptime body: []const u8, boundary: []const u8, expected_parts: []const ExpectedPart) !void { + var src = std.io.StreamSource{ + .const_buffer = std.io.fixedBufferStream(body), + }; + + var stream = try openMultipart(boundary, src.reader()); + + for (expected_parts) |expected| { + var part = try stream.next(std.testing.allocator) orelse return error.TestExpectedEqual; + defer part.close(); + + const dispo = part.fields.get("Content-Disposition") orelse return error.TestExpectedEqual; + try std.testing.expectEqualStrings(expected.disposition, dispo); + + var buf: [128]u8 = undefined; + const count = try part.reader().read(&buf); + try std.testing.expectEqualStrings(expected.value, buf[0..count]); + } + + try std.testing.expect(try stream.next(std.testing.allocator) == null); + } + }.case; + + try testCase("--abc--\r\n", "abc", &.{}); + try testCase( + util.comptimeToCrlf( + \\------abcd + \\Content-Disposition: form-data; name=first; charset=utf8 + \\ + \\content + \\------abcd + \\content-Disposition: form-data; name=second + \\ + \\no content + \\------abcd + \\content-disposition: form-data; name=third + \\ + \\ + \\------abcd-- + \\ + ), + "----abcd", + &.{ + .{ .disposition = "form-data; name=first; charset=utf8", .value = "content" }, + .{ .disposition = "form-data; name=second", .value = "no content" }, + .{ .disposition = "form-data; name=third", .value = "" }, + }, + ); + + try testCase( + util.comptimeToCrlf( + \\--xyz + \\Content-Disposition: uhh + \\ + \\xyz + \\--xyz + \\Content-disposition: ok + \\ + \\ --xyz + \\--xyz-- + \\ + ), + "xyz", + &.{ + .{ .disposition = "uhh", .value = "xyz" }, + .{ .disposition = "ok", .value = " --xyz" }, + }, + ); +} + +test "MultipartForm" { + const testCase = struct { + fn case(comptime body: []const u8, boundary: []const u8, expected_parts: []const MultipartFormField) !void { + var src = std.io.StreamSource{ + .const_buffer = std.io.fixedBufferStream(body), + }; + + var form = openForm(try openMultipart(boundary, src.reader())); + + for (expected_parts) |expected| { + var data = try form.next(std.testing.allocator) orelse return error.TestExpectedEqual; + defer util.deepFree(std.testing.allocator, data); + + try util.testing.expectDeepEqual(expected, data); + } + + try std.testing.expect(try form.next(std.testing.allocator) == null); + } + }.case; + + try testCase( + util.comptimeToCrlf( + \\--abcd + \\Content-Disposition: form-data; name=foo + \\ + \\content + \\--abcd-- + \\ + ), + "abcd", + &.{.{ .name = "foo", .value = "content" }}, + ); + try testCase( + util.comptimeToCrlf( + \\--abcd + \\Content-Disposition: form-data; name=foo + \\ + \\content + \\--abcd + \\Content-Disposition: form-data; name=bar + \\Content-Type: blah + \\ + \\abcd + \\--abcd + \\Content-Disposition: form-data; name=baz; filename="myfile.txt" + \\Content-Type: text/plain + \\ + \\ --abcd + \\ + \\--abcd-- + \\ + ), + "abcd", + &.{ + .{ .name = "foo", .value = "content" }, + .{ .name = "bar", .value = "abcd", .content_type = "blah" }, + .{ + .name = "baz", + .value = " --abcd\r\n", + .content_type = "text/plain", + .filename = "myfile.txt", + }, + }, + ); +} + +test "parseFormData" { + const body = util.comptimeToCrlf( + \\--abcd + \\Content-Disposition: form-data; name=foo + \\ + \\content + \\--abcd-- + \\ + ); + var src = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) }; + const val = try parseFormData(struct { + foo: []const u8, + }, "abcd", src.reader(), std.testing.allocator); + util.deepFree(std.testing.allocator, val); +} diff --git a/src/http/request/parser.zig b/src/http/request/parser.zig index 6ffba12..a11c315 100644 --- a/src/http/request/parser.zig +++ b/src/http/request/parser.zig @@ -93,7 +93,7 @@ fn parseProto(reader: anytype) !http.Protocol { }; } -fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields { +pub fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields { var headers = Fields.init(allocator); var buf: [4096]u8 = undefined; diff --git a/src/http/request/test_parser.zig b/src/http/request/test_parser.zig index 55a66d6..b715528 100644 --- a/src/http/request/test_parser.zig +++ b/src/http/request/test_parser.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const util = @import("util"); const parser = @import("./parser.zig"); const http = @import("../lib.zig"); const t = std.testing; @@ -30,30 +31,9 @@ const test_case = struct { } }; -fn toCrlf(comptime str: []const u8) []const u8 { - comptime { - var buf: [str.len * 2]u8 = undefined; - - @setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2 - - var buf_len: usize = 0; - for (str) |ch| { - if (ch == '\n') { - buf[buf_len] = '\r'; - buf_len += 1; - } - - buf[buf_len] = ch; - buf_len += 1; - } - - return buf[0..buf_len]; - } -} - test "HTTP/1.x parse - No body" { try test_case.parse( - toCrlf( + util.comptimeToCrlf( \\GET / HTTP/1.1 \\ \\ @@ -65,7 +45,7 @@ test "HTTP/1.x parse - No body" { }, ); try test_case.parse( - toCrlf( + util.comptimeToCrlf( \\POST / HTTP/1.1 \\ \\ @@ -77,7 +57,7 @@ test "HTTP/1.x parse - No body" { }, ); try test_case.parse( - toCrlf( + util.comptimeToCrlf( \\GET /url/abcd HTTP/1.1 \\ \\ @@ -89,7 +69,7 @@ test "HTTP/1.x parse - No body" { }, ); try test_case.parse( - toCrlf( + util.comptimeToCrlf( \\GET / HTTP/1.0 \\ \\ @@ -101,7 +81,7 @@ test "HTTP/1.x parse - No body" { }, ); try test_case.parse( - toCrlf( + util.comptimeToCrlf( \\GET /url/abcd HTTP/1.1 \\Content-Type: application/json \\ @@ -115,7 +95,7 @@ test "HTTP/1.x parse - No body" { }, ); try test_case.parse( - toCrlf( + util.comptimeToCrlf( \\GET /url/abcd HTTP/1.1 \\Content-Type: application/json \\Authorization: bearer @@ -163,7 +143,7 @@ test "HTTP/1.x parse - No body" { }, ); try test_case.parse( - toCrlf( + util.comptimeToCrlf( \\GET / HTTP/1.2 \\ \\ @@ -265,7 +245,7 @@ test "HTTP/1.x parse - bad requests" { test "HTTP/1.x parse - Headers" { try test_case.parse( - toCrlf( + util.comptimeToCrlf( \\GET /url/abcd HTTP/1.1 \\Content-Type: application/json \\Content-Type: application/xml diff --git a/src/http/server/response.zig b/src/http/server/response.zig index fdbe9cc..384677d 100644 --- a/src/http/server/response.zig +++ b/src/http/server/response.zig @@ -1,4 +1,5 @@ const std = @import("std"); +const util = @import("util"); const http = @import("../lib.zig"); const Status = http.Status; @@ -169,25 +170,7 @@ test { _ = _tests; } const _tests = struct { - fn toCrlf(comptime str: []const u8) []const u8 { - comptime { - var buf: [str.len * 2]u8 = undefined; - @setEvalBranchQuota(@as(u32, str.len * 2)); - - var len: usize = 0; - for (str) |ch| { - if (ch == '\n') { - buf[len] = '\r'; - len += 1; - } - - buf[len] = ch; - len += 1; - } - - return buf[0..len]; - } - } + const toCrlf = util.comptimeToCrlf; const test_buffer_size = chunk_size * 4; test "ResponseStream no headers empty body" { diff --git a/src/http/test.zig b/src/http/test.zig deleted file mode 100644 index c142f68..0000000 --- a/src/http/test.zig +++ /dev/null @@ -1,5 +0,0 @@ -test { - _ = @import("./request/test_parser.zig"); - _ = @import("./middleware.zig"); - _ = @import("./query.zig"); -} diff --git a/src/http/query.zig b/src/http/urlencode.zig similarity index 58% rename from src/http/query.zig rename to src/http/urlencode.zig index 36b5d33..3f49423 100644 --- a/src/http/query.zig +++ b/src/http/urlencode.zig @@ -1,7 +1,38 @@ const std = @import("std"); const util = @import("util"); -const QueryIter = util.QueryIter; +pub const Iter = struct { + const Pair = struct { + key: []const u8, + value: ?[]const u8, + }; + + iter: std.mem.SplitIterator(u8), + + pub fn from(q: []const u8) Iter { + return Iter{ + .iter = std.mem.split(u8, std.mem.trimLeft(u8, q, "?"), "&"), + }; + } + + pub fn next(self: *Iter) ?Pair { + while (true) { + const part = self.iter.next() orelse return null; + if (part.len == 0) continue; + + const key = std.mem.sliceTo(part, '='); + if (key.len == part.len) return Pair{ + .key = key, + .value = null, + }; + + return Pair{ + .key = key, + .value = part[key.len + 1 ..], + }; + } + } +}; /// Parses a set of query parameters described by the struct `T`. /// @@ -67,25 +98,44 @@ const QueryIter = util.QueryIter; /// Would be used to parse a query string like /// `?foo.baz=12345` /// -pub fn parseQuery(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T { - if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct"); - var iter = QueryIter.from(query); +pub fn parse(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T { + var iter = Iter.from(query); + + var deserializer = Deserializer(T){}; - var fields = Intermediary(T){}; while (iter.next()) |pair| { - // TODO: Hash map - inline for (std.meta.fields(Intermediary(T))) |field| { - if (std.ascii.eqlIgnoreCase(field.name[2..], pair.key)) { - @field(fields, field.name) = if (pair.value) |v| .{ .value = v } else .{ .no_value = {} }; - break; - } - } else std.log.debug("unknown param {s}", .{pair.key}); + try deserializer.setSerializedField(pair.key, pair.value); } - return (try parse(alloc, T, "", "", fields)) orelse error.NoQuery; + return try deserializer.finish(alloc); } -pub fn parseQueryFree(alloc: std.mem.Allocator, val: anytype) void { +fn Deserializer(comptime Result: type) type { + return util.DeserializerContext(Result, ?[]const u8, struct { + pub const options = util.serialize.default_options; + pub fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, maybe_val: ?[]const u8) !T { + const is_optional = comptime std.meta.trait.is(.Optional)(T); + if (maybe_val) |val| { + if (val.len == 0 and is_optional) return null; + + const decoded = try decodeString(alloc, val); + defer alloc.free(decoded); + + return try util.serialize.deserializeString(alloc, T, decoded); + } else { + // If param is present, but without an associated value + return if (is_optional) + null + else if (T == bool) + true + else + error.InvalidValue; + } + } + }); +} + +pub fn parseFree(alloc: std.mem.Allocator, val: anytype) void { util.deepFree(alloc, val); } @@ -110,186 +160,6 @@ fn decodeString(alloc: std.mem.Allocator, val: []const u8) ![]u8 { return list.toOwnedSlice(); } -fn parseScalar(alloc: std.mem.Allocator, comptime T: type, comptime name: []const u8, fields: anytype) !?T { - const param = @field(fields, name); - return switch (param) { - .not_specified => null, - .no_value => try parseQueryValue(alloc, T, null), - .value => |v| try parseQueryValue(alloc, T, v), - }; -} - -fn parse( - alloc: std.mem.Allocator, - comptime T: type, - comptime prefix: []const u8, - comptime name: []const u8, - fields: anytype, -) !?T { - if (comptime isScalar(T)) return parseScalar(alloc, T, prefix ++ "." ++ name, fields); - switch (@typeInfo(T)) { - .Union => |info| { - var result: ?T = null; - inline for (info.fields) |field| { - const F = field.field_type; - - const maybe_value = try parse(alloc, F, prefix, field.name, fields); - if (maybe_value) |value| { - if (result != null) return error.DuplicateUnionField; - - result = @unionInit(T, field.name, value); - } - } - std.log.debug("{any}", .{result}); - return result; - }, - - .Struct => |info| { - var result: T = undefined; - var fields_specified: usize = 0; - errdefer inline for (info.fields) |field, i| { - if (fields_specified < i) util.deepFree(alloc, @field(result, field.name)); - }; - - inline for (info.fields) |field| { - const F = field.field_type; - - var maybe_value: ?F = null; - if (try parse(alloc, F, prefix ++ "." ++ name, field.name, fields)) |v| { - maybe_value = v; - } else if (field.default_value) |default| { - if (comptime @sizeOf(F) != 0) { - maybe_value = try util.deepClone(alloc, @ptrCast(*const F, @alignCast(@alignOf(F), default)).*); - } else { - maybe_value = std.mem.zeroes(F); - } - } - - if (maybe_value) |v| { - fields_specified += 1; - @field(result, field.name) = v; - } - } - - if (fields_specified == 0) { - return null; - } else if (fields_specified != info.fields.len) { - std.log.debug("{} {s} {s}", .{ T, prefix, name }); - return error.PartiallySpecifiedStruct; - } else { - return result; - } - }, - - // Only applies to non-scalar optionals - .Optional => |info| return try parse(alloc, info.child, prefix, name, fields), - - else => @compileError("tmp"), - } -} - -fn recursiveFieldPaths(comptime T: type, comptime prefix: []const u8) []const []const u8 { - comptime { - if (std.meta.trait.is(.Optional)(T)) return recursiveFieldPaths(std.meta.Child(T), prefix); - - var fields: []const []const u8 = &.{}; - - for (std.meta.fields(T)) |f| { - const full_name = prefix ++ f.name; - - if (isScalar(f.field_type)) { - fields = fields ++ @as([]const []const u8, &.{full_name}); - } else { - const field_prefix = if (@typeInfo(f.field_type) == .Union) prefix else full_name ++ "."; - fields = fields ++ recursiveFieldPaths(f.field_type, field_prefix); - } - } - - return fields; - } -} - -const QueryParam = union(enum) { - not_specified: void, - no_value: void, - value: []const u8, -}; - -fn Intermediary(comptime T: type) type { - const field_names = recursiveFieldPaths(T, ".."); - - var fields: [field_names.len]std.builtin.Type.StructField = undefined; - for (field_names) |name, i| fields[i] = .{ - .name = name, - .field_type = QueryParam, - .default_value = &QueryParam{ .not_specified = {} }, - .is_comptime = false, - .alignment = @alignOf(QueryParam), - }; - - return @Type(.{ .Struct = .{ - .layout = .Auto, - .fields = &fields, - .decls = &.{}, - .is_tuple = false, - } }); -} - -fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, maybe_value: ?[]const u8) !T { - const is_optional = comptime std.meta.trait.is(.Optional)(T); - if (maybe_value) |value| { - const Eff = if (is_optional) std.meta.Child(T) else T; - - if (value.len == 0 and is_optional) return null; - - const decoded = try decodeString(alloc, value); - errdefer alloc.free(decoded); - - if (comptime std.meta.trait.isZigString(Eff)) return decoded; - - defer alloc.free(decoded); - - const result = if (comptime std.meta.trait.isIntegral(Eff)) - try std.fmt.parseInt(Eff, decoded, 0) - else if (comptime std.meta.trait.isFloat(Eff)) - try std.fmt.parseFloat(Eff, decoded) - else if (comptime std.meta.trait.is(.Enum)(Eff)) blk: { - _ = std.ascii.lowerString(decoded, decoded); - break :blk std.meta.stringToEnum(Eff, decoded) orelse return error.InvalidEnumValue; - } else if (Eff == bool) blk: { - _ = std.ascii.lowerString(decoded, decoded); - break :blk bool_map.get(decoded) orelse return error.InvalidBool; - } else if (comptime std.meta.trait.hasFn("parse")(Eff)) - try Eff.parse(value) - else - @compileError("Invalid type " ++ @typeName(T)); - - return result; - } else { - // If param is present, but without an associated value - return if (is_optional) - null - else if (T == bool) - true - else - error.InvalidValue; - } -} - -const bool_map = std.ComptimeStringMap(bool, .{ - .{ "true", true }, - .{ "t", true }, - .{ "yes", true }, - .{ "y", true }, - .{ "1", true }, - - .{ "false", false }, - .{ "f", false }, - .{ "no", false }, - .{ "n", false }, - .{ "0", false }, -}); - fn isScalar(comptime T: type) bool { if (comptime std.meta.trait.isZigString(T)) return true; if (comptime std.meta.trait.isIntegral(T)) return true; @@ -304,7 +174,7 @@ fn isScalar(comptime T: type) bool { return false; } -pub fn QueryStringify(comptime Params: type) type { +pub fn EncodeStruct(comptime Params: type) type { return struct { params: Params, pub fn format(v: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { @@ -312,8 +182,8 @@ pub fn QueryStringify(comptime Params: type) type { } }; } -pub fn queryStringify(val: anytype) QueryStringify(@TypeOf(val)) { - return QueryStringify(@TypeOf(val)){ .params = val }; +pub fn encodeStruct(val: anytype) EncodeStruct(@TypeOf(val)) { + return EncodeStruct(@TypeOf(val)){ .params = val }; } fn urlFormatString(writer: anytype, val: []const u8) !void { @@ -375,11 +245,11 @@ fn formatQuery(comptime prefix: []const u8, comptime name: []const u8, params: a } } -test "parseQuery" { +test "parse" { const testCase = struct { fn case(comptime T: type, expected: T, query_string: []const u8) !void { - const result = try parseQuery(std.testing.allocator, T, query_string); - defer parseQueryFree(std.testing.allocator, result); + const result = try parse(std.testing.allocator, T, query_string); + defer parseFree(std.testing.allocator, result); try util.testing.expectDeepEqual(expected, result); } }.case; @@ -465,14 +335,46 @@ test "parseQuery" { try testCase(SubUnion2, .{ .sub = .{ .foo = 1, .val = .{ .baz = "abc" } } }, "sub.foo=1&sub.baz=abc"); } -test "formatQuery" { - try std.testing.expectFmt("", "{}", .{queryStringify(.{})}); - try std.testing.expectFmt("id=3&", "{}", .{queryStringify(.{ .id = 3 })}); - try std.testing.expectFmt("id=3&id2=4&", "{}", .{queryStringify(.{ .id = 3, .id2 = 4 })}); +test "encodeStruct" { + try std.testing.expectFmt("", "{}", .{encodeStruct(.{})}); + try std.testing.expectFmt("id=3&", "{}", .{encodeStruct(.{ .id = 3 })}); + try std.testing.expectFmt("id=3&id2=4&", "{}", .{encodeStruct(.{ .id = 3, .id2 = 4 })}); - try std.testing.expectFmt("str=foo&", "{}", .{queryStringify(.{ .str = "foo" })}); - try std.testing.expectFmt("enum_str=foo&", "{}", .{queryStringify(.{ .enum_str = .foo })}); + try std.testing.expectFmt("str=foo&", "{}", .{encodeStruct(.{ .str = "foo" })}); + try std.testing.expectFmt("enum_str=foo&", "{}", .{encodeStruct(.{ .enum_str = .foo })}); - try std.testing.expectFmt("boolean=false&", "{}", .{queryStringify(.{ .boolean = false })}); - try std.testing.expectFmt("boolean=true&", "{}", .{queryStringify(.{ .boolean = true })}); + try std.testing.expectFmt("boolean=false&", "{}", .{encodeStruct(.{ .boolean = false })}); + try std.testing.expectFmt("boolean=true&", "{}", .{encodeStruct(.{ .boolean = true })}); +} + +test "Iter" { + const testCase = struct { + fn case(str: []const u8, pairs: []const Iter.Pair) !void { + var iter = Iter.from(str); + for (pairs) |pair| { + try util.testing.expectDeepEqual(@as(?Iter.Pair, pair), iter.next()); + } + try std.testing.expect(iter.next() == null); + } + }.case; + + try testCase("", &.{}); + try testCase("abc", &.{.{ .key = "abc", .value = null }}); + try testCase("abc=", &.{.{ .key = "abc", .value = "" }}); + try testCase("abc=def", &.{.{ .key = "abc", .value = "def" }}); + try testCase("abc=def&", &.{.{ .key = "abc", .value = "def" }}); + try testCase("?abc=def&", &.{.{ .key = "abc", .value = "def" }}); + try testCase("?abc=def&foo&bar=baz&qux=", &.{ + .{ .key = "abc", .value = "def" }, + .{ .key = "foo", .value = null }, + .{ .key = "bar", .value = "baz" }, + .{ .key = "qux", .value = "" }, + }); + try testCase("?abc=def&&foo&bar=baz&&qux=&", &.{ + .{ .key = "abc", .value = "def" }, + .{ .key = "foo", .value = null }, + .{ .key = "bar", .value = "baz" }, + .{ .key = "qux", .value = "" }, + }); + try testCase("&=def&", &.{.{ .key = "", .value = "def" }}); } diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 10ecdb7..398424c 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -267,13 +267,13 @@ pub const helpers = struct { try std.fmt.format( writer, "<{s}://{s}/{s}?{}>; rel=\"{s}\"", - .{ @tagName(c.scheme), c.host, path, http.queryStringify(params), rel }, + .{ @tagName(c.scheme), c.host, path, http.urlencode.encodeStruct(params), rel }, ); } else { try std.fmt.format( writer, "<{s}?{}>; rel=\"{s}\"", - .{ path, http.queryStringify(params), rel }, + .{ path, http.urlencode.encodeStruct(params), rel }, ); } // TODO: percent-encode diff --git a/src/main/controllers/api.zig b/src/main/controllers/api.zig index 12f3a1f..9a76c91 100644 --- a/src/main/controllers/api.zig +++ b/src/main/controllers/api.zig @@ -2,6 +2,7 @@ const controllers = @import("../controllers.zig"); const auth = @import("./api/auth.zig"); const communities = @import("./api/communities.zig"); +const drive = @import("./api/drive.zig"); const invites = @import("./api/invites.zig"); const users = @import("./api/users.zig"); const follows = @import("./api/users/follows.zig"); @@ -26,4 +27,6 @@ pub const routes = .{ controllers.apiEndpoint(follows.delete), controllers.apiEndpoint(follows.query_followers), controllers.apiEndpoint(follows.query_following), + controllers.apiEndpoint(drive.upload), + controllers.apiEndpoint(drive.mkdir), }; diff --git a/src/main/controllers/api/drive.zig b/src/main/controllers/api/drive.zig new file mode 100644 index 0000000..f617898 --- /dev/null +++ b/src/main/controllers/api/drive.zig @@ -0,0 +1,144 @@ +const api = @import("api"); +const http = @import("http"); +const util = @import("util"); +const controller_utils = @import("../../controllers.zig").helpers; + +const Uuid = util.Uuid; +const DateTime = util.DateTime; + +pub const drive_path = "/drive/:path*"; +pub const DriveArgs = struct { + path: []const u8, +}; + +pub const query = struct { + pub const method = .GET; + pub const path = drive_path; + pub const Args = DriveArgs; + + pub const Query = struct { + const OrderBy = enum { + created_at, + filename, + }; + + max_items: usize = 20, + + like: ?[]const u8 = null, + + order_by: OrderBy = .created_at, + direction: api.Direction = .descending, + + prev: ?struct { + id: Uuid, + order_val: union(OrderBy) { + created_at: DateTime, + filename: []const u8, + }, + } = null, + + page_direction: api.PageDirection = .forward, + }; + + pub fn handler(req: anytype, res: anytype, srv: anytype) !void { + const result = srv.driveQuery(req.args.path, req.query) catch |err| switch (err) { + error.NotADirectory => { + const meta = try srv.getFile(path); + try res.json(.ok, meta); + return; + }, + else => |e| return e, + }; + + try controller_utils.paginate(result, res, req.allocator); + } +}; + +pub const upload = struct { + pub const method = .POST; + pub const path = drive_path; + pub const Args = DriveArgs; + + pub const Body = struct { + file: http.FormFile, + description: ?[]const u8 = null, + sensitive: bool = false, + }; + + pub fn handler(req: anytype, res: anytype, srv: anytype) !void { + const f = req.body.file; + try srv.uploadFile(.{ + .dir = req.args.path, + .filename = f.filename, + .description = req.body.description, + .content_type = f.content_type, + .sensitive = req.body.sensitive, + }, f.data); + + // TODO: print meta + try res.json(.created, .{}); + } +}; + +pub const delete = struct { + pub const method = .DELETE; + pub const path = drive_path; + pub const Args = DriveArgs; + + pub fn handler(req: anytype, res: anytype, srv: anytype) !void { + const info = try srv.driveLookup(req.args.path); + if (info == .dir) + try srv.driveRmdir(req.args.path) + else if (info == .file) + try srv.deleteFile(req.args.path); + + return res.json(.ok, .{}); + } +}; + +pub const mkdir = struct { + pub const method = .MKCOL; + pub const path = drive_path; + pub const Args = DriveArgs; + + pub fn handler(req: anytype, res: anytype, srv: anytype) !void { + try srv.driveMkdir(req.args.path); + + return res.json(.created, .{}); + } +}; + +pub const update = struct { + pub const method = .PUT; + pub const path = drive_path; + pub const Args = DriveArgs; + + pub const Body = struct { + description: ?[]const u8 = null, + content_type: ?[]const u8 = null, + sensitive: ?bool = null, + }; + + pub fn handler(req: anytype, res: anytype, srv: anytype) !void { + const info = try srv.driveLookup(req.args.path); + if (info != .file) return error.NotFile; + + const new_info = try srv.updateFile(path, req.body); + try res.json(.ok, new_info); + } +}; + +pub const move = struct { + pub const method = .MOVE; + pub const path = drive_path; + pub const Args = DriveArgs; + + pub fn handler(req: anytype, res: anytype, srv: anytype) !void { + const destination = req.fields.get("Destination") orelse return error.NoDestination; + + try srv.driveMove(req.args.path, destination); + + try res.fields.put("Location", destination); + try srv.json(.created, .{}); + } +}; diff --git a/src/main/migrations.zig b/src/main/migrations.zig index dc97b44..a7465cb 100644 --- a/src/main/migrations.zig +++ b/src/main/migrations.zig @@ -19,8 +19,9 @@ fn execScript(db: anytype, script: []const u8, alloc: std.mem.Allocator) !void { const tx = try db.beginOrSavepoint(); errdefer tx.rollback(); - var iter = util.SqlStmtIter.from(script); + var iter = std.mem.split(u8, script, ";"); while (iter.next()) |stmt| { + if (stmt.len == 0) continue; try execStmt(tx, stmt, alloc); } @@ -208,22 +209,139 @@ const migrations: []const Migration = &.{ .{ .name = "files", .up = - \\CREATE TABLE drive_file( + \\CREATE TABLE file_upload( \\ id UUID NOT NULL PRIMARY KEY, \\ - \\ filename TEXT NOT NULL, - \\ account_owner_id UUID REFERENCES account(id), - \\ community_owner_id UUID REFERENCES community(id), + \\ created_by UUID REFERENCES account(id), \\ size INTEGER NOT NULL, \\ + \\ filename TEXT NOT NULL, + \\ description TEXT, + \\ content_type TEXT, + \\ sensitive BOOLEAN NOT NULL, + \\ + \\ is_deleted BOOLEAN NOT NULL DEFAULT FALSE, + \\ \\ created_at TIMESTAMPTZ NOT NULL, + \\ updated_at TIMESTAMPTZ NOT NULL + \\); + \\ + \\CREATE TABLE drive_entry( + \\ id UUID NOT NULL PRIMARY KEY, + \\ + \\ account_owner_id UUID REFERENCES account(id), + \\ community_owner_id UUID REFERENCES community(id), + \\ + \\ name TEXT, + \\ parent_directory_id UUID REFERENCES drive_entry(id), + \\ + \\ file_id UUID REFERENCES file_upload(id), \\ \\ CHECK( \\ (account_owner_id IS NULL AND community_owner_id IS NOT NULL) \\ OR (account_owner_id IS NOT NULL AND community_owner_id IS NULL) + \\ ), + \\ CHECK( + \\ (name IS NULL AND parent_directory_id IS NULL AND file_id IS NULL) + \\ OR (name IS NOT NULL AND parent_directory_id IS NOT NULL) \\ ) \\); + \\CREATE UNIQUE INDEX drive_entry_uniqueness + \\ON drive_entry( + \\ name, + \\ COALESCE(parent_directory_id, ''), + \\ COALESCE(account_owner_id, community_owner_id) + \\); , - .down = "DROP TABLE drive_file", + .down = + \\DROP INDEX drive_entry_uniqueness; + \\DROP TABLE drive_entry; + \\DROP TABLE file_upload; + , + }, + .{ + .name = "drive_entry_path", + .up = + \\CREATE VIEW drive_entry_path( + \\ id, + \\ path, + \\ account_owner_id, + \\ community_owner_id, + \\ kind + \\) AS WITH RECURSIVE full_path( + \\ id, + \\ path, + \\ account_owner_id, + \\ community_owner_id, + \\ kind + \\) AS ( + \\ SELECT + \\ id, + \\ '' AS path, + \\ account_owner_id, + \\ community_owner_id, + \\ 'dir' AS kind + \\ FROM drive_entry + \\ WHERE parent_directory_id IS NULL + \\ UNION ALL + \\ SELECT + \\ base.id, + \\ (dir.path || '/' || base.name) AS path, + \\ base.account_owner_id, + \\ base.community_owner_id, + \\ (CASE WHEN base.file_id IS NULL THEN 'dir' ELSE 'file' END) as kind + \\ FROM drive_entry AS base + \\ JOIN full_path AS dir ON + \\ base.parent_directory_id = dir.id + \\ AND base.account_owner_id IS NOT DISTINCT FROM dir.account_owner_id + \\ AND base.community_owner_id IS NOT DISTINCT FROM dir.community_owner_id + \\) + \\SELECT + \\ id, + \\ (CASE WHEN kind = 'dir' THEN path || '/' ELSE path END) AS path, + \\ account_owner_id, + \\ community_owner_id, + \\ kind + \\FROM full_path; + , + .down = + \\DROP VIEW drive_entry_path; + , + }, + .{ + .name = "create drive root directories", + .up = + \\INSERT INTO drive_entry( + \\ id, + \\ account_owner_id, + \\ community_owner_id, + \\ parent_directory_id, + \\ name, + \\ file_id + \\) SELECT + \\ id, + \\ id AS account_owner_id, + \\ NULL AS community_owner_id, + \\ NULL AS parent_directory_id, + \\ NULL AS name, + \\ NULL AS file_id + \\FROM account; + \\INSERT INTO drive_entry( + \\ id, + \\ account_owner_id, + \\ community_owner_id, + \\ parent_directory_id, + \\ name, + \\ file_id + \\) SELECT + \\ id, + \\ NULL AS account_owner_id, + \\ id AS community_owner_id, + \\ NULL AS parent_directory_id, + \\ NULL AS name, + \\ NULL AS file_id + \\FROM community; + , + .down = "", }, }; diff --git a/src/sql/engines/common.zig b/src/sql/engines/common.zig index b50b7d0..93169c4 100644 --- a/src/sql/engines/common.zig +++ b/src/sql/engines/common.zig @@ -88,7 +88,7 @@ pub fn prepareParamText(arena: *std.heap.ArenaAllocator, val: anytype) !?[:0]con else => |T| switch (@typeInfo(T)) { .Enum => return @tagName(val), .Optional => if (val) |v| try prepareParamText(arena, v) else null, - .Int => try std.fmt.allocPrintZ(arena.allocator(), "{}", .{val}), + .Bool, .Int => try std.fmt.allocPrintZ(arena.allocator(), "{}", .{val}), .Union => loop: inline for (std.meta.fields(T)) |field| { // Have to do this in a roundabout way to satisfy comptime checker const Tag = std.meta.Tag(T); diff --git a/src/sql/engines/sqlite.zig b/src/sql/engines/sqlite.zig index 5b59910..3b9c8c4 100644 --- a/src/sql/engines/sqlite.zig +++ b/src/sql/engines/sqlite.zig @@ -193,6 +193,7 @@ pub const Db = struct { .Null => return self.bindNull(stmt, idx), .Int => return self.bindInt(stmt, idx, std.math.cast(i64, val) orelse unreachable), .Float => return self.bindFloat(stmt, idx, val), + .Bool => return self.bindInt(stmt, idx, if (val) 1 else 0), else => @compileError("Unable to serialize type " ++ @typeName(T)), } } @@ -251,18 +252,20 @@ pub const Results = struct { db: *c.sqlite3, pub fn finish(self: Results) void { - switch (c.sqlite3_finalize(self.stmt)) { - c.SQLITE_OK => {}, - else => |err| { - handleUnexpectedError(self.db, err, self.getGeneratingSql()) catch {}; - }, - } + _ = c.sqlite3_finalize(self.stmt); } pub fn row(self: Results) common.RowError!?Row { return switch (c.sqlite3_step(self.stmt)) { c.SQLITE_ROW => Row{ .stmt = self.stmt, .db = self.db }, c.SQLITE_DONE => null, + + c.SQLITE_CONSTRAINT_UNIQUE => return error.UniqueViolation, + c.SQLITE_CONSTRAINT_CHECK => return error.CheckViolation, + c.SQLITE_CONSTRAINT_NOTNULL => return error.NotNullViolation, + c.SQLITE_CONSTRAINT_FOREIGNKEY => return error.ForeignKeyViolation, + c.SQLITE_CONSTRAINT => return error.ConstraintViolation, + else => |err| handleUnexpectedError(self.db, err, self.getGeneratingSql()), }; } diff --git a/src/sql/lib.zig b/src/sql/lib.zig index 358a2d3..69c371c 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -144,42 +144,16 @@ fn fieldPtr(ptr: anytype, comptime names: []const []const u8) FieldPtr(@TypeOf(p return fieldPtr(&@field(ptr.*, names[0]), names[1..]); } -fn isScalar(comptime T: type) bool { - if (comptime std.meta.trait.isZigString(T)) return true; - if (comptime std.meta.trait.isIntegral(T)) return true; - if (comptime std.meta.trait.isFloat(T)) return true; - if (comptime std.meta.trait.is(.Enum)(T)) return true; - if (T == bool) return true; - if (comptime std.meta.trait.hasFn("parse")(T)) return true; - - if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true; - - return false; -} - -fn recursiveFieldPaths(comptime T: type, comptime prefix: []const []const u8) []const []const []const u8 { - comptime { - var fields: []const []const []const u8 = &.{}; - - for (std.meta.fields(T)) |f| { - const full_name = prefix ++ [_][]const u8{f.name}; - if (isScalar(f.field_type)) { - fields = fields ++ [_][]const []const u8{full_name}; - } else { - fields = fields ++ recursiveFieldPaths(f.field_type, full_name); - } - } - - return fields; - } -} - // Represents a set of results. // row() must be called until it returns null, or the query may not complete // Must be deallocated by a call to finish() pub fn Results(comptime T: type) type { // would normally make this a declaration of the struct, but it causes the compiler to crash - const fields = if (T == void) .{} else recursiveFieldPaths(T, &.{}); + const fields = if (T == void) .{} else util.serialize.getRecursiveFieldList( + T, + &.{}, + util.serialize.default_options, + ); return struct { const Self = @This(); @@ -457,6 +431,7 @@ fn Tx(comptime tx_level: u8) type { pub fn rollback(self: Self) void { (if (tx_level < 2) self.rollbackTx() else self.rollbackSavepoint()) catch |err| { std.log.err("Failed to rollback transaction: {}", .{err}); + std.log.err("{any}", .{@errorReturnTrace()}); @panic("TODO: more gracefully handle rollback failures"); }; } @@ -654,7 +629,7 @@ fn Tx(comptime tx_level: u8) type { } fn rollbackUnchecked(self: Self) !void { - try self.exec("ROLLBACK", {}, null); + try self.execInternal("ROLLBACK", {}, null, false); } }; } diff --git a/src/template/lib.zig b/src/template/lib.zig index 8c6c8e9..f1e141d 100644 --- a/src/template/lib.zig +++ b/src/template/lib.zig @@ -601,3 +601,20 @@ const ControlTokenIter = struct { self.peeked_token = token; } }; + +test "template" { + const testCase = struct { + fn case(comptime tmpl: []const u8, args: anytype, expected: []const u8) !void { + var stream = std.io.changeDetectionStream(expected, std.io.null_writer); + try execute(stream.writer(), tmpl, args); + try std.testing.expect(!stream.changeDetected()); + } + }.case; + + try testCase("", .{}, ""); + try testCase("abcd", .{}, "abcd"); + try testCase("{.val}", .{ .val = 3 }, "3"); + try testCase("{#if .val}1{/if}", .{ .val = true }, "1"); + try testCase("{#for .vals |$v|}{$v}{/for}", .{ .vals = [_]u8{ 1, 2, 3 } }, "123"); + try testCase("{#for .vals |$v|=} {$v} {=/for}", .{ .vals = [_]u8{ 1, 2, 3 } }, "123"); +} diff --git a/src/util/Url.zig b/src/util/Url.zig deleted file mode 100644 index 35d74e8..0000000 --- a/src/util/Url.zig +++ /dev/null @@ -1,161 +0,0 @@ -const Url = @This(); -const std = @import("std"); - -scheme: []const u8, -hostport: []const u8, -path: []const u8, -query: []const u8, -fragment: []const u8, - -pub fn parse(url: []const u8) !Url { - const scheme_end = for (url) |ch, i| { - if (ch == ':') break i; - } else return error.InvalidUrl; - - if (url.len < scheme_end + 3 or url[scheme_end + 1] != '/' or url[scheme_end + 1] != '/') return error.InvalidUrl; - - const hostport_start = scheme_end + 3; - const hostport_end = for (url[hostport_start..]) |ch, i| { - if (ch == '/' or ch == '?' or ch == '#') break i + hostport_start; - } else url.len; - - const path_end = for (url[hostport_end..]) |ch, i| { - if (ch == '?' or ch == '#') break i + hostport_end; - } else url.len; - - const query_end = if (!(url.len > path_end and url[path_end] == '?')) - path_end - else for (url[path_end..]) |ch, i| { - if (ch == '#') break i + path_end; - } else url.len; - - const query = url[path_end..query_end]; - const fragment = url[query_end..]; - - return Url{ - .scheme = url[0..scheme_end], - .hostport = url[hostport_start..hostport_end], - .path = url[hostport_end..path_end], - .query = if (query.len > 0) query[1..] else query, - .fragment = if (fragment.len > 0) fragment[1..] else fragment, - }; -} - -pub fn getQuery(self: Url, param: []const u8) ?[]const u8 { - var key_start: usize = 0; - std.log.debug("query: {s}", .{self.query}); - while (key_start < self.query.len) { - const key_end = for (self.query[key_start..]) |ch, i| { - if (ch == '=') break key_start + i; - } else return null; - - const val_start = key_end + 1; - const val_end = for (self.query[val_start..]) |ch, i| { - if (ch == '&') break val_start + i; - } else self.query.len; - - const key = self.query[key_start..key_end]; - if (std.mem.eql(u8, key, param)) return self.query[val_start..val_end]; - - key_start = val_end + 1; - } - - return null; -} - -pub fn strDecode(buf: []u8, str: []const u8) ![]u8 { - var str_i: usize = 0; - var buf_i: usize = 0; - while (str_i < str.len) : ({ - str_i += 1; - buf_i += 1; - }) { - if (buf_i >= buf.len) return error.NoSpaceLeft; - const ch = str[str_i]; - if (ch == '%') { - if (str.len < str_i + 2) return error.BadEscape; - - const hi = try std.fmt.charToDigit(str[str_i + 1], 16); - const lo = try std.fmt.charToDIgit(str[str_i + 2], 16); - str_i += 2; - - buf[buf_i] = (hi << 4) | lo; - } else { - buf[buf_i] = str[str_i]; - } - } - - return buf[0..buf_i]; -} - -fn expectEqualUrl(expected: Url, actual: Url) !void { - const t = @import("std").testing; - try t.expectEqualStrings(expected.scheme, actual.scheme); - try t.expectEqualStrings(expected.hostport, actual.hostport); - try t.expectEqualStrings(expected.path, actual.path); - try t.expectEqualStrings(expected.query, actual.query); - try t.expectEqualStrings(expected.fragment, actual.fragment); -} -test "Url" { - try expectEqualUrl(.{ - .scheme = "https", - .hostport = "example.com", - .path = "", - .query = "", - .fragment = "", - }, try Url.parse("https://example.com")); - - try expectEqualUrl(.{ - .scheme = "https", - .hostport = "example.com:1234", - .path = "", - .query = "", - .fragment = "", - }, try Url.parse("https://example.com:1234")); - - try expectEqualUrl(.{ - .scheme = "http", - .hostport = "example.com", - .path = "/home", - .query = "", - .fragment = "", - }, try Url.parse("http://example.com/home")); - - try expectEqualUrl(.{ - .scheme = "https", - .hostport = "example.com", - .path = "", - .query = "query=abc", - .fragment = "", - }, try Url.parse("https://example.com?query=abc")); - - try expectEqualUrl(.{ - .scheme = "https", - .hostport = "example.com", - .path = "", - .query = "query=abc", - .fragment = "", - }, try Url.parse("https://example.com?query=abc")); - - try expectEqualUrl(.{ - .scheme = "https", - .hostport = "example.com", - .path = "/path/to/resource", - .query = "query=abc", - .fragment = "123", - }, try Url.parse("https://example.com/path/to/resource?query=abc#123")); - - const t = @import("std").testing; - try t.expectError(error.InvalidUrl, Url.parse("https:example.com")); - try t.expectError(error.InvalidUrl, Url.parse("example.com")); -} - -test "Url.getQuery" { - const url = try Url.parse("https://example.com?a=xyz&b=jkl"); - const t = @import("std").testing; - - try t.expectEqualStrings("xyz", url.getQuery("a").?); - try t.expectEqualStrings("jkl", url.getQuery("b").?); - try t.expect(url.getQuery("c") == null); - try t.expect(url.getQuery("xyz") == null); -} diff --git a/src/util/ciutf8.zig b/src/util/ciutf8.zig deleted file mode 100644 index 39d1a20..0000000 --- a/src/util/ciutf8.zig +++ /dev/null @@ -1,106 +0,0 @@ -const std = @import("std"); - -const Hash = std.hash.Wyhash; -const View = std.unicode.Utf8View; -const toLower = std.ascii.toLower; -const isAscii = std.ascii.isASCII; -const hash_seed = 1; - -pub fn hash(str: []const u8) u64 { - // fallback to regular hash on invalid utf8 - const view = View.init(str) catch return Hash.hash(hash_seed, str); - var iter = view.iterator(); - - var h = Hash.init(hash_seed); - - var it = iter.nextCodepointSlice(); - while (it != null) : (it = iter.nextCodepointSlice()) { - if (it.?.len == 1 and isAscii(it.?[0])) { - const ch = [1]u8{toLower(it.?[0])}; - h.update(&ch); - } else { - h.update(it.?); - } - } - - return h.final(); -} - -pub fn eql(a: []const u8, b: []const u8) bool { - if (a.len != b.len) return false; - - const va = View.init(a) catch return std.mem.eql(u8, a, b); - const vb = View.init(b) catch return false; - - var iter_a = va.iterator(); - var iter_b = vb.iterator(); - - var it_a = iter_a.nextCodepointSlice(); - var it_b = iter_b.nextCodepointSlice(); - - while (it_a != null and it_b != null) : ({ - it_a = iter_a.nextCodepointSlice(); - it_b = iter_b.nextCodepointSlice(); - }) { - if (it_a.?.len != it_b.?.len) return false; - - if (it_a.?.len == 1) { - if (isAscii(it_a.?[0]) and isAscii(it_b.?[0])) { - const ch_a = toLower(it_a.?[0]); - const ch_b = toLower(it_b.?[0]); - - if (ch_a != ch_b) return false; - } else if (it_a.?[0] != it_b.?[0]) return false; - } else if (!std.mem.eql(u8, it_a.?, it_b.?)) return false; - } - - return it_a == null and it_b == null; -} - -test "case insensitive eql with utf-8 chars" { - const t = std.testing; - try t.expectEqual(true, eql("abc 💯 def", "aBc 💯 DEF")); - try t.expectEqual(false, eql("xyz 💯 ijk", "aBc 💯 DEF")); - try t.expectEqual(false, eql("abc 💯 def", "aBc x DEF")); - try t.expectEqual(true, eql("💯", "💯")); - try t.expectEqual(false, eql("💯", "a")); - try t.expectEqual(false, eql("💯", "💯 continues")); - try t.expectEqual(false, eql("💯 fsdfs", "💯")); - try t.expectEqual(false, eql("💯", "")); - try t.expectEqual(false, eql("", "💯")); - - try t.expectEqual(true, eql("abc x def", "aBc x DEF")); - try t.expectEqual(false, eql("xyz x ijk", "aBc x DEF")); - try t.expectEqual(true, eql("x", "x")); - try t.expectEqual(false, eql("x", "a")); - try t.expectEqual(false, eql("x", "x continues")); - try t.expectEqual(false, eql("x fsdfs", "x")); - try t.expectEqual(false, eql("x", "")); - try t.expectEqual(false, eql("", "x")); - - try t.expectEqual(true, eql("", "")); -} - -test "case insensitive hash with utf-8 chars" { - const t = std.testing; - try t.expect(hash("abc 💯 def") == hash("aBc 💯 DEF")); - try t.expect(hash("xyz 💯 ijk") != hash("aBc 💯 DEF")); - try t.expect(hash("abc 💯 def") != hash("aBc x DEF")); - try t.expect(hash("💯") == hash("💯")); - try t.expect(hash("💯") != hash("a")); - try t.expect(hash("💯") != hash("💯 continues")); - try t.expect(hash("💯 fsdfs") != hash("💯")); - try t.expect(hash("💯") != hash("")); - try t.expect(hash("") != hash("💯")); - - try t.expect(hash("abc x def") == hash("aBc x DEF")); - try t.expect(hash("xyz x ijk") != hash("aBc x DEF")); - try t.expect(hash("x") == hash("x")); - try t.expect(hash("x") != hash("a")); - try t.expect(hash("x") != hash("x continues")); - try t.expect(hash("x fsdfs") != hash("x")); - try t.expect(hash("x") != hash("")); - try t.expect(hash("") != hash("x")); - - try t.expect(hash("") == hash("")); -} diff --git a/src/util/iters.zig b/src/util/iters.zig deleted file mode 100644 index 5ad2258..0000000 --- a/src/util/iters.zig +++ /dev/null @@ -1,189 +0,0 @@ -const std = @import("std"); - -pub fn Separator(comptime separator: u8) type { - return struct { - const Self = @This(); - str: []const u8, - pub fn from(str: []const u8) Self { - return .{ .str = std.mem.trim(u8, str, &.{separator}) }; - } - - pub fn next(self: *Self) ?[]const u8 { - if (self.str.len == 0) return null; - - const part = std.mem.sliceTo(self.str, separator); - self.str = std.mem.trimLeft(u8, self.str[part.len..], &.{separator}); - - return part; - } - }; -} - -pub const QueryIter = struct { - const Pair = struct { - key: []const u8, - value: ?[]const u8, - }; - - iter: Separator('&'), - - pub fn from(q: []const u8) QueryIter { - return QueryIter{ .iter = Separator('&').from(std.mem.trimLeft(u8, q, "?")) }; - } - - pub fn next(self: *QueryIter) ?Pair { - const part = self.iter.next() orelse return null; - - const key = std.mem.sliceTo(part, '='); - if (key.len == part.len) return Pair{ - .key = key, - .value = null, - }; - - return Pair{ - .key = key, - .value = part[key.len + 1 ..], - }; - } -}; - -pub const PathIter = struct { - is_first: bool, - iter: std.mem.SplitIterator(u8), - - pub fn from(path: []const u8) PathIter { - return .{ .is_first = true, .iter = std.mem.split(u8, path, "/") }; - } - - pub fn next(self: *PathIter) ?[]const u8 { - defer self.is_first = false; - while (self.iter.next()) |it| if (it.len != 0) { - return it; - }; - - if (self.is_first) return self.iter.rest(); - - return null; - } - - pub fn first(self: *PathIter) []const u8 { - std.debug.assert(self.is_first); - return self.next().?; - } - - pub fn rest(self: *PathIter) []const u8 { - return self.iter.rest(); - } -}; - -test "QueryIter" { - const t = @import("std").testing; - if (true) return error.SkipZigTest; - { - var iter = QueryIter.from(""); - try t.expect(iter.next() == null); - try t.expect(iter.next() == null); - } - - { - var iter = QueryIter.from("?"); - try t.expect(iter.next() == null); - try t.expect(iter.next() == null); - } - - { - var iter = QueryIter.from("?abc"); - try t.expectEqual(QueryIter.Pair{ - .key = "abc", - .value = null, - }, iter.next().?); - try t.expect(iter.next() == null); - try t.expect(iter.next() == null); - } - - { - var iter = QueryIter.from("?abc="); - try t.expectEqual(QueryIter.Pair{ - .key = "abc", - .value = "", - }, iter.next().?); - try t.expect(iter.next() == null); - try t.expect(iter.next() == null); - } - - { - var iter = QueryIter.from("?abc=def"); - try t.expectEqual(QueryIter.Pair{ - .key = "abc", - .value = "def", - }, iter.next().?); - try t.expect(iter.next() == null); - try t.expect(iter.next() == null); - } - - { - var iter = QueryIter.from("?abc=def&"); - try t.expectEqual(QueryIter.Pair{ - .key = "abc", - .value = "def", - }, iter.next().?); - try t.expect(iter.next() == null); - try t.expect(iter.next() == null); - } - - { - var iter = QueryIter.from("?abc=def&foo&bar=baz&qux="); - try t.expectEqual(QueryIter.Pair{ - .key = "abc", - .value = "def", - }, iter.next().?); - try t.expectEqual(QueryIter.Pair{ - .key = "foo", - .value = null, - }, iter.next().?); - try t.expectEqual(QueryIter.Pair{ - .key = "bar", - .value = "baz", - }, iter.next().?); - try t.expectEqual(QueryIter.Pair{ - .key = "qux", - .value = "", - }, iter.next().?); - try t.expect(iter.next() == null); - try t.expect(iter.next() == null); - } - - { - var iter = QueryIter.from("?=def&"); - try t.expectEqual(QueryIter.Pair{ - .key = "", - .value = "def", - }, iter.next().?); - try t.expect(iter.next() == null); - try t.expect(iter.next() == null); - } -} - -test "PathIter /ab/cd/" { - const path = "/ab/cd/"; - var it = PathIter.from(path); - try std.testing.expectEqualStrings("ab", it.next().?); - try std.testing.expectEqualStrings("cd", it.next().?); - try std.testing.expectEqual(@as(?[]const u8, null), it.next()); -} - -test "PathIter ''" { - const path = ""; - var it = PathIter.from(path); - try std.testing.expectEqualStrings("", it.next().?); - try std.testing.expectEqual(@as(?[]const u8, null), it.next()); -} - -test "PathIter ab/c//defg/" { - const path = "ab/c//defg/"; - var it = PathIter.from(path); - try std.testing.expectEqualStrings("ab", it.next().?); - try std.testing.expectEqualStrings("c", it.next().?); - try std.testing.expectEqualStrings("defg", it.next().?); - try std.testing.expectEqual(@as(?[]const u8, null), it.next()); -} diff --git a/src/util/lib.zig b/src/util/lib.zig index 9829ae3..fb2bee0 100644 --- a/src/util/lib.zig +++ b/src/util/lib.zig @@ -1,13 +1,10 @@ const std = @import("std"); -const iters = @import("./iters.zig"); -pub const ciutf8 = @import("./ciutf8.zig"); pub const Uuid = @import("./Uuid.zig"); pub const DateTime = @import("./DateTime.zig"); -pub const Url = @import("./Url.zig"); -pub const PathIter = iters.PathIter; -pub const QueryIter = iters.QueryIter; -pub const SqlStmtIter = iters.Separator(';'); +pub const serialize = @import("./serialize.zig"); +pub const Deserializer = serialize.Deserializer; +pub const DeserializerContext = serialize.DeserializerContext; /// Joins an array of strings, prefixing every entry with `prefix`, /// and putting `separator` in between each pair @@ -202,6 +199,16 @@ pub fn seedThreadPrng() !void { prng = std.rand.DefaultPrng.init(@bitCast(u64, buf)); } +pub fn comptimeToCrlf(comptime str: []const u8) []const u8 { + comptime { + @setEvalBranchQuota(str.len * 10); + const size = std.mem.replacementSize(u8, str, "\n", "\r\n"); + var buf: [size]u8 = undefined; + _ = std.mem.replace(u8, str, "\n", "\r\n", &buf); + return &buf; + } +} + pub const testing = struct { pub fn expectDeepEqual(expected: anytype, actual: @TypeOf(expected)) !void { const T = @TypeOf(expected); @@ -242,3 +249,7 @@ pub const testing = struct { } } }; + +test { + _ = std.testing.refAllDecls(@This()); +} diff --git a/src/util/serialize.zig b/src/util/serialize.zig new file mode 100644 index 0000000..53d882f --- /dev/null +++ b/src/util/serialize.zig @@ -0,0 +1,386 @@ +const std = @import("std"); +const util = @import("./lib.zig"); + +pub const FieldRef = []const []const u8; + +pub fn defaultIsScalar(comptime T: type) bool { + if (comptime std.meta.trait.is(.Optional)(T) and defaultIsScalar(std.meta.Child(T))) return true; + if (comptime std.meta.trait.isZigString(T)) return true; + if (comptime std.meta.trait.isIntegral(T)) return true; + if (comptime std.meta.trait.isFloat(T)) return true; + if (comptime std.meta.trait.is(.Enum)(T)) return true; + if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true; + if (comptime std.meta.trait.hasFn("parse")(T)) return true; + if (T == bool) return true; + + return false; +} + +pub fn deserializeString(allocator: std.mem.Allocator, comptime T: type, value: []const u8) !T { + if (comptime std.meta.trait.is(.Optional)(T)) { + if (value.len == 0) return null; + return try deserializeString(allocator, std.meta.Child(T), value); + } + + if (T == []u8 or T == []const u8) return try util.deepClone(allocator, value); + if (comptime std.meta.trait.isIntegral(T)) return try std.fmt.parseInt(T, value, 0); + if (comptime std.meta.trait.isFloat(T)) return try std.fmt.parseFloat(T, value); + if (comptime std.meta.trait.hasFn("parse")(T)) return try T.parse(value); + + var buf: [64]u8 = undefined; + const lowered = std.ascii.lowerString(&buf, value); + + if (T == bool) return bool_map.get(lowered) orelse return error.InvalidBool; + if (comptime std.meta.trait.is(.Enum)(T)) { + return std.meta.stringToEnum(T, lowered) orelse return error.InvalidEnumTag; + } + + @compileError("Invalid type " ++ @typeName(T)); +} + +pub fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: SerializationOptions) []const FieldRef { + comptime { + if (std.meta.trait.is(.Union)(T) and prefix.len == 0 and options.embed_unions) { + @compileError("Cannot embed a union into nothing"); + } + + if (options.isScalar(T)) return &.{prefix}; + if (std.meta.trait.is(.Optional)(T)) return getRecursiveFieldList(std.meta.Child(T), prefix, options); + + const eff_prefix: FieldRef = if (std.meta.trait.is(.Union)(T) and options.embed_unions) + prefix[0 .. prefix.len - 1] + else + prefix; + + var fields: []const FieldRef = &.{}; + + for (std.meta.fields(T)) |f| { + const new_prefix = eff_prefix ++ &[_][]const u8{f.name}; + const F = f.field_type; + fields = fields ++ getRecursiveFieldList(F, new_prefix, options); + } + + return fields; + } +} + +pub const SerializationOptions = struct { + embed_unions: bool, + isScalar: fn (type) bool, +}; + +pub const default_options = SerializationOptions{ + .embed_unions = true, + .isScalar = defaultIsScalar, +}; + +fn Intermediary(comptime Result: type, comptime From: type, comptime options: SerializationOptions) type { + const field_refs = getRecursiveFieldList(Result, &.{}, options); + + var fields: [field_refs.len]std.builtin.Type.StructField = undefined; + for (field_refs) |ref, i| { + fields[i] = .{ + .name = util.comptimeJoin(".", ref), + .field_type = ?From, + .default_value = &@as(?From, null), + .is_comptime = false, + .alignment = @alignOf(?From), + }; + } + + return @Type(.{ .Struct = .{ + .layout = .Auto, + .fields = &fields, + .decls = &.{}, + .is_tuple = false, + } }); +} + +pub fn Deserializer(comptime Result: type) type { + return DeserializerContext(Result, []const u8, struct { + const options = default_options; + fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, val: []const u8) !T { + return try deserializeString(alloc, T, val); + } + }); +} + +pub fn DeserializerContext(comptime Result: type, comptime From: type, comptime Context: type) type { + return struct { + const Data = Intermediary(Result, From, Context.options); + + data: Data = .{}, + context: Context = .{}, + + pub fn setSerializedField(self: *@This(), key: []const u8, value: From) !void { + const field = std.meta.stringToEnum(std.meta.FieldEnum(Data), key) orelse return error.UnknownField; + inline for (comptime std.meta.fieldNames(Data)) |field_name| { + @setEvalBranchQuota(10000); + const f = comptime std.meta.stringToEnum(std.meta.FieldEnum(Data), field_name); + if (field == f) { + @field(self.data, field_name) = value; + return; + } + } + + unreachable; + } + + pub const Iter = struct { + data: *const Data, + field_index: usize, + + const Item = struct { + key: []const u8, + value: From, + }; + + pub fn next(self: *Iter) ?Item { + while (self.field_index < std.meta.fields(Data).len) { + const idx = self.field_index; + self.field_index += 1; + inline for (comptime std.meta.fieldNames(Data)) |field, i| { + if (i == idx) { + const maybe_value = @field(self.data.*, field); + if (maybe_value) |value| return Item{ .key = field, .value = value }; + } + } + } + + return null; + } + }; + + pub fn iterator(self: *const @This()) Iter { + return .{ .data = &self.data, .field_index = 0 }; + } + + pub fn finishFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void { + util.deepFree(allocator, val); + } + + pub fn finish(self: *@This(), allocator: std.mem.Allocator) !Result { + return (try self.deserialize(allocator, Result, &.{})) orelse error.MissingField; + } + + fn getSerializedField(self: *@This(), comptime field_ref: FieldRef) ?From { + //inline for (comptime std.meta.fieldNames(Data)) |f| @compileLog(f.ptr); + return @field(self.data, util.comptimeJoin(".", field_ref)); + } + + fn deserializeFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void { + util.deepFree(allocator, val); + } + + fn deserialize(self: *@This(), allocator: std.mem.Allocator, comptime T: type, comptime field_ref: FieldRef) !?T { + if (comptime Context.options.isScalar(T)) { + return try self.context.deserializeScalar(allocator, T, self.getSerializedField(field_ref) orelse return null); + } + + switch (@typeInfo(T)) { + // At most one of any union field can be active at a time, and it is embedded + // in its parent container + .Union => |info| { + var result: ?T = null; + errdefer if (result) |v| self.deserializeFree(allocator, v); + // TODO: errdefer cleanup + const union_ref: FieldRef = if (Context.options.embed_unions) field_ref[0 .. field_ref.len - 1] else field_ref; + inline for (info.fields) |field| { + const F = field.field_type; + const new_field_ref = union_ref ++ &[_][]const u8{field.name}; + const maybe_value = try self.deserialize(allocator, F, new_field_ref); + if (maybe_value) |value| { + // TODO: errdefer cleanup + errdefer self.deserializeFree(allocator, value); + if (result != null) return error.DuplicateUnionMember; + result = @unionInit(T, field.name, value); + } + } + return result; + }, + + .Struct => |info| { + var result: T = undefined; + + var any_explicit = false; + var any_missing = false; + var fields_alloced = [1]bool{false} ** info.fields.len; + errdefer inline for (info.fields) |field, i| { + if (fields_alloced[i]) self.deserializeFree(allocator, @field(result, field.name)); + }; + inline for (info.fields) |field, i| { + const F = field.field_type; + const new_field_ref = field_ref ++ &[_][]const u8{field.name}; + const maybe_value = try self.deserialize(allocator, F, new_field_ref); + if (maybe_value) |v| { + @field(result, field.name) = v; + fields_alloced[i] = true; + any_explicit = true; + } else if (field.default_value) |ptr| { + if (@sizeOf(F) != 0) { + const cast_ptr = @ptrCast(*const F, @alignCast(field.alignment, ptr)); + @field(result, field.name) = try util.deepClone(allocator, cast_ptr.*); + fields_alloced[i] = true; + } + } else { + any_missing = true; + } + } + if (any_missing) { + return if (any_explicit) error.MissingField else null; + } + + return result; + }, + + // Specifically non-scalar optionals + .Optional => |info| return try self.deserialize(allocator, info.child, field_ref), + + else => @compileError("Unsupported type"), + } + } + }; +} + +const bool_map = std.ComptimeStringMap(bool, .{ + .{ "true", true }, + .{ "t", true }, + .{ "yes", true }, + .{ "y", true }, + .{ "1", true }, + + .{ "false", false }, + .{ "f", false }, + .{ "no", false }, + .{ "n", false }, + .{ "0", false }, +}); + +test "Deserializer" { + + // Happy case - simple + { + const T = struct { foo: []const u8, bar: bool }; + + var ds = Deserializer(T){}; + try ds.setSerializedField("foo", "123"); + try ds.setSerializedField("bar", "true"); + + const val = try ds.finish(std.testing.allocator); + defer ds.finishFree(std.testing.allocator, val); + try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val); + } + + // Returns error if nonexistent field set + { + const T = struct { foo: []const u8, bar: bool }; + + var ds = Deserializer(T){}; + try std.testing.expectError(error.UnknownField, ds.setSerializedField("baz", "123")); + } + + // Substruct dereferencing + { + const T = struct { + foo: struct { bar: bool, baz: bool }, + }; + + var ds = Deserializer(T){}; + try ds.setSerializedField("foo.bar", "true"); + try ds.setSerializedField("foo.baz", "true"); + + const val = try ds.finish(std.testing.allocator); + defer ds.finishFree(std.testing.allocator, val); + try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true, .baz = true } }, val); + } + + // Union embedding + { + const T = struct { + foo: union(enum) { bar: bool, baz: bool }, + }; + + var ds = Deserializer(T){}; + try ds.setSerializedField("bar", "true"); + + const val = try ds.finish(std.testing.allocator); + defer ds.finishFree(std.testing.allocator, val); + try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true } }, val); + } + + // Returns error if multiple union fields specified + { + const T = struct { + foo: union(enum) { bar: bool, baz: bool }, + }; + + var ds = Deserializer(T){}; + try ds.setSerializedField("bar", "true"); + try ds.setSerializedField("baz", "true"); + + try std.testing.expectError(error.DuplicateUnionMember, ds.finish(std.testing.allocator)); + } + + // Uses default values if fields aren't provided + { + const T = struct { foo: []const u8 = "123", bar: bool = true }; + + var ds = Deserializer(T){}; + + const val = try ds.finish(std.testing.allocator); + defer ds.finishFree(std.testing.allocator, val); + try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val); + } + + // Returns an error if fields aren't provided and no default exists + { + const T = struct { foo: []const u8, bar: bool }; + + var ds = Deserializer(T){}; + try ds.setSerializedField("foo", "123"); + + try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator)); + } + + // Handles optional containers + { + const T = struct { + foo: ?struct { bar: usize = 3, baz: usize } = null, + qux: ?union(enum) { quux: usize } = null, + }; + + var ds = Deserializer(T){}; + + const val = try ds.finish(std.testing.allocator); + defer ds.finishFree(std.testing.allocator, val); + try util.testing.expectDeepEqual(T{ .foo = null, .qux = null }, val); + } + + { + const T = struct { + foo: ?struct { bar: usize = 3, baz: usize } = null, + qux: ?union(enum) { quux: usize } = null, + }; + + var ds = Deserializer(T){}; + try ds.setSerializedField("foo.baz", "3"); + try ds.setSerializedField("quux", "3"); + + const val = try ds.finish(std.testing.allocator); + defer ds.finishFree(std.testing.allocator, val); + try util.testing.expectDeepEqual(T{ .foo = .{ .bar = 3, .baz = 3 }, .qux = .{ .quux = 3 } }, val); + } + + { + const T = struct { + foo: ?struct { bar: usize = 3, baz: usize } = null, + qux: ?union(enum) { quux: usize } = null, + }; + + var ds = Deserializer(T){}; + try ds.setSerializedField("foo.bar", "3"); + try ds.setSerializedField("quux", "3"); + + try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator)); + } +}