Compare commits
25 commits
8aa4f900f6
...
208007c0f7
Author | SHA1 | Date | |
---|---|---|---|
208007c0f7 | |||
31f676580d | |||
6cfd035883 | |||
e27d0064ee | |||
a97850964e | |||
a45ccfe0e4 | |||
2bcef49e5e | |||
2206cd6ac9 | |||
e6f57495c0 | |||
6e56775d61 | |||
0b13f210c7 | |||
ba4f3a7bf4 | |||
f7bcafe1b1 | |||
16c574bdd6 | |||
b2093128de | |||
04c593ffdd | |||
8400cd74fd | |||
83af6a40e4 | |||
c7dcded04a | |||
aa632ace8b | |||
96a46a98c9 | |||
2f78490545 | |||
4a98b6a9c4 | |||
938ee61477 | |||
b99a0095d4 |
26 changed files with 1724 additions and 880 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -2,3 +2,4 @@
|
|||
**/zig-cache
|
||||
**.db
|
||||
/config.json
|
||||
/files
|
||||
|
|
22
build.zig
22
build.zig
|
@ -99,13 +99,23 @@ pub fn build(b: *std.build.Builder) !void {
|
|||
exe.addSystemIncludePath("/usr/include/");
|
||||
|
||||
const unittest_http_cmd = b.step("unit:http", "Run tests for http package");
|
||||
const unittest_http = b.addTest("src/http/test.zig");
|
||||
const unittest_http = b.addTest("src/http/lib.zig");
|
||||
unittest_http_cmd.dependOn(&unittest_http.step);
|
||||
unittest_http.addPackage(pkgs.util);
|
||||
|
||||
//const unittest_util_cmd = b.step("unit:util", "Run tests for util package");
|
||||
//const unittest_util = b.addTest("src/util/Uuid.zig");
|
||||
//unittest_util_cmd.dependOn(&unittest_util.step);
|
||||
const unittest_util_cmd = b.step("unit:util", "Run tests for util package");
|
||||
const unittest_util = b.addTest("src/util/lib.zig");
|
||||
unittest_util_cmd.dependOn(&unittest_util.step);
|
||||
|
||||
const unittest_sql_cmd = b.step("unit:sql", "Run tests for sql package");
|
||||
const unittest_sql = b.addTest("src/sql/lib.zig");
|
||||
unittest_sql_cmd.dependOn(&unittest_sql.step);
|
||||
unittest_sql.addPackage(pkgs.util);
|
||||
|
||||
const unittest_template_cmd = b.step("unit:template", "Run tests for template package");
|
||||
const unittest_template = b.addTest("src/template/lib.zig");
|
||||
unittest_template_cmd.dependOn(&unittest_template.step);
|
||||
//unittest_template.addPackage(pkgs.util);
|
||||
|
||||
//const util_tests = b.addTest("src/util/lib.zig");
|
||||
//const sql_tests = b.addTest("src/sql/lib.zig");
|
||||
|
@ -115,7 +125,9 @@ pub fn build(b: *std.build.Builder) !void {
|
|||
//const unit_tests = b.step("unit-tests", "Run tests");
|
||||
const unittest_all = b.step("unit", "Run unit tests");
|
||||
unittest_all.dependOn(unittest_http_cmd);
|
||||
//unittest_all.dependOn(unittest_util_cmd);
|
||||
unittest_all.dependOn(unittest_util_cmd);
|
||||
unittest_all.dependOn(unittest_sql_cmd);
|
||||
unittest_all.dependOn(unittest_template_cmd);
|
||||
|
||||
const api_integration = b.addTest("./tests/api_integration/lib.zig");
|
||||
api_integration.addPackage(pkgs.opts);
|
||||
|
|
|
@ -9,6 +9,7 @@ const services = struct {
|
|||
const communities = @import("./services/communities.zig");
|
||||
const actors = @import("./services/actors.zig");
|
||||
const auth = @import("./services/auth.zig");
|
||||
const drive = @import("./services/files.zig");
|
||||
const invites = @import("./services/invites.zig");
|
||||
const notes = @import("./services/notes.zig");
|
||||
const follows = @import("./services/follows.zig");
|
||||
|
@ -136,6 +137,14 @@ pub const FollowerQueryResult = FollowQueryResult;
|
|||
pub const FollowingQueryArgs = FollowQueryArgs;
|
||||
pub const FollowingQueryResult = FollowQueryResult;
|
||||
|
||||
pub const UploadFileArgs = struct {
|
||||
filename: []const u8,
|
||||
dir: ?[]const u8,
|
||||
description: ?[]const u8,
|
||||
content_type: []const u8,
|
||||
sensitive: bool,
|
||||
};
|
||||
|
||||
pub fn isAdminSetup(db: sql.Db) !bool {
|
||||
_ = services.communities.adminCommunityId(db) catch |err| switch (err) {
|
||||
error.NotFound => return false,
|
||||
|
@ -509,5 +518,23 @@ fn ApiConn(comptime DbConn: type) type {
|
|||
self.allocator,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn uploadFile(self: *Self, meta: UploadFileArgs, body: []const u8) !void {
|
||||
const user_id = self.user_id orelse return error.NoToken;
|
||||
return try services.drive.createFile(self.db, .{
|
||||
.dir = meta.dir orelse "/",
|
||||
.filename = meta.filename,
|
||||
.owner = .{ .user_id = user_id },
|
||||
.created_by = user_id,
|
||||
.description = meta.description,
|
||||
.content_type = meta.content_type,
|
||||
.sensitive = meta.sensitive,
|
||||
}, body, self.allocator);
|
||||
}
|
||||
|
||||
pub fn driveMkdir(self: *Self, path: []const u8) !void {
|
||||
const user_id = self.user_id orelse return error.NoToken;
|
||||
try services.drive.mkdir(self.db, .{ .user_id = user_id }, path, self.allocator);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -11,59 +11,224 @@ pub const FileOwner = union(enum) {
|
|||
|
||||
pub const DriveFile = struct {
|
||||
id: Uuid,
|
||||
|
||||
path: []const u8,
|
||||
filename: []const u8,
|
||||
|
||||
owner: FileOwner,
|
||||
|
||||
size: usize,
|
||||
|
||||
description: []const u8,
|
||||
content_type: []const u8,
|
||||
sensitive: bool,
|
||||
|
||||
created_at: DateTime,
|
||||
updated_at: DateTime,
|
||||
};
|
||||
|
||||
const EntryType = enum {
|
||||
dir,
|
||||
file,
|
||||
};
|
||||
|
||||
pub const CreateFileArgs = struct {
|
||||
dir: []const u8,
|
||||
filename: []const u8,
|
||||
owner: FileOwner,
|
||||
size: usize,
|
||||
created_at: DateTime,
|
||||
created_by: Uuid,
|
||||
description: ?[]const u8,
|
||||
content_type: ?[]const u8,
|
||||
sensitive: bool,
|
||||
};
|
||||
|
||||
pub const files = struct {
|
||||
pub fn create(db: anytype, owner: FileOwner, filename: []const u8, data: []const u8, alloc: std.mem.Allocator) !void {
|
||||
const id = Uuid.randV4(util.getThreadPrng());
|
||||
const now = DateTime.now();
|
||||
fn lookupDirectory(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !Uuid {
|
||||
return (try db.queryRow(
|
||||
std.meta.Tuple(
|
||||
&.{util.Uuid},
|
||||
),
|
||||
\\SELECT id
|
||||
\\FROM drive_entry_path
|
||||
\\WHERE
|
||||
\\ path = (CASE WHEN LENGTH($1) = 0 THEN '/' ELSE '/' || $1 || '/' END)
|
||||
\\ AND account_owner_id IS NOT DISTINCT FROM $2
|
||||
\\ AND community_owner_id IS NOT DISTINCT FROM $3
|
||||
\\ AND kind = 'dir'
|
||||
\\LIMIT 1
|
||||
,
|
||||
.{
|
||||
std.mem.trim(u8, path, "/"),
|
||||
if (owner == .user_id) owner.user_id else null,
|
||||
if (owner == .community_id) owner.community_id else null,
|
||||
},
|
||||
alloc,
|
||||
))[0];
|
||||
}
|
||||
|
||||
// TODO: assert we're not in a transaction
|
||||
db.insert("drive_file", .{
|
||||
fn lookup(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !Uuid {
|
||||
return (try db.queryRow(
|
||||
std.meta.Tuple(
|
||||
&.{util.Uuid},
|
||||
),
|
||||
\\SELECT id
|
||||
\\FROM drive_entry_path
|
||||
\\WHERE
|
||||
\\ path = (CASE WHEN LENGTH($1) = 0 THEN '/' ELSE '/' || $1 || '/' END)
|
||||
\\ AND account_owner_id IS NOT DISTINCT FROM $2
|
||||
\\ AND community_owner_id IS NOT DISTINCT FROM $3
|
||||
\\LIMIT 1
|
||||
,
|
||||
.{
|
||||
std.mem.trim(u8, path, "/"),
|
||||
if (owner == .user_id) owner.user_id else null,
|
||||
if (owner == .community_id) owner.community_id else null,
|
||||
},
|
||||
alloc,
|
||||
))[0];
|
||||
}
|
||||
|
||||
pub fn mkdir(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !void {
|
||||
var split = std.mem.splitBackwards(u8, std.mem.trim(u8, path, "/"), "/");
|
||||
const name = split.first();
|
||||
const dir = split.rest();
|
||||
std.log.debug("'{s}' / '{s}'", .{ name, dir });
|
||||
|
||||
if (name.len == 0) return error.EmptyName;
|
||||
|
||||
const id = Uuid.randV4(util.getThreadPrng());
|
||||
|
||||
const tx = try db.begin();
|
||||
errdefer tx.rollback();
|
||||
|
||||
const parent = try lookupDirectory(tx, owner, dir, alloc);
|
||||
|
||||
try tx.insert("drive_entry", .{
|
||||
.id = id,
|
||||
|
||||
.account_owner_id = if (owner == .user_id) owner.user_id else null,
|
||||
.community_owner_id = if (owner == .community_id) owner.community_id else null,
|
||||
|
||||
.name = name,
|
||||
.parent_directory_id = parent,
|
||||
}, alloc);
|
||||
try tx.commit();
|
||||
}
|
||||
|
||||
pub fn rmdir(db: anytype, owner: FileOwner, path: []const u8, alloc: std.mem.Allocator) !void {
|
||||
const tx = try db.begin();
|
||||
errdefer tx.rollback();
|
||||
|
||||
const id = try lookupDirectory(tx, owner, path, alloc);
|
||||
try tx.exec("DELETE FROM drive_directory WHERE id = $1", .{id}, alloc);
|
||||
try tx.commit();
|
||||
}
|
||||
|
||||
fn insertFileRow(tx: anytype, id: Uuid, filename: []const u8, owner: FileOwner, dir: Uuid, alloc: std.mem.Allocator) !void {
|
||||
try tx.insert("drive_entry", .{
|
||||
.id = id,
|
||||
|
||||
.account_owner_id = if (owner == .user_id) owner.user_id else null,
|
||||
.community_owner_id = if (owner == .community_id) owner.community_id else null,
|
||||
|
||||
.parent_directory_id = dir,
|
||||
.name = filename,
|
||||
|
||||
.file_id = id,
|
||||
}, alloc);
|
||||
}
|
||||
|
||||
pub fn createFile(db: anytype, args: CreateFileArgs, data: []const u8, alloc: std.mem.Allocator) !void {
|
||||
const id = Uuid.randV4(util.getThreadPrng());
|
||||
const now = DateTime.now();
|
||||
|
||||
{
|
||||
var tx = try db.begin();
|
||||
errdefer tx.rollback();
|
||||
|
||||
const dir_id = try lookupDirectory(tx, args.owner, args.dir, alloc);
|
||||
|
||||
try tx.insert("file_upload", .{
|
||||
.id = id,
|
||||
.filename = filename,
|
||||
.owner = owner,
|
||||
|
||||
.filename = args.filename,
|
||||
|
||||
.created_by = args.created_by,
|
||||
.size = data.len,
|
||||
|
||||
.description = args.description,
|
||||
.content_type = args.content_type,
|
||||
.sensitive = args.sensitive,
|
||||
|
||||
.is_deleted = false,
|
||||
|
||||
.created_at = now,
|
||||
}, alloc) catch return error.DatabaseFailure;
|
||||
// Assume the previous statement succeeded and is not stuck in a transaction
|
||||
errdefer {
|
||||
db.exec("DELETE FROM drive_file WHERE ID = $1", .{id}, alloc) catch |err| {
|
||||
std.log.err("Unable to remove file record in DB: {}", .{err});
|
||||
};
|
||||
.updated_at = now,
|
||||
}, alloc);
|
||||
|
||||
var sub_tx = try tx.savepoint();
|
||||
if (insertFileRow(sub_tx, id, args.filename, args.owner, dir_id, alloc)) |_| {
|
||||
try sub_tx.release();
|
||||
} else |err| {
|
||||
std.log.debug("{}", .{err});
|
||||
switch (err) {
|
||||
error.UniqueViolation => {
|
||||
try sub_tx.rollbackSavepoint();
|
||||
// Rename the file before trying again
|
||||
var split = std.mem.split(u8, args.filename, ".");
|
||||
const name = split.first();
|
||||
const ext = split.rest();
|
||||
var buf: [256]u8 = undefined;
|
||||
const drive_filename = try std.fmt.bufPrint(&buf, "{s}.{}.{s}", .{ name, id, ext });
|
||||
try insertFileRow(tx, id, drive_filename, args.owner, dir_id, alloc);
|
||||
},
|
||||
else => return error.DatabaseFailure,
|
||||
}
|
||||
}
|
||||
|
||||
try saveFile(id, data);
|
||||
try tx.commit();
|
||||
}
|
||||
|
||||
const data_root = "./files";
|
||||
fn saveFile(id: Uuid, data: []const u8) !void {
|
||||
var dir = try std.fs.cwd().openDir(data_root);
|
||||
defer dir.close();
|
||||
|
||||
var file = try dir.createFile(id.toCharArray(), .{ .exclusive = true });
|
||||
defer file.close();
|
||||
|
||||
try file.writer().writeAll(data);
|
||||
try file.sync();
|
||||
errdefer {
|
||||
db.exec("DELETE FROM file_upload WHERE ID = $1", .{id}, alloc) catch |err| {
|
||||
std.log.err("Unable to remove file record in DB: {}", .{err});
|
||||
};
|
||||
db.exec("DELETE FROM drive_entry WHERE ID = $1", .{id}, alloc) catch |err| {
|
||||
std.log.err("Unable to remove file record in DB: {}", .{err});
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deref(alloc: std.mem.Allocator, id: Uuid) ![]const u8 {
|
||||
var dir = try std.fs.cwd().openDir(data_root);
|
||||
defer dir.close();
|
||||
try saveFile(id, data);
|
||||
}
|
||||
|
||||
return dir.readFileAlloc(alloc, id.toCharArray(), 1 << 32);
|
||||
}
|
||||
const data_root = "./files";
|
||||
fn saveFile(id: Uuid, data: []const u8) !void {
|
||||
var dir = try std.fs.cwd().openDir(data_root, .{});
|
||||
defer dir.close();
|
||||
|
||||
pub fn delete(db: anytype, alloc: std.mem.Allocator, id: Uuid) !void {
|
||||
var dir = try std.fs.cwd().openDir(data_root);
|
||||
defer dir.close();
|
||||
var file = try dir.createFile(&id.toCharArray(), .{ .exclusive = true });
|
||||
defer file.close();
|
||||
|
||||
try dir.deleteFile(id.toCharArray());
|
||||
try file.writer().writeAll(data);
|
||||
try file.sync();
|
||||
}
|
||||
|
||||
db.exec("DELETE FROM drive_file WHERE ID = $1", .{id}, alloc) catch return error.DatabaseFailure;
|
||||
}
|
||||
};
|
||||
pub fn deref(alloc: std.mem.Allocator, id: Uuid) ![]const u8 {
|
||||
var dir = try std.fs.cwd().openDir(data_root, .{});
|
||||
defer dir.close();
|
||||
|
||||
return dir.readFileAlloc(alloc, &id.toCharArray(), 1 << 32);
|
||||
}
|
||||
|
||||
pub fn deleteFile(db: anytype, alloc: std.mem.Allocator, id: Uuid) !void {
|
||||
var dir = try std.fs.cwd().openDir(data_root, .{});
|
||||
defer dir.close();
|
||||
|
||||
try dir.deleteFile(id.toCharArray());
|
||||
|
||||
const tx = try db.beginOrSavepoint();
|
||||
errdefer tx.rollback();
|
||||
|
||||
tx.exec("DELETE FROM drive_entry WHERE ID = $1", .{id}, alloc) catch return error.DatabaseFailure;
|
||||
tx.exec("DELETE FROM file_upload WHERE ID = $1", .{id}, alloc) catch return error.DatabaseFailure;
|
||||
try tx.commitOrRelease();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,60 @@
|
|||
const std = @import("std");
|
||||
|
||||
pub const ParamIter = struct {
|
||||
str: []const u8,
|
||||
index: usize = 0,
|
||||
|
||||
const Param = struct {
|
||||
name: []const u8,
|
||||
value: []const u8,
|
||||
};
|
||||
|
||||
pub fn from(str: []const u8) ParamIter {
|
||||
return .{ .str = str, .index = std.mem.indexOfScalar(u8, str, ';') orelse str.len };
|
||||
}
|
||||
|
||||
pub fn fieldValue(self: *ParamIter) []const u8 {
|
||||
return std.mem.sliceTo(self.str, ';');
|
||||
}
|
||||
|
||||
pub fn next(self: *ParamIter) ?Param {
|
||||
if (self.index >= self.str.len) return null;
|
||||
|
||||
const start = self.index + 1;
|
||||
const new_start = std.mem.indexOfScalarPos(u8, self.str, start, ';') orelse self.str.len;
|
||||
self.index = new_start;
|
||||
|
||||
const param = std.mem.trim(u8, self.str[start..new_start], " \t");
|
||||
var split = std.mem.split(u8, param, "=");
|
||||
const name = split.first();
|
||||
const value = std.mem.trimLeft(u8, split.rest(), " \t");
|
||||
// TODO: handle quoted values
|
||||
// TODO: handle parse errors
|
||||
|
||||
return Param{
|
||||
.name = name,
|
||||
.value = value,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub fn getParam(field: []const u8, name: ?[]const u8) ?[]const u8 {
|
||||
var iter = ParamIter.from(field);
|
||||
|
||||
if (name) |param| {
|
||||
while (iter.next()) |p| {
|
||||
if (std.ascii.eqlIgnoreCase(param, p.name)) {
|
||||
const trimmed = std.mem.trim(u8, p.value, " \t");
|
||||
if (trimmed.len >= 2 and trimmed[0] == '"' and trimmed[trimmed.len - 1] == '"') {
|
||||
return trimmed[1 .. trimmed.len - 1];
|
||||
}
|
||||
return trimmed;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
} else return iter.fieldValue();
|
||||
}
|
||||
|
||||
pub const Fields = struct {
|
||||
const HashContext = struct {
|
||||
const hash_seed = 1;
|
|
@ -1,27 +1,55 @@
|
|||
const std = @import("std");
|
||||
const ciutf8 = @import("util").ciutf8;
|
||||
|
||||
const request = @import("./request.zig");
|
||||
|
||||
const server = @import("./server.zig");
|
||||
|
||||
pub const urlencode = @import("./urlencode.zig");
|
||||
pub const socket = @import("./socket.zig");
|
||||
const json = @import("./json.zig");
|
||||
const multipart = @import("./multipart.zig");
|
||||
pub const fields = @import("./fields.zig");
|
||||
|
||||
pub const Method = enum {
|
||||
GET,
|
||||
HEAD,
|
||||
POST,
|
||||
PUT,
|
||||
DELETE,
|
||||
CONNECT,
|
||||
OPTIONS,
|
||||
TRACE,
|
||||
PATCH,
|
||||
|
||||
// WebDAV methods (we use some of them for the drive system)
|
||||
MKCOL,
|
||||
MOVE,
|
||||
|
||||
pub fn requestHasBody(self: Method) bool {
|
||||
return switch (self) {
|
||||
.POST, .PUT, .PATCH, .MKCOL, .MOVE => true,
|
||||
else => false,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub const Method = std.http.Method;
|
||||
pub const Status = std.http.Status;
|
||||
|
||||
pub const Request = request.Request(server.Stream.Reader);
|
||||
pub const Response = server.Response;
|
||||
pub const Handler = server.Handler;
|
||||
//pub const Handler = server.Handler;
|
||||
pub const Server = server.Server;
|
||||
|
||||
pub const middleware = @import("./middleware.zig");
|
||||
pub const queryStringify = @import("./query.zig").queryStringify;
|
||||
|
||||
pub const Fields = @import("./headers.zig").Fields;
|
||||
pub const Fields = fields.Fields;
|
||||
|
||||
pub const FormFile = multipart.FormFile;
|
||||
|
||||
pub const Protocol = enum {
|
||||
http_1_0,
|
||||
http_1_1,
|
||||
http_1_x,
|
||||
};
|
||||
|
||||
test {
|
||||
_ = std.testing.refAllDecls(@This());
|
||||
}
|
||||
|
|
|
@ -14,10 +14,11 @@
|
|||
/// Terminal middlewares that are not implemented using other middlewares should
|
||||
/// only accept a `void` value for `next_handler`.
|
||||
const std = @import("std");
|
||||
const http = @import("./lib.zig");
|
||||
const util = @import("util");
|
||||
const query_utils = @import("./query.zig");
|
||||
const http = @import("./lib.zig");
|
||||
const urlencode = @import("./urlencode.zig");
|
||||
const json_utils = @import("./json.zig");
|
||||
const fields = @import("./fields.zig");
|
||||
|
||||
/// Takes an iterable of middlewares and chains them together.
|
||||
pub fn apply(middlewares: anytype) Apply(@TypeOf(middlewares)) {
|
||||
|
@ -29,20 +30,20 @@ pub fn Apply(comptime Middlewares: type) type {
|
|||
return ApplyInternal(std.meta.fields(Middlewares));
|
||||
}
|
||||
|
||||
fn ApplyInternal(comptime fields: []const std.builtin.Type.StructField) type {
|
||||
if (fields.len == 0) return void;
|
||||
fn ApplyInternal(comptime which: []const std.builtin.Type.StructField) type {
|
||||
if (which.len == 0) return void;
|
||||
|
||||
return HandlerList(
|
||||
fields[0].field_type,
|
||||
ApplyInternal(fields[1..]),
|
||||
which[0].field_type,
|
||||
ApplyInternal(which[1..]),
|
||||
);
|
||||
}
|
||||
|
||||
fn applyInternal(middlewares: anytype, comptime fields: []const std.builtin.Type.StructField) ApplyInternal(fields) {
|
||||
if (fields.len == 0) return {};
|
||||
fn applyInternal(middlewares: anytype, comptime which: []const std.builtin.Type.StructField) ApplyInternal(which) {
|
||||
if (which.len == 0) return {};
|
||||
return .{
|
||||
.first = @field(middlewares, fields[0].name),
|
||||
.next = applyInternal(middlewares, fields[1..]),
|
||||
.first = @field(middlewares, which[0].name),
|
||||
.next = applyInternal(middlewares, which[1..]),
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -349,15 +350,71 @@ pub fn router(routes: anytype) Router(@TypeOf(routes)) {
|
|||
return Router(@TypeOf(routes)){ .routes = routes };
|
||||
}
|
||||
|
||||
pub const PathIter = struct {
|
||||
is_first: bool,
|
||||
iter: std.mem.SplitIterator(u8),
|
||||
|
||||
pub fn from(path: []const u8) PathIter {
|
||||
return .{ .is_first = true, .iter = std.mem.split(u8, path, "/") };
|
||||
}
|
||||
|
||||
pub fn next(self: *PathIter) ?[]const u8 {
|
||||
defer self.is_first = false;
|
||||
while (self.iter.next()) |it| if (it.len != 0) {
|
||||
return it;
|
||||
};
|
||||
|
||||
if (self.is_first) return self.iter.rest();
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
pub fn first(self: *PathIter) []const u8 {
|
||||
std.debug.assert(self.is_first);
|
||||
return self.next().?;
|
||||
}
|
||||
|
||||
pub fn rest(self: *PathIter) []const u8 {
|
||||
return self.iter.rest();
|
||||
}
|
||||
};
|
||||
|
||||
test "PathIter" {
|
||||
const testCase = struct {
|
||||
fn case(path: []const u8, segments: []const []const u8) !void {
|
||||
var iter = PathIter.from(path);
|
||||
for (segments) |s| {
|
||||
try std.testing.expectEqualStrings(s, iter.next() orelse return error.TestExpectedEqual);
|
||||
}
|
||||
try std.testing.expect(iter.next() == null);
|
||||
}
|
||||
}.case;
|
||||
|
||||
try testCase("", &.{""});
|
||||
try testCase("*", &.{"*"});
|
||||
try testCase("/", &.{""});
|
||||
try testCase("/ab/cd", &.{ "ab", "cd" });
|
||||
try testCase("/ab/cd/", &.{ "ab", "cd" });
|
||||
try testCase("/ab/cd//", &.{ "ab", "cd" });
|
||||
try testCase("ab", &.{"ab"});
|
||||
try testCase("/ab", &.{"ab"});
|
||||
try testCase("ab/", &.{"ab"});
|
||||
try testCase("ab//ab//", &.{ "ab", "ab" });
|
||||
}
|
||||
|
||||
// helper function for doing route analysis
|
||||
fn pathMatches(route: []const u8, path: []const u8) bool {
|
||||
var path_iter = util.PathIter.from(path);
|
||||
var route_iter = util.PathIter.from(route);
|
||||
var path_iter = PathIter.from(path);
|
||||
var route_iter = PathIter.from(route);
|
||||
while (route_iter.next()) |route_segment| {
|
||||
const path_segment = path_iter.next() orelse return false;
|
||||
const path_segment = path_iter.next() orelse "";
|
||||
|
||||
if (route_segment.len > 0 and route_segment[0] == ':') {
|
||||
// Route Argument
|
||||
if (path_segment.len == 0) return false;
|
||||
if (route_segment[route_segment.len - 1] == '*') {
|
||||
// consume rest of path segments
|
||||
while (path_iter.next()) |_| {}
|
||||
} else if (path_segment.len == 0) return false;
|
||||
} else {
|
||||
if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return false;
|
||||
}
|
||||
|
@ -428,6 +485,10 @@ test "route" {
|
|||
try testCase(true, .{ .method = .POST, .path = "/abcd/efgh" }, .POST, "abcd/efgh");
|
||||
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "abcd/efgh");
|
||||
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz");
|
||||
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/efgh/xyz");
|
||||
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/efgh");
|
||||
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd/");
|
||||
try testCase(true, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "abcd");
|
||||
|
||||
try testCase(false, .{ .method = .POST, .path = "/" }, .GET, "/");
|
||||
try testCase(false, .{ .method = .GET, .path = "/abcd" }, .GET, "");
|
||||
|
@ -436,32 +497,21 @@ test "route" {
|
|||
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg" }, .GET, "/abcd/");
|
||||
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/");
|
||||
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg/xyz" }, .GET, "abcd/efgh/xyz/foo");
|
||||
try testCase(false, .{ .method = .GET, .path = "/abcd/:arg*" }, .GET, "defg/abcd");
|
||||
}
|
||||
|
||||
/// Mounts a router subtree under a given path. Middlewares further down on the list
|
||||
/// are called with the path prefix specified by `route` removed from the path.
|
||||
/// Must be below `split_uri` on the middleware list.
|
||||
pub fn Mount(comptime route: []const u8) type {
|
||||
if (std.mem.indexOfScalar(u8, route, ':') != null) @compileError("Route args cannot be mounted");
|
||||
return struct {
|
||||
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
||||
var path_iter = util.PathIter.from(ctx.path);
|
||||
comptime var route_iter = util.PathIter.from(route);
|
||||
var path_unused: []const u8 = ctx.path;
|
||||
|
||||
inline while (comptime route_iter.next()) |route_segment| {
|
||||
if (comptime route_segment.len == 0) continue;
|
||||
const path_segment = path_iter.next() orelse return error.RouteMismatch;
|
||||
path_unused = path_iter.rest();
|
||||
if (comptime route_segment[0] == ':') {
|
||||
@compileLog("Argument segments cannot be mounted");
|
||||
// Route Argument
|
||||
} else {
|
||||
if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch;
|
||||
}
|
||||
}
|
||||
const args = try parseArgsFromPath(route ++ "/:path*", struct { path: []const u8 }, ctx.path);
|
||||
|
||||
var new_ctx = ctx;
|
||||
new_ctx.path = path_unused;
|
||||
new_ctx.path = args.path;
|
||||
|
||||
return next.handle(req, res, new_ctx, {});
|
||||
}
|
||||
};
|
||||
|
@ -491,18 +541,33 @@ test "mount" {
|
|||
|
||||
fn parseArgsFromPath(comptime route: []const u8, comptime Args: type, path: []const u8) !Args {
|
||||
var args: Args = undefined;
|
||||
var path_iter = util.PathIter.from(path);
|
||||
comptime var route_iter = util.PathIter.from(route);
|
||||
var path_iter = PathIter.from(path);
|
||||
comptime var route_iter = PathIter.from(route);
|
||||
var path_unused: []const u8 = path;
|
||||
|
||||
inline while (comptime route_iter.next()) |route_segment| {
|
||||
const path_segment = path_iter.next() orelse return error.RouteMismatch;
|
||||
if (route_segment.len > 0 and route_segment[0] == ':') {
|
||||
const path_segment = path_iter.next() orelse "";
|
||||
if (route_segment[0] == ':') {
|
||||
comptime var name: []const u8 = route_segment[1..];
|
||||
var value: []const u8 = path_segment;
|
||||
|
||||
// route segment is an argument segment
|
||||
if (path_segment.len == 0) return error.RouteMismatch;
|
||||
const A = @TypeOf(@field(args, route_segment[1..]));
|
||||
@field(args, route_segment[1..]) = try parseArgFromPath(A, path_segment);
|
||||
if (comptime route_segment[route_segment.len - 1] == '*') {
|
||||
// waste remaining args
|
||||
while (path_iter.next()) |_| {}
|
||||
name = route_segment[1 .. route_segment.len - 1];
|
||||
value = path_unused;
|
||||
} else {
|
||||
if (path_segment.len == 0) return error.RouteMismatch;
|
||||
}
|
||||
|
||||
const A = @TypeOf(@field(args, name));
|
||||
@field(args, name) = try parseArgFromPath(A, value);
|
||||
} else {
|
||||
// route segment is a literal segment
|
||||
if (!std.ascii.eqlIgnoreCase(route_segment, path_segment)) return error.RouteMismatch;
|
||||
}
|
||||
path_unused = path_iter.rest();
|
||||
}
|
||||
|
||||
if (path_iter.next() != null) return error.RouteMismatch;
|
||||
|
@ -577,6 +642,21 @@ test "ParsePathArgs" {
|
|||
try testCase("/:id/xyz/:str", struct { id: usize, str: []const u8 }, "/3/xyz/abcd", .{ .id = 3, .str = "abcd" });
|
||||
try testCase("/:id", struct { id: util.Uuid }, "/" ++ util.Uuid.nil.toCharArray(), .{ .id = util.Uuid.nil });
|
||||
|
||||
try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/abc", .{ .arg = "abc" });
|
||||
try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/abc/def", .{ .arg = "abc/def" });
|
||||
try testCase("/xyz/:arg*", struct { arg: []const u8 }, "/xyz/", .{ .arg = "" });
|
||||
|
||||
// Compiler crashes if i keep the args named the same as above.
|
||||
// TODO: Debug this and try to fix it
|
||||
try testCase("/xyz/:bar*", struct { bar: []const u8 }, "/xyz", .{ .bar = "" });
|
||||
|
||||
// It's a quirk that the initial / is left in for these cases. However, it results in a path
|
||||
// that's semantically equivalent so i didn't bother fixing it
|
||||
try testCase("/:foo*", struct { foo: []const u8 }, "/abc", .{ .foo = "/abc" });
|
||||
try testCase("/:foo*", struct { foo: []const u8 }, "/abc/def", .{ .foo = "/abc/def" });
|
||||
try testCase("/:foo*", struct { foo: []const u8 }, "/", .{ .foo = "/" });
|
||||
try testCase("/:foo*", struct { foo: []const u8 }, "", .{ .foo = "" });
|
||||
|
||||
try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/", .{}));
|
||||
try std.testing.expectError(error.RouteMismatch, testCase("/abcd/:id", struct { id: usize }, "/123", .{}));
|
||||
try std.testing.expectError(error.RouteMismatch, testCase("/:id", struct { id: usize }, "/3/id/blahblah", .{ .id = 3 }));
|
||||
|
@ -587,41 +667,51 @@ const BaseContentType = enum {
|
|||
json,
|
||||
url_encoded,
|
||||
octet_stream,
|
||||
multipart_formdata,
|
||||
|
||||
other,
|
||||
};
|
||||
|
||||
fn parseBodyFromRequest(comptime T: type, content_type: BaseContentType, reader: anytype, alloc: std.mem.Allocator) !T {
|
||||
//@compileLog(T);
|
||||
const buf = try reader.readAllAlloc(alloc, 1 << 16);
|
||||
defer alloc.free(buf);
|
||||
fn parseBodyFromRequest(comptime T: type, content_type: ?[]const u8, reader: anytype, alloc: std.mem.Allocator) !T {
|
||||
// Use json by default for now for testing purposes
|
||||
const eff_type = content_type orelse "application/json";
|
||||
const parser_type = matchContentType(eff_type);
|
||||
|
||||
switch (content_type) {
|
||||
switch (parser_type) {
|
||||
.octet_stream, .json => {
|
||||
const buf = try reader.readAllAlloc(alloc, 1 << 16);
|
||||
defer alloc.free(buf);
|
||||
const body = try json_utils.parse(T, buf, alloc);
|
||||
defer json_utils.parseFree(body, alloc);
|
||||
|
||||
return try util.deepClone(alloc, body);
|
||||
},
|
||||
.url_encoded => return query_utils.parseQuery(alloc, T, buf) catch |err| switch (err) {
|
||||
error.NoQuery => error.NoBody,
|
||||
else => err,
|
||||
.url_encoded => {
|
||||
const buf = try reader.readAllAlloc(alloc, 1 << 16);
|
||||
defer alloc.free(buf);
|
||||
return urlencode.parse(alloc, T, buf) catch |err| switch (err) {
|
||||
//error.NoQuery => error.NoBody,
|
||||
else => err,
|
||||
};
|
||||
},
|
||||
.multipart_formdata => {
|
||||
const boundary = fields.getParam(eff_type, "boundary") orelse return error.MissingBoundary;
|
||||
return try @import("./multipart.zig").parseFormData(T, boundary, reader, alloc);
|
||||
},
|
||||
else => return error.UnsupportedMediaType,
|
||||
}
|
||||
}
|
||||
|
||||
// figure out what base parser to use
|
||||
fn matchContentType(hdr: ?[]const u8) ?BaseContentType {
|
||||
if (hdr) |h| {
|
||||
if (std.ascii.eqlIgnoreCase(h, "application/x-www-form-urlencoded")) return .url_encoded;
|
||||
if (std.ascii.eqlIgnoreCase(h, "application/json")) return .json;
|
||||
if (std.ascii.eqlIgnoreCase(h, "application/octet-stream")) return .octet_stream;
|
||||
fn matchContentType(hdr: []const u8) BaseContentType {
|
||||
const trimmed = std.mem.sliceTo(hdr, ';');
|
||||
if (std.ascii.eqlIgnoreCase(trimmed, "application/x-www-form-urlencoded")) return .url_encoded;
|
||||
if (std.ascii.eqlIgnoreCase(trimmed, "application/json")) return .json;
|
||||
if (std.ascii.endsWithIgnoreCase(trimmed, "+json")) return .json;
|
||||
if (std.ascii.eqlIgnoreCase(trimmed, "application/octet-stream")) return .octet_stream;
|
||||
if (std.ascii.eqlIgnoreCase(trimmed, "multipart/form-data")) return .multipart_formdata;
|
||||
|
||||
return .other;
|
||||
}
|
||||
|
||||
return null;
|
||||
return .other;
|
||||
}
|
||||
|
||||
/// Parses a set of body arguments from the request body based on the request's Content-Type
|
||||
|
@ -640,10 +730,8 @@ pub fn ParseBody(comptime Body: type) type {
|
|||
return next.handle(req, res, new_ctx, {});
|
||||
}
|
||||
|
||||
const base_content_type = matchContentType(content_type);
|
||||
|
||||
var stream = req.body orelse return error.NoBody;
|
||||
const body = try parseBodyFromRequest(Body, base_content_type orelse .json, stream.reader(), ctx.allocator);
|
||||
const body = try parseBodyFromRequest(Body, content_type, stream.reader(), ctx.allocator);
|
||||
defer util.deepFree(ctx.allocator, body);
|
||||
|
||||
return next.handle(
|
||||
|
@ -659,12 +747,57 @@ pub fn parseBody(comptime Body: type) ParseBody(Body) {
|
|||
return .{};
|
||||
}
|
||||
|
||||
test "parseBodyFromRequest" {
|
||||
const testCase = struct {
|
||||
fn case(content_type: []const u8, body: []const u8, expected: anytype) !void {
|
||||
var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) };
|
||||
const result = try parseBodyFromRequest(@TypeOf(expected), content_type, stream.reader(), std.testing.allocator);
|
||||
defer util.deepFree(std.testing.allocator, result);
|
||||
|
||||
try util.testing.expectDeepEqual(expected, result);
|
||||
}
|
||||
}.case;
|
||||
|
||||
const Struct = struct {
|
||||
id: usize,
|
||||
};
|
||||
try testCase("application/json", "{\"id\": 3}", Struct{ .id = 3 });
|
||||
try testCase("application/x-www-form-urlencoded", "id=3", Struct{ .id = 3 });
|
||||
|
||||
//try testCase("multipart/form-data; ",
|
||||
//\\
|
||||
//, Struct{ .id = 3 });
|
||||
}
|
||||
|
||||
test "parseBody" {
|
||||
const Struct = struct {
|
||||
foo: []const u8,
|
||||
};
|
||||
const body =
|
||||
\\{"foo": "bar"}
|
||||
;
|
||||
var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) };
|
||||
var headers = http.Fields.init(std.testing.allocator);
|
||||
defer headers.deinit();
|
||||
|
||||
try parseBody(Struct).handle(
|
||||
.{ .body = @as(?std.io.StreamSource, stream), .headers = headers },
|
||||
.{},
|
||||
.{ .allocator = std.testing.allocator },
|
||||
struct {
|
||||
fn handle(_: anytype, _: anytype, _: anytype, ctx: anytype, _: void) !void {
|
||||
try util.testing.expectDeepEqual(Struct{ .foo = "bar" }, ctx.body);
|
||||
}
|
||||
}{},
|
||||
);
|
||||
}
|
||||
|
||||
/// Parses query parameters as defined in query.zig
|
||||
pub fn ParseQueryParams(comptime QueryParams: type) type {
|
||||
return struct {
|
||||
pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void {
|
||||
if (QueryParams == void) return next.handle(req, res, addField(ctx, "query_params", {}), {});
|
||||
const query = try query_utils.parseQuery(ctx.allocator, QueryParams, ctx.query_string);
|
||||
const query = try urlencode.parse(ctx.allocator, QueryParams, ctx.query_string);
|
||||
defer util.deepFree(ctx.allocator, query);
|
||||
|
||||
return next.handle(
|
||||
|
|
362
src/http/multipart.zig
Normal file
362
src/http/multipart.zig
Normal file
|
@ -0,0 +1,362 @@
|
|||
const std = @import("std");
|
||||
const util = @import("util");
|
||||
const fields = @import("./fields.zig");
|
||||
|
||||
const max_boundary = 70;
|
||||
const read_ahead = max_boundary + 4;
|
||||
|
||||
pub fn MultipartStream(comptime ReaderType: type) type {
|
||||
return struct {
|
||||
const Multipart = @This();
|
||||
|
||||
pub const BaseReader = ReaderType;
|
||||
pub const PartReader = std.io.Reader(*Part, ReaderType.Error, Part.read);
|
||||
|
||||
stream: std.io.PeekStream(.{ .Static = read_ahead }, ReaderType),
|
||||
boundary: []const u8,
|
||||
|
||||
pub fn next(self: *Multipart, alloc: std.mem.Allocator) !?Part {
|
||||
const reader = self.stream.reader();
|
||||
while (true) {
|
||||
try reader.skipUntilDelimiterOrEof('\r');
|
||||
var line_buf: [read_ahead]u8 = undefined;
|
||||
const len = try reader.readAll(line_buf[0 .. self.boundary.len + 3]);
|
||||
const line = line_buf[0..len];
|
||||
if (line.len == 0) return null;
|
||||
if (std.mem.startsWith(u8, line, "\n--") and std.mem.endsWith(u8, line, self.boundary)) {
|
||||
// match, check for end thing
|
||||
var more_buf: [2]u8 = undefined;
|
||||
if (try reader.readAll(&more_buf) != 2) return error.EndOfStream;
|
||||
|
||||
const more = !(more_buf[0] == '-' and more_buf[1] == '-');
|
||||
try self.stream.putBack(&more_buf);
|
||||
try reader.skipUntilDelimiterOrEof('\n');
|
||||
if (more) return try Part.open(self, alloc) else return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub const Part = struct {
|
||||
base: ?*Multipart,
|
||||
fields: fields.Fields,
|
||||
|
||||
pub fn open(base: *Multipart, alloc: std.mem.Allocator) !Part {
|
||||
var parsed_fields = try @import("./request/parser.zig").parseHeaders(alloc, base.stream.reader());
|
||||
return .{ .base = base, .fields = parsed_fields };
|
||||
}
|
||||
|
||||
pub fn reader(self: *Part) PartReader {
|
||||
return .{ .context = self };
|
||||
}
|
||||
|
||||
pub fn close(self: *Part) void {
|
||||
self.fields.deinit();
|
||||
}
|
||||
|
||||
pub fn read(self: *Part, buf: []u8) ReaderType.Error!usize {
|
||||
const base = self.base orelse return 0;
|
||||
|
||||
const r = base.stream.reader();
|
||||
|
||||
var count: usize = 0;
|
||||
while (count < buf.len) {
|
||||
const byte = r.readByte() catch |err| switch (err) {
|
||||
error.EndOfStream => {
|
||||
self.base = null;
|
||||
return count;
|
||||
},
|
||||
else => |e| return e,
|
||||
};
|
||||
|
||||
buf[count] = byte;
|
||||
count += 1;
|
||||
if (byte != '\r') continue;
|
||||
|
||||
var line_buf: [read_ahead]u8 = undefined;
|
||||
const line = line_buf[0..try r.readAll(line_buf[0 .. base.boundary.len + 3])];
|
||||
if (!std.mem.startsWith(u8, line, "\n--") or !std.mem.endsWith(u8, line, base.boundary)) {
|
||||
base.stream.putBack(line) catch unreachable;
|
||||
continue;
|
||||
} else {
|
||||
base.stream.putBack(line) catch unreachable;
|
||||
base.stream.putBackByte('\r') catch unreachable;
|
||||
self.base = null;
|
||||
return count - 1;
|
||||
}
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
pub fn openMultipart(boundary: []const u8, reader: anytype) !MultipartStream(@TypeOf(reader)) {
|
||||
if (boundary.len > max_boundary) return error.BoundaryTooLarge;
|
||||
var stream = .{
|
||||
.stream = std.io.peekStream(read_ahead, reader),
|
||||
.boundary = boundary,
|
||||
};
|
||||
|
||||
stream.stream.putBack("\r\n") catch unreachable;
|
||||
return stream;
|
||||
}
|
||||
|
||||
const MultipartFormField = struct {
|
||||
name: []const u8,
|
||||
value: []const u8,
|
||||
|
||||
filename: ?[]const u8 = null,
|
||||
content_type: ?[]const u8 = null,
|
||||
};
|
||||
|
||||
pub const FormFile = struct {
|
||||
data: []const u8,
|
||||
filename: []const u8,
|
||||
content_type: []const u8,
|
||||
};
|
||||
|
||||
pub fn MultipartForm(comptime ReaderType: type) type {
|
||||
return struct {
|
||||
stream: MultipartStream(ReaderType),
|
||||
|
||||
pub fn next(self: *@This(), alloc: std.mem.Allocator) !?MultipartFormField {
|
||||
var part = (try self.stream.next(alloc)) orelse return null;
|
||||
defer part.close();
|
||||
|
||||
const disposition = part.fields.get("Content-Disposition") orelse return error.MissingDisposition;
|
||||
|
||||
if (!std.ascii.eqlIgnoreCase(fields.getParam(disposition, null).?, "form-data")) return error.BadDisposition;
|
||||
const name = try util.deepClone(alloc, fields.getParam(disposition, "name") orelse return error.BadDisposition);
|
||||
errdefer util.deepFree(alloc, name);
|
||||
const filename = try util.deepClone(alloc, fields.getParam(disposition, "filename"));
|
||||
errdefer util.deepFree(alloc, filename);
|
||||
const content_type = try util.deepClone(alloc, part.fields.get("Content-Type"));
|
||||
errdefer util.deepFree(alloc, content_type);
|
||||
|
||||
const value = try part.reader().readAllAlloc(alloc, 1 << 32);
|
||||
|
||||
return MultipartFormField{
|
||||
.name = name,
|
||||
.value = value,
|
||||
|
||||
.filename = filename,
|
||||
.content_type = content_type,
|
||||
};
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn openForm(multipart_stream: anytype) MultipartForm(@TypeOf(multipart_stream).BaseReader) {
|
||||
return .{ .stream = multipart_stream };
|
||||
}
|
||||
|
||||
fn Deserializer(comptime Result: type) type {
|
||||
return util.DeserializerContext(Result, MultipartFormField, struct {
|
||||
pub const options = .{ .isScalar = isScalar, .embed_unions = true };
|
||||
|
||||
pub fn isScalar(comptime T: type) bool {
|
||||
if (T == FormFile or T == ?FormFile) return true;
|
||||
return util.serialize.defaultIsScalar(T);
|
||||
}
|
||||
|
||||
pub fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, val: MultipartFormField) !T {
|
||||
if (T == FormFile or T == ?FormFile) return try deserializeFormFile(alloc, val);
|
||||
|
||||
if (val.filename != null) return error.FilenameProvidedForNonFile;
|
||||
return try util.serialize.deserializeString(alloc, T, val.value);
|
||||
}
|
||||
|
||||
fn deserializeFormFile(alloc: std.mem.Allocator, val: MultipartFormField) !FormFile {
|
||||
const data = try util.deepClone(alloc, val.value);
|
||||
errdefer util.deepFree(alloc, data);
|
||||
const filename = try util.deepClone(alloc, val.filename orelse "(untitled)");
|
||||
errdefer util.deepFree(alloc, filename);
|
||||
const content_type = try util.deepClone(alloc, val.content_type orelse "application/octet-stream");
|
||||
return FormFile{
|
||||
.data = data,
|
||||
.filename = filename,
|
||||
.content_type = content_type,
|
||||
};
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn parseFormData(comptime T: type, boundary: []const u8, reader: anytype, alloc: std.mem.Allocator) !T {
|
||||
var form = openForm(try openMultipart(boundary, reader));
|
||||
|
||||
var ds = Deserializer(T){};
|
||||
defer {
|
||||
var iter = ds.iterator();
|
||||
while (iter.next()) |pair| {
|
||||
util.deepFree(alloc, pair.value);
|
||||
}
|
||||
}
|
||||
while (true) {
|
||||
var part = (try form.next(alloc)) orelse break;
|
||||
errdefer util.deepFree(alloc, part);
|
||||
|
||||
try ds.setSerializedField(part.name, part);
|
||||
}
|
||||
|
||||
return try ds.finish(alloc);
|
||||
}
|
||||
|
||||
// TODO: Fix these tests
|
||||
test "MultipartStream" {
|
||||
const ExpectedPart = struct {
|
||||
disposition: []const u8,
|
||||
value: []const u8,
|
||||
};
|
||||
const testCase = struct {
|
||||
fn case(comptime body: []const u8, boundary: []const u8, expected_parts: []const ExpectedPart) !void {
|
||||
var src = std.io.StreamSource{
|
||||
.const_buffer = std.io.fixedBufferStream(body),
|
||||
};
|
||||
|
||||
var stream = try openMultipart(boundary, src.reader());
|
||||
|
||||
for (expected_parts) |expected| {
|
||||
var part = try stream.next(std.testing.allocator) orelse return error.TestExpectedEqual;
|
||||
defer part.close();
|
||||
|
||||
const dispo = part.fields.get("Content-Disposition") orelse return error.TestExpectedEqual;
|
||||
try std.testing.expectEqualStrings(expected.disposition, dispo);
|
||||
|
||||
var buf: [128]u8 = undefined;
|
||||
const count = try part.reader().read(&buf);
|
||||
try std.testing.expectEqualStrings(expected.value, buf[0..count]);
|
||||
}
|
||||
|
||||
try std.testing.expect(try stream.next(std.testing.allocator) == null);
|
||||
}
|
||||
}.case;
|
||||
|
||||
try testCase("--abc--\r\n", "abc", &.{});
|
||||
try testCase(
|
||||
util.comptimeToCrlf(
|
||||
\\------abcd
|
||||
\\Content-Disposition: form-data; name=first; charset=utf8
|
||||
\\
|
||||
\\content
|
||||
\\------abcd
|
||||
\\content-Disposition: form-data; name=second
|
||||
\\
|
||||
\\no content
|
||||
\\------abcd
|
||||
\\content-disposition: form-data; name=third
|
||||
\\
|
||||
\\
|
||||
\\------abcd--
|
||||
\\
|
||||
),
|
||||
"----abcd",
|
||||
&.{
|
||||
.{ .disposition = "form-data; name=first; charset=utf8", .value = "content" },
|
||||
.{ .disposition = "form-data; name=second", .value = "no content" },
|
||||
.{ .disposition = "form-data; name=third", .value = "" },
|
||||
},
|
||||
);
|
||||
|
||||
try testCase(
|
||||
util.comptimeToCrlf(
|
||||
\\--xyz
|
||||
\\Content-Disposition: uhh
|
||||
\\
|
||||
\\xyz
|
||||
\\--xyz
|
||||
\\Content-disposition: ok
|
||||
\\
|
||||
\\ --xyz
|
||||
\\--xyz--
|
||||
\\
|
||||
),
|
||||
"xyz",
|
||||
&.{
|
||||
.{ .disposition = "uhh", .value = "xyz" },
|
||||
.{ .disposition = "ok", .value = " --xyz" },
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
test "MultipartForm" {
|
||||
const testCase = struct {
|
||||
fn case(comptime body: []const u8, boundary: []const u8, expected_parts: []const MultipartFormField) !void {
|
||||
var src = std.io.StreamSource{
|
||||
.const_buffer = std.io.fixedBufferStream(body),
|
||||
};
|
||||
|
||||
var form = openForm(try openMultipart(boundary, src.reader()));
|
||||
|
||||
for (expected_parts) |expected| {
|
||||
var data = try form.next(std.testing.allocator) orelse return error.TestExpectedEqual;
|
||||
defer util.deepFree(std.testing.allocator, data);
|
||||
|
||||
try util.testing.expectDeepEqual(expected, data);
|
||||
}
|
||||
|
||||
try std.testing.expect(try form.next(std.testing.allocator) == null);
|
||||
}
|
||||
}.case;
|
||||
|
||||
try testCase(
|
||||
util.comptimeToCrlf(
|
||||
\\--abcd
|
||||
\\Content-Disposition: form-data; name=foo
|
||||
\\
|
||||
\\content
|
||||
\\--abcd--
|
||||
\\
|
||||
),
|
||||
"abcd",
|
||||
&.{.{ .name = "foo", .value = "content" }},
|
||||
);
|
||||
try testCase(
|
||||
util.comptimeToCrlf(
|
||||
\\--abcd
|
||||
\\Content-Disposition: form-data; name=foo
|
||||
\\
|
||||
\\content
|
||||
\\--abcd
|
||||
\\Content-Disposition: form-data; name=bar
|
||||
\\Content-Type: blah
|
||||
\\
|
||||
\\abcd
|
||||
\\--abcd
|
||||
\\Content-Disposition: form-data; name=baz; filename="myfile.txt"
|
||||
\\Content-Type: text/plain
|
||||
\\
|
||||
\\ --abcd
|
||||
\\
|
||||
\\--abcd--
|
||||
\\
|
||||
),
|
||||
"abcd",
|
||||
&.{
|
||||
.{ .name = "foo", .value = "content" },
|
||||
.{ .name = "bar", .value = "abcd", .content_type = "blah" },
|
||||
.{
|
||||
.name = "baz",
|
||||
.value = " --abcd\r\n",
|
||||
.content_type = "text/plain",
|
||||
.filename = "myfile.txt",
|
||||
},
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
test "parseFormData" {
|
||||
const body = util.comptimeToCrlf(
|
||||
\\--abcd
|
||||
\\Content-Disposition: form-data; name=foo
|
||||
\\
|
||||
\\content
|
||||
\\--abcd--
|
||||
\\
|
||||
);
|
||||
var src = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) };
|
||||
const val = try parseFormData(struct {
|
||||
foo: []const u8,
|
||||
}, "abcd", src.reader(), std.testing.allocator);
|
||||
util.deepFree(std.testing.allocator, val);
|
||||
}
|
|
@ -93,7 +93,7 @@ fn parseProto(reader: anytype) !http.Protocol {
|
|||
};
|
||||
}
|
||||
|
||||
fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields {
|
||||
pub fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields {
|
||||
var headers = Fields.init(allocator);
|
||||
|
||||
var buf: [4096]u8 = undefined;
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
const std = @import("std");
|
||||
const util = @import("util");
|
||||
const parser = @import("./parser.zig");
|
||||
const http = @import("../lib.zig");
|
||||
const t = std.testing;
|
||||
|
@ -30,30 +31,9 @@ const test_case = struct {
|
|||
}
|
||||
};
|
||||
|
||||
fn toCrlf(comptime str: []const u8) []const u8 {
|
||||
comptime {
|
||||
var buf: [str.len * 2]u8 = undefined;
|
||||
|
||||
@setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2
|
||||
|
||||
var buf_len: usize = 0;
|
||||
for (str) |ch| {
|
||||
if (ch == '\n') {
|
||||
buf[buf_len] = '\r';
|
||||
buf_len += 1;
|
||||
}
|
||||
|
||||
buf[buf_len] = ch;
|
||||
buf_len += 1;
|
||||
}
|
||||
|
||||
return buf[0..buf_len];
|
||||
}
|
||||
}
|
||||
|
||||
test "HTTP/1.x parse - No body" {
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
util.comptimeToCrlf(
|
||||
\\GET / HTTP/1.1
|
||||
\\
|
||||
\\
|
||||
|
@ -65,7 +45,7 @@ test "HTTP/1.x parse - No body" {
|
|||
},
|
||||
);
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
util.comptimeToCrlf(
|
||||
\\POST / HTTP/1.1
|
||||
\\
|
||||
\\
|
||||
|
@ -77,7 +57,7 @@ test "HTTP/1.x parse - No body" {
|
|||
},
|
||||
);
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
util.comptimeToCrlf(
|
||||
\\GET /url/abcd HTTP/1.1
|
||||
\\
|
||||
\\
|
||||
|
@ -89,7 +69,7 @@ test "HTTP/1.x parse - No body" {
|
|||
},
|
||||
);
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
util.comptimeToCrlf(
|
||||
\\GET / HTTP/1.0
|
||||
\\
|
||||
\\
|
||||
|
@ -101,7 +81,7 @@ test "HTTP/1.x parse - No body" {
|
|||
},
|
||||
);
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
util.comptimeToCrlf(
|
||||
\\GET /url/abcd HTTP/1.1
|
||||
\\Content-Type: application/json
|
||||
\\
|
||||
|
@ -115,7 +95,7 @@ test "HTTP/1.x parse - No body" {
|
|||
},
|
||||
);
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
util.comptimeToCrlf(
|
||||
\\GET /url/abcd HTTP/1.1
|
||||
\\Content-Type: application/json
|
||||
\\Authorization: bearer <token>
|
||||
|
@ -163,7 +143,7 @@ test "HTTP/1.x parse - No body" {
|
|||
},
|
||||
);
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
util.comptimeToCrlf(
|
||||
\\GET / HTTP/1.2
|
||||
\\
|
||||
\\
|
||||
|
@ -265,7 +245,7 @@ test "HTTP/1.x parse - bad requests" {
|
|||
|
||||
test "HTTP/1.x parse - Headers" {
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
util.comptimeToCrlf(
|
||||
\\GET /url/abcd HTTP/1.1
|
||||
\\Content-Type: application/json
|
||||
\\Content-Type: application/xml
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
const std = @import("std");
|
||||
const util = @import("util");
|
||||
const http = @import("../lib.zig");
|
||||
|
||||
const Status = http.Status;
|
||||
|
@ -169,25 +170,7 @@ test {
|
|||
_ = _tests;
|
||||
}
|
||||
const _tests = struct {
|
||||
fn toCrlf(comptime str: []const u8) []const u8 {
|
||||
comptime {
|
||||
var buf: [str.len * 2]u8 = undefined;
|
||||
@setEvalBranchQuota(@as(u32, str.len * 2));
|
||||
|
||||
var len: usize = 0;
|
||||
for (str) |ch| {
|
||||
if (ch == '\n') {
|
||||
buf[len] = '\r';
|
||||
len += 1;
|
||||
}
|
||||
|
||||
buf[len] = ch;
|
||||
len += 1;
|
||||
}
|
||||
|
||||
return buf[0..len];
|
||||
}
|
||||
}
|
||||
const toCrlf = util.comptimeToCrlf;
|
||||
|
||||
const test_buffer_size = chunk_size * 4;
|
||||
test "ResponseStream no headers empty body" {
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
test {
|
||||
_ = @import("./request/test_parser.zig");
|
||||
_ = @import("./middleware.zig");
|
||||
_ = @import("./query.zig");
|
||||
}
|
|
@ -1,7 +1,38 @@
|
|||
const std = @import("std");
|
||||
const util = @import("util");
|
||||
|
||||
const QueryIter = util.QueryIter;
|
||||
pub const Iter = struct {
|
||||
const Pair = struct {
|
||||
key: []const u8,
|
||||
value: ?[]const u8,
|
||||
};
|
||||
|
||||
iter: std.mem.SplitIterator(u8),
|
||||
|
||||
pub fn from(q: []const u8) Iter {
|
||||
return Iter{
|
||||
.iter = std.mem.split(u8, std.mem.trimLeft(u8, q, "?"), "&"),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn next(self: *Iter) ?Pair {
|
||||
while (true) {
|
||||
const part = self.iter.next() orelse return null;
|
||||
if (part.len == 0) continue;
|
||||
|
||||
const key = std.mem.sliceTo(part, '=');
|
||||
if (key.len == part.len) return Pair{
|
||||
.key = key,
|
||||
.value = null,
|
||||
};
|
||||
|
||||
return Pair{
|
||||
.key = key,
|
||||
.value = part[key.len + 1 ..],
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Parses a set of query parameters described by the struct `T`.
|
||||
///
|
||||
|
@ -67,25 +98,44 @@ const QueryIter = util.QueryIter;
|
|||
/// Would be used to parse a query string like
|
||||
/// `?foo.baz=12345`
|
||||
///
|
||||
pub fn parseQuery(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T {
|
||||
if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct");
|
||||
var iter = QueryIter.from(query);
|
||||
pub fn parse(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T {
|
||||
var iter = Iter.from(query);
|
||||
|
||||
var deserializer = Deserializer(T){};
|
||||
|
||||
var fields = Intermediary(T){};
|
||||
while (iter.next()) |pair| {
|
||||
// TODO: Hash map
|
||||
inline for (std.meta.fields(Intermediary(T))) |field| {
|
||||
if (std.ascii.eqlIgnoreCase(field.name[2..], pair.key)) {
|
||||
@field(fields, field.name) = if (pair.value) |v| .{ .value = v } else .{ .no_value = {} };
|
||||
break;
|
||||
}
|
||||
} else std.log.debug("unknown param {s}", .{pair.key});
|
||||
try deserializer.setSerializedField(pair.key, pair.value);
|
||||
}
|
||||
|
||||
return (try parse(alloc, T, "", "", fields)) orelse error.NoQuery;
|
||||
return try deserializer.finish(alloc);
|
||||
}
|
||||
|
||||
pub fn parseQueryFree(alloc: std.mem.Allocator, val: anytype) void {
|
||||
fn Deserializer(comptime Result: type) type {
|
||||
return util.DeserializerContext(Result, ?[]const u8, struct {
|
||||
pub const options = util.serialize.default_options;
|
||||
pub fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, maybe_val: ?[]const u8) !T {
|
||||
const is_optional = comptime std.meta.trait.is(.Optional)(T);
|
||||
if (maybe_val) |val| {
|
||||
if (val.len == 0 and is_optional) return null;
|
||||
|
||||
const decoded = try decodeString(alloc, val);
|
||||
defer alloc.free(decoded);
|
||||
|
||||
return try util.serialize.deserializeString(alloc, T, decoded);
|
||||
} else {
|
||||
// If param is present, but without an associated value
|
||||
return if (is_optional)
|
||||
null
|
||||
else if (T == bool)
|
||||
true
|
||||
else
|
||||
error.InvalidValue;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn parseFree(alloc: std.mem.Allocator, val: anytype) void {
|
||||
util.deepFree(alloc, val);
|
||||
}
|
||||
|
||||
|
@ -110,186 +160,6 @@ fn decodeString(alloc: std.mem.Allocator, val: []const u8) ![]u8 {
|
|||
return list.toOwnedSlice();
|
||||
}
|
||||
|
||||
fn parseScalar(alloc: std.mem.Allocator, comptime T: type, comptime name: []const u8, fields: anytype) !?T {
|
||||
const param = @field(fields, name);
|
||||
return switch (param) {
|
||||
.not_specified => null,
|
||||
.no_value => try parseQueryValue(alloc, T, null),
|
||||
.value => |v| try parseQueryValue(alloc, T, v),
|
||||
};
|
||||
}
|
||||
|
||||
fn parse(
|
||||
alloc: std.mem.Allocator,
|
||||
comptime T: type,
|
||||
comptime prefix: []const u8,
|
||||
comptime name: []const u8,
|
||||
fields: anytype,
|
||||
) !?T {
|
||||
if (comptime isScalar(T)) return parseScalar(alloc, T, prefix ++ "." ++ name, fields);
|
||||
switch (@typeInfo(T)) {
|
||||
.Union => |info| {
|
||||
var result: ?T = null;
|
||||
inline for (info.fields) |field| {
|
||||
const F = field.field_type;
|
||||
|
||||
const maybe_value = try parse(alloc, F, prefix, field.name, fields);
|
||||
if (maybe_value) |value| {
|
||||
if (result != null) return error.DuplicateUnionField;
|
||||
|
||||
result = @unionInit(T, field.name, value);
|
||||
}
|
||||
}
|
||||
std.log.debug("{any}", .{result});
|
||||
return result;
|
||||
},
|
||||
|
||||
.Struct => |info| {
|
||||
var result: T = undefined;
|
||||
var fields_specified: usize = 0;
|
||||
errdefer inline for (info.fields) |field, i| {
|
||||
if (fields_specified < i) util.deepFree(alloc, @field(result, field.name));
|
||||
};
|
||||
|
||||
inline for (info.fields) |field| {
|
||||
const F = field.field_type;
|
||||
|
||||
var maybe_value: ?F = null;
|
||||
if (try parse(alloc, F, prefix ++ "." ++ name, field.name, fields)) |v| {
|
||||
maybe_value = v;
|
||||
} else if (field.default_value) |default| {
|
||||
if (comptime @sizeOf(F) != 0) {
|
||||
maybe_value = try util.deepClone(alloc, @ptrCast(*const F, @alignCast(@alignOf(F), default)).*);
|
||||
} else {
|
||||
maybe_value = std.mem.zeroes(F);
|
||||
}
|
||||
}
|
||||
|
||||
if (maybe_value) |v| {
|
||||
fields_specified += 1;
|
||||
@field(result, field.name) = v;
|
||||
}
|
||||
}
|
||||
|
||||
if (fields_specified == 0) {
|
||||
return null;
|
||||
} else if (fields_specified != info.fields.len) {
|
||||
std.log.debug("{} {s} {s}", .{ T, prefix, name });
|
||||
return error.PartiallySpecifiedStruct;
|
||||
} else {
|
||||
return result;
|
||||
}
|
||||
},
|
||||
|
||||
// Only applies to non-scalar optionals
|
||||
.Optional => |info| return try parse(alloc, info.child, prefix, name, fields),
|
||||
|
||||
else => @compileError("tmp"),
|
||||
}
|
||||
}
|
||||
|
||||
fn recursiveFieldPaths(comptime T: type, comptime prefix: []const u8) []const []const u8 {
|
||||
comptime {
|
||||
if (std.meta.trait.is(.Optional)(T)) return recursiveFieldPaths(std.meta.Child(T), prefix);
|
||||
|
||||
var fields: []const []const u8 = &.{};
|
||||
|
||||
for (std.meta.fields(T)) |f| {
|
||||
const full_name = prefix ++ f.name;
|
||||
|
||||
if (isScalar(f.field_type)) {
|
||||
fields = fields ++ @as([]const []const u8, &.{full_name});
|
||||
} else {
|
||||
const field_prefix = if (@typeInfo(f.field_type) == .Union) prefix else full_name ++ ".";
|
||||
fields = fields ++ recursiveFieldPaths(f.field_type, field_prefix);
|
||||
}
|
||||
}
|
||||
|
||||
return fields;
|
||||
}
|
||||
}
|
||||
|
||||
const QueryParam = union(enum) {
|
||||
not_specified: void,
|
||||
no_value: void,
|
||||
value: []const u8,
|
||||
};
|
||||
|
||||
fn Intermediary(comptime T: type) type {
|
||||
const field_names = recursiveFieldPaths(T, "..");
|
||||
|
||||
var fields: [field_names.len]std.builtin.Type.StructField = undefined;
|
||||
for (field_names) |name, i| fields[i] = .{
|
||||
.name = name,
|
||||
.field_type = QueryParam,
|
||||
.default_value = &QueryParam{ .not_specified = {} },
|
||||
.is_comptime = false,
|
||||
.alignment = @alignOf(QueryParam),
|
||||
};
|
||||
|
||||
return @Type(.{ .Struct = .{
|
||||
.layout = .Auto,
|
||||
.fields = &fields,
|
||||
.decls = &.{},
|
||||
.is_tuple = false,
|
||||
} });
|
||||
}
|
||||
|
||||
fn parseQueryValue(alloc: std.mem.Allocator, comptime T: type, maybe_value: ?[]const u8) !T {
|
||||
const is_optional = comptime std.meta.trait.is(.Optional)(T);
|
||||
if (maybe_value) |value| {
|
||||
const Eff = if (is_optional) std.meta.Child(T) else T;
|
||||
|
||||
if (value.len == 0 and is_optional) return null;
|
||||
|
||||
const decoded = try decodeString(alloc, value);
|
||||
errdefer alloc.free(decoded);
|
||||
|
||||
if (comptime std.meta.trait.isZigString(Eff)) return decoded;
|
||||
|
||||
defer alloc.free(decoded);
|
||||
|
||||
const result = if (comptime std.meta.trait.isIntegral(Eff))
|
||||
try std.fmt.parseInt(Eff, decoded, 0)
|
||||
else if (comptime std.meta.trait.isFloat(Eff))
|
||||
try std.fmt.parseFloat(Eff, decoded)
|
||||
else if (comptime std.meta.trait.is(.Enum)(Eff)) blk: {
|
||||
_ = std.ascii.lowerString(decoded, decoded);
|
||||
break :blk std.meta.stringToEnum(Eff, decoded) orelse return error.InvalidEnumValue;
|
||||
} else if (Eff == bool) blk: {
|
||||
_ = std.ascii.lowerString(decoded, decoded);
|
||||
break :blk bool_map.get(decoded) orelse return error.InvalidBool;
|
||||
} else if (comptime std.meta.trait.hasFn("parse")(Eff))
|
||||
try Eff.parse(value)
|
||||
else
|
||||
@compileError("Invalid type " ++ @typeName(T));
|
||||
|
||||
return result;
|
||||
} else {
|
||||
// If param is present, but without an associated value
|
||||
return if (is_optional)
|
||||
null
|
||||
else if (T == bool)
|
||||
true
|
||||
else
|
||||
error.InvalidValue;
|
||||
}
|
||||
}
|
||||
|
||||
const bool_map = std.ComptimeStringMap(bool, .{
|
||||
.{ "true", true },
|
||||
.{ "t", true },
|
||||
.{ "yes", true },
|
||||
.{ "y", true },
|
||||
.{ "1", true },
|
||||
|
||||
.{ "false", false },
|
||||
.{ "f", false },
|
||||
.{ "no", false },
|
||||
.{ "n", false },
|
||||
.{ "0", false },
|
||||
});
|
||||
|
||||
fn isScalar(comptime T: type) bool {
|
||||
if (comptime std.meta.trait.isZigString(T)) return true;
|
||||
if (comptime std.meta.trait.isIntegral(T)) return true;
|
||||
|
@ -304,7 +174,7 @@ fn isScalar(comptime T: type) bool {
|
|||
return false;
|
||||
}
|
||||
|
||||
pub fn QueryStringify(comptime Params: type) type {
|
||||
pub fn EncodeStruct(comptime Params: type) type {
|
||||
return struct {
|
||||
params: Params,
|
||||
pub fn format(v: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
|
||||
|
@ -312,8 +182,8 @@ pub fn QueryStringify(comptime Params: type) type {
|
|||
}
|
||||
};
|
||||
}
|
||||
pub fn queryStringify(val: anytype) QueryStringify(@TypeOf(val)) {
|
||||
return QueryStringify(@TypeOf(val)){ .params = val };
|
||||
pub fn encodeStruct(val: anytype) EncodeStruct(@TypeOf(val)) {
|
||||
return EncodeStruct(@TypeOf(val)){ .params = val };
|
||||
}
|
||||
|
||||
fn urlFormatString(writer: anytype, val: []const u8) !void {
|
||||
|
@ -375,11 +245,11 @@ fn formatQuery(comptime prefix: []const u8, comptime name: []const u8, params: a
|
|||
}
|
||||
}
|
||||
|
||||
test "parseQuery" {
|
||||
test "parse" {
|
||||
const testCase = struct {
|
||||
fn case(comptime T: type, expected: T, query_string: []const u8) !void {
|
||||
const result = try parseQuery(std.testing.allocator, T, query_string);
|
||||
defer parseQueryFree(std.testing.allocator, result);
|
||||
const result = try parse(std.testing.allocator, T, query_string);
|
||||
defer parseFree(std.testing.allocator, result);
|
||||
try util.testing.expectDeepEqual(expected, result);
|
||||
}
|
||||
}.case;
|
||||
|
@ -465,14 +335,46 @@ test "parseQuery" {
|
|||
try testCase(SubUnion2, .{ .sub = .{ .foo = 1, .val = .{ .baz = "abc" } } }, "sub.foo=1&sub.baz=abc");
|
||||
}
|
||||
|
||||
test "formatQuery" {
|
||||
try std.testing.expectFmt("", "{}", .{queryStringify(.{})});
|
||||
try std.testing.expectFmt("id=3&", "{}", .{queryStringify(.{ .id = 3 })});
|
||||
try std.testing.expectFmt("id=3&id2=4&", "{}", .{queryStringify(.{ .id = 3, .id2 = 4 })});
|
||||
test "encodeStruct" {
|
||||
try std.testing.expectFmt("", "{}", .{encodeStruct(.{})});
|
||||
try std.testing.expectFmt("id=3&", "{}", .{encodeStruct(.{ .id = 3 })});
|
||||
try std.testing.expectFmt("id=3&id2=4&", "{}", .{encodeStruct(.{ .id = 3, .id2 = 4 })});
|
||||
|
||||
try std.testing.expectFmt("str=foo&", "{}", .{queryStringify(.{ .str = "foo" })});
|
||||
try std.testing.expectFmt("enum_str=foo&", "{}", .{queryStringify(.{ .enum_str = .foo })});
|
||||
try std.testing.expectFmt("str=foo&", "{}", .{encodeStruct(.{ .str = "foo" })});
|
||||
try std.testing.expectFmt("enum_str=foo&", "{}", .{encodeStruct(.{ .enum_str = .foo })});
|
||||
|
||||
try std.testing.expectFmt("boolean=false&", "{}", .{queryStringify(.{ .boolean = false })});
|
||||
try std.testing.expectFmt("boolean=true&", "{}", .{queryStringify(.{ .boolean = true })});
|
||||
try std.testing.expectFmt("boolean=false&", "{}", .{encodeStruct(.{ .boolean = false })});
|
||||
try std.testing.expectFmt("boolean=true&", "{}", .{encodeStruct(.{ .boolean = true })});
|
||||
}
|
||||
|
||||
test "Iter" {
|
||||
const testCase = struct {
|
||||
fn case(str: []const u8, pairs: []const Iter.Pair) !void {
|
||||
var iter = Iter.from(str);
|
||||
for (pairs) |pair| {
|
||||
try util.testing.expectDeepEqual(@as(?Iter.Pair, pair), iter.next());
|
||||
}
|
||||
try std.testing.expect(iter.next() == null);
|
||||
}
|
||||
}.case;
|
||||
|
||||
try testCase("", &.{});
|
||||
try testCase("abc", &.{.{ .key = "abc", .value = null }});
|
||||
try testCase("abc=", &.{.{ .key = "abc", .value = "" }});
|
||||
try testCase("abc=def", &.{.{ .key = "abc", .value = "def" }});
|
||||
try testCase("abc=def&", &.{.{ .key = "abc", .value = "def" }});
|
||||
try testCase("?abc=def&", &.{.{ .key = "abc", .value = "def" }});
|
||||
try testCase("?abc=def&foo&bar=baz&qux=", &.{
|
||||
.{ .key = "abc", .value = "def" },
|
||||
.{ .key = "foo", .value = null },
|
||||
.{ .key = "bar", .value = "baz" },
|
||||
.{ .key = "qux", .value = "" },
|
||||
});
|
||||
try testCase("?abc=def&&foo&bar=baz&&qux=&", &.{
|
||||
.{ .key = "abc", .value = "def" },
|
||||
.{ .key = "foo", .value = null },
|
||||
.{ .key = "bar", .value = "baz" },
|
||||
.{ .key = "qux", .value = "" },
|
||||
});
|
||||
try testCase("&=def&", &.{.{ .key = "", .value = "def" }});
|
||||
}
|
|
@ -267,13 +267,13 @@ pub const helpers = struct {
|
|||
try std.fmt.format(
|
||||
writer,
|
||||
"<{s}://{s}/{s}?{}>; rel=\"{s}\"",
|
||||
.{ @tagName(c.scheme), c.host, path, http.queryStringify(params), rel },
|
||||
.{ @tagName(c.scheme), c.host, path, http.urlencode.encodeStruct(params), rel },
|
||||
);
|
||||
} else {
|
||||
try std.fmt.format(
|
||||
writer,
|
||||
"<{s}?{}>; rel=\"{s}\"",
|
||||
.{ path, http.queryStringify(params), rel },
|
||||
.{ path, http.urlencode.encodeStruct(params), rel },
|
||||
);
|
||||
}
|
||||
// TODO: percent-encode
|
||||
|
|
|
@ -2,6 +2,7 @@ const controllers = @import("../controllers.zig");
|
|||
|
||||
const auth = @import("./api/auth.zig");
|
||||
const communities = @import("./api/communities.zig");
|
||||
const drive = @import("./api/drive.zig");
|
||||
const invites = @import("./api/invites.zig");
|
||||
const users = @import("./api/users.zig");
|
||||
const follows = @import("./api/users/follows.zig");
|
||||
|
@ -26,4 +27,6 @@ pub const routes = .{
|
|||
controllers.apiEndpoint(follows.delete),
|
||||
controllers.apiEndpoint(follows.query_followers),
|
||||
controllers.apiEndpoint(follows.query_following),
|
||||
controllers.apiEndpoint(drive.upload),
|
||||
controllers.apiEndpoint(drive.mkdir),
|
||||
};
|
||||
|
|
144
src/main/controllers/api/drive.zig
Normal file
144
src/main/controllers/api/drive.zig
Normal file
|
@ -0,0 +1,144 @@
|
|||
const api = @import("api");
|
||||
const http = @import("http");
|
||||
const util = @import("util");
|
||||
const controller_utils = @import("../../controllers.zig").helpers;
|
||||
|
||||
const Uuid = util.Uuid;
|
||||
const DateTime = util.DateTime;
|
||||
|
||||
pub const drive_path = "/drive/:path*";
|
||||
pub const DriveArgs = struct {
|
||||
path: []const u8,
|
||||
};
|
||||
|
||||
pub const query = struct {
|
||||
pub const method = .GET;
|
||||
pub const path = drive_path;
|
||||
pub const Args = DriveArgs;
|
||||
|
||||
pub const Query = struct {
|
||||
const OrderBy = enum {
|
||||
created_at,
|
||||
filename,
|
||||
};
|
||||
|
||||
max_items: usize = 20,
|
||||
|
||||
like: ?[]const u8 = null,
|
||||
|
||||
order_by: OrderBy = .created_at,
|
||||
direction: api.Direction = .descending,
|
||||
|
||||
prev: ?struct {
|
||||
id: Uuid,
|
||||
order_val: union(OrderBy) {
|
||||
created_at: DateTime,
|
||||
filename: []const u8,
|
||||
},
|
||||
} = null,
|
||||
|
||||
page_direction: api.PageDirection = .forward,
|
||||
};
|
||||
|
||||
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
|
||||
const result = srv.driveQuery(req.args.path, req.query) catch |err| switch (err) {
|
||||
error.NotADirectory => {
|
||||
const meta = try srv.getFile(path);
|
||||
try res.json(.ok, meta);
|
||||
return;
|
||||
},
|
||||
else => |e| return e,
|
||||
};
|
||||
|
||||
try controller_utils.paginate(result, res, req.allocator);
|
||||
}
|
||||
};
|
||||
|
||||
pub const upload = struct {
|
||||
pub const method = .POST;
|
||||
pub const path = drive_path;
|
||||
pub const Args = DriveArgs;
|
||||
|
||||
pub const Body = struct {
|
||||
file: http.FormFile,
|
||||
description: ?[]const u8 = null,
|
||||
sensitive: bool = false,
|
||||
};
|
||||
|
||||
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
|
||||
const f = req.body.file;
|
||||
try srv.uploadFile(.{
|
||||
.dir = req.args.path,
|
||||
.filename = f.filename,
|
||||
.description = req.body.description,
|
||||
.content_type = f.content_type,
|
||||
.sensitive = req.body.sensitive,
|
||||
}, f.data);
|
||||
|
||||
// TODO: print meta
|
||||
try res.json(.created, .{});
|
||||
}
|
||||
};
|
||||
|
||||
pub const delete = struct {
|
||||
pub const method = .DELETE;
|
||||
pub const path = drive_path;
|
||||
pub const Args = DriveArgs;
|
||||
|
||||
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
|
||||
const info = try srv.driveLookup(req.args.path);
|
||||
if (info == .dir)
|
||||
try srv.driveRmdir(req.args.path)
|
||||
else if (info == .file)
|
||||
try srv.deleteFile(req.args.path);
|
||||
|
||||
return res.json(.ok, .{});
|
||||
}
|
||||
};
|
||||
|
||||
pub const mkdir = struct {
|
||||
pub const method = .MKCOL;
|
||||
pub const path = drive_path;
|
||||
pub const Args = DriveArgs;
|
||||
|
||||
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
|
||||
try srv.driveMkdir(req.args.path);
|
||||
|
||||
return res.json(.created, .{});
|
||||
}
|
||||
};
|
||||
|
||||
pub const update = struct {
|
||||
pub const method = .PUT;
|
||||
pub const path = drive_path;
|
||||
pub const Args = DriveArgs;
|
||||
|
||||
pub const Body = struct {
|
||||
description: ?[]const u8 = null,
|
||||
content_type: ?[]const u8 = null,
|
||||
sensitive: ?bool = null,
|
||||
};
|
||||
|
||||
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
|
||||
const info = try srv.driveLookup(req.args.path);
|
||||
if (info != .file) return error.NotFile;
|
||||
|
||||
const new_info = try srv.updateFile(path, req.body);
|
||||
try res.json(.ok, new_info);
|
||||
}
|
||||
};
|
||||
|
||||
pub const move = struct {
|
||||
pub const method = .MOVE;
|
||||
pub const path = drive_path;
|
||||
pub const Args = DriveArgs;
|
||||
|
||||
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
|
||||
const destination = req.fields.get("Destination") orelse return error.NoDestination;
|
||||
|
||||
try srv.driveMove(req.args.path, destination);
|
||||
|
||||
try res.fields.put("Location", destination);
|
||||
try srv.json(.created, .{});
|
||||
}
|
||||
};
|
|
@ -19,8 +19,9 @@ fn execScript(db: anytype, script: []const u8, alloc: std.mem.Allocator) !void {
|
|||
const tx = try db.beginOrSavepoint();
|
||||
errdefer tx.rollback();
|
||||
|
||||
var iter = util.SqlStmtIter.from(script);
|
||||
var iter = std.mem.split(u8, script, ";");
|
||||
while (iter.next()) |stmt| {
|
||||
if (stmt.len == 0) continue;
|
||||
try execStmt(tx, stmt, alloc);
|
||||
}
|
||||
|
||||
|
@ -208,22 +209,139 @@ const migrations: []const Migration = &.{
|
|||
.{
|
||||
.name = "files",
|
||||
.up =
|
||||
\\CREATE TABLE drive_file(
|
||||
\\CREATE TABLE file_upload(
|
||||
\\ id UUID NOT NULL PRIMARY KEY,
|
||||
\\
|
||||
\\ filename TEXT NOT NULL,
|
||||
\\ account_owner_id UUID REFERENCES account(id),
|
||||
\\ community_owner_id UUID REFERENCES community(id),
|
||||
\\ created_by UUID REFERENCES account(id),
|
||||
\\ size INTEGER NOT NULL,
|
||||
\\
|
||||
\\ filename TEXT NOT NULL,
|
||||
\\ description TEXT,
|
||||
\\ content_type TEXT,
|
||||
\\ sensitive BOOLEAN NOT NULL,
|
||||
\\
|
||||
\\ is_deleted BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
\\
|
||||
\\ created_at TIMESTAMPTZ NOT NULL,
|
||||
\\ updated_at TIMESTAMPTZ NOT NULL
|
||||
\\);
|
||||
\\
|
||||
\\CREATE TABLE drive_entry(
|
||||
\\ id UUID NOT NULL PRIMARY KEY,
|
||||
\\
|
||||
\\ account_owner_id UUID REFERENCES account(id),
|
||||
\\ community_owner_id UUID REFERENCES community(id),
|
||||
\\
|
||||
\\ name TEXT,
|
||||
\\ parent_directory_id UUID REFERENCES drive_entry(id),
|
||||
\\
|
||||
\\ file_id UUID REFERENCES file_upload(id),
|
||||
\\
|
||||
\\ CHECK(
|
||||
\\ (account_owner_id IS NULL AND community_owner_id IS NOT NULL)
|
||||
\\ OR (account_owner_id IS NOT NULL AND community_owner_id IS NULL)
|
||||
\\ ),
|
||||
\\ CHECK(
|
||||
\\ (name IS NULL AND parent_directory_id IS NULL AND file_id IS NULL)
|
||||
\\ OR (name IS NOT NULL AND parent_directory_id IS NOT NULL)
|
||||
\\ )
|
||||
\\);
|
||||
\\CREATE UNIQUE INDEX drive_entry_uniqueness
|
||||
\\ON drive_entry(
|
||||
\\ name,
|
||||
\\ COALESCE(parent_directory_id, ''),
|
||||
\\ COALESCE(account_owner_id, community_owner_id)
|
||||
\\);
|
||||
,
|
||||
.down = "DROP TABLE drive_file",
|
||||
.down =
|
||||
\\DROP INDEX drive_entry_uniqueness;
|
||||
\\DROP TABLE drive_entry;
|
||||
\\DROP TABLE file_upload;
|
||||
,
|
||||
},
|
||||
.{
|
||||
.name = "drive_entry_path",
|
||||
.up =
|
||||
\\CREATE VIEW drive_entry_path(
|
||||
\\ id,
|
||||
\\ path,
|
||||
\\ account_owner_id,
|
||||
\\ community_owner_id,
|
||||
\\ kind
|
||||
\\) AS WITH RECURSIVE full_path(
|
||||
\\ id,
|
||||
\\ path,
|
||||
\\ account_owner_id,
|
||||
\\ community_owner_id,
|
||||
\\ kind
|
||||
\\) AS (
|
||||
\\ SELECT
|
||||
\\ id,
|
||||
\\ '' AS path,
|
||||
\\ account_owner_id,
|
||||
\\ community_owner_id,
|
||||
\\ 'dir' AS kind
|
||||
\\ FROM drive_entry
|
||||
\\ WHERE parent_directory_id IS NULL
|
||||
\\ UNION ALL
|
||||
\\ SELECT
|
||||
\\ base.id,
|
||||
\\ (dir.path || '/' || base.name) AS path,
|
||||
\\ base.account_owner_id,
|
||||
\\ base.community_owner_id,
|
||||
\\ (CASE WHEN base.file_id IS NULL THEN 'dir' ELSE 'file' END) as kind
|
||||
\\ FROM drive_entry AS base
|
||||
\\ JOIN full_path AS dir ON
|
||||
\\ base.parent_directory_id = dir.id
|
||||
\\ AND base.account_owner_id IS NOT DISTINCT FROM dir.account_owner_id
|
||||
\\ AND base.community_owner_id IS NOT DISTINCT FROM dir.community_owner_id
|
||||
\\)
|
||||
\\SELECT
|
||||
\\ id,
|
||||
\\ (CASE WHEN kind = 'dir' THEN path || '/' ELSE path END) AS path,
|
||||
\\ account_owner_id,
|
||||
\\ community_owner_id,
|
||||
\\ kind
|
||||
\\FROM full_path;
|
||||
,
|
||||
.down =
|
||||
\\DROP VIEW drive_entry_path;
|
||||
,
|
||||
},
|
||||
.{
|
||||
.name = "create drive root directories",
|
||||
.up =
|
||||
\\INSERT INTO drive_entry(
|
||||
\\ id,
|
||||
\\ account_owner_id,
|
||||
\\ community_owner_id,
|
||||
\\ parent_directory_id,
|
||||
\\ name,
|
||||
\\ file_id
|
||||
\\) SELECT
|
||||
\\ id,
|
||||
\\ id AS account_owner_id,
|
||||
\\ NULL AS community_owner_id,
|
||||
\\ NULL AS parent_directory_id,
|
||||
\\ NULL AS name,
|
||||
\\ NULL AS file_id
|
||||
\\FROM account;
|
||||
\\INSERT INTO drive_entry(
|
||||
\\ id,
|
||||
\\ account_owner_id,
|
||||
\\ community_owner_id,
|
||||
\\ parent_directory_id,
|
||||
\\ name,
|
||||
\\ file_id
|
||||
\\) SELECT
|
||||
\\ id,
|
||||
\\ NULL AS account_owner_id,
|
||||
\\ id AS community_owner_id,
|
||||
\\ NULL AS parent_directory_id,
|
||||
\\ NULL AS name,
|
||||
\\ NULL AS file_id
|
||||
\\FROM community;
|
||||
,
|
||||
.down = "",
|
||||
},
|
||||
};
|
||||
|
|
|
@ -88,7 +88,7 @@ pub fn prepareParamText(arena: *std.heap.ArenaAllocator, val: anytype) !?[:0]con
|
|||
else => |T| switch (@typeInfo(T)) {
|
||||
.Enum => return @tagName(val),
|
||||
.Optional => if (val) |v| try prepareParamText(arena, v) else null,
|
||||
.Int => try std.fmt.allocPrintZ(arena.allocator(), "{}", .{val}),
|
||||
.Bool, .Int => try std.fmt.allocPrintZ(arena.allocator(), "{}", .{val}),
|
||||
.Union => loop: inline for (std.meta.fields(T)) |field| {
|
||||
// Have to do this in a roundabout way to satisfy comptime checker
|
||||
const Tag = std.meta.Tag(T);
|
||||
|
|
|
@ -193,6 +193,7 @@ pub const Db = struct {
|
|||
.Null => return self.bindNull(stmt, idx),
|
||||
.Int => return self.bindInt(stmt, idx, std.math.cast(i64, val) orelse unreachable),
|
||||
.Float => return self.bindFloat(stmt, idx, val),
|
||||
.Bool => return self.bindInt(stmt, idx, if (val) 1 else 0),
|
||||
else => @compileError("Unable to serialize type " ++ @typeName(T)),
|
||||
}
|
||||
}
|
||||
|
@ -251,18 +252,20 @@ pub const Results = struct {
|
|||
db: *c.sqlite3,
|
||||
|
||||
pub fn finish(self: Results) void {
|
||||
switch (c.sqlite3_finalize(self.stmt)) {
|
||||
c.SQLITE_OK => {},
|
||||
else => |err| {
|
||||
handleUnexpectedError(self.db, err, self.getGeneratingSql()) catch {};
|
||||
},
|
||||
}
|
||||
_ = c.sqlite3_finalize(self.stmt);
|
||||
}
|
||||
|
||||
pub fn row(self: Results) common.RowError!?Row {
|
||||
return switch (c.sqlite3_step(self.stmt)) {
|
||||
c.SQLITE_ROW => Row{ .stmt = self.stmt, .db = self.db },
|
||||
c.SQLITE_DONE => null,
|
||||
|
||||
c.SQLITE_CONSTRAINT_UNIQUE => return error.UniqueViolation,
|
||||
c.SQLITE_CONSTRAINT_CHECK => return error.CheckViolation,
|
||||
c.SQLITE_CONSTRAINT_NOTNULL => return error.NotNullViolation,
|
||||
c.SQLITE_CONSTRAINT_FOREIGNKEY => return error.ForeignKeyViolation,
|
||||
c.SQLITE_CONSTRAINT => return error.ConstraintViolation,
|
||||
|
||||
else => |err| handleUnexpectedError(self.db, err, self.getGeneratingSql()),
|
||||
};
|
||||
}
|
||||
|
|
|
@ -144,42 +144,16 @@ fn fieldPtr(ptr: anytype, comptime names: []const []const u8) FieldPtr(@TypeOf(p
|
|||
return fieldPtr(&@field(ptr.*, names[0]), names[1..]);
|
||||
}
|
||||
|
||||
fn isScalar(comptime T: type) bool {
|
||||
if (comptime std.meta.trait.isZigString(T)) return true;
|
||||
if (comptime std.meta.trait.isIntegral(T)) return true;
|
||||
if (comptime std.meta.trait.isFloat(T)) return true;
|
||||
if (comptime std.meta.trait.is(.Enum)(T)) return true;
|
||||
if (T == bool) return true;
|
||||
if (comptime std.meta.trait.hasFn("parse")(T)) return true;
|
||||
|
||||
if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
fn recursiveFieldPaths(comptime T: type, comptime prefix: []const []const u8) []const []const []const u8 {
|
||||
comptime {
|
||||
var fields: []const []const []const u8 = &.{};
|
||||
|
||||
for (std.meta.fields(T)) |f| {
|
||||
const full_name = prefix ++ [_][]const u8{f.name};
|
||||
if (isScalar(f.field_type)) {
|
||||
fields = fields ++ [_][]const []const u8{full_name};
|
||||
} else {
|
||||
fields = fields ++ recursiveFieldPaths(f.field_type, full_name);
|
||||
}
|
||||
}
|
||||
|
||||
return fields;
|
||||
}
|
||||
}
|
||||
|
||||
// Represents a set of results.
|
||||
// row() must be called until it returns null, or the query may not complete
|
||||
// Must be deallocated by a call to finish()
|
||||
pub fn Results(comptime T: type) type {
|
||||
// would normally make this a declaration of the struct, but it causes the compiler to crash
|
||||
const fields = if (T == void) .{} else recursiveFieldPaths(T, &.{});
|
||||
const fields = if (T == void) .{} else util.serialize.getRecursiveFieldList(
|
||||
T,
|
||||
&.{},
|
||||
util.serialize.default_options,
|
||||
);
|
||||
return struct {
|
||||
const Self = @This();
|
||||
|
||||
|
@ -457,6 +431,7 @@ fn Tx(comptime tx_level: u8) type {
|
|||
pub fn rollback(self: Self) void {
|
||||
(if (tx_level < 2) self.rollbackTx() else self.rollbackSavepoint()) catch |err| {
|
||||
std.log.err("Failed to rollback transaction: {}", .{err});
|
||||
std.log.err("{any}", .{@errorReturnTrace()});
|
||||
@panic("TODO: more gracefully handle rollback failures");
|
||||
};
|
||||
}
|
||||
|
@ -654,7 +629,7 @@ fn Tx(comptime tx_level: u8) type {
|
|||
}
|
||||
|
||||
fn rollbackUnchecked(self: Self) !void {
|
||||
try self.exec("ROLLBACK", {}, null);
|
||||
try self.execInternal("ROLLBACK", {}, null, false);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -601,3 +601,20 @@ const ControlTokenIter = struct {
|
|||
self.peeked_token = token;
|
||||
}
|
||||
};
|
||||
|
||||
test "template" {
|
||||
const testCase = struct {
|
||||
fn case(comptime tmpl: []const u8, args: anytype, expected: []const u8) !void {
|
||||
var stream = std.io.changeDetectionStream(expected, std.io.null_writer);
|
||||
try execute(stream.writer(), tmpl, args);
|
||||
try std.testing.expect(!stream.changeDetected());
|
||||
}
|
||||
}.case;
|
||||
|
||||
try testCase("", .{}, "");
|
||||
try testCase("abcd", .{}, "abcd");
|
||||
try testCase("{.val}", .{ .val = 3 }, "3");
|
||||
try testCase("{#if .val}1{/if}", .{ .val = true }, "1");
|
||||
try testCase("{#for .vals |$v|}{$v}{/for}", .{ .vals = [_]u8{ 1, 2, 3 } }, "123");
|
||||
try testCase("{#for .vals |$v|=} {$v} {=/for}", .{ .vals = [_]u8{ 1, 2, 3 } }, "123");
|
||||
}
|
||||
|
|
161
src/util/Url.zig
161
src/util/Url.zig
|
@ -1,161 +0,0 @@
|
|||
const Url = @This();
|
||||
const std = @import("std");
|
||||
|
||||
scheme: []const u8,
|
||||
hostport: []const u8,
|
||||
path: []const u8,
|
||||
query: []const u8,
|
||||
fragment: []const u8,
|
||||
|
||||
pub fn parse(url: []const u8) !Url {
|
||||
const scheme_end = for (url) |ch, i| {
|
||||
if (ch == ':') break i;
|
||||
} else return error.InvalidUrl;
|
||||
|
||||
if (url.len < scheme_end + 3 or url[scheme_end + 1] != '/' or url[scheme_end + 1] != '/') return error.InvalidUrl;
|
||||
|
||||
const hostport_start = scheme_end + 3;
|
||||
const hostport_end = for (url[hostport_start..]) |ch, i| {
|
||||
if (ch == '/' or ch == '?' or ch == '#') break i + hostport_start;
|
||||
} else url.len;
|
||||
|
||||
const path_end = for (url[hostport_end..]) |ch, i| {
|
||||
if (ch == '?' or ch == '#') break i + hostport_end;
|
||||
} else url.len;
|
||||
|
||||
const query_end = if (!(url.len > path_end and url[path_end] == '?'))
|
||||
path_end
|
||||
else for (url[path_end..]) |ch, i| {
|
||||
if (ch == '#') break i + path_end;
|
||||
} else url.len;
|
||||
|
||||
const query = url[path_end..query_end];
|
||||
const fragment = url[query_end..];
|
||||
|
||||
return Url{
|
||||
.scheme = url[0..scheme_end],
|
||||
.hostport = url[hostport_start..hostport_end],
|
||||
.path = url[hostport_end..path_end],
|
||||
.query = if (query.len > 0) query[1..] else query,
|
||||
.fragment = if (fragment.len > 0) fragment[1..] else fragment,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn getQuery(self: Url, param: []const u8) ?[]const u8 {
|
||||
var key_start: usize = 0;
|
||||
std.log.debug("query: {s}", .{self.query});
|
||||
while (key_start < self.query.len) {
|
||||
const key_end = for (self.query[key_start..]) |ch, i| {
|
||||
if (ch == '=') break key_start + i;
|
||||
} else return null;
|
||||
|
||||
const val_start = key_end + 1;
|
||||
const val_end = for (self.query[val_start..]) |ch, i| {
|
||||
if (ch == '&') break val_start + i;
|
||||
} else self.query.len;
|
||||
|
||||
const key = self.query[key_start..key_end];
|
||||
if (std.mem.eql(u8, key, param)) return self.query[val_start..val_end];
|
||||
|
||||
key_start = val_end + 1;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
pub fn strDecode(buf: []u8, str: []const u8) ![]u8 {
|
||||
var str_i: usize = 0;
|
||||
var buf_i: usize = 0;
|
||||
while (str_i < str.len) : ({
|
||||
str_i += 1;
|
||||
buf_i += 1;
|
||||
}) {
|
||||
if (buf_i >= buf.len) return error.NoSpaceLeft;
|
||||
const ch = str[str_i];
|
||||
if (ch == '%') {
|
||||
if (str.len < str_i + 2) return error.BadEscape;
|
||||
|
||||
const hi = try std.fmt.charToDigit(str[str_i + 1], 16);
|
||||
const lo = try std.fmt.charToDIgit(str[str_i + 2], 16);
|
||||
str_i += 2;
|
||||
|
||||
buf[buf_i] = (hi << 4) | lo;
|
||||
} else {
|
||||
buf[buf_i] = str[str_i];
|
||||
}
|
||||
}
|
||||
|
||||
return buf[0..buf_i];
|
||||
}
|
||||
|
||||
fn expectEqualUrl(expected: Url, actual: Url) !void {
|
||||
const t = @import("std").testing;
|
||||
try t.expectEqualStrings(expected.scheme, actual.scheme);
|
||||
try t.expectEqualStrings(expected.hostport, actual.hostport);
|
||||
try t.expectEqualStrings(expected.path, actual.path);
|
||||
try t.expectEqualStrings(expected.query, actual.query);
|
||||
try t.expectEqualStrings(expected.fragment, actual.fragment);
|
||||
}
|
||||
test "Url" {
|
||||
try expectEqualUrl(.{
|
||||
.scheme = "https",
|
||||
.hostport = "example.com",
|
||||
.path = "",
|
||||
.query = "",
|
||||
.fragment = "",
|
||||
}, try Url.parse("https://example.com"));
|
||||
|
||||
try expectEqualUrl(.{
|
||||
.scheme = "https",
|
||||
.hostport = "example.com:1234",
|
||||
.path = "",
|
||||
.query = "",
|
||||
.fragment = "",
|
||||
}, try Url.parse("https://example.com:1234"));
|
||||
|
||||
try expectEqualUrl(.{
|
||||
.scheme = "http",
|
||||
.hostport = "example.com",
|
||||
.path = "/home",
|
||||
.query = "",
|
||||
.fragment = "",
|
||||
}, try Url.parse("http://example.com/home"));
|
||||
|
||||
try expectEqualUrl(.{
|
||||
.scheme = "https",
|
||||
.hostport = "example.com",
|
||||
.path = "",
|
||||
.query = "query=abc",
|
||||
.fragment = "",
|
||||
}, try Url.parse("https://example.com?query=abc"));
|
||||
|
||||
try expectEqualUrl(.{
|
||||
.scheme = "https",
|
||||
.hostport = "example.com",
|
||||
.path = "",
|
||||
.query = "query=abc",
|
||||
.fragment = "",
|
||||
}, try Url.parse("https://example.com?query=abc"));
|
||||
|
||||
try expectEqualUrl(.{
|
||||
.scheme = "https",
|
||||
.hostport = "example.com",
|
||||
.path = "/path/to/resource",
|
||||
.query = "query=abc",
|
||||
.fragment = "123",
|
||||
}, try Url.parse("https://example.com/path/to/resource?query=abc#123"));
|
||||
|
||||
const t = @import("std").testing;
|
||||
try t.expectError(error.InvalidUrl, Url.parse("https:example.com"));
|
||||
try t.expectError(error.InvalidUrl, Url.parse("example.com"));
|
||||
}
|
||||
|
||||
test "Url.getQuery" {
|
||||
const url = try Url.parse("https://example.com?a=xyz&b=jkl");
|
||||
const t = @import("std").testing;
|
||||
|
||||
try t.expectEqualStrings("xyz", url.getQuery("a").?);
|
||||
try t.expectEqualStrings("jkl", url.getQuery("b").?);
|
||||
try t.expect(url.getQuery("c") == null);
|
||||
try t.expect(url.getQuery("xyz") == null);
|
||||
}
|
|
@ -1,106 +0,0 @@
|
|||
const std = @import("std");
|
||||
|
||||
const Hash = std.hash.Wyhash;
|
||||
const View = std.unicode.Utf8View;
|
||||
const toLower = std.ascii.toLower;
|
||||
const isAscii = std.ascii.isASCII;
|
||||
const hash_seed = 1;
|
||||
|
||||
pub fn hash(str: []const u8) u64 {
|
||||
// fallback to regular hash on invalid utf8
|
||||
const view = View.init(str) catch return Hash.hash(hash_seed, str);
|
||||
var iter = view.iterator();
|
||||
|
||||
var h = Hash.init(hash_seed);
|
||||
|
||||
var it = iter.nextCodepointSlice();
|
||||
while (it != null) : (it = iter.nextCodepointSlice()) {
|
||||
if (it.?.len == 1 and isAscii(it.?[0])) {
|
||||
const ch = [1]u8{toLower(it.?[0])};
|
||||
h.update(&ch);
|
||||
} else {
|
||||
h.update(it.?);
|
||||
}
|
||||
}
|
||||
|
||||
return h.final();
|
||||
}
|
||||
|
||||
pub fn eql(a: []const u8, b: []const u8) bool {
|
||||
if (a.len != b.len) return false;
|
||||
|
||||
const va = View.init(a) catch return std.mem.eql(u8, a, b);
|
||||
const vb = View.init(b) catch return false;
|
||||
|
||||
var iter_a = va.iterator();
|
||||
var iter_b = vb.iterator();
|
||||
|
||||
var it_a = iter_a.nextCodepointSlice();
|
||||
var it_b = iter_b.nextCodepointSlice();
|
||||
|
||||
while (it_a != null and it_b != null) : ({
|
||||
it_a = iter_a.nextCodepointSlice();
|
||||
it_b = iter_b.nextCodepointSlice();
|
||||
}) {
|
||||
if (it_a.?.len != it_b.?.len) return false;
|
||||
|
||||
if (it_a.?.len == 1) {
|
||||
if (isAscii(it_a.?[0]) and isAscii(it_b.?[0])) {
|
||||
const ch_a = toLower(it_a.?[0]);
|
||||
const ch_b = toLower(it_b.?[0]);
|
||||
|
||||
if (ch_a != ch_b) return false;
|
||||
} else if (it_a.?[0] != it_b.?[0]) return false;
|
||||
} else if (!std.mem.eql(u8, it_a.?, it_b.?)) return false;
|
||||
}
|
||||
|
||||
return it_a == null and it_b == null;
|
||||
}
|
||||
|
||||
test "case insensitive eql with utf-8 chars" {
|
||||
const t = std.testing;
|
||||
try t.expectEqual(true, eql("abc 💯 def", "aBc 💯 DEF"));
|
||||
try t.expectEqual(false, eql("xyz 💯 ijk", "aBc 💯 DEF"));
|
||||
try t.expectEqual(false, eql("abc 💯 def", "aBc x DEF"));
|
||||
try t.expectEqual(true, eql("💯", "💯"));
|
||||
try t.expectEqual(false, eql("💯", "a"));
|
||||
try t.expectEqual(false, eql("💯", "💯 continues"));
|
||||
try t.expectEqual(false, eql("💯 fsdfs", "💯"));
|
||||
try t.expectEqual(false, eql("💯", ""));
|
||||
try t.expectEqual(false, eql("", "💯"));
|
||||
|
||||
try t.expectEqual(true, eql("abc x def", "aBc x DEF"));
|
||||
try t.expectEqual(false, eql("xyz x ijk", "aBc x DEF"));
|
||||
try t.expectEqual(true, eql("x", "x"));
|
||||
try t.expectEqual(false, eql("x", "a"));
|
||||
try t.expectEqual(false, eql("x", "x continues"));
|
||||
try t.expectEqual(false, eql("x fsdfs", "x"));
|
||||
try t.expectEqual(false, eql("x", ""));
|
||||
try t.expectEqual(false, eql("", "x"));
|
||||
|
||||
try t.expectEqual(true, eql("", ""));
|
||||
}
|
||||
|
||||
test "case insensitive hash with utf-8 chars" {
|
||||
const t = std.testing;
|
||||
try t.expect(hash("abc 💯 def") == hash("aBc 💯 DEF"));
|
||||
try t.expect(hash("xyz 💯 ijk") != hash("aBc 💯 DEF"));
|
||||
try t.expect(hash("abc 💯 def") != hash("aBc x DEF"));
|
||||
try t.expect(hash("💯") == hash("💯"));
|
||||
try t.expect(hash("💯") != hash("a"));
|
||||
try t.expect(hash("💯") != hash("💯 continues"));
|
||||
try t.expect(hash("💯 fsdfs") != hash("💯"));
|
||||
try t.expect(hash("💯") != hash(""));
|
||||
try t.expect(hash("") != hash("💯"));
|
||||
|
||||
try t.expect(hash("abc x def") == hash("aBc x DEF"));
|
||||
try t.expect(hash("xyz x ijk") != hash("aBc x DEF"));
|
||||
try t.expect(hash("x") == hash("x"));
|
||||
try t.expect(hash("x") != hash("a"));
|
||||
try t.expect(hash("x") != hash("x continues"));
|
||||
try t.expect(hash("x fsdfs") != hash("x"));
|
||||
try t.expect(hash("x") != hash(""));
|
||||
try t.expect(hash("") != hash("x"));
|
||||
|
||||
try t.expect(hash("") == hash(""));
|
||||
}
|
|
@ -1,189 +0,0 @@
|
|||
const std = @import("std");
|
||||
|
||||
pub fn Separator(comptime separator: u8) type {
|
||||
return struct {
|
||||
const Self = @This();
|
||||
str: []const u8,
|
||||
pub fn from(str: []const u8) Self {
|
||||
return .{ .str = std.mem.trim(u8, str, &.{separator}) };
|
||||
}
|
||||
|
||||
pub fn next(self: *Self) ?[]const u8 {
|
||||
if (self.str.len == 0) return null;
|
||||
|
||||
const part = std.mem.sliceTo(self.str, separator);
|
||||
self.str = std.mem.trimLeft(u8, self.str[part.len..], &.{separator});
|
||||
|
||||
return part;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub const QueryIter = struct {
|
||||
const Pair = struct {
|
||||
key: []const u8,
|
||||
value: ?[]const u8,
|
||||
};
|
||||
|
||||
iter: Separator('&'),
|
||||
|
||||
pub fn from(q: []const u8) QueryIter {
|
||||
return QueryIter{ .iter = Separator('&').from(std.mem.trimLeft(u8, q, "?")) };
|
||||
}
|
||||
|
||||
pub fn next(self: *QueryIter) ?Pair {
|
||||
const part = self.iter.next() orelse return null;
|
||||
|
||||
const key = std.mem.sliceTo(part, '=');
|
||||
if (key.len == part.len) return Pair{
|
||||
.key = key,
|
||||
.value = null,
|
||||
};
|
||||
|
||||
return Pair{
|
||||
.key = key,
|
||||
.value = part[key.len + 1 ..],
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub const PathIter = struct {
|
||||
is_first: bool,
|
||||
iter: std.mem.SplitIterator(u8),
|
||||
|
||||
pub fn from(path: []const u8) PathIter {
|
||||
return .{ .is_first = true, .iter = std.mem.split(u8, path, "/") };
|
||||
}
|
||||
|
||||
pub fn next(self: *PathIter) ?[]const u8 {
|
||||
defer self.is_first = false;
|
||||
while (self.iter.next()) |it| if (it.len != 0) {
|
||||
return it;
|
||||
};
|
||||
|
||||
if (self.is_first) return self.iter.rest();
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
pub fn first(self: *PathIter) []const u8 {
|
||||
std.debug.assert(self.is_first);
|
||||
return self.next().?;
|
||||
}
|
||||
|
||||
pub fn rest(self: *PathIter) []const u8 {
|
||||
return self.iter.rest();
|
||||
}
|
||||
};
|
||||
|
||||
test "QueryIter" {
|
||||
const t = @import("std").testing;
|
||||
if (true) return error.SkipZigTest;
|
||||
{
|
||||
var iter = QueryIter.from("");
|
||||
try t.expect(iter.next() == null);
|
||||
try t.expect(iter.next() == null);
|
||||
}
|
||||
|
||||
{
|
||||
var iter = QueryIter.from("?");
|
||||
try t.expect(iter.next() == null);
|
||||
try t.expect(iter.next() == null);
|
||||
}
|
||||
|
||||
{
|
||||
var iter = QueryIter.from("?abc");
|
||||
try t.expectEqual(QueryIter.Pair{
|
||||
.key = "abc",
|
||||
.value = null,
|
||||
}, iter.next().?);
|
||||
try t.expect(iter.next() == null);
|
||||
try t.expect(iter.next() == null);
|
||||
}
|
||||
|
||||
{
|
||||
var iter = QueryIter.from("?abc=");
|
||||
try t.expectEqual(QueryIter.Pair{
|
||||
.key = "abc",
|
||||
.value = "",
|
||||
}, iter.next().?);
|
||||
try t.expect(iter.next() == null);
|
||||
try t.expect(iter.next() == null);
|
||||
}
|
||||
|
||||
{
|
||||
var iter = QueryIter.from("?abc=def");
|
||||
try t.expectEqual(QueryIter.Pair{
|
||||
.key = "abc",
|
||||
.value = "def",
|
||||
}, iter.next().?);
|
||||
try t.expect(iter.next() == null);
|
||||
try t.expect(iter.next() == null);
|
||||
}
|
||||
|
||||
{
|
||||
var iter = QueryIter.from("?abc=def&");
|
||||
try t.expectEqual(QueryIter.Pair{
|
||||
.key = "abc",
|
||||
.value = "def",
|
||||
}, iter.next().?);
|
||||
try t.expect(iter.next() == null);
|
||||
try t.expect(iter.next() == null);
|
||||
}
|
||||
|
||||
{
|
||||
var iter = QueryIter.from("?abc=def&foo&bar=baz&qux=");
|
||||
try t.expectEqual(QueryIter.Pair{
|
||||
.key = "abc",
|
||||
.value = "def",
|
||||
}, iter.next().?);
|
||||
try t.expectEqual(QueryIter.Pair{
|
||||
.key = "foo",
|
||||
.value = null,
|
||||
}, iter.next().?);
|
||||
try t.expectEqual(QueryIter.Pair{
|
||||
.key = "bar",
|
||||
.value = "baz",
|
||||
}, iter.next().?);
|
||||
try t.expectEqual(QueryIter.Pair{
|
||||
.key = "qux",
|
||||
.value = "",
|
||||
}, iter.next().?);
|
||||
try t.expect(iter.next() == null);
|
||||
try t.expect(iter.next() == null);
|
||||
}
|
||||
|
||||
{
|
||||
var iter = QueryIter.from("?=def&");
|
||||
try t.expectEqual(QueryIter.Pair{
|
||||
.key = "",
|
||||
.value = "def",
|
||||
}, iter.next().?);
|
||||
try t.expect(iter.next() == null);
|
||||
try t.expect(iter.next() == null);
|
||||
}
|
||||
}
|
||||
|
||||
test "PathIter /ab/cd/" {
|
||||
const path = "/ab/cd/";
|
||||
var it = PathIter.from(path);
|
||||
try std.testing.expectEqualStrings("ab", it.next().?);
|
||||
try std.testing.expectEqualStrings("cd", it.next().?);
|
||||
try std.testing.expectEqual(@as(?[]const u8, null), it.next());
|
||||
}
|
||||
|
||||
test "PathIter ''" {
|
||||
const path = "";
|
||||
var it = PathIter.from(path);
|
||||
try std.testing.expectEqualStrings("", it.next().?);
|
||||
try std.testing.expectEqual(@as(?[]const u8, null), it.next());
|
||||
}
|
||||
|
||||
test "PathIter ab/c//defg/" {
|
||||
const path = "ab/c//defg/";
|
||||
var it = PathIter.from(path);
|
||||
try std.testing.expectEqualStrings("ab", it.next().?);
|
||||
try std.testing.expectEqualStrings("c", it.next().?);
|
||||
try std.testing.expectEqualStrings("defg", it.next().?);
|
||||
try std.testing.expectEqual(@as(?[]const u8, null), it.next());
|
||||
}
|
|
@ -1,13 +1,10 @@
|
|||
const std = @import("std");
|
||||
const iters = @import("./iters.zig");
|
||||
|
||||
pub const ciutf8 = @import("./ciutf8.zig");
|
||||
pub const Uuid = @import("./Uuid.zig");
|
||||
pub const DateTime = @import("./DateTime.zig");
|
||||
pub const Url = @import("./Url.zig");
|
||||
pub const PathIter = iters.PathIter;
|
||||
pub const QueryIter = iters.QueryIter;
|
||||
pub const SqlStmtIter = iters.Separator(';');
|
||||
pub const serialize = @import("./serialize.zig");
|
||||
pub const Deserializer = serialize.Deserializer;
|
||||
pub const DeserializerContext = serialize.DeserializerContext;
|
||||
|
||||
/// Joins an array of strings, prefixing every entry with `prefix`,
|
||||
/// and putting `separator` in between each pair
|
||||
|
@ -202,6 +199,16 @@ pub fn seedThreadPrng() !void {
|
|||
prng = std.rand.DefaultPrng.init(@bitCast(u64, buf));
|
||||
}
|
||||
|
||||
pub fn comptimeToCrlf(comptime str: []const u8) []const u8 {
|
||||
comptime {
|
||||
@setEvalBranchQuota(str.len * 10);
|
||||
const size = std.mem.replacementSize(u8, str, "\n", "\r\n");
|
||||
var buf: [size]u8 = undefined;
|
||||
_ = std.mem.replace(u8, str, "\n", "\r\n", &buf);
|
||||
return &buf;
|
||||
}
|
||||
}
|
||||
|
||||
pub const testing = struct {
|
||||
pub fn expectDeepEqual(expected: anytype, actual: @TypeOf(expected)) !void {
|
||||
const T = @TypeOf(expected);
|
||||
|
@ -242,3 +249,7 @@ pub const testing = struct {
|
|||
}
|
||||
}
|
||||
};
|
||||
|
||||
test {
|
||||
_ = std.testing.refAllDecls(@This());
|
||||
}
|
||||
|
|
386
src/util/serialize.zig
Normal file
386
src/util/serialize.zig
Normal file
|
@ -0,0 +1,386 @@
|
|||
const std = @import("std");
|
||||
const util = @import("./lib.zig");
|
||||
|
||||
pub const FieldRef = []const []const u8;
|
||||
|
||||
pub fn defaultIsScalar(comptime T: type) bool {
|
||||
if (comptime std.meta.trait.is(.Optional)(T) and defaultIsScalar(std.meta.Child(T))) return true;
|
||||
if (comptime std.meta.trait.isZigString(T)) return true;
|
||||
if (comptime std.meta.trait.isIntegral(T)) return true;
|
||||
if (comptime std.meta.trait.isFloat(T)) return true;
|
||||
if (comptime std.meta.trait.is(.Enum)(T)) return true;
|
||||
if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true;
|
||||
if (comptime std.meta.trait.hasFn("parse")(T)) return true;
|
||||
if (T == bool) return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
pub fn deserializeString(allocator: std.mem.Allocator, comptime T: type, value: []const u8) !T {
|
||||
if (comptime std.meta.trait.is(.Optional)(T)) {
|
||||
if (value.len == 0) return null;
|
||||
return try deserializeString(allocator, std.meta.Child(T), value);
|
||||
}
|
||||
|
||||
if (T == []u8 or T == []const u8) return try util.deepClone(allocator, value);
|
||||
if (comptime std.meta.trait.isIntegral(T)) return try std.fmt.parseInt(T, value, 0);
|
||||
if (comptime std.meta.trait.isFloat(T)) return try std.fmt.parseFloat(T, value);
|
||||
if (comptime std.meta.trait.hasFn("parse")(T)) return try T.parse(value);
|
||||
|
||||
var buf: [64]u8 = undefined;
|
||||
const lowered = std.ascii.lowerString(&buf, value);
|
||||
|
||||
if (T == bool) return bool_map.get(lowered) orelse return error.InvalidBool;
|
||||
if (comptime std.meta.trait.is(.Enum)(T)) {
|
||||
return std.meta.stringToEnum(T, lowered) orelse return error.InvalidEnumTag;
|
||||
}
|
||||
|
||||
@compileError("Invalid type " ++ @typeName(T));
|
||||
}
|
||||
|
||||
pub fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: SerializationOptions) []const FieldRef {
|
||||
comptime {
|
||||
if (std.meta.trait.is(.Union)(T) and prefix.len == 0 and options.embed_unions) {
|
||||
@compileError("Cannot embed a union into nothing");
|
||||
}
|
||||
|
||||
if (options.isScalar(T)) return &.{prefix};
|
||||
if (std.meta.trait.is(.Optional)(T)) return getRecursiveFieldList(std.meta.Child(T), prefix, options);
|
||||
|
||||
const eff_prefix: FieldRef = if (std.meta.trait.is(.Union)(T) and options.embed_unions)
|
||||
prefix[0 .. prefix.len - 1]
|
||||
else
|
||||
prefix;
|
||||
|
||||
var fields: []const FieldRef = &.{};
|
||||
|
||||
for (std.meta.fields(T)) |f| {
|
||||
const new_prefix = eff_prefix ++ &[_][]const u8{f.name};
|
||||
const F = f.field_type;
|
||||
fields = fields ++ getRecursiveFieldList(F, new_prefix, options);
|
||||
}
|
||||
|
||||
return fields;
|
||||
}
|
||||
}
|
||||
|
||||
pub const SerializationOptions = struct {
|
||||
embed_unions: bool,
|
||||
isScalar: fn (type) bool,
|
||||
};
|
||||
|
||||
pub const default_options = SerializationOptions{
|
||||
.embed_unions = true,
|
||||
.isScalar = defaultIsScalar,
|
||||
};
|
||||
|
||||
fn Intermediary(comptime Result: type, comptime From: type, comptime options: SerializationOptions) type {
|
||||
const field_refs = getRecursiveFieldList(Result, &.{}, options);
|
||||
|
||||
var fields: [field_refs.len]std.builtin.Type.StructField = undefined;
|
||||
for (field_refs) |ref, i| {
|
||||
fields[i] = .{
|
||||
.name = util.comptimeJoin(".", ref),
|
||||
.field_type = ?From,
|
||||
.default_value = &@as(?From, null),
|
||||
.is_comptime = false,
|
||||
.alignment = @alignOf(?From),
|
||||
};
|
||||
}
|
||||
|
||||
return @Type(.{ .Struct = .{
|
||||
.layout = .Auto,
|
||||
.fields = &fields,
|
||||
.decls = &.{},
|
||||
.is_tuple = false,
|
||||
} });
|
||||
}
|
||||
|
||||
pub fn Deserializer(comptime Result: type) type {
|
||||
return DeserializerContext(Result, []const u8, struct {
|
||||
const options = default_options;
|
||||
fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, val: []const u8) !T {
|
||||
return try deserializeString(alloc, T, val);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn DeserializerContext(comptime Result: type, comptime From: type, comptime Context: type) type {
|
||||
return struct {
|
||||
const Data = Intermediary(Result, From, Context.options);
|
||||
|
||||
data: Data = .{},
|
||||
context: Context = .{},
|
||||
|
||||
pub fn setSerializedField(self: *@This(), key: []const u8, value: From) !void {
|
||||
const field = std.meta.stringToEnum(std.meta.FieldEnum(Data), key) orelse return error.UnknownField;
|
||||
inline for (comptime std.meta.fieldNames(Data)) |field_name| {
|
||||
@setEvalBranchQuota(10000);
|
||||
const f = comptime std.meta.stringToEnum(std.meta.FieldEnum(Data), field_name);
|
||||
if (field == f) {
|
||||
@field(self.data, field_name) = value;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
unreachable;
|
||||
}
|
||||
|
||||
pub const Iter = struct {
|
||||
data: *const Data,
|
||||
field_index: usize,
|
||||
|
||||
const Item = struct {
|
||||
key: []const u8,
|
||||
value: From,
|
||||
};
|
||||
|
||||
pub fn next(self: *Iter) ?Item {
|
||||
while (self.field_index < std.meta.fields(Data).len) {
|
||||
const idx = self.field_index;
|
||||
self.field_index += 1;
|
||||
inline for (comptime std.meta.fieldNames(Data)) |field, i| {
|
||||
if (i == idx) {
|
||||
const maybe_value = @field(self.data.*, field);
|
||||
if (maybe_value) |value| return Item{ .key = field, .value = value };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
pub fn iterator(self: *const @This()) Iter {
|
||||
return .{ .data = &self.data, .field_index = 0 };
|
||||
}
|
||||
|
||||
pub fn finishFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void {
|
||||
util.deepFree(allocator, val);
|
||||
}
|
||||
|
||||
pub fn finish(self: *@This(), allocator: std.mem.Allocator) !Result {
|
||||
return (try self.deserialize(allocator, Result, &.{})) orelse error.MissingField;
|
||||
}
|
||||
|
||||
fn getSerializedField(self: *@This(), comptime field_ref: FieldRef) ?From {
|
||||
//inline for (comptime std.meta.fieldNames(Data)) |f| @compileLog(f.ptr);
|
||||
return @field(self.data, util.comptimeJoin(".", field_ref));
|
||||
}
|
||||
|
||||
fn deserializeFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void {
|
||||
util.deepFree(allocator, val);
|
||||
}
|
||||
|
||||
fn deserialize(self: *@This(), allocator: std.mem.Allocator, comptime T: type, comptime field_ref: FieldRef) !?T {
|
||||
if (comptime Context.options.isScalar(T)) {
|
||||
return try self.context.deserializeScalar(allocator, T, self.getSerializedField(field_ref) orelse return null);
|
||||
}
|
||||
|
||||
switch (@typeInfo(T)) {
|
||||
// At most one of any union field can be active at a time, and it is embedded
|
||||
// in its parent container
|
||||
.Union => |info| {
|
||||
var result: ?T = null;
|
||||
errdefer if (result) |v| self.deserializeFree(allocator, v);
|
||||
// TODO: errdefer cleanup
|
||||
const union_ref: FieldRef = if (Context.options.embed_unions) field_ref[0 .. field_ref.len - 1] else field_ref;
|
||||
inline for (info.fields) |field| {
|
||||
const F = field.field_type;
|
||||
const new_field_ref = union_ref ++ &[_][]const u8{field.name};
|
||||
const maybe_value = try self.deserialize(allocator, F, new_field_ref);
|
||||
if (maybe_value) |value| {
|
||||
// TODO: errdefer cleanup
|
||||
errdefer self.deserializeFree(allocator, value);
|
||||
if (result != null) return error.DuplicateUnionMember;
|
||||
result = @unionInit(T, field.name, value);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
},
|
||||
|
||||
.Struct => |info| {
|
||||
var result: T = undefined;
|
||||
|
||||
var any_explicit = false;
|
||||
var any_missing = false;
|
||||
var fields_alloced = [1]bool{false} ** info.fields.len;
|
||||
errdefer inline for (info.fields) |field, i| {
|
||||
if (fields_alloced[i]) self.deserializeFree(allocator, @field(result, field.name));
|
||||
};
|
||||
inline for (info.fields) |field, i| {
|
||||
const F = field.field_type;
|
||||
const new_field_ref = field_ref ++ &[_][]const u8{field.name};
|
||||
const maybe_value = try self.deserialize(allocator, F, new_field_ref);
|
||||
if (maybe_value) |v| {
|
||||
@field(result, field.name) = v;
|
||||
fields_alloced[i] = true;
|
||||
any_explicit = true;
|
||||
} else if (field.default_value) |ptr| {
|
||||
if (@sizeOf(F) != 0) {
|
||||
const cast_ptr = @ptrCast(*const F, @alignCast(field.alignment, ptr));
|
||||
@field(result, field.name) = try util.deepClone(allocator, cast_ptr.*);
|
||||
fields_alloced[i] = true;
|
||||
}
|
||||
} else {
|
||||
any_missing = true;
|
||||
}
|
||||
}
|
||||
if (any_missing) {
|
||||
return if (any_explicit) error.MissingField else null;
|
||||
}
|
||||
|
||||
return result;
|
||||
},
|
||||
|
||||
// Specifically non-scalar optionals
|
||||
.Optional => |info| return try self.deserialize(allocator, info.child, field_ref),
|
||||
|
||||
else => @compileError("Unsupported type"),
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
const bool_map = std.ComptimeStringMap(bool, .{
|
||||
.{ "true", true },
|
||||
.{ "t", true },
|
||||
.{ "yes", true },
|
||||
.{ "y", true },
|
||||
.{ "1", true },
|
||||
|
||||
.{ "false", false },
|
||||
.{ "f", false },
|
||||
.{ "no", false },
|
||||
.{ "n", false },
|
||||
.{ "0", false },
|
||||
});
|
||||
|
||||
test "Deserializer" {
|
||||
|
||||
// Happy case - simple
|
||||
{
|
||||
const T = struct { foo: []const u8, bar: bool };
|
||||
|
||||
var ds = Deserializer(T){};
|
||||
try ds.setSerializedField("foo", "123");
|
||||
try ds.setSerializedField("bar", "true");
|
||||
|
||||
const val = try ds.finish(std.testing.allocator);
|
||||
defer ds.finishFree(std.testing.allocator, val);
|
||||
try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val);
|
||||
}
|
||||
|
||||
// Returns error if nonexistent field set
|
||||
{
|
||||
const T = struct { foo: []const u8, bar: bool };
|
||||
|
||||
var ds = Deserializer(T){};
|
||||
try std.testing.expectError(error.UnknownField, ds.setSerializedField("baz", "123"));
|
||||
}
|
||||
|
||||
// Substruct dereferencing
|
||||
{
|
||||
const T = struct {
|
||||
foo: struct { bar: bool, baz: bool },
|
||||
};
|
||||
|
||||
var ds = Deserializer(T){};
|
||||
try ds.setSerializedField("foo.bar", "true");
|
||||
try ds.setSerializedField("foo.baz", "true");
|
||||
|
||||
const val = try ds.finish(std.testing.allocator);
|
||||
defer ds.finishFree(std.testing.allocator, val);
|
||||
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true, .baz = true } }, val);
|
||||
}
|
||||
|
||||
// Union embedding
|
||||
{
|
||||
const T = struct {
|
||||
foo: union(enum) { bar: bool, baz: bool },
|
||||
};
|
||||
|
||||
var ds = Deserializer(T){};
|
||||
try ds.setSerializedField("bar", "true");
|
||||
|
||||
const val = try ds.finish(std.testing.allocator);
|
||||
defer ds.finishFree(std.testing.allocator, val);
|
||||
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true } }, val);
|
||||
}
|
||||
|
||||
// Returns error if multiple union fields specified
|
||||
{
|
||||
const T = struct {
|
||||
foo: union(enum) { bar: bool, baz: bool },
|
||||
};
|
||||
|
||||
var ds = Deserializer(T){};
|
||||
try ds.setSerializedField("bar", "true");
|
||||
try ds.setSerializedField("baz", "true");
|
||||
|
||||
try std.testing.expectError(error.DuplicateUnionMember, ds.finish(std.testing.allocator));
|
||||
}
|
||||
|
||||
// Uses default values if fields aren't provided
|
||||
{
|
||||
const T = struct { foo: []const u8 = "123", bar: bool = true };
|
||||
|
||||
var ds = Deserializer(T){};
|
||||
|
||||
const val = try ds.finish(std.testing.allocator);
|
||||
defer ds.finishFree(std.testing.allocator, val);
|
||||
try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val);
|
||||
}
|
||||
|
||||
// Returns an error if fields aren't provided and no default exists
|
||||
{
|
||||
const T = struct { foo: []const u8, bar: bool };
|
||||
|
||||
var ds = Deserializer(T){};
|
||||
try ds.setSerializedField("foo", "123");
|
||||
|
||||
try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator));
|
||||
}
|
||||
|
||||
// Handles optional containers
|
||||
{
|
||||
const T = struct {
|
||||
foo: ?struct { bar: usize = 3, baz: usize } = null,
|
||||
qux: ?union(enum) { quux: usize } = null,
|
||||
};
|
||||
|
||||
var ds = Deserializer(T){};
|
||||
|
||||
const val = try ds.finish(std.testing.allocator);
|
||||
defer ds.finishFree(std.testing.allocator, val);
|
||||
try util.testing.expectDeepEqual(T{ .foo = null, .qux = null }, val);
|
||||
}
|
||||
|
||||
{
|
||||
const T = struct {
|
||||
foo: ?struct { bar: usize = 3, baz: usize } = null,
|
||||
qux: ?union(enum) { quux: usize } = null,
|
||||
};
|
||||
|
||||
var ds = Deserializer(T){};
|
||||
try ds.setSerializedField("foo.baz", "3");
|
||||
try ds.setSerializedField("quux", "3");
|
||||
|
||||
const val = try ds.finish(std.testing.allocator);
|
||||
defer ds.finishFree(std.testing.allocator, val);
|
||||
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = 3, .baz = 3 }, .qux = .{ .quux = 3 } }, val);
|
||||
}
|
||||
|
||||
{
|
||||
const T = struct {
|
||||
foo: ?struct { bar: usize = 3, baz: usize } = null,
|
||||
qux: ?union(enum) { quux: usize } = null,
|
||||
};
|
||||
|
||||
var ds = Deserializer(T){};
|
||||
try ds.setSerializedField("foo.bar", "3");
|
||||
try ds.setSerializedField("quux", "3");
|
||||
|
||||
try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator));
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue