diff --git a/.gitignore b/.gitignore index db4419f..7e4c4e7 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,3 @@ **/zig-cache **.db /config.json -/files diff --git a/build.zig b/build.zig index d644d4f..cf43658 100644 --- a/build.zig +++ b/build.zig @@ -99,23 +99,13 @@ pub fn build(b: *std.build.Builder) !void { exe.addSystemIncludePath("/usr/include/"); const unittest_http_cmd = b.step("unit:http", "Run tests for http package"); - const unittest_http = b.addTest("src/http/lib.zig"); + const unittest_http = b.addTest("src/http/test.zig"); unittest_http_cmd.dependOn(&unittest_http.step); unittest_http.addPackage(pkgs.util); - const unittest_util_cmd = b.step("unit:util", "Run tests for util package"); - const unittest_util = b.addTest("src/util/lib.zig"); - unittest_util_cmd.dependOn(&unittest_util.step); - - const unittest_sql_cmd = b.step("unit:sql", "Run tests for sql package"); - const unittest_sql = b.addTest("src/sql/lib.zig"); - unittest_sql_cmd.dependOn(&unittest_sql.step); - unittest_sql.addPackage(pkgs.util); - - const unittest_template_cmd = b.step("unit:template", "Run tests for template package"); - const unittest_template = b.addTest("src/template/lib.zig"); - unittest_template_cmd.dependOn(&unittest_template.step); - //unittest_template.addPackage(pkgs.util); + //const unittest_util_cmd = b.step("unit:util", "Run tests for util package"); + //const unittest_util = b.addTest("src/util/Uuid.zig"); + //unittest_util_cmd.dependOn(&unittest_util.step); //const util_tests = b.addTest("src/util/lib.zig"); //const sql_tests = b.addTest("src/sql/lib.zig"); @@ -125,9 +115,7 @@ pub fn build(b: *std.build.Builder) !void { //const unit_tests = b.step("unit-tests", "Run tests"); const unittest_all = b.step("unit", "Run unit tests"); unittest_all.dependOn(unittest_http_cmd); - unittest_all.dependOn(unittest_util_cmd); - unittest_all.dependOn(unittest_sql_cmd); - unittest_all.dependOn(unittest_template_cmd); + //unittest_all.dependOn(unittest_util_cmd); const api_integration = b.addTest("./tests/api_integration/lib.zig"); api_integration.addPackage(pkgs.opts); diff --git a/src/api/lib.zig b/src/api/lib.zig index 065fdc7..1a46fa9 100644 --- a/src/api/lib.zig +++ b/src/api/lib.zig @@ -9,7 +9,6 @@ const services = struct { const communities = @import("./services/communities.zig"); const actors = @import("./services/actors.zig"); const auth = @import("./services/auth.zig"); - const drive = @import("./services/files.zig"); const invites = @import("./services/invites.zig"); const notes = @import("./services/notes.zig"); const follows = @import("./services/follows.zig"); @@ -137,14 +136,6 @@ pub const FollowerQueryResult = FollowQueryResult; pub const FollowingQueryArgs = FollowQueryArgs; pub const FollowingQueryResult = FollowQueryResult; -pub const UploadFileArgs = struct { - filename: []const u8, - dir: ?[]const u8, - description: ?[]const u8, - content_type: []const u8, - sensitive: bool, -}; - pub fn isAdminSetup(db: sql.Db) !bool { _ = services.communities.adminCommunityId(db) catch |err| switch (err) { error.NotFound => return false, @@ -518,23 +509,5 @@ fn ApiConn(comptime DbConn: type) type { self.allocator, ); } - - pub fn uploadFile(self: *Self, meta: UploadFileArgs, body: []const u8) !void { - const user_id = self.user_id orelse return error.NoToken; - return try services.drive.createFile(self.db, .{ - .dir = meta.dir orelse "/", - .filename = meta.filename, - .owner = .{ .user_id = user_id }, - .created_by = user_id, - .description = meta.description, - .content_type = meta.content_type, - .sensitive = meta.sensitive, - }, body, self.allocator); - } - - pub fn driveMkdir(self: *Self, path: []const u8) !void { - const user_id = self.user_id orelse return error.NoToken; - try services.drive.mkdir(self.db, .{ .user_id = user_id }, path, self.allocator); - } }; } diff --git a/src/api/services/files.zig b/src/api/services/files.zig index 147d049..18c0e9d 100644 --- a/src/api/services/files.zig +++ b/src/api/services/files.zig @@ -11,224 +11,59 @@ pub const FileOwner = union(enum) { pub const DriveFile = struct { id: Uuid, - - path: []const u8, filename: []const u8, - owner: FileOwner, - size: usize, - - description: []const u8, - content_type: []const u8, - sensitive: bool, - created_at: DateTime, - updated_at: DateTime, }; -const EntryType = enum { - dir, - file, -}; +pub const files = struct { + pub fn create(db: anytype, owner: FileOwner, filename: []const u8, data: []const u8, alloc: std.mem.Allocator) !void { + const id = Uuid.randV4(util.getThreadPrng()); + const now = DateTime.now(); -pub const CreateFileArgs = struct { - dir: []const u8, - filename: []const u8, - owner: FileOwner, - created_by: Uuid, - description: ?[]const u8, - content_type: ?[]const u8, - sensitive: bool, -}; - -fn lookupDirectory(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !Uuid { - return (try db.queryRow( - std.meta.Tuple( - &.{util.Uuid}, - ), - \\SELECT id - \\FROM drive_entry_path - \\WHERE - \\ path = (CASE WHEN LENGTH($1) = 0 THEN '/' ELSE '/' || $1 || '/' END) - \\ AND account_owner_id IS NOT DISTINCT FROM $2 - \\ AND community_owner_id IS NOT DISTINCT FROM $3 - \\ AND kind = 'dir' - \\LIMIT 1 - , - .{ - std.mem.trim(u8, path, "/"), - if (owner == .user_id) owner.user_id else null, - if (owner == .community_id) owner.community_id else null, - }, - alloc, - ))[0]; -} - -fn lookup(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !Uuid { - return (try db.queryRow( - std.meta.Tuple( - &.{util.Uuid}, - ), - \\SELECT id - \\FROM drive_entry_path - \\WHERE - \\ path = (CASE WHEN LENGTH($1) = 0 THEN '/' ELSE '/' || $1 || '/' END) - \\ AND account_owner_id IS NOT DISTINCT FROM $2 - \\ AND community_owner_id IS NOT DISTINCT FROM $3 - \\LIMIT 1 - , - .{ - std.mem.trim(u8, path, "/"), - if (owner == .user_id) owner.user_id else null, - if (owner == .community_id) owner.community_id else null, - }, - alloc, - ))[0]; -} - -pub fn mkdir(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !void { - var split = std.mem.splitBackwards(u8, std.mem.trim(u8, path, "/"), "/"); - const name = split.first(); - const dir = split.rest(); - std.log.debug("'{s}' / '{s}'", .{ name, dir }); - - if (name.len == 0) return error.EmptyName; - - const id = Uuid.randV4(util.getThreadPrng()); - - const tx = try db.begin(); - errdefer tx.rollback(); - - const parent = try lookupDirectory(tx, owner, dir, alloc); - - try tx.insert("drive_entry", .{ - .id = id, - - .account_owner_id = if (owner == .user_id) owner.user_id else null, - .community_owner_id = if (owner == .community_id) owner.community_id else null, - - .name = name, - .parent_directory_id = parent, - }, alloc); - try tx.commit(); -} - -pub fn rmdir(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !void { - const tx = try db.begin(); - errdefer tx.rollback(); - - const id = try lookupDirectory(tx, owner, path, alloc); - try tx.exec("DELETE FROM drive_directory WHERE id = $1", .{id}, alloc); - try tx.commit(); -} - -fn insertFileRow(tx: anytype, id: Uuid, filename: []const u8, owner: FileOwner, dir: Uuid, alloc: std.mem.Allocator) !void { - try tx.insert("drive_entry", .{ - .id = id, - - .account_owner_id = if (owner == .user_id) owner.user_id else null, - .community_owner_id = if (owner == .community_id) owner.community_id else null, - - .parent_directory_id = dir, - .name = filename, - - .file_id = id, - }, alloc); -} - -pub fn createFile(db: anytype, args: CreateFileArgs, data: []const u8, alloc: std.mem.Allocator) !void { - const id = Uuid.randV4(util.getThreadPrng()); - const now = DateTime.now(); - - { - var tx = try db.begin(); - errdefer tx.rollback(); - - const dir_id = try lookupDirectory(tx, args.owner, args.dir, alloc); - - try tx.insert("file_upload", .{ + // TODO: assert we're not in a transaction + db.insert("drive_file", .{ .id = id, - - .filename = args.filename, - - .created_by = args.created_by, - .size = data.len, - - .description = args.description, - .content_type = args.content_type, - .sensitive = args.sensitive, - - .is_deleted = false, - + .filename = filename, + .owner = owner, .created_at = now, - .updated_at = now, - }, alloc); - - var sub_tx = try tx.savepoint(); - if (insertFileRow(sub_tx, id, args.filename, args.owner, dir_id, alloc)) |_| { - try sub_tx.release(); - } else |err| { - std.log.debug("{}", .{err}); - switch (err) { - error.UniqueViolation => { - try sub_tx.rollbackSavepoint(); - // Rename the file before trying again - var split = std.mem.split(u8, args.filename, "."); - const name = split.first(); - const ext = split.rest(); - var buf: [256]u8 = undefined; - const drive_filename = try std.fmt.bufPrint(&buf, "{s}.{}.{s}", .{ name, id, ext }); - try insertFileRow(tx, id, drive_filename, args.owner, dir_id, alloc); - }, - else => return error.DatabaseFailure, - } + }, alloc) catch return error.DatabaseFailure; + // Assume the previous statement succeeded and is not stuck in a transaction + errdefer { + db.exec("DELETE FROM drive_file WHERE ID = $1", .{id}, alloc) catch |err| { + std.log.err("Unable to remove file record in DB: {}", .{err}); + }; } - try tx.commit(); + try saveFile(id, data); } - errdefer { - db.exec("DELETE FROM file_upload WHERE ID = $1", .{id}, alloc) catch |err| { - std.log.err("Unable to remove file record in DB: {}", .{err}); - }; - db.exec("DELETE FROM drive_entry WHERE ID = $1", .{id}, alloc) catch |err| { - std.log.err("Unable to remove file record in DB: {}", .{err}); - }; + const data_root = "./files"; + fn saveFile(id: Uuid, data: []const u8) !void { + var dir = try std.fs.cwd().openDir(data_root); + defer dir.close(); + + var file = try dir.createFile(id.toCharArray(), .{ .exclusive = true }); + defer file.close(); + + try file.writer().writeAll(data); + try file.sync(); } - try saveFile(id, data); -} + pub fn deref(alloc: std.mem.Allocator, id: Uuid) ![]const u8 { + var dir = try std.fs.cwd().openDir(data_root); + defer dir.close(); -const data_root = "./files"; -fn saveFile(id: Uuid, data: []const u8) !void { - var dir = try std.fs.cwd().openDir(data_root, .{}); - defer dir.close(); + return dir.readFileAlloc(alloc, id.toCharArray(), 1 << 32); + } - var file = try dir.createFile(&id.toCharArray(), .{ .exclusive = true }); - defer file.close(); + pub fn delete(db: anytype, alloc: std.mem.Allocator, id: Uuid) !void { + var dir = try std.fs.cwd().openDir(data_root); + defer dir.close(); - try file.writer().writeAll(data); - try file.sync(); -} + try dir.deleteFile(id.toCharArray()); -pub fn deref(alloc: std.mem.Allocator, id: Uuid) ![]const u8 { - var dir = try std.fs.cwd().openDir(data_root, .{}); - defer dir.close(); - - return dir.readFileAlloc(alloc, &id.toCharArray(), 1 << 32); -} - -pub fn deleteFile(db: anytype, alloc: std.mem.Allocator, id: Uuid) !void { - var dir = try std.fs.cwd().openDir(data_root, .{}); - defer dir.close(); - - try dir.deleteFile(id.toCharArray()); - - const tx = try db.beginOrSavepoint(); - errdefer tx.rollback(); - - tx.exec("DELETE FROM drive_entry WHERE ID = $1", .{id}, alloc) catch return error.DatabaseFailure; - tx.exec("DELETE FROM file_upload WHERE ID = $1", .{id}, alloc) catch return error.DatabaseFailure; - try tx.commitOrRelease(); -} + db.exec("DELETE FROM drive_file WHERE ID = $1", .{id}, alloc) catch return error.DatabaseFailure; + } +}; diff --git a/src/http/fields.zig b/src/http/headers.zig similarity index 76% rename from src/http/fields.zig rename to src/http/headers.zig index d6319da..1b91865 100644 --- a/src/http/fields.zig +++ b/src/http/headers.zig @@ -1,60 +1,5 @@ const std = @import("std"); -pub const ParamIter = struct { - str: []const u8, - index: usize = 0, - - const Param = struct { - name: []const u8, - value: []const u8, - }; - - pub fn from(str: []const u8) ParamIter { - return .{ .str = str, .index = std.mem.indexOfScalar(u8, str, ';') orelse str.len }; - } - - pub fn fieldValue(self: *ParamIter) []const u8 { - return std.mem.sliceTo(self.str, ';'); - } - - pub fn next(self: *ParamIter) ?Param { - if (self.index >= self.str.len) return null; - - const start = self.index + 1; - const new_start = std.mem.indexOfScalarPos(u8, self.str, start, ';') orelse self.str.len; - self.index = new_start; - - const param = std.mem.trim(u8, self.str[start..new_start], " \t"); - var split = std.mem.split(u8, param, "="); - const name = split.first(); - const value = std.mem.trimLeft(u8, split.rest(), " \t"); - // TODO: handle quoted values - // TODO: handle parse errors - - return Param{ - .name = name, - .value = value, - }; - } -}; - -pub fn getParam(field: []const u8, name: ?[]const u8) ?[]const u8 { - var iter = ParamIter.from(field); - - if (name) |param| { - while (iter.next()) |p| { - if (std.ascii.eqlIgnoreCase(param, p.name)) { - const trimmed = std.mem.trim(u8, p.value, " \t"); - if (trimmed.len >= 2 and trimmed[0] == '"' and trimmed[trimmed.len - 1] == '"') { - return trimmed[1 .. trimmed.len - 1]; - } - return trimmed; - } - } - return null; - } else return iter.fieldValue(); -} - pub const Fields = struct { const HashContext = struct { const hash_seed = 1; diff --git a/src/http/lib.zig b/src/http/lib.zig index c114bea..9fc5a61 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -1,55 +1,27 @@ const std = @import("std"); +const ciutf8 = @import("util").ciutf8; const request = @import("./request.zig"); + const server = @import("./server.zig"); -pub const urlencode = @import("./urlencode.zig"); + pub const socket = @import("./socket.zig"); -const json = @import("./json.zig"); -const multipart = @import("./multipart.zig"); -pub const fields = @import("./fields.zig"); - -pub const Method = enum { - GET, - HEAD, - POST, - PUT, - DELETE, - CONNECT, - OPTIONS, - TRACE, - PATCH, - - // WebDAV methods (we use some of them for the drive system) - MKCOL, - MOVE, - - pub fn requestHasBody(self: Method) bool { - return switch (self) { - .POST, .PUT, .PATCH, .MKCOL, .MOVE => true, - else => false, - }; - } -}; +pub const Method = std.http.Method; pub const Status = std.http.Status; pub const Request = request.Request(server.Stream.Reader); pub const Response = server.Response; -//pub const Handler = server.Handler; +pub const Handler = server.Handler; pub const Server = server.Server; pub const middleware = @import("./middleware.zig"); +pub const queryStringify = @import("./query.zig").queryStringify; -pub const Fields = fields.Fields; - -pub const FormFile = multipart.FormFile; +pub const Fields = @import("./headers.zig").Fields; pub const Protocol = enum { http_1_0, http_1_1, http_1_x, }; - -test { - _ = std.testing.refAllDecls(@This()); -} diff --git a/src/http/middleware.zig b/src/http/middleware.zig index ce4d307..97855b1 100644 --- a/src/http/middleware.zig +++ b/src/http/middleware.zig @@ -14,11 +14,10 @@ /// Terminal middlewares that are not implemented using other middlewares should /// only accept a `void` value for `next_handler`. const std = @import("std"); -const util = @import("util"); const http = @import("./lib.zig"); -const urlencode = @import("./urlencode.zig"); +const util = @import("util"); +const query_utils = @import("./query.zig"); const json_utils = @import("./json.zig"); -const fields = @import("./fields.zig"); /// Takes an iterable of middlewares and chains them together. pub fn apply(middlewares: anytype) Apply(@TypeOf(middlewares)) { @@ -30,20 +29,20 @@ pub fn Apply(comptime Middlewares: type) type { return ApplyInternal(std.meta.fields(Middlewares)); } -fn ApplyInternal(comptime which: []const std.builtin.Type.StructField) type { - if (which.len == 0) return void; +fn ApplyInternal(comptime fields: []const std.builtin.Type.StructField) type { + if (fields.len == 0) return void; return HandlerList( - which[0].field_type, - ApplyInternal(which[1..]), + fields[0].field_type, + ApplyInternal(fields[1..]), ); } -fn applyInternal(middlewares: anytype, comptime which: []const std.builtin.Type.StructField) ApplyInternal(which) { - if (which.len == 0) return {}; +fn applyInternal(middlewares: anytype, comptime fields: []const std.builtin.Type.StructField) ApplyInternal(fields) { + if (fields.len == 0) return {}; return .{ - .first = @field(middlewares, which[0].name), - .next = applyInternal(middlewares, which[1..]), + .first = @field(middlewares, fields[0].name), + .next = applyInternal(middlewares, fields[1..]), }; } @@ -350,71 +349,15 @@ pub fn router(routes: anytype) Router(@TypeOf(routes)) { return Router(@TypeOf(routes)){ .routes = routes }; } -pub const PathIter = struct { - is_first: bool, - iter: std.mem.SplitIterator(u8), - - pub fn from(path: []const u8) PathIter { - return .{ .is_first = true, .iter = std.mem.split(u8, path, "/") }; - } - - pub fn next(self: *PathIter) ?[]const u8 { - defer self.is_first = false; - while (self.iter.next()) |it| if (it.len != 0) { - return it; - }; - - if (self.is_first) return self.iter.rest(); - - return null; - } - - pub fn first(self: *PathIter) []const u8 { - std.debug.assert(self.is_first); - return self.next().?; - } - - pub fn rest(self: *PathIter) []const u8 { - return self.iter.rest(); - } -}; - -test "PathIter" { - const testCase = struct { - fn case(path: []const u8, segments: []const []const u8) !void { - var iter = PathIter.from(path); - for (segments) |s| { - try std.testing.expectEqualStrings(s, iter.next() orelse return error.TestExpectedEqual); - } - try std.testing.expect(iter.next() == null); - } - }.case; - - try testCase("", &.{""}); - try testCase("*", &.{"*"}); - try testCase("/", &.{""}); - try testCase("/ab/cd", &.{ "ab", "cd" }); - try testCase("/ab/cd/", &.{ "ab", "cd" }); - try testCase("/ab/cd//", &.{ "ab", "cd" }); - try testCase("ab", &.{"ab"}); - try testCase("/ab", &.{"ab"}); - try testCase("ab/", &.{"ab"}); - try testCase("ab//ab//", &.{ "ab", "ab" }); -} - // helper function for doing route analysis fn pathMatches(route: []const u8, path: []const u8) bool { - var path_iter = PathIter.from(path); - var route_iter = PathIter.from(route); + var path_iter = util.PathIter.from(path); + var route_iter = util.PathIter.from(route); while (route_iter.next()) |route_segment| { - const path_segment = path_iter.next() orelse ""; - + const path_segment = path_iter.next() orelse return false; if (route_segment.len > 0 and route_segment[0] == ':') { // Route Argument - if (route_segment[route_segment.len - 1] == '*') { - // consume rest of path segments - while (path_iter.next()) |_| {} - } else if (path_segment.len == 0) return false; + if (path_segment.len == 0) return false; } else { if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return false; } @@ -485,10 +428,6 @@ test "route" { try testCase(true, .{ .method = .POST, .path = "/abcd/efgh" }, .POST, "abcd/efgh"); try testCase(true, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "abcd/efgh"); try testCase(true, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz"); - try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/efgh/xyz"); - try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/efgh"); - try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/"); - try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd"); try testCase(false, .{ .method = .POST, .path = "/" }, .GET, "/"); try testCase(false, .{ .method = .GET, .path = "/abcd" }, .GET, ""); @@ -497,21 +436,32 @@ test "route" { try testCase(false, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "/abcd/"); try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/"); try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz/foo"); - try testCase(false, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "defg/abcd"); } /// Mounts a router subtree under a given path. Middlewares further down on the list /// are called with the path prefix specified by `route` removed from the path. /// Must be below `split_uri` on the middleware list. pub fn Mount(comptime route: []const u8) type { - if (std.mem.indexOfScalar(u8, route, ':') != null) @compileError("Route args cannot be mounted"); return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { - const args = try parseArgsFromPath(route ++ "/:path*", struct { path: []const u8 }, ctx.path); + var path_iter = util.PathIter.from(ctx.path); + comptime var route_iter = util.PathIter.from(route); + var path_unused: []const u8 = ctx.path; + + inline while (comptime route_iter.next()) |route_segment| { + if (comptime route_segment.len == 0) continue; + const path_segment = path_iter.next() orelse return error.RouteMismatch; + path_unused = path_iter.rest(); + if (comptime route_segment[0] == ':') { + @compileLog("Argument segments cannot be mounted"); + // Route Argument + } else { + if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; + } + } var new_ctx = ctx; - new_ctx.path = args.path; - + new_ctx.path = path_unused; return next.handle(req, res, new_ctx, {}); } }; @@ -541,33 +491,18 @@ test "mount" { fn parseArgsFromPath(comptime route: []const u8, comptime Args: type, path: []const u8) !Args { var args: Args = undefined; - var path_iter = PathIter.from(path); - comptime var route_iter = PathIter.from(route); - var path_unused: []const u8 = path; - + var path_iter = util.PathIter.from(path); + comptime var route_iter = util.PathIter.from(route); inline while (comptime route_iter.next()) |route_segment| { - const path_segment = path_iter.next() orelse ""; - if (route_segment[0] == ':') { - comptime var name: []const u8 = route_segment[1..]; - var value: []const u8 = path_segment; - + const path_segment = path_iter.next() orelse return error.RouteMismatch; + if (route_segment.len > 0 and route_segment[0] == ':') { // route segment is an argument segment - if (comptime route_segment[route_segment.len - 1] == '*') { - // waste remaining args - while (path_iter.next()) |_| {} - name = route_segment[1 .. route_segment.len - 1]; - value = path_unused; - } else { - if (path_segment.len == 0) return error.RouteMismatch; - } - - const A = @TypeOf(@field(args, name)); - @field(args, name) = try parseArgFromPath(A, value); + if (path_segment.len == 0) return error.RouteMismatch; + const A = @TypeOf(@field(args, route_segment[1..])); + @field(args, route_segment[1..]) = try parseArgFromPath(A, path_segment); } else { - // route segment is a literal segment if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch; } - path_unused = path_iter.rest(); } if (path_iter.next() != null) return error.RouteMismatch; @@ -642,21 +577,6 @@ test "ParsePathArgs" { try testCase("/:id/xyz/:str", struct { id: usize, str: []const u8 }, "/3/xyz/abcd", .{ .id = 3, .str = "abcd" }); try testCase("/:id", struct { id: util.Uuid }, "/" ++ util.Uuid.nil.toCharArray(), .{ .id = util.Uuid.nil }); - try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/abc", .{ .arg = "abc" }); - try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/abc/def", .{ .arg = "abc/def" }); - try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/", .{ .arg = "" }); - - // Compiler crashes if i keep the args named the same as above. - // TODO: Debug this and try to fix it - try testCase("/xyz/:bar*", struct { bar: []const u8 }, "/xyz", .{ .bar = "" }); - - // It's a quirk that the initial / is left in for these cases. However, it results in a path - // that's semantically equivalent so i didn't bother fixing it - try testCase("/:foo*", struct { foo: []const u8 }, "/abc", .{ .foo = "/abc" }); - try testCase("/:foo*", struct { foo: []const u8 }, "/abc/def", .{ .foo = "/abc/def" }); - try testCase("/:foo*", struct { foo: []const u8 }, "/", .{ .foo = "/" }); - try testCase("/:foo*", struct { foo: []const u8 }, "", .{ .foo = "" }); - try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/", .{})); try std.testing.expectError(error.RouteMismatch, testCase("/abcd/:id", struct { id: usize }, "/123", .{})); try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/3/id/blahblah", .{ .id = 3 })); @@ -667,51 +587,41 @@ const BaseContentType = enum { json, url_encoded, octet_stream, - multipart_formdata, other, }; -fn parseBodyFromRequest(comptime T: type, content_type: ?[]const u8, reader: anytype, alloc: std.mem.Allocator) !T { - // Use json by default for now for testing purposes - const eff_type = content_type orelse "application/json"; - const parser_type = matchContentType(eff_type); +fn parseBodyFromRequest(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T { + //@compileLog(T); + const buf = try reader.readAllAlloc(alloc, 1 << 16); + defer alloc.free(buf); - switch (parser_type) { + switch (content_type) { .octet_stream, .json => { - const buf = try reader.readAllAlloc(alloc, 1 << 16); - defer alloc.free(buf); const body = try json_utils.parse(T, buf, alloc); defer json_utils.parseFree(body, alloc); return try util.deepClone(alloc, body); }, - .url_encoded => { - const buf = try reader.readAllAlloc(alloc, 1 << 16); - defer alloc.free(buf); - return urlencode.parse(alloc, T, buf) catch |err| switch (err) { - //error.NoQuery => error.NoBody, - else => err, - }; - }, - .multipart_formdata => { - const boundary = fields.getParam(eff_type, "boundary") orelse return error.MissingBoundary; - return try @import("./multipart.zig").parseFormData(T, boundary, reader, alloc); + .url_encoded => return query_utils.parseQuery(alloc, T, buf) catch |err| switch (err) { + error.NoQuery => error.NoBody, + else => err, }, else => return error.UnsupportedMediaType, } } // figure out what base parser to use -fn matchContentType(hdr: []const u8) BaseContentType { - const trimmed = std.mem.sliceTo(hdr, ';'); - if (std.ascii.eqlIgnoreCase(trimmed, "application/x-www-form-urlencoded")) return .url_encoded; - if (std.ascii.eqlIgnoreCase(trimmed, "application/json")) return .json; - if (std.ascii.endsWithIgnoreCase(trimmed, "+json")) return .json; - if (std.ascii.eqlIgnoreCase(trimmed, "application/octet-stream")) return .octet_stream; - if (std.ascii.eqlIgnoreCase(trimmed, "multipart/form-data")) return .multipart_formdata; +fn matchContentType(hdr: ?[]const u8) ?BaseContentType { + if (hdr) |h| { + if (std.ascii.eqlIgnoreCase(h, "application/x-www-form-urlencoded")) return .url_encoded; + if (std.ascii.eqlIgnoreCase(h, "application/json")) return .json; + if (std.ascii.eqlIgnoreCase(h, "application/octet-stream")) return .octet_stream; - return .other; + return .other; + } + + return null; } /// Parses a set of body arguments from the request body based on the request's Content-Type @@ -730,8 +640,10 @@ pub fn ParseBody(comptime Body: type) type { return next.handle(req, res, new_ctx, {}); } + const base_content_type = matchContentType(content_type); + var stream = req.body orelse return error.NoBody; - const body = try parseBodyFromRequest(Body, content_type, stream.reader(), ctx.allocator); + const body = try parseBodyFromRequest(Body, base_content_type orelse .json, stream.reader(), ctx.allocator); defer util.deepFree(ctx.allocator, body); return next.handle( @@ -747,57 +659,12 @@ pub fn parseBody(comptime Body: type) ParseBody(Body) { return .{}; } -test "parseBodyFromRequest" { - const testCase = struct { - fn case(content_type: []const u8, body: []const u8, expected: anytype) !void { - var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) }; - const result = try parseBodyFromRequest(@TypeOf(expected), content_type, stream.reader(), std.testing.allocator); - defer util.deepFree(std.testing.allocator, result); - - try util.testing.expectDeepEqual(expected, result); - } - }.case; - - const Struct = struct { - id: usize, - }; - try testCase("application/json", "{\"id\": 3}", Struct{ .id = 3 }); - try testCase("application/x-www-form-urlencoded", "id=3", Struct{ .id = 3 }); - - //try testCase("multipart/form-data; ", - //\\ - //, Struct{ .id = 3 }); -} - -test "parseBody" { - const Struct = struct { - foo: []const u8, - }; - const body = - \\{"foo": "bar"} - ; - var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) }; - var headers = http.Fields.init(std.testing.allocator); - defer headers.deinit(); - - try parseBody(Struct).handle( - .{ .body = @as(?std.io.StreamSource, stream), .headers = headers }, - .{}, - .{ .allocator = std.testing.allocator }, - struct { - fn handle(_: anytype, _: anytype, _: anytype, ctx: anytype, _: void) !void { - try util.testing.expectDeepEqual(Struct{ .foo = "bar" }, ctx.body); - } - }{}, - ); -} - /// Parses query parameters as defined in query.zig pub fn ParseQueryParams(comptime QueryParams: type) type { return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { if (QueryParams == void) return next.handle(req, res, addField(ctx, "query_params", {}), {}); - const query = try urlencode.parse(ctx.allocator, QueryParams, ctx.query_string); + const query = try query_utils.parseQuery(ctx.allocator, QueryParams, ctx.query_string); defer util.deepFree(ctx.allocator, query); return next.handle( diff --git a/src/http/multipart.zig b/src/http/multipart.zig deleted file mode 100644 index 815711d..0000000 --- a/src/http/multipart.zig +++ /dev/null @@ -1,362 +0,0 @@ -const std = @import("std"); -const util = @import("util"); -const fields = @import("./fields.zig"); - -const max_boundary = 70; -const read_ahead = max_boundary + 4; - -pub fn MultipartStream(comptime ReaderType: type) type { - return struct { - const Multipart = @This(); - - pub const BaseReader = ReaderType; - pub const PartReader = std.io.Reader(*Part, ReaderType.Error, Part.read); - - stream: std.io.PeekStream(.{ .Static = read_ahead }, ReaderType), - boundary: []const u8, - - pub fn next(self: *Multipart, alloc: std.mem.Allocator) !?Part { - const reader = self.stream.reader(); - while (true) { - try reader.skipUntilDelimiterOrEof('\r'); - var line_buf: [read_ahead]u8 = undefined; - const len = try reader.readAll(line_buf[0 .. self.boundary.len + 3]); - const line = line_buf[0..len]; - if (line.len == 0) return null; - if (std.mem.startsWith(u8, line, "\n--") and std.mem.endsWith(u8, line, self.boundary)) { - // match, check for end thing - var more_buf: [2]u8 = undefined; - if (try reader.readAll(&more_buf) != 2) return error.EndOfStream; - - const more = !(more_buf[0] == '-' and more_buf[1] == '-'); - try self.stream.putBack(&more_buf); - try reader.skipUntilDelimiterOrEof('\n'); - if (more) return try Part.open(self, alloc) else return null; - } - } - } - - pub const Part = struct { - base: ?*Multipart, - fields: fields.Fields, - - pub fn open(base: *Multipart, alloc: std.mem.Allocator) !Part { - var parsed_fields = try @import("./request/parser.zig").parseHeaders(alloc, base.stream.reader()); - return .{ .base = base, .fields = parsed_fields }; - } - - pub fn reader(self: *Part) PartReader { - return .{ .context = self }; - } - - pub fn close(self: *Part) void { - self.fields.deinit(); - } - - pub fn read(self: *Part, buf: []u8) ReaderType.Error!usize { - const base = self.base orelse return 0; - - const r = base.stream.reader(); - - var count: usize = 0; - while (count < buf.len) { - const byte = r.readByte() catch |err| switch (err) { - error.EndOfStream => { - self.base = null; - return count; - }, - else => |e| return e, - }; - - buf[count] = byte; - count += 1; - if (byte != '\r') continue; - - var line_buf: [read_ahead]u8 = undefined; - const line = line_buf[0..try r.readAll(line_buf[0 .. base.boundary.len + 3])]; - if (!std.mem.startsWith(u8, line, "\n--") or !std.mem.endsWith(u8, line, base.boundary)) { - base.stream.putBack(line) catch unreachable; - continue; - } else { - base.stream.putBack(line) catch unreachable; - base.stream.putBackByte('\r') catch unreachable; - self.base = null; - return count - 1; - } - } - - return count; - } - }; - }; -} - -pub fn openMultipart(boundary: []const u8, reader: anytype) !MultipartStream(@TypeOf(reader)) { - if (boundary.len > max_boundary) return error.BoundaryTooLarge; - var stream = .{ - .stream = std.io.peekStream(read_ahead, reader), - .boundary = boundary, - }; - - stream.stream.putBack("\r\n") catch unreachable; - return stream; -} - -const MultipartFormField = struct { - name: []const u8, - value: []const u8, - - filename: ?[]const u8 = null, - content_type: ?[]const u8 = null, -}; - -pub const FormFile = struct { - data: []const u8, - filename: []const u8, - content_type: []const u8, -}; - -pub fn MultipartForm(comptime ReaderType: type) type { - return struct { - stream: MultipartStream(ReaderType), - - pub fn next(self: *@This(), alloc: std.mem.Allocator) !?MultipartFormField { - var part = (try self.stream.next(alloc)) orelse return null; - defer part.close(); - - const disposition = part.fields.get("Content-Disposition") orelse return error.MissingDisposition; - - if (!std.ascii.eqlIgnoreCase(fields.getParam(disposition, null).?, "form-data")) return error.BadDisposition; - const name = try util.deepClone(alloc, fields.getParam(disposition, "name") orelse return error.BadDisposition); - errdefer util.deepFree(alloc, name); - const filename = try util.deepClone(alloc, fields.getParam(disposition, "filename")); - errdefer util.deepFree(alloc, filename); - const content_type = try util.deepClone(alloc, part.fields.get("Content-Type")); - errdefer util.deepFree(alloc, content_type); - - const value = try part.reader().readAllAlloc(alloc, 1 << 32); - - return MultipartFormField{ - .name = name, - .value = value, - - .filename = filename, - .content_type = content_type, - }; - } - }; -} - -pub fn openForm(multipart_stream: anytype) MultipartForm(@TypeOf(multipart_stream).BaseReader) { - return .{ .stream = multipart_stream }; -} - -fn Deserializer(comptime Result: type) type { - return util.DeserializerContext(Result, MultipartFormField, struct { - pub const options = .{ .isScalar = isScalar, .embed_unions = true }; - - pub fn isScalar(comptime T: type) bool { - if (T == FormFile or T == ?FormFile) return true; - return util.serialize.defaultIsScalar(T); - } - - pub fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, val: MultipartFormField) !T { - if (T == FormFile or T == ?FormFile) return try deserializeFormFile(alloc, val); - - if (val.filename != null) return error.FilenameProvidedForNonFile; - return try util.serialize.deserializeString(alloc, T, val.value); - } - - fn deserializeFormFile(alloc: std.mem.Allocator, val: MultipartFormField) !FormFile { - const data = try util.deepClone(alloc, val.value); - errdefer util.deepFree(alloc, data); - const filename = try util.deepClone(alloc, val.filename orelse "(untitled)"); - errdefer util.deepFree(alloc, filename); - const content_type = try util.deepClone(alloc, val.content_type orelse "application/octet-stream"); - return FormFile{ - .data = data, - .filename = filename, - .content_type = content_type, - }; - } - }); -} - -pub fn parseFormData(comptime T: type, boundary: []const u8, reader: anytype, alloc: std.mem.Allocator) !T { - var form = openForm(try openMultipart(boundary, reader)); - - var ds = Deserializer(T){}; - defer { - var iter = ds.iterator(); - while (iter.next()) |pair| { - util.deepFree(alloc, pair.value); - } - } - while (true) { - var part = (try form.next(alloc)) orelse break; - errdefer util.deepFree(alloc, part); - - try ds.setSerializedField(part.name, part); - } - - return try ds.finish(alloc); -} - -// TODO: Fix these tests -test "MultipartStream" { - const ExpectedPart = struct { - disposition: []const u8, - value: []const u8, - }; - const testCase = struct { - fn case(comptime body: []const u8, boundary: []const u8, expected_parts: []const ExpectedPart) !void { - var src = std.io.StreamSource{ - .const_buffer = std.io.fixedBufferStream(body), - }; - - var stream = try openMultipart(boundary, src.reader()); - - for (expected_parts) |expected| { - var part = try stream.next(std.testing.allocator) orelse return error.TestExpectedEqual; - defer part.close(); - - const dispo = part.fields.get("Content-Disposition") orelse return error.TestExpectedEqual; - try std.testing.expectEqualStrings(expected.disposition, dispo); - - var buf: [128]u8 = undefined; - const count = try part.reader().read(&buf); - try std.testing.expectEqualStrings(expected.value, buf[0..count]); - } - - try std.testing.expect(try stream.next(std.testing.allocator) == null); - } - }.case; - - try testCase("--abc--\r\n", "abc", &.{}); - try testCase( - util.comptimeToCrlf( - \\------abcd - \\Content-Disposition: form-data; name=first; charset=utf8 - \\ - \\content - \\------abcd - \\content-Disposition: form-data; name=second - \\ - \\no content - \\------abcd - \\content-disposition: form-data; name=third - \\ - \\ - \\------abcd-- - \\ - ), - "----abcd", - &.{ - .{ .disposition = "form-data; name=first; charset=utf8", .value = "content" }, - .{ .disposition = "form-data; name=second", .value = "no content" }, - .{ .disposition = "form-data; name=third", .value = "" }, - }, - ); - - try testCase( - util.comptimeToCrlf( - \\--xyz - \\Content-Disposition: uhh - \\ - \\xyz - \\--xyz - \\Content-disposition: ok - \\ - \\ --xyz - \\--xyz-- - \\ - ), - "xyz", - &.{ - .{ .disposition = "uhh", .value = "xyz" }, - .{ .disposition = "ok", .value = " --xyz" }, - }, - ); -} - -test "MultipartForm" { - const testCase = struct { - fn case(comptime body: []const u8, boundary: []const u8, expected_parts: []const MultipartFormField) !void { - var src = std.io.StreamSource{ - .const_buffer = std.io.fixedBufferStream(body), - }; - - var form = openForm(try openMultipart(boundary, src.reader())); - - for (expected_parts) |expected| { - var data = try form.next(std.testing.allocator) orelse return error.TestExpectedEqual; - defer util.deepFree(std.testing.allocator, data); - - try util.testing.expectDeepEqual(expected, data); - } - - try std.testing.expect(try form.next(std.testing.allocator) == null); - } - }.case; - - try testCase( - util.comptimeToCrlf( - \\--abcd - \\Content-Disposition: form-data; name=foo - \\ - \\content - \\--abcd-- - \\ - ), - "abcd", - &.{.{ .name = "foo", .value = "content" }}, - ); - try testCase( - util.comptimeToCrlf( - \\--abcd - \\Content-Disposition: form-data; name=foo - \\ - \\content - \\--abcd - \\Content-Disposition: form-data; name=bar - \\Content-Type: blah - \\ - \\abcd - \\--abcd - \\Content-Disposition: form-data; name=baz; filename="myfile.txt" - \\Content-Type: text/plain - \\ - \\ --abcd - \\ - \\--abcd-- - \\ - ), - "abcd", - &.{ - .{ .name = "foo", .value = "content" }, - .{ .name = "bar", .value = "abcd", .content_type = "blah" }, - .{ - .name = "baz", - .value = " --abcd\r\n", - .content_type = "text/plain", - .filename = "myfile.txt", - }, - }, - ); -} - -test "parseFormData" { - const body = util.comptimeToCrlf( - \\--abcd - \\Content-Disposition: form-data; name=foo - \\ - \\content - \\--abcd-- - \\ - ); - var src = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) }; - const val = try parseFormData(struct { - foo: []const u8, - }, "abcd", src.reader(), std.testing.allocator); - util.deepFree(std.testing.allocator, val); -} diff --git a/src/http/urlencode.zig b/src/http/query.zig similarity index 58% rename from src/http/urlencode.zig rename to src/http/query.zig index 3f49423..36b5d33 100644 --- a/src/http/urlencode.zig +++ b/src/http/query.zig @@ -1,38 +1,7 @@ const std = @import("std"); const util = @import("util"); -pub const Iter = struct { - const Pair = struct { - key: []const u8, - value: ?[]const u8, - }; - - iter: std.mem.SplitIterator(u8), - - pub fn from(q: []const u8) Iter { - return Iter{ - .iter = std.mem.split(u8, std.mem.trimLeft(u8, q, "?"), "&"), - }; - } - - pub fn next(self: *Iter) ?Pair { - while (true) { - const part = self.iter.next() orelse return null; - if (part.len == 0) continue; - - const key = std.mem.sliceTo(part, '='); - if (key.len == part.len) return Pair{ - .key = key, - .value = null, - }; - - return Pair{ - .key = key, - .value = part[key.len + 1 ..], - }; - } - } -}; +const QueryIter = util.QueryIter; /// Parses a set of query parameters described by the struct `T`. /// @@ -98,44 +67,25 @@ pub const Iter = struct { /// Would be used to parse a query string like /// `?foo.baz=12345` /// -pub fn parse(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T { - var iter = Iter.from(query); - - var deserializer = Deserializer(T){}; +pub fn parseQuery(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T { + if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct"); + var iter = QueryIter.from(query); + var fields = Intermediary(T){}; while (iter.next()) |pair| { - try deserializer.setSerializedField(pair.key, pair.value); + // TODO: Hash map + inline for (std.meta.fields(Intermediary(T))) |field| { + if (std.ascii.eqlIgnoreCase(field.name[2..], pair.key)) { + @field(fields, field.name) = if (pair.value) |v| .{ .value = v } else .{ .no_value = {} }; + break; + } + } else std.log.debug("unknown param {s}", .{pair.key}); } - return try deserializer.finish(alloc); + return (try parse(alloc, T, "", "", fields)) orelse error.NoQuery; } -fn Deserializer(comptime Result: type) type { - return util.DeserializerContext(Result, ?[]const u8, struct { - pub const options = util.serialize.default_options; - pub fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, maybe_val: ?[]const u8) !T { - const is_optional = comptime std.meta.trait.is(.Optional)(T); - if (maybe_val) |val| { - if (val.len == 0 and is_optional) return null; - - const decoded = try decodeString(alloc, val); - defer alloc.free(decoded); - - return try util.serialize.deserializeString(alloc, T, decoded); - } else { - // If param is present, but without an associated value - return if (is_optional) - null - else if (T == bool) - true - else - error.InvalidValue; - } - } - }); -} - -pub fn parseFree(alloc: std.mem.Allocator, val: anytype) void { +pub fn parseQueryFree(alloc: std.mem.Allocator, val: anytype) void { util.deepFree(alloc, val); } @@ -160,6 +110,186 @@ fn decodeString(alloc: std.mem.Allocator, val: []const u8) ![]u8 { return list.toOwnedSlice(); } +fn parseScalar(alloc: std.mem.Allocator, comptime T: type, comptime name: []const u8, fields: anytype) !?T { + const param = @field(fields, name); + return switch (param) { + .not_specified => null, + .no_value => try parseQueryValue(alloc, T, null), + .value => |v| try parseQueryValue(alloc, T, v), + }; +} + +fn parse( + alloc: std.mem.Allocator, + comptime T: type, + comptime prefix: []const u8, + comptime name: []const u8, + fields: anytype, +) !?T { + if (comptime isScalar(T)) return parseScalar(alloc, T, prefix ++ "." ++ name, fields); + switch (@typeInfo(T)) { + .Union => |info| { + var result: ?T = null; + inline for (info.fields) |field| { + const F = field.field_type; + + const maybe_value = try parse(alloc, F, prefix, field.name, fields); + if (maybe_value) |value| { + if (result != null) return error.DuplicateUnionField; + + result = @unionInit(T, field.name, value); + } + } + std.log.debug("{any}", .{result}); + return result; + }, + + .Struct => |info| { + var result: T = undefined; + var fields_specified: usize = 0; + errdefer inline for (info.fields) |field, i| { + if (fields_specified < i) util.deepFree(alloc, @field(result, field.name)); + }; + + inline for (info.fields) |field| { + const F = field.field_type; + + var maybe_value: ?F = null; + if (try parse(alloc, F, prefix ++ "." ++ name, field.name, fields)) |v| { + maybe_value = v; + } else if (field.default_value) |default| { + if (comptime @sizeOf(F) != 0) { + maybe_value = try util.deepClone(alloc, @ptrCast(*const F, @alignCast(@alignOf(F), default)).*); + } else { + maybe_value = std.mem.zeroes(F); + } + } + + if (maybe_value) |v| { + fields_specified += 1; + @field(result, field.name) = v; + } + } + + if (fields_specified == 0) { + return null; + } else if (fields_specified != info.fields.len) { + std.log.debug("{} {s} {s}", .{ T, prefix, name }); + return error.PartiallySpecifiedStruct; + } else { + return result; + } + }, + + // Only applies to non-scalar optionals + .Optional => |info| return try parse(alloc, info.child, prefix, name, fields), + + else => @compileError("tmp"), + } +} + +fn recursiveFieldPaths(comptime T: type, comptime prefix: []const u8) []const []const u8 { + comptime { + if (std.meta.trait.is(.Optional)(T)) return recursiveFieldPaths(std.meta.Child(T), prefix); + + var fields: []const []const u8 = &.{}; + + for (std.meta.fields(T)) |f| { + const full_name = prefix ++ f.name; + + if (isScalar(f.field_type)) { + fields = fields ++ @as([]const []const u8, &.{full_name}); + } else { + const field_prefix = if (@typeInfo(f.field_type) == .Union) prefix else full_name ++ "."; + fields = fields ++ recursiveFieldPaths(f.field_type, field_prefix); + } + } + + return fields; + } +} + +const QueryParam = union(enum) { + not_specified: void, + no_value: void, + value: []const u8, +}; + +fn Intermediary(comptime T: type) type { + const field_names = recursiveFieldPaths(T, ".."); + + var fields: [field_names.len]std.builtin.Type.StructField = undefined; + for (field_names) |name, i| fields[i] = .{ + .name = name, + .field_type = QueryParam, + .default_value = &QueryParam{ .not_specified = {} }, + .is_comptime = false, + .alignment = @alignOf(QueryParam), + }; + + return @Type(.{ .Struct = .{ + .layout = .Auto, + .fields = &fields, + .decls = &.{}, + .is_tuple = false, + } }); +} + +fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, maybe_value: ?[]const u8) !T { + const is_optional = comptime std.meta.trait.is(.Optional)(T); + if (maybe_value) |value| { + const Eff = if (is_optional) std.meta.Child(T) else T; + + if (value.len == 0 and is_optional) return null; + + const decoded = try decodeString(alloc, value); + errdefer alloc.free(decoded); + + if (comptime std.meta.trait.isZigString(Eff)) return decoded; + + defer alloc.free(decoded); + + const result = if (comptime std.meta.trait.isIntegral(Eff)) + try std.fmt.parseInt(Eff, decoded, 0) + else if (comptime std.meta.trait.isFloat(Eff)) + try std.fmt.parseFloat(Eff, decoded) + else if (comptime std.meta.trait.is(.Enum)(Eff)) blk: { + _ = std.ascii.lowerString(decoded, decoded); + break :blk std.meta.stringToEnum(Eff, decoded) orelse return error.InvalidEnumValue; + } else if (Eff == bool) blk: { + _ = std.ascii.lowerString(decoded, decoded); + break :blk bool_map.get(decoded) orelse return error.InvalidBool; + } else if (comptime std.meta.trait.hasFn("parse")(Eff)) + try Eff.parse(value) + else + @compileError("Invalid type " ++ @typeName(T)); + + return result; + } else { + // If param is present, but without an associated value + return if (is_optional) + null + else if (T == bool) + true + else + error.InvalidValue; + } +} + +const bool_map = std.ComptimeStringMap(bool, .{ + .{ "true", true }, + .{ "t", true }, + .{ "yes", true }, + .{ "y", true }, + .{ "1", true }, + + .{ "false", false }, + .{ "f", false }, + .{ "no", false }, + .{ "n", false }, + .{ "0", false }, +}); + fn isScalar(comptime T: type) bool { if (comptime std.meta.trait.isZigString(T)) return true; if (comptime std.meta.trait.isIntegral(T)) return true; @@ -174,7 +304,7 @@ fn isScalar(comptime T: type) bool { return false; } -pub fn EncodeStruct(comptime Params: type) type { +pub fn QueryStringify(comptime Params: type) type { return struct { params: Params, pub fn format(v: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { @@ -182,8 +312,8 @@ pub fn EncodeStruct(comptime Params: type) type { } }; } -pub fn encodeStruct(val: anytype) EncodeStruct(@TypeOf(val)) { - return EncodeStruct(@TypeOf(val)){ .params = val }; +pub fn queryStringify(val: anytype) QueryStringify(@TypeOf(val)) { + return QueryStringify(@TypeOf(val)){ .params = val }; } fn urlFormatString(writer: anytype, val: []const u8) !void { @@ -245,11 +375,11 @@ fn formatQuery(comptime prefix: []const u8, comptime name: []const u8, params: a } } -test "parse" { +test "parseQuery" { const testCase = struct { fn case(comptime T: type, expected: T, query_string: []const u8) !void { - const result = try parse(std.testing.allocator, T, query_string); - defer parseFree(std.testing.allocator, result); + const result = try parseQuery(std.testing.allocator, T, query_string); + defer parseQueryFree(std.testing.allocator, result); try util.testing.expectDeepEqual(expected, result); } }.case; @@ -335,46 +465,14 @@ test "parse" { try testCase(SubUnion2, .{ .sub = .{ .foo = 1, .val = .{ .baz = "abc" } } }, "sub.foo=1&sub.baz=abc"); } -test "encodeStruct" { - try std.testing.expectFmt("", "{}", .{encodeStruct(.{})}); - try std.testing.expectFmt("id=3&", "{}", .{encodeStruct(.{ .id = 3 })}); - try std.testing.expectFmt("id=3&id2=4&", "{}", .{encodeStruct(.{ .id = 3, .id2 = 4 })}); +test "formatQuery" { + try std.testing.expectFmt("", "{}", .{queryStringify(.{})}); + try std.testing.expectFmt("id=3&", "{}", .{queryStringify(.{ .id = 3 })}); + try std.testing.expectFmt("id=3&id2=4&", "{}", .{queryStringify(.{ .id = 3, .id2 = 4 })}); - try std.testing.expectFmt("str=foo&", "{}", .{encodeStruct(.{ .str = "foo" })}); - try std.testing.expectFmt("enum_str=foo&", "{}", .{encodeStruct(.{ .enum_str = .foo })}); + try std.testing.expectFmt("str=foo&", "{}", .{queryStringify(.{ .str = "foo" })}); + try std.testing.expectFmt("enum_str=foo&", "{}", .{queryStringify(.{ .enum_str = .foo })}); - try std.testing.expectFmt("boolean=false&", "{}", .{encodeStruct(.{ .boolean = false })}); - try std.testing.expectFmt("boolean=true&", "{}", .{encodeStruct(.{ .boolean = true })}); -} - -test "Iter" { - const testCase = struct { - fn case(str: []const u8, pairs: []const Iter.Pair) !void { - var iter = Iter.from(str); - for (pairs) |pair| { - try util.testing.expectDeepEqual(@as(?Iter.Pair, pair), iter.next()); - } - try std.testing.expect(iter.next() == null); - } - }.case; - - try testCase("", &.{}); - try testCase("abc", &.{.{ .key = "abc", .value = null }}); - try testCase("abc=", &.{.{ .key = "abc", .value = "" }}); - try testCase("abc=def", &.{.{ .key = "abc", .value = "def" }}); - try testCase("abc=def&", &.{.{ .key = "abc", .value = "def" }}); - try testCase("?abc=def&", &.{.{ .key = "abc", .value = "def" }}); - try testCase("?abc=def&foo&bar=baz&qux=", &.{ - .{ .key = "abc", .value = "def" }, - .{ .key = "foo", .value = null }, - .{ .key = "bar", .value = "baz" }, - .{ .key = "qux", .value = "" }, - }); - try testCase("?abc=def&&foo&bar=baz&&qux=&", &.{ - .{ .key = "abc", .value = "def" }, - .{ .key = "foo", .value = null }, - .{ .key = "bar", .value = "baz" }, - .{ .key = "qux", .value = "" }, - }); - try testCase("&=def&", &.{.{ .key = "", .value = "def" }}); + try std.testing.expectFmt("boolean=false&", "{}", .{queryStringify(.{ .boolean = false })}); + try std.testing.expectFmt("boolean=true&", "{}", .{queryStringify(.{ .boolean = true })}); } diff --git a/src/http/request/parser.zig b/src/http/request/parser.zig index a11c315..6ffba12 100644 --- a/src/http/request/parser.zig +++ b/src/http/request/parser.zig @@ -93,7 +93,7 @@ fn parseProto(reader: anytype) !http.Protocol { }; } -pub fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields { +fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields { var headers = Fields.init(allocator); var buf: [4096]u8 = undefined; diff --git a/src/http/request/test_parser.zig b/src/http/request/test_parser.zig index b715528..55a66d6 100644 --- a/src/http/request/test_parser.zig +++ b/src/http/request/test_parser.zig @@ -1,5 +1,4 @@ const std = @import("std"); -const util = @import("util"); const parser = @import("./parser.zig"); const http = @import("../lib.zig"); const t = std.testing; @@ -31,9 +30,30 @@ const test_case = struct { } }; +fn toCrlf(comptime str: []const u8) []const u8 { + comptime { + var buf: [str.len * 2]u8 = undefined; + + @setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2 + + var buf_len: usize = 0; + for (str) |ch| { + if (ch == '\n') { + buf[buf_len] = '\r'; + buf_len += 1; + } + + buf[buf_len] = ch; + buf_len += 1; + } + + return buf[0..buf_len]; + } +} + test "HTTP/1.x parse - No body" { try test_case.parse( - util.comptimeToCrlf( + toCrlf( \\GET / HTTP/1.1 \\ \\ @@ -45,7 +65,7 @@ test "HTTP/1.x parse - No body" { }, ); try test_case.parse( - util.comptimeToCrlf( + toCrlf( \\POST / HTTP/1.1 \\ \\ @@ -57,7 +77,7 @@ test "HTTP/1.x parse - No body" { }, ); try test_case.parse( - util.comptimeToCrlf( + toCrlf( \\GET /url/abcd HTTP/1.1 \\ \\ @@ -69,7 +89,7 @@ test "HTTP/1.x parse - No body" { }, ); try test_case.parse( - util.comptimeToCrlf( + toCrlf( \\GET / HTTP/1.0 \\ \\ @@ -81,7 +101,7 @@ test "HTTP/1.x parse - No body" { }, ); try test_case.parse( - util.comptimeToCrlf( + toCrlf( \\GET /url/abcd HTTP/1.1 \\Content-Type: application/json \\ @@ -95,7 +115,7 @@ test "HTTP/1.x parse - No body" { }, ); try test_case.parse( - util.comptimeToCrlf( + toCrlf( \\GET /url/abcd HTTP/1.1 \\Content-Type: application/json \\Authorization: bearer @@ -143,7 +163,7 @@ test "HTTP/1.x parse - No body" { }, ); try test_case.parse( - util.comptimeToCrlf( + toCrlf( \\GET / HTTP/1.2 \\ \\ @@ -245,7 +265,7 @@ test "HTTP/1.x parse - bad requests" { test "HTTP/1.x parse - Headers" { try test_case.parse( - util.comptimeToCrlf( + toCrlf( \\GET /url/abcd HTTP/1.1 \\Content-Type: application/json \\Content-Type: application/xml diff --git a/src/http/server/response.zig b/src/http/server/response.zig index 384677d..fdbe9cc 100644 --- a/src/http/server/response.zig +++ b/src/http/server/response.zig @@ -1,5 +1,4 @@ const std = @import("std"); -const util = @import("util"); const http = @import("../lib.zig"); const Status = http.Status; @@ -170,7 +169,25 @@ test { _ = _tests; } const _tests = struct { - const toCrlf = util.comptimeToCrlf; + fn toCrlf(comptime str: []const u8) []const u8 { + comptime { + var buf: [str.len * 2]u8 = undefined; + @setEvalBranchQuota(@as(u32, str.len * 2)); + + var len: usize = 0; + for (str) |ch| { + if (ch == '\n') { + buf[len] = '\r'; + len += 1; + } + + buf[len] = ch; + len += 1; + } + + return buf[0..len]; + } + } const test_buffer_size = chunk_size * 4; test "ResponseStream no headers empty body" { diff --git a/src/http/test.zig b/src/http/test.zig new file mode 100644 index 0000000..c142f68 --- /dev/null +++ b/src/http/test.zig @@ -0,0 +1,5 @@ +test { + _ = @import("./request/test_parser.zig"); + _ = @import("./middleware.zig"); + _ = @import("./query.zig"); +} diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 398424c..10ecdb7 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -267,13 +267,13 @@ pub const helpers = struct { try std.fmt.format( writer, "<{s}://{s}/{s}?{}>; rel=\"{s}\"", - .{ @tagName(c.scheme), c.host, path, http.urlencode.encodeStruct(params), rel }, + .{ @tagName(c.scheme), c.host, path, http.queryStringify(params), rel }, ); } else { try std.fmt.format( writer, "<{s}?{}>; rel=\"{s}\"", - .{ path, http.urlencode.encodeStruct(params), rel }, + .{ path, http.queryStringify(params), rel }, ); } // TODO: percent-encode diff --git a/src/main/controllers/api.zig b/src/main/controllers/api.zig index 9a76c91..12f3a1f 100644 --- a/src/main/controllers/api.zig +++ b/src/main/controllers/api.zig @@ -2,7 +2,6 @@ const controllers = @import("../controllers.zig"); const auth = @import("./api/auth.zig"); const communities = @import("./api/communities.zig"); -const drive = @import("./api/drive.zig"); const invites = @import("./api/invites.zig"); const users = @import("./api/users.zig"); const follows = @import("./api/users/follows.zig"); @@ -27,6 +26,4 @@ pub const routes = .{ controllers.apiEndpoint(follows.delete), controllers.apiEndpoint(follows.query_followers), controllers.apiEndpoint(follows.query_following), - controllers.apiEndpoint(drive.upload), - controllers.apiEndpoint(drive.mkdir), }; diff --git a/src/main/controllers/api/drive.zig b/src/main/controllers/api/drive.zig deleted file mode 100644 index f617898..0000000 --- a/src/main/controllers/api/drive.zig +++ /dev/null @@ -1,144 +0,0 @@ -const api = @import("api"); -const http = @import("http"); -const util = @import("util"); -const controller_utils = @import("../../controllers.zig").helpers; - -const Uuid = util.Uuid; -const DateTime = util.DateTime; - -pub const drive_path = "/drive/:path*"; -pub const DriveArgs = struct { - path: []const u8, -}; - -pub const query = struct { - pub const method = .GET; - pub const path = drive_path; - pub const Args = DriveArgs; - - pub const Query = struct { - const OrderBy = enum { - created_at, - filename, - }; - - max_items: usize = 20, - - like: ?[]const u8 = null, - - order_by: OrderBy = .created_at, - direction: api.Direction = .descending, - - prev: ?struct { - id: Uuid, - order_val: union(OrderBy) { - created_at: DateTime, - filename: []const u8, - }, - } = null, - - page_direction: api.PageDirection = .forward, - }; - - pub fn handler(req: anytype, res: anytype, srv: anytype) !void { - const result = srv.driveQuery(req.args.path, req.query) catch |err| switch (err) { - error.NotADirectory => { - const meta = try srv.getFile(path); - try res.json(.ok, meta); - return; - }, - else => |e| return e, - }; - - try controller_utils.paginate(result, res, req.allocator); - } -}; - -pub const upload = struct { - pub const method = .POST; - pub const path = drive_path; - pub const Args = DriveArgs; - - pub const Body = struct { - file: http.FormFile, - description: ?[]const u8 = null, - sensitive: bool = false, - }; - - pub fn handler(req: anytype, res: anytype, srv: anytype) !void { - const f = req.body.file; - try srv.uploadFile(.{ - .dir = req.args.path, - .filename = f.filename, - .description = req.body.description, - .content_type = f.content_type, - .sensitive = req.body.sensitive, - }, f.data); - - // TODO: print meta - try res.json(.created, .{}); - } -}; - -pub const delete = struct { - pub const method = .DELETE; - pub const path = drive_path; - pub const Args = DriveArgs; - - pub fn handler(req: anytype, res: anytype, srv: anytype) !void { - const info = try srv.driveLookup(req.args.path); - if (info == .dir) - try srv.driveRmdir(req.args.path) - else if (info == .file) - try srv.deleteFile(req.args.path); - - return res.json(.ok, .{}); - } -}; - -pub const mkdir = struct { - pub const method = .MKCOL; - pub const path = drive_path; - pub const Args = DriveArgs; - - pub fn handler(req: anytype, res: anytype, srv: anytype) !void { - try srv.driveMkdir(req.args.path); - - return res.json(.created, .{}); - } -}; - -pub const update = struct { - pub const method = .PUT; - pub const path = drive_path; - pub const Args = DriveArgs; - - pub const Body = struct { - description: ?[]const u8 = null, - content_type: ?[]const u8 = null, - sensitive: ?bool = null, - }; - - pub fn handler(req: anytype, res: anytype, srv: anytype) !void { - const info = try srv.driveLookup(req.args.path); - if (info != .file) return error.NotFile; - - const new_info = try srv.updateFile(path, req.body); - try res.json(.ok, new_info); - } -}; - -pub const move = struct { - pub const method = .MOVE; - pub const path = drive_path; - pub const Args = DriveArgs; - - pub fn handler(req: anytype, res: anytype, srv: anytype) !void { - const destination = req.fields.get("Destination") orelse return error.NoDestination; - - try srv.driveMove(req.args.path, destination); - - try res.fields.put("Location", destination); - try srv.json(.created, .{}); - } -}; diff --git a/src/main/migrations.zig b/src/main/migrations.zig index a7465cb..dc97b44 100644 --- a/src/main/migrations.zig +++ b/src/main/migrations.zig @@ -19,9 +19,8 @@ fn execScript(db: anytype, script: []const u8, alloc: std.mem.Allocator) !void { const tx = try db.beginOrSavepoint(); errdefer tx.rollback(); - var iter = std.mem.split(u8, script, ";"); + var iter = util.SqlStmtIter.from(script); while (iter.next()) |stmt| { - if (stmt.len == 0) continue; try execStmt(tx, stmt, alloc); } @@ -209,139 +208,22 @@ const migrations: []const Migration = &.{ .{ .name = "files", .up = - \\CREATE TABLE file_upload( + \\CREATE TABLE drive_file( \\ id UUID NOT NULL PRIMARY KEY, \\ - \\ created_by UUID REFERENCES account(id), - \\ size INTEGER NOT NULL, - \\ \\ filename TEXT NOT NULL, - \\ description TEXT, - \\ content_type TEXT, - \\ sensitive BOOLEAN NOT NULL, - \\ - \\ is_deleted BOOLEAN NOT NULL DEFAULT FALSE, - \\ - \\ created_at TIMESTAMPTZ NOT NULL, - \\ updated_at TIMESTAMPTZ NOT NULL - \\); - \\ - \\CREATE TABLE drive_entry( - \\ id UUID NOT NULL PRIMARY KEY, - \\ \\ account_owner_id UUID REFERENCES account(id), \\ community_owner_id UUID REFERENCES community(id), + \\ size INTEGER NOT NULL, \\ - \\ name TEXT, - \\ parent_directory_id UUID REFERENCES drive_entry(id), - \\ - \\ file_id UUID REFERENCES file_upload(id), + \\ created_at TIMESTAMPTZ NOT NULL, \\ \\ CHECK( \\ (account_owner_id IS NULL AND community_owner_id IS NOT NULL) \\ OR (account_owner_id IS NOT NULL AND community_owner_id IS NULL) - \\ ), - \\ CHECK( - \\ (name IS NULL AND parent_directory_id IS NULL AND file_id IS NULL) - \\ OR (name IS NOT NULL AND parent_directory_id IS NOT NULL) \\ ) \\); - \\CREATE UNIQUE INDEX drive_entry_uniqueness - \\ON drive_entry( - \\ name, - \\ COALESCE(parent_directory_id, ''), - \\ COALESCE(account_owner_id, community_owner_id) - \\); , - .down = - \\DROP INDEX drive_entry_uniqueness; - \\DROP TABLE drive_entry; - \\DROP TABLE file_upload; - , - }, - .{ - .name = "drive_entry_path", - .up = - \\CREATE VIEW drive_entry_path( - \\ id, - \\ path, - \\ account_owner_id, - \\ community_owner_id, - \\ kind - \\) AS WITH RECURSIVE full_path( - \\ id, - \\ path, - \\ account_owner_id, - \\ community_owner_id, - \\ kind - \\) AS ( - \\ SELECT - \\ id, - \\ '' AS path, - \\ account_owner_id, - \\ community_owner_id, - \\ 'dir' AS kind - \\ FROM drive_entry - \\ WHERE parent_directory_id IS NULL - \\ UNION ALL - \\ SELECT - \\ base.id, - \\ (dir.path || '/' || base.name) AS path, - \\ base.account_owner_id, - \\ base.community_owner_id, - \\ (CASE WHEN base.file_id IS NULL THEN 'dir' ELSE 'file' END) as kind - \\ FROM drive_entry AS base - \\ JOIN full_path AS dir ON - \\ base.parent_directory_id = dir.id - \\ AND base.account_owner_id IS NOT DISTINCT FROM dir.account_owner_id - \\ AND base.community_owner_id IS NOT DISTINCT FROM dir.community_owner_id - \\) - \\SELECT - \\ id, - \\ (CASE WHEN kind = 'dir' THEN path || '/' ELSE path END) AS path, - \\ account_owner_id, - \\ community_owner_id, - \\ kind - \\FROM full_path; - , - .down = - \\DROP VIEW drive_entry_path; - , - }, - .{ - .name = "create drive root directories", - .up = - \\INSERT INTO drive_entry( - \\ id, - \\ account_owner_id, - \\ community_owner_id, - \\ parent_directory_id, - \\ name, - \\ file_id - \\) SELECT - \\ id, - \\ id AS account_owner_id, - \\ NULL AS community_owner_id, - \\ NULL AS parent_directory_id, - \\ NULL AS name, - \\ NULL AS file_id - \\FROM account; - \\INSERT INTO drive_entry( - \\ id, - \\ account_owner_id, - \\ community_owner_id, - \\ parent_directory_id, - \\ name, - \\ file_id - \\) SELECT - \\ id, - \\ NULL AS account_owner_id, - \\ id AS community_owner_id, - \\ NULL AS parent_directory_id, - \\ NULL AS name, - \\ NULL AS file_id - \\FROM community; - , - .down = "", + .down = "DROP TABLE drive_file", }, }; diff --git a/src/sql/engines/common.zig b/src/sql/engines/common.zig index 93169c4..b50b7d0 100644 --- a/src/sql/engines/common.zig +++ b/src/sql/engines/common.zig @@ -88,7 +88,7 @@ pub fn prepareParamText(arena: *std.heap.ArenaAllocator, val: anytype) !?[:0]con else => |T| switch (@typeInfo(T)) { .Enum => return @tagName(val), .Optional => if (val) |v| try prepareParamText(arena, v) else null, - .Bool, .Int => try std.fmt.allocPrintZ(arena.allocator(), "{}", .{val}), + .Int => try std.fmt.allocPrintZ(arena.allocator(), "{}", .{val}), .Union => loop: inline for (std.meta.fields(T)) |field| { // Have to do this in a roundabout way to satisfy comptime checker const Tag = std.meta.Tag(T); diff --git a/src/sql/engines/sqlite.zig b/src/sql/engines/sqlite.zig index 3b9c8c4..5b59910 100644 --- a/src/sql/engines/sqlite.zig +++ b/src/sql/engines/sqlite.zig @@ -193,7 +193,6 @@ pub const Db = struct { .Null => return self.bindNull(stmt, idx), .Int => return self.bindInt(stmt, idx, std.math.cast(i64, val) orelse unreachable), .Float => return self.bindFloat(stmt, idx, val), - .Bool => return self.bindInt(stmt, idx, if (val) 1 else 0), else => @compileError("Unable to serialize type " ++ @typeName(T)), } } @@ -252,20 +251,18 @@ pub const Results = struct { db: *c.sqlite3, pub fn finish(self: Results) void { - _ = c.sqlite3_finalize(self.stmt); + switch (c.sqlite3_finalize(self.stmt)) { + c.SQLITE_OK => {}, + else => |err| { + handleUnexpectedError(self.db, err, self.getGeneratingSql()) catch {}; + }, + } } pub fn row(self: Results) common.RowError!?Row { return switch (c.sqlite3_step(self.stmt)) { c.SQLITE_ROW => Row{ .stmt = self.stmt, .db = self.db }, c.SQLITE_DONE => null, - - c.SQLITE_CONSTRAINT_UNIQUE => return error.UniqueViolation, - c.SQLITE_CONSTRAINT_CHECK => return error.CheckViolation, - c.SQLITE_CONSTRAINT_NOTNULL => return error.NotNullViolation, - c.SQLITE_CONSTRAINT_FOREIGNKEY => return error.ForeignKeyViolation, - c.SQLITE_CONSTRAINT => return error.ConstraintViolation, - else => |err| handleUnexpectedError(self.db, err, self.getGeneratingSql()), }; } diff --git a/src/sql/lib.zig b/src/sql/lib.zig index 69c371c..358a2d3 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -144,16 +144,42 @@ fn fieldPtr(ptr: anytype, comptime names: []const []const u8) FieldPtr(@TypeOf(p return fieldPtr(&@field(ptr.*, names[0]), names[1..]); } +fn isScalar(comptime T: type) bool { + if (comptime std.meta.trait.isZigString(T)) return true; + if (comptime std.meta.trait.isIntegral(T)) return true; + if (comptime std.meta.trait.isFloat(T)) return true; + if (comptime std.meta.trait.is(.Enum)(T)) return true; + if (T == bool) return true; + if (comptime std.meta.trait.hasFn("parse")(T)) return true; + + if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true; + + return false; +} + +fn recursiveFieldPaths(comptime T: type, comptime prefix: []const []const u8) []const []const []const u8 { + comptime { + var fields: []const []const []const u8 = &.{}; + + for (std.meta.fields(T)) |f| { + const full_name = prefix ++ [_][]const u8{f.name}; + if (isScalar(f.field_type)) { + fields = fields ++ [_][]const []const u8{full_name}; + } else { + fields = fields ++ recursiveFieldPaths(f.field_type, full_name); + } + } + + return fields; + } +} + // Represents a set of results. // row() must be called until it returns null, or the query may not complete // Must be deallocated by a call to finish() pub fn Results(comptime T: type) type { // would normally make this a declaration of the struct, but it causes the compiler to crash - const fields = if (T == void) .{} else util.serialize.getRecursiveFieldList( - T, - &.{}, - util.serialize.default_options, - ); + const fields = if (T == void) .{} else recursiveFieldPaths(T, &.{}); return struct { const Self = @This(); @@ -431,7 +457,6 @@ fn Tx(comptime tx_level: u8) type { pub fn rollback(self: Self) void { (if (tx_level < 2) self.rollbackTx() else self.rollbackSavepoint()) catch |err| { std.log.err("Failed to rollback transaction: {}", .{err}); - std.log.err("{any}", .{@errorReturnTrace()}); @panic("TODO: more gracefully handle rollback failures"); }; } @@ -629,7 +654,7 @@ fn Tx(comptime tx_level: u8) type { } fn rollbackUnchecked(self: Self) !void { - try self.execInternal("ROLLBACK", {}, null, false); + try self.exec("ROLLBACK", {}, null); } }; } diff --git a/src/template/lib.zig b/src/template/lib.zig index f1e141d..8c6c8e9 100644 --- a/src/template/lib.zig +++ b/src/template/lib.zig @@ -601,20 +601,3 @@ const ControlTokenIter = struct { self.peeked_token = token; } }; - -test "template" { - const testCase = struct { - fn case(comptime tmpl: []const u8, args: anytype, expected: []const u8) !void { - var stream = std.io.changeDetectionStream(expected, std.io.null_writer); - try execute(stream.writer(), tmpl, args); - try std.testing.expect(!stream.changeDetected()); - } - }.case; - - try testCase("", .{}, ""); - try testCase("abcd", .{}, "abcd"); - try testCase("{.val}", .{ .val = 3 }, "3"); - try testCase("{#if .val}1{/if}", .{ .val = true }, "1"); - try testCase("{#for .vals |$v|}{$v}{/for}", .{ .vals = [_]u8{ 1, 2, 3 } }, "123"); - try testCase("{#for .vals |$v|=} {$v} {=/for}", .{ .vals = [_]u8{ 1, 2, 3 } }, "123"); -} diff --git a/src/util/Url.zig b/src/util/Url.zig new file mode 100644 index 0000000..35d74e8 --- /dev/null +++ b/src/util/Url.zig @@ -0,0 +1,161 @@ +const Url = @This(); +const std = @import("std"); + +scheme: []const u8, +hostport: []const u8, +path: []const u8, +query: []const u8, +fragment: []const u8, + +pub fn parse(url: []const u8) !Url { + const scheme_end = for (url) |ch, i| { + if (ch == ':') break i; + } else return error.InvalidUrl; + + if (url.len < scheme_end + 3 or url[scheme_end + 1] != '/' or url[scheme_end + 1] != '/') return error.InvalidUrl; + + const hostport_start = scheme_end + 3; + const hostport_end = for (url[hostport_start..]) |ch, i| { + if (ch == '/' or ch == '?' or ch == '#') break i + hostport_start; + } else url.len; + + const path_end = for (url[hostport_end..]) |ch, i| { + if (ch == '?' or ch == '#') break i + hostport_end; + } else url.len; + + const query_end = if (!(url.len > path_end and url[path_end] == '?')) + path_end + else for (url[path_end..]) |ch, i| { + if (ch == '#') break i + path_end; + } else url.len; + + const query = url[path_end..query_end]; + const fragment = url[query_end..]; + + return Url{ + .scheme = url[0..scheme_end], + .hostport = url[hostport_start..hostport_end], + .path = url[hostport_end..path_end], + .query = if (query.len > 0) query[1..] else query, + .fragment = if (fragment.len > 0) fragment[1..] else fragment, + }; +} + +pub fn getQuery(self: Url, param: []const u8) ?[]const u8 { + var key_start: usize = 0; + std.log.debug("query: {s}", .{self.query}); + while (key_start < self.query.len) { + const key_end = for (self.query[key_start..]) |ch, i| { + if (ch == '=') break key_start + i; + } else return null; + + const val_start = key_end + 1; + const val_end = for (self.query[val_start..]) |ch, i| { + if (ch == '&') break val_start + i; + } else self.query.len; + + const key = self.query[key_start..key_end]; + if (std.mem.eql(u8, key, param)) return self.query[val_start..val_end]; + + key_start = val_end + 1; + } + + return null; +} + +pub fn strDecode(buf: []u8, str: []const u8) ![]u8 { + var str_i: usize = 0; + var buf_i: usize = 0; + while (str_i < str.len) : ({ + str_i += 1; + buf_i += 1; + }) { + if (buf_i >= buf.len) return error.NoSpaceLeft; + const ch = str[str_i]; + if (ch == '%') { + if (str.len < str_i + 2) return error.BadEscape; + + const hi = try std.fmt.charToDigit(str[str_i + 1], 16); + const lo = try std.fmt.charToDIgit(str[str_i + 2], 16); + str_i += 2; + + buf[buf_i] = (hi << 4) | lo; + } else { + buf[buf_i] = str[str_i]; + } + } + + return buf[0..buf_i]; +} + +fn expectEqualUrl(expected: Url, actual: Url) !void { + const t = @import("std").testing; + try t.expectEqualStrings(expected.scheme, actual.scheme); + try t.expectEqualStrings(expected.hostport, actual.hostport); + try t.expectEqualStrings(expected.path, actual.path); + try t.expectEqualStrings(expected.query, actual.query); + try t.expectEqualStrings(expected.fragment, actual.fragment); +} +test "Url" { + try expectEqualUrl(.{ + .scheme = "https", + .hostport = "example.com", + .path = "", + .query = "", + .fragment = "", + }, try Url.parse("https://example.com")); + + try expectEqualUrl(.{ + .scheme = "https", + .hostport = "example.com:1234", + .path = "", + .query = "", + .fragment = "", + }, try Url.parse("https://example.com:1234")); + + try expectEqualUrl(.{ + .scheme = "http", + .hostport = "example.com", + .path = "/home", + .query = "", + .fragment = "", + }, try Url.parse("http://example.com/home")); + + try expectEqualUrl(.{ + .scheme = "https", + .hostport = "example.com", + .path = "", + .query = "query=abc", + .fragment = "", + }, try Url.parse("https://example.com?query=abc")); + + try expectEqualUrl(.{ + .scheme = "https", + .hostport = "example.com", + .path = "", + .query = "query=abc", + .fragment = "", + }, try Url.parse("https://example.com?query=abc")); + + try expectEqualUrl(.{ + .scheme = "https", + .hostport = "example.com", + .path = "/path/to/resource", + .query = "query=abc", + .fragment = "123", + }, try Url.parse("https://example.com/path/to/resource?query=abc#123")); + + const t = @import("std").testing; + try t.expectError(error.InvalidUrl, Url.parse("https:example.com")); + try t.expectError(error.InvalidUrl, Url.parse("example.com")); +} + +test "Url.getQuery" { + const url = try Url.parse("https://example.com?a=xyz&b=jkl"); + const t = @import("std").testing; + + try t.expectEqualStrings("xyz", url.getQuery("a").?); + try t.expectEqualStrings("jkl", url.getQuery("b").?); + try t.expect(url.getQuery("c") == null); + try t.expect(url.getQuery("xyz") == null); +} diff --git a/src/util/ciutf8.zig b/src/util/ciutf8.zig new file mode 100644 index 0000000..39d1a20 --- /dev/null +++ b/src/util/ciutf8.zig @@ -0,0 +1,106 @@ +const std = @import("std"); + +const Hash = std.hash.Wyhash; +const View = std.unicode.Utf8View; +const toLower = std.ascii.toLower; +const isAscii = std.ascii.isASCII; +const hash_seed = 1; + +pub fn hash(str: []const u8) u64 { + // fallback to regular hash on invalid utf8 + const view = View.init(str) catch return Hash.hash(hash_seed, str); + var iter = view.iterator(); + + var h = Hash.init(hash_seed); + + var it = iter.nextCodepointSlice(); + while (it != null) : (it = iter.nextCodepointSlice()) { + if (it.?.len == 1 and isAscii(it.?[0])) { + const ch = [1]u8{toLower(it.?[0])}; + h.update(&ch); + } else { + h.update(it.?); + } + } + + return h.final(); +} + +pub fn eql(a: []const u8, b: []const u8) bool { + if (a.len != b.len) return false; + + const va = View.init(a) catch return std.mem.eql(u8, a, b); + const vb = View.init(b) catch return false; + + var iter_a = va.iterator(); + var iter_b = vb.iterator(); + + var it_a = iter_a.nextCodepointSlice(); + var it_b = iter_b.nextCodepointSlice(); + + while (it_a != null and it_b != null) : ({ + it_a = iter_a.nextCodepointSlice(); + it_b = iter_b.nextCodepointSlice(); + }) { + if (it_a.?.len != it_b.?.len) return false; + + if (it_a.?.len == 1) { + if (isAscii(it_a.?[0]) and isAscii(it_b.?[0])) { + const ch_a = toLower(it_a.?[0]); + const ch_b = toLower(it_b.?[0]); + + if (ch_a != ch_b) return false; + } else if (it_a.?[0] != it_b.?[0]) return false; + } else if (!std.mem.eql(u8, it_a.?, it_b.?)) return false; + } + + return it_a == null and it_b == null; +} + +test "case insensitive eql with utf-8 chars" { + const t = std.testing; + try t.expectEqual(true, eql("abc 💯 def", "aBc 💯 DEF")); + try t.expectEqual(false, eql("xyz 💯 ijk", "aBc 💯 DEF")); + try t.expectEqual(false, eql("abc 💯 def", "aBc x DEF")); + try t.expectEqual(true, eql("💯", "💯")); + try t.expectEqual(false, eql("💯", "a")); + try t.expectEqual(false, eql("💯", "💯 continues")); + try t.expectEqual(false, eql("💯 fsdfs", "💯")); + try t.expectEqual(false, eql("💯", "")); + try t.expectEqual(false, eql("", "💯")); + + try t.expectEqual(true, eql("abc x def", "aBc x DEF")); + try t.expectEqual(false, eql("xyz x ijk", "aBc x DEF")); + try t.expectEqual(true, eql("x", "x")); + try t.expectEqual(false, eql("x", "a")); + try t.expectEqual(false, eql("x", "x continues")); + try t.expectEqual(false, eql("x fsdfs", "x")); + try t.expectEqual(false, eql("x", "")); + try t.expectEqual(false, eql("", "x")); + + try t.expectEqual(true, eql("", "")); +} + +test "case insensitive hash with utf-8 chars" { + const t = std.testing; + try t.expect(hash("abc 💯 def") == hash("aBc 💯 DEF")); + try t.expect(hash("xyz 💯 ijk") != hash("aBc 💯 DEF")); + try t.expect(hash("abc 💯 def") != hash("aBc x DEF")); + try t.expect(hash("💯") == hash("💯")); + try t.expect(hash("💯") != hash("a")); + try t.expect(hash("💯") != hash("💯 continues")); + try t.expect(hash("💯 fsdfs") != hash("💯")); + try t.expect(hash("💯") != hash("")); + try t.expect(hash("") != hash("💯")); + + try t.expect(hash("abc x def") == hash("aBc x DEF")); + try t.expect(hash("xyz x ijk") != hash("aBc x DEF")); + try t.expect(hash("x") == hash("x")); + try t.expect(hash("x") != hash("a")); + try t.expect(hash("x") != hash("x continues")); + try t.expect(hash("x fsdfs") != hash("x")); + try t.expect(hash("x") != hash("")); + try t.expect(hash("") != hash("x")); + + try t.expect(hash("") == hash("")); +} diff --git a/src/util/iters.zig b/src/util/iters.zig new file mode 100644 index 0000000..5ad2258 --- /dev/null +++ b/src/util/iters.zig @@ -0,0 +1,189 @@ +const std = @import("std"); + +pub fn Separator(comptime separator: u8) type { + return struct { + const Self = @This(); + str: []const u8, + pub fn from(str: []const u8) Self { + return .{ .str = std.mem.trim(u8, str, &.{separator}) }; + } + + pub fn next(self: *Self) ?[]const u8 { + if (self.str.len == 0) return null; + + const part = std.mem.sliceTo(self.str, separator); + self.str = std.mem.trimLeft(u8, self.str[part.len..], &.{separator}); + + return part; + } + }; +} + +pub const QueryIter = struct { + const Pair = struct { + key: []const u8, + value: ?[]const u8, + }; + + iter: Separator('&'), + + pub fn from(q: []const u8) QueryIter { + return QueryIter{ .iter = Separator('&').from(std.mem.trimLeft(u8, q, "?")) }; + } + + pub fn next(self: *QueryIter) ?Pair { + const part = self.iter.next() orelse return null; + + const key = std.mem.sliceTo(part, '='); + if (key.len == part.len) return Pair{ + .key = key, + .value = null, + }; + + return Pair{ + .key = key, + .value = part[key.len + 1 ..], + }; + } +}; + +pub const PathIter = struct { + is_first: bool, + iter: std.mem.SplitIterator(u8), + + pub fn from(path: []const u8) PathIter { + return .{ .is_first = true, .iter = std.mem.split(u8, path, "/") }; + } + + pub fn next(self: *PathIter) ?[]const u8 { + defer self.is_first = false; + while (self.iter.next()) |it| if (it.len != 0) { + return it; + }; + + if (self.is_first) return self.iter.rest(); + + return null; + } + + pub fn first(self: *PathIter) []const u8 { + std.debug.assert(self.is_first); + return self.next().?; + } + + pub fn rest(self: *PathIter) []const u8 { + return self.iter.rest(); + } +}; + +test "QueryIter" { + const t = @import("std").testing; + if (true) return error.SkipZigTest; + { + var iter = QueryIter.from(""); + try t.expect(iter.next() == null); + try t.expect(iter.next() == null); + } + + { + var iter = QueryIter.from("?"); + try t.expect(iter.next() == null); + try t.expect(iter.next() == null); + } + + { + var iter = QueryIter.from("?abc"); + try t.expectEqual(QueryIter.Pair{ + .key = "abc", + .value = null, + }, iter.next().?); + try t.expect(iter.next() == null); + try t.expect(iter.next() == null); + } + + { + var iter = QueryIter.from("?abc="); + try t.expectEqual(QueryIter.Pair{ + .key = "abc", + .value = "", + }, iter.next().?); + try t.expect(iter.next() == null); + try t.expect(iter.next() == null); + } + + { + var iter = QueryIter.from("?abc=def"); + try t.expectEqual(QueryIter.Pair{ + .key = "abc", + .value = "def", + }, iter.next().?); + try t.expect(iter.next() == null); + try t.expect(iter.next() == null); + } + + { + var iter = QueryIter.from("?abc=def&"); + try t.expectEqual(QueryIter.Pair{ + .key = "abc", + .value = "def", + }, iter.next().?); + try t.expect(iter.next() == null); + try t.expect(iter.next() == null); + } + + { + var iter = QueryIter.from("?abc=def&foo&bar=baz&qux="); + try t.expectEqual(QueryIter.Pair{ + .key = "abc", + .value = "def", + }, iter.next().?); + try t.expectEqual(QueryIter.Pair{ + .key = "foo", + .value = null, + }, iter.next().?); + try t.expectEqual(QueryIter.Pair{ + .key = "bar", + .value = "baz", + }, iter.next().?); + try t.expectEqual(QueryIter.Pair{ + .key = "qux", + .value = "", + }, iter.next().?); + try t.expect(iter.next() == null); + try t.expect(iter.next() == null); + } + + { + var iter = QueryIter.from("?=def&"); + try t.expectEqual(QueryIter.Pair{ + .key = "", + .value = "def", + }, iter.next().?); + try t.expect(iter.next() == null); + try t.expect(iter.next() == null); + } +} + +test "PathIter /ab/cd/" { + const path = "/ab/cd/"; + var it = PathIter.from(path); + try std.testing.expectEqualStrings("ab", it.next().?); + try std.testing.expectEqualStrings("cd", it.next().?); + try std.testing.expectEqual(@as(?[]const u8, null), it.next()); +} + +test "PathIter ''" { + const path = ""; + var it = PathIter.from(path); + try std.testing.expectEqualStrings("", it.next().?); + try std.testing.expectEqual(@as(?[]const u8, null), it.next()); +} + +test "PathIter ab/c//defg/" { + const path = "ab/c//defg/"; + var it = PathIter.from(path); + try std.testing.expectEqualStrings("ab", it.next().?); + try std.testing.expectEqualStrings("c", it.next().?); + try std.testing.expectEqualStrings("defg", it.next().?); + try std.testing.expectEqual(@as(?[]const u8, null), it.next()); +} diff --git a/src/util/lib.zig b/src/util/lib.zig index fb2bee0..9829ae3 100644 --- a/src/util/lib.zig +++ b/src/util/lib.zig @@ -1,10 +1,13 @@ const std = @import("std"); +const iters = @import("./iters.zig"); +pub const ciutf8 = @import("./ciutf8.zig"); pub const Uuid = @import("./Uuid.zig"); pub const DateTime = @import("./DateTime.zig"); -pub const serialize = @import("./serialize.zig"); -pub const Deserializer = serialize.Deserializer; -pub const DeserializerContext = serialize.DeserializerContext; +pub const Url = @import("./Url.zig"); +pub const PathIter = iters.PathIter; +pub const QueryIter = iters.QueryIter; +pub const SqlStmtIter = iters.Separator(';'); /// Joins an array of strings, prefixing every entry with `prefix`, /// and putting `separator` in between each pair @@ -199,16 +202,6 @@ pub fn seedThreadPrng() !void { prng = std.rand.DefaultPrng.init(@bitCast(u64, buf)); } -pub fn comptimeToCrlf(comptime str: []const u8) []const u8 { - comptime { - @setEvalBranchQuota(str.len * 10); - const size = std.mem.replacementSize(u8, str, "\n", "\r\n"); - var buf: [size]u8 = undefined; - _ = std.mem.replace(u8, str, "\n", "\r\n", &buf); - return &buf; - } -} - pub const testing = struct { pub fn expectDeepEqual(expected: anytype, actual: @TypeOf(expected)) !void { const T = @TypeOf(expected); @@ -249,7 +242,3 @@ pub const testing = struct { } } }; - -test { - _ = std.testing.refAllDecls(@This()); -} diff --git a/src/util/serialize.zig b/src/util/serialize.zig deleted file mode 100644 index 53d882f..0000000 --- a/src/util/serialize.zig +++ /dev/null @@ -1,386 +0,0 @@ -const std = @import("std"); -const util = @import("./lib.zig"); - -pub const FieldRef = []const []const u8; - -pub fn defaultIsScalar(comptime T: type) bool { - if (comptime std.meta.trait.is(.Optional)(T) and defaultIsScalar(std.meta.Child(T))) return true; - if (comptime std.meta.trait.isZigString(T)) return true; - if (comptime std.meta.trait.isIntegral(T)) return true; - if (comptime std.meta.trait.isFloat(T)) return true; - if (comptime std.meta.trait.is(.Enum)(T)) return true; - if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true; - if (comptime std.meta.trait.hasFn("parse")(T)) return true; - if (T == bool) return true; - - return false; -} - -pub fn deserializeString(allocator: std.mem.Allocator, comptime T: type, value: []const u8) !T { - if (comptime std.meta.trait.is(.Optional)(T)) { - if (value.len == 0) return null; - return try deserializeString(allocator, std.meta.Child(T), value); - } - - if (T == []u8 or T == []const u8) return try util.deepClone(allocator, value); - if (comptime std.meta.trait.isIntegral(T)) return try std.fmt.parseInt(T, value, 0); - if (comptime std.meta.trait.isFloat(T)) return try std.fmt.parseFloat(T, value); - if (comptime std.meta.trait.hasFn("parse")(T)) return try T.parse(value); - - var buf: [64]u8 = undefined; - const lowered = std.ascii.lowerString(&buf, value); - - if (T == bool) return bool_map.get(lowered) orelse return error.InvalidBool; - if (comptime std.meta.trait.is(.Enum)(T)) { - return std.meta.stringToEnum(T, lowered) orelse return error.InvalidEnumTag; - } - - @compileError("Invalid type " ++ @typeName(T)); -} - -pub fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: SerializationOptions) []const FieldRef { - comptime { - if (std.meta.trait.is(.Union)(T) and prefix.len == 0 and options.embed_unions) { - @compileError("Cannot embed a union into nothing"); - } - - if (options.isScalar(T)) return &.{prefix}; - if (std.meta.trait.is(.Optional)(T)) return getRecursiveFieldList(std.meta.Child(T), prefix, options); - - const eff_prefix: FieldRef = if (std.meta.trait.is(.Union)(T) and options.embed_unions) - prefix[0 .. prefix.len - 1] - else - prefix; - - var fields: []const FieldRef = &.{}; - - for (std.meta.fields(T)) |f| { - const new_prefix = eff_prefix ++ &[_][]const u8{f.name}; - const F = f.field_type; - fields = fields ++ getRecursiveFieldList(F, new_prefix, options); - } - - return fields; - } -} - -pub const SerializationOptions = struct { - embed_unions: bool, - isScalar: fn (type) bool, -}; - -pub const default_options = SerializationOptions{ - .embed_unions = true, - .isScalar = defaultIsScalar, -}; - -fn Intermediary(comptime Result: type, comptime From: type, comptime options: SerializationOptions) type { - const field_refs = getRecursiveFieldList(Result, &.{}, options); - - var fields: [field_refs.len]std.builtin.Type.StructField = undefined; - for (field_refs) |ref, i| { - fields[i] = .{ - .name = util.comptimeJoin(".", ref), - .field_type = ?From, - .default_value = &@as(?From, null), - .is_comptime = false, - .alignment = @alignOf(?From), - }; - } - - return @Type(.{ .Struct = .{ - .layout = .Auto, - .fields = &fields, - .decls = &.{}, - .is_tuple = false, - } }); -} - -pub fn Deserializer(comptime Result: type) type { - return DeserializerContext(Result, []const u8, struct { - const options = default_options; - fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, val: []const u8) !T { - return try deserializeString(alloc, T, val); - } - }); -} - -pub fn DeserializerContext(comptime Result: type, comptime From: type, comptime Context: type) type { - return struct { - const Data = Intermediary(Result, From, Context.options); - - data: Data = .{}, - context: Context = .{}, - - pub fn setSerializedField(self: *@This(), key: []const u8, value: From) !void { - const field = std.meta.stringToEnum(std.meta.FieldEnum(Data), key) orelse return error.UnknownField; - inline for (comptime std.meta.fieldNames(Data)) |field_name| { - @setEvalBranchQuota(10000); - const f = comptime std.meta.stringToEnum(std.meta.FieldEnum(Data), field_name); - if (field == f) { - @field(self.data, field_name) = value; - return; - } - } - - unreachable; - } - - pub const Iter = struct { - data: *const Data, - field_index: usize, - - const Item = struct { - key: []const u8, - value: From, - }; - - pub fn next(self: *Iter) ?Item { - while (self.field_index < std.meta.fields(Data).len) { - const idx = self.field_index; - self.field_index += 1; - inline for (comptime std.meta.fieldNames(Data)) |field, i| { - if (i == idx) { - const maybe_value = @field(self.data.*, field); - if (maybe_value) |value| return Item{ .key = field, .value = value }; - } - } - } - - return null; - } - }; - - pub fn iterator(self: *const @This()) Iter { - return .{ .data = &self.data, .field_index = 0 }; - } - - pub fn finishFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void { - util.deepFree(allocator, val); - } - - pub fn finish(self: *@This(), allocator: std.mem.Allocator) !Result { - return (try self.deserialize(allocator, Result, &.{})) orelse error.MissingField; - } - - fn getSerializedField(self: *@This(), comptime field_ref: FieldRef) ?From { - //inline for (comptime std.meta.fieldNames(Data)) |f| @compileLog(f.ptr); - return @field(self.data, util.comptimeJoin(".", field_ref)); - } - - fn deserializeFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void { - util.deepFree(allocator, val); - } - - fn deserialize(self: *@This(), allocator: std.mem.Allocator, comptime T: type, comptime field_ref: FieldRef) !?T { - if (comptime Context.options.isScalar(T)) { - return try self.context.deserializeScalar(allocator, T, self.getSerializedField(field_ref) orelse return null); - } - - switch (@typeInfo(T)) { - // At most one of any union field can be active at a time, and it is embedded - // in its parent container - .Union => |info| { - var result: ?T = null; - errdefer if (result) |v| self.deserializeFree(allocator, v); - // TODO: errdefer cleanup - const union_ref: FieldRef = if (Context.options.embed_unions) field_ref[0 .. field_ref.len - 1] else field_ref; - inline for (info.fields) |field| { - const F = field.field_type; - const new_field_ref = union_ref ++ &[_][]const u8{field.name}; - const maybe_value = try self.deserialize(allocator, F, new_field_ref); - if (maybe_value) |value| { - // TODO: errdefer cleanup - errdefer self.deserializeFree(allocator, value); - if (result != null) return error.DuplicateUnionMember; - result = @unionInit(T, field.name, value); - } - } - return result; - }, - - .Struct => |info| { - var result: T = undefined; - - var any_explicit = false; - var any_missing = false; - var fields_alloced = [1]bool{false} ** info.fields.len; - errdefer inline for (info.fields) |field, i| { - if (fields_alloced[i]) self.deserializeFree(allocator, @field(result, field.name)); - }; - inline for (info.fields) |field, i| { - const F = field.field_type; - const new_field_ref = field_ref ++ &[_][]const u8{field.name}; - const maybe_value = try self.deserialize(allocator, F, new_field_ref); - if (maybe_value) |v| { - @field(result, field.name) = v; - fields_alloced[i] = true; - any_explicit = true; - } else if (field.default_value) |ptr| { - if (@sizeOf(F) != 0) { - const cast_ptr = @ptrCast(*const F, @alignCast(field.alignment, ptr)); - @field(result, field.name) = try util.deepClone(allocator, cast_ptr.*); - fields_alloced[i] = true; - } - } else { - any_missing = true; - } - } - if (any_missing) { - return if (any_explicit) error.MissingField else null; - } - - return result; - }, - - // Specifically non-scalar optionals - .Optional => |info| return try self.deserialize(allocator, info.child, field_ref), - - else => @compileError("Unsupported type"), - } - } - }; -} - -const bool_map = std.ComptimeStringMap(bool, .{ - .{ "true", true }, - .{ "t", true }, - .{ "yes", true }, - .{ "y", true }, - .{ "1", true }, - - .{ "false", false }, - .{ "f", false }, - .{ "no", false }, - .{ "n", false }, - .{ "0", false }, -}); - -test "Deserializer" { - - // Happy case - simple - { - const T = struct { foo: []const u8, bar: bool }; - - var ds = Deserializer(T){}; - try ds.setSerializedField("foo", "123"); - try ds.setSerializedField("bar", "true"); - - const val = try ds.finish(std.testing.allocator); - defer ds.finishFree(std.testing.allocator, val); - try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val); - } - - // Returns error if nonexistent field set - { - const T = struct { foo: []const u8, bar: bool }; - - var ds = Deserializer(T){}; - try std.testing.expectError(error.UnknownField, ds.setSerializedField("baz", "123")); - } - - // Substruct dereferencing - { - const T = struct { - foo: struct { bar: bool, baz: bool }, - }; - - var ds = Deserializer(T){}; - try ds.setSerializedField("foo.bar", "true"); - try ds.setSerializedField("foo.baz", "true"); - - const val = try ds.finish(std.testing.allocator); - defer ds.finishFree(std.testing.allocator, val); - try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true, .baz = true } }, val); - } - - // Union embedding - { - const T = struct { - foo: union(enum) { bar: bool, baz: bool }, - }; - - var ds = Deserializer(T){}; - try ds.setSerializedField("bar", "true"); - - const val = try ds.finish(std.testing.allocator); - defer ds.finishFree(std.testing.allocator, val); - try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true } }, val); - } - - // Returns error if multiple union fields specified - { - const T = struct { - foo: union(enum) { bar: bool, baz: bool }, - }; - - var ds = Deserializer(T){}; - try ds.setSerializedField("bar", "true"); - try ds.setSerializedField("baz", "true"); - - try std.testing.expectError(error.DuplicateUnionMember, ds.finish(std.testing.allocator)); - } - - // Uses default values if fields aren't provided - { - const T = struct { foo: []const u8 = "123", bar: bool = true }; - - var ds = Deserializer(T){}; - - const val = try ds.finish(std.testing.allocator); - defer ds.finishFree(std.testing.allocator, val); - try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val); - } - - // Returns an error if fields aren't provided and no default exists - { - const T = struct { foo: []const u8, bar: bool }; - - var ds = Deserializer(T){}; - try ds.setSerializedField("foo", "123"); - - try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator)); - } - - // Handles optional containers - { - const T = struct { - foo: ?struct { bar: usize = 3, baz: usize } = null, - qux: ?union(enum) { quux: usize } = null, - }; - - var ds = Deserializer(T){}; - - const val = try ds.finish(std.testing.allocator); - defer ds.finishFree(std.testing.allocator, val); - try util.testing.expectDeepEqual(T{ .foo = null, .qux = null }, val); - } - - { - const T = struct { - foo: ?struct { bar: usize = 3, baz: usize } = null, - qux: ?union(enum) { quux: usize } = null, - }; - - var ds = Deserializer(T){}; - try ds.setSerializedField("foo.baz", "3"); - try ds.setSerializedField("quux", "3"); - - const val = try ds.finish(std.testing.allocator); - defer ds.finishFree(std.testing.allocator, val); - try util.testing.expectDeepEqual(T{ .foo = .{ .bar = 3, .baz = 3 }, .qux = .{ .quux = 3 } }, val); - } - - { - const T = struct { - foo: ?struct { bar: usize = 3, baz: usize } = null, - qux: ?union(enum) { quux: usize } = null, - }; - - var ds = Deserializer(T){}; - try ds.setSerializedField("foo.bar", "3"); - try ds.setSerializedField("quux", "3"); - - try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator)); - } -}