More refactoring

This commit is contained in:
jaina heartles 2022-10-01 22:18:24 -07:00
parent baa17ccc26
commit 4b5d11b00a
13 changed files with 418 additions and 568 deletions

View File

@ -1,37 +0,0 @@
# Overview
## Packages
- `main`: primary package, has application-specific functionality
* TODO: consider moving controllers and api into different packages
* `controllers/**.zig`:
- Transforms HTTP to/from API calls
- Turns error codes into HTTP statuses
* `api.zig`:
- Makes sure API call is allowed with the given user/host context
- Transforms API models into display models
- `api/**.zig`: Performs action associated with API call
* Transforms DB models into API models
* Data validation
- TODO: the distinction between what goes in `api.zig` and in its submodules is gross. Refactor?
* `migrations.zig`:
- Defines database migrations to apply
- Should be ran on startup
- `util`: utility packages
* Components:
- `Uuid`: UUID utils (random uuid generation, equality, parsing, printing)
* `Uuid.eql`
* `Uuid.randV4`
* UUID's are serialized to their string representation for JSON, db
- `PathIter`: Path segment iterator
- `Url`: URL utils (parsing)
- `ciutf8`: case-insensitive UTF-8 (TODO: Scrap this, replace with ICU library)
- `DateTime`: Time utils
- `deepClone(alloc, orig)`/`deepFree(alloc, to_free)`: Utils for cloning and freeing basic data structs
* Clones/frees any strings/sub structs within the value
- `sql`: SQL library
* Supports 2 engines (SQLite, PostgreSQL)
* `var my_transaction = try db.begin()`
* `const results = try db.query(RowType, "SELECT ...", .{arg_1, ...}, alloc)`
- `http`: HTTP Server
* The API sucks. Needs a refactor

View File

@ -12,37 +12,62 @@ pub const passwords = struct {
const PwHashBuf = [pw_hash_buf_size]u8;
// Returned slice points into buf
fn hashPassword(password: []const u8, alloc: std.mem.Allocator, buf: *PwHashBuf) []const u8 {
return PwHash.strHash(password, .{ .allocator = alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, buf) catch unreachable;
}
pub const VerifyError = error{
InvalidLogin,
DbError,
DatabaseFailure,
HashFailure,
};
pub fn verify(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) VerifyError!void {
pub fn verify(
db: anytype,
account_id: Uuid,
password: []const u8,
alloc: std.mem.Allocator,
) VerifyError!void {
// TODO: This could be done w/o the dynamically allocated hash buf
const hash = (db.queryRow(
const hash = db.queryRow(
std.meta.Tuple(&.{[]const u8}),
"SELECT hashed_password FROM account_password WHERE user_id = $1 LIMIT 1",
.{user_id},
\\SELECT hashed_password
\\FROM account_password
\\WHERE account_id = $1
\\LIMIT 1
,
.{account_id},
alloc,
) catch return error.DbError) orelse return error.InvalidLogin;
) catch |err| return switch (err) {
error.NoRows => error.InvalidLogin,
else => error.DatabaseFailure,
};
errdefer alloc.free(hash[0]);
PwHash.strVerify(hash[0], password, .{ .allocator = alloc }) catch unreachable;
PwHash.strVerify(
hash[0],
password,
.{ .allocator = alloc },
) catch error.HashFailure;
}
pub const CreateError = error{DbError};
pub fn create(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) CreateError!void {
pub const CreateError = error{ DatabaseFailure, HashFailure };
pub fn create(
db: anytype,
account_id: Uuid,
password: []const u8,
alloc: std.mem.Allocator,
) CreateError!void {
var buf: PwHashBuf = undefined;
const hash = PwHash.strHash(password, .{ .allocator = alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, &buf) catch unreachable;
const hash = PwHash.strHash(
password,
.{ .allocator = alloc, .params = pw_hash_params, .encoding = pw_hash_encoding },
&buf,
) catch return error.HashFailure;
db.insert("account_password", .{
.user_id = user_id,
.hashed_password = hash,
}) catch return error.DbError;
db.insert(
"account_password",
.{
.account_id = account_id,
.hashed_password = hash,
},
alloc,
) catch return error.DatabaseFailure;
}
};
@ -52,7 +77,7 @@ pub const tokens = struct {
pub const Token = struct {
pub const Value = [token_str_len]u8;
pub const Info = struct {
user_id: Uuid,
account_id: Uuid,
issued_at: DateTime,
};
@ -65,12 +90,12 @@ pub const tokens = struct {
const DbToken = struct {
hash: []const u8,
user_id: Uuid,
account_id: Uuid,
issued_at: DateTime,
};
pub const CreateError = error{DbError};
pub fn create(db: anytype, user_id: Uuid) CreateError!Token {
pub const CreateError = error{DatabaseFailure};
pub fn create(db: anytype, account_id: Uuid) CreateError!Token {
var token: [token_len]u8 = undefined;
std.crypto.random.bytes(&token);
@ -81,7 +106,7 @@ pub const tokens = struct {
db.insert("token", DbToken{
.hash = &hash,
.user_id = user_id,
.account_id = account_id,
.issued_at = issued_at,
}) catch return error.DbError;
@ -89,67 +114,44 @@ pub const tokens = struct {
_ = std.base64.standard.Encoder.encode(&token_enc, &token);
return Token{ .value = token_enc, .info = .{
.user_id = user_id,
.account_id = account_id,
.issued_at = issued_at,
} };
}
fn lookupUserTokenFromHash(db: anytype, hash: []const u8, community_id: Uuid) !?Token.Info {
return if (try db.queryRow(
std.meta.Tuple(&.{ Uuid, DateTime }),
\\SELECT user.id, token.issued_at
\\FROM token JOIN user ON token.user_id = user.id
\\WHERE user.community_id = $1 AND token.hash = $2
\\LIMIT 1
,
.{ community_id, hash },
null,
)) |result|
Token.Info{
.user_id = result[0],
.issued_at = result[1],
}
else
null;
}
fn lookupSystemTokenFromHash(db: anytype, hash: []const u8) !?Token.Info {
return if (try db.queryRow(
std.meta.Tuple(&.{ Uuid, DateTime }),
\\SELECT user.id, token.issued_at
\\FROM token JOIN user ON token.user_id = user.id
\\WHERE user.community_id IS NULL AND token.hash = $1
\\LIMIT 1
,
.{hash},
null,
)) |result|
Token.Info{
.user_id = result[0],
.issued_at = result[1],
}
else
null;
}
pub const VerifyError = error{ InvalidToken, DbError };
pub fn verify(db: anytype, token: []const u8, community_id: ?Uuid) VerifyError!Token.Info {
const decoded_len = std.base64.standard.Decoder.calcSizeForSlice(token) catch return error.InvalidToken;
pub const VerifyError = error{ InvalidToken, DatabaseError };
pub fn verify(
db: anytype,
token: []const u8,
community_id: Uuid,
alloc: std.mem.Allocator,
) VerifyError!Token.Info {
const decoded_len = std.base64.standard.Decoder.calcSizeForSlice(
token,
) catch return error.InvalidToken;
if (decoded_len != token_len) return error.InvalidToken;
var decoded: [token_len]u8 = undefined;
std.base64.standard.Decoder.decode(&decoded, token) catch return error.InvalidToken;
std.base64.standard.Decoder.decode(
&decoded,
token,
) catch return error.InvalidToken;
var hash: TokenDigestBuf = undefined;
TokenHash.hash(&decoded, &hash, .{});
const token_info = if (community_id) |id|
lookupUserTokenFromHash(db, &hash, id) catch return error.DbError
else
lookupSystemTokenFromHash(db, &hash) catch return error.DbError;
if (token_info) |info| return info;
return error.InvalidToken;
return db.queryRow(
Token.Info,
\\SELECT account.id, token.issued_at
\\FROM token JOIN account ON token.account_id = account.id
\\WHERE token.hash = $1 AND account.community_id = $2
\\LIMIT 1
,
.{ hash, community_id },
alloc,
) catch |err| switch (err) {
error.NoRows => error.InvalidToken,
else => error.DatabaseFailure,
};
}
};

View File

@ -8,19 +8,18 @@ const getRandom = @import("../api.zig").getRandom;
const Uuid = util.Uuid;
const DateTime = util.DateTime;
const CreateError = error{
InvalidOrigin,
UnsupportedScheme,
CommunityExists,
} || anyerror; // TODO
pub const Scheme = enum {
https,
http,
pub fn jsonStringify(s: Scheme, _: std.json.StringifyOptions, writer: anytype) !void {
return std.fmt.format(writer, "\"{s}\"", .{@tagName(s)});
}
pub const jsonStringify = util.jsonSerializeEnumAsString;
};
pub const Kind = enum {
admin,
local,
pub const jsonStringify = util.jsonSerializeEnumAsString;
};
pub const Community = struct {
@ -35,21 +34,19 @@ pub const Community = struct {
created_at: DateTime,
};
pub const Kind = enum {
admin,
local,
pub fn jsonStringify(val: Kind, _: std.json.StringifyOptions, writer: anytype) !void {
return std.fmt.format(writer, "\"{s}\"", .{@tagName(val)});
}
};
pub const CreateOptions = struct {
name: ?[]const u8 = null,
kind: Kind = .local,
};
pub fn create(db: anytype, origin: []const u8, owner: Uuid, options: CreateOptions) CreateError!Community {
pub const CreateError = error{
DatabaseFailure,
UnsupportedScheme,
InvalidOrigin,
CommunityExists,
};
pub fn create(db: anytype, origin: []const u8, owner: Uuid, options: CreateOptions, alloc: std.mem.Allocator) CreateError!Uuid {
const scheme_len = std.mem.indexOfScalar(u8, origin, ':') orelse return error.InvalidOrigin;
const scheme_str = origin[0..scheme_len];
const scheme = std.meta.stringToEnum(Scheme, scheme_str) orelse return error.UnsupportedScheme;
@ -75,36 +72,61 @@ pub fn create(db: anytype, origin: []const u8, owner: Uuid, options: CreateOptio
const id = Uuid.randV4(getRandom());
const community = Community{
// TODO: wrap this in TX
if (db.queryRow(
std.meta.Tuple(&.{Uuid}),
"SELECT id FROM community WHERE host = $1",
.{host},
alloc,
)) |_| {
return error.CommunityExists;
} else |err| switch (err) {
error.NoRows => {},
else => return error.DatabaseFailure,
}
try db.insert("community", .{
.id = id,
.owner_id = owner,
.host = host,
.name = options.name orelse host,
.scheme = scheme,
.kind = options.kind,
.created_at = DateTime.now(),
};
}, alloc);
if ((try db.queryRow(std.meta.Tuple(&.{Uuid}), "SELECT id FROM community WHERE host = $1", .{host}, null)) != null) {
return error.CommunityExists;
}
try db.insert("community", community);
return community;
return id;
}
pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) !Community {
return (try db.queryRow(
pub const GetError = error{
NotFound,
DatabaseFailure,
};
pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) GetError!Community {
return db.queryRow(
Community,
std.fmt.comptimePrint("SELECT {s} FROM community WHERE host = $1", .{comptime sql.fieldList(Community)}),
std.fmt.comptimePrint(
\\SELECT {s}
\\FROM community
\\WHERE host = $1
\\LIMIT 1
,
.{comptime sql.fieldList(Community)},
),
.{host},
alloc,
)) orelse return error.NotFound;
) catch |err| switch (err) {
error.NoRows => error.NotFound,
else => error.DatabaseFailure,
};
}
pub fn transferOwnership(db: anytype, community_id: Uuid, new_owner: Uuid) !void {
try db.exec("UPDATE community SET owner_id = $1 WHERE id = $2", .{ new_owner, community_id }, null);
// TODO: check that this actually found/updated the row (needs update to sql lib)
db.exec(
"UPDATE community SET owner_id = $1 WHERE id = $2",
.{ new_owner, community_id },
null,
) catch return error.DatabaseFailure;
}
pub const QueryArgs = struct {
@ -154,19 +176,19 @@ pub const QueryArgs = struct {
} = .forward,
};
const Builder = struct {
const QueryBuilder = struct {
array: std.ArrayList(u8),
where_clauses_appended: usize = 0,
pub fn init(alloc: std.mem.Allocator) Builder {
return Builder{ .array = std.ArrayList(u8).init(alloc) };
pub fn init(alloc: std.mem.Allocator) QueryBuilder {
return QueryBuilder{ .array = std.ArrayList(u8).init(alloc) };
}
pub fn deinit(self: *const Builder) void {
pub fn deinit(self: *const QueryBuilder) void {
self.array.deinit();
}
pub fn andWhere(self: *Builder, clause: []const u8) !void {
pub fn andWhere(self: *QueryBuilder, clause: []const u8) !void {
if (self.where_clauses_appended == 0) {
try self.array.appendSlice("WHERE ");
} else {
@ -179,15 +201,27 @@ const Builder = struct {
};
const max_max_items = 100;
pub const QueryError = error{
PageArgMismatch,
DatabaseError,
};
// Retrieves up to `args.max_items` Community entries matching the given query
// arguments.
// `args.max_items` is only a request, and fewer entries may be returned.
pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) ![]Community {
var builder = Builder.init(alloc);
var builder = QueryBuilder.init(alloc);
defer builder.deinit();
try builder.array.appendSlice(
\\SELECT id, owner_id, host, name, scheme, created_at
\\FROM community
\\
std.fmt.comptimePrint(
\\SELECT {s}
\\FROM community
\\
, .{util.comptimeJoin(",", std.meta.fieldNames(Community))}),
);
const max_items = if (args.max_items > max_max_items) max_max_items else args.max_items;
if (args.owner_id != null) try builder.andWhere("owner_id = $1");
@ -268,12 +302,15 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) ![]Communit
}
pub fn adminCommunityId(db: anytype) !Uuid {
const row = (try db.queryRow(
const row = db.queryRow(
std.meta.Tuple(&.{Uuid}),
"SELECT id FROM community WHERE kind = 'admin' LIMIT 1",
.{},
{},
null,
)) orelse return error.NotFound;
) catch |err| return switch (err) {
error.NoRows => error.NotFound,
else => error.DatabaseFailure,
};
return row[0];
}

View File

@ -1,7 +1,6 @@
const std = @import("std");
const builtin = @import("builtin");
const util = @import("util");
const models = @import("../db/models.zig");
const getRandom = @import("../api.zig").getRandom;
const Uuid = util.Uuid;
@ -14,22 +13,14 @@ const code_len = 12;
const Encoder = std.base64.url_safe.Encoder;
const Decoder = std.base64.url_safe.Decoder;
pub const InviteType = enum {
pub const InviteKind = enum {
system,
community_owner,
user,
pub const jsonStringify = defaultJsonStringify(@This());
pub const jsonStringify = util.jsonSerializeEnumAsString;
};
fn defaultJsonStringify(comptime T: type) fn (T, std.json.StringifyOptions, anytype) anyerror!void {
return struct {
pub fn jsonStringify(s: T, _: std.json.StringifyOptions, writer: anytype) !void {
return std.fmt.format(writer, "\"{s}\"", .{@tagName(s)});
}
}.jsonStringify;
}
const InviteCount = u16;
pub const Invite = struct {
id: Uuid,
@ -45,116 +36,102 @@ pub const Invite = struct {
expires_at: ?DateTime,
max_uses: ?InviteCount,
invite_type: InviteType,
invite_kind: InviteKind,
};
const DbModel = struct {
id: Uuid,
created_by: Uuid, // User ID
to_community: ?Uuid,
name: []const u8,
code: []const u8,
created_at: DateTime,
expires_at: ?DateTime,
max_uses: ?InviteCount,
@"type": InviteType,
};
fn cloneStr(str: []const u8, alloc: std.mem.Allocator) ![]const u8 {
const new = try alloc.alloc(u8, str.len);
std.mem.copy(u8, new, str);
return new;
}
pub const InviteOptions = struct {
name: ?[]const u8 = null,
max_uses: ?InviteCount = null,
expires_at: ?DateTime = null,
invite_type: InviteType = .user,
lifespan: ?DateTime.Duration = null,
invite_kind: InviteKind = .user,
};
pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: InviteOptions, alloc: std.mem.Allocator) !Invite {
pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: InviteOptions, alloc: std.mem.Allocator) !Uuid {
const id = Uuid.randV4(getRandom());
var code_bytes: [rand_len]u8 = undefined;
getRandom().bytes(&code_bytes);
const code = try alloc.alloc(u8, code_len);
errdefer alloc.free(code);
defer alloc.free(code);
_ = Encoder.encode(code, &code_bytes);
const name = if (options.name) |name|
try cloneStr(name, alloc)
else
try cloneStr(code, alloc);
errdefer alloc.free(name);
const id = Uuid.randV4(getRandom());
const name = options.name orelse code;
const created_at = DateTime.now();
try db.insert("invite", DbModel{
.id = id,
try db.insert(
"invite",
.{
.id = id,
.created_by = created_by,
.to_community = to_community,
.name = name,
.code = code,
.created_by = created_by,
.to_community = to_community,
.name = name,
.code = code,
.created_at = created_at,
.expires_at = options.expires_at,
.max_uses = options.max_uses,
.created_at = created_at,
.expires_at = if (options.lifespan) |lifespan|
created_at.add(lifespan)
else
null,
.max_uses = options.max_uses,
.invite_kind = options.invite_kind,
},
alloc,
);
.@"type" = options.invite_type,
});
return Invite{
.id = id,
.created_by = created_by,
.to_community = to_community,
.name = name,
.code = code,
.created_at = created_at,
.expires_at = options.expires_at,
.times_used = 0,
.max_uses = options.max_uses,
.invite_type = options.invite_type,
};
return id;
}
pub fn getByCode(db: anytype, code: []const u8, alloc: std.mem.Allocator) !Invite {
const code_clone = try cloneStr(code, alloc);
const info = (try db.queryRow(std.meta.Tuple(&.{ Uuid, Uuid, Uuid, []const u8, DateTime, ?DateTime, InviteCount, ?InviteCount, InviteType }),
\\SELECT
\\ invite.id, invite.created_by, invite.to_community, invite.name,
\\ invite.created_at, invite.expires_at,
\\ COUNT(local_user.user_id) as uses, invite.max_uses,
\\ invite.type
\\FROM invite LEFT OUTER JOIN local_user ON invite.id = local_user.invite_id
\\WHERE invite.code = $1
pub const GetError = error{
NotFound,
DatabaseFailure,
};
// Helper fn for getting a single invite
fn doGetQuery(
db: anytype,
comptime where: []const u8,
query_args: anytype,
alloc: std.mem.Allocator,
) GetError!Invite {
// Generate list of fields from struct
const field_list = util.comptimeJoinWithPrefix(
",",
"invite.",
std.meta.fieldNames(Invite),
);
// times_used field is not stored directly in the DB, instead
// it is calculated based on the number of accounts that were created
// from it
const query = std.fmt.comptimePrint(
\\SELECT {s}, COUNT(local_account.id) AS times_used
\\FROM invite LEFT OUTER JOIN local_account
\\ ON invite.id = local_account.invite_id
\\WHERE {s}
\\GROUP BY invite.id
, .{code}, alloc)) orelse return error.NotFound;
\\LIMIT 1
,
.{ field_list, where },
);
return Invite{
.id = info[0],
.created_by = info[1],
.to_community = info[2],
.name = info[3],
.code = code_clone,
.created_at = info[4],
.expires_at = info[5],
.times_used = info[6],
.max_uses = info[7],
.invite_type = info[8],
return db.queryRow(Invite, query, query_args, alloc) catch |err| switch (err) {
error.NoRows => error.NotFound,
else => error.DatabaseFailure,
};
}
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Invite {
return doGetQuery(db, "invite.id = $1", .{id}, alloc);
}
pub fn getByCode(db: anytype, code: []const u8, community_id: Uuid, alloc: std.mem.Allocator) GetError!Invite {
return doGetQuery(
db,
"invite.code = $1 AND invite.community_id = $2",
.{ code, community_id },
alloc,
);
}

View File

@ -1,6 +1,6 @@
const std = @import("std");
const util = @import("util");
const auth = @import("./auth.zig");
const sql = @import("sql");
const Uuid = util.Uuid;
const DateTime = util.DateTime;
@ -14,47 +14,42 @@ pub const Note = struct {
created_at: DateTime,
};
const DbModel = struct {
id: Uuid,
author_id: Uuid,
content: []const u8,
created_at: DateTime,
pub const CreateError = error{
DatabaseFailure,
};
pub fn create(
db: anytype,
author: Uuid,
content: []const u8,
) !Uuid {
alloc: std.mem.Allocator,
) CreateError!Uuid {
const id = Uuid.randV4(getRandom());
try db.insert("note", .{
db.insert("note", .{
.id = id,
.author_id = author,
.content = content,
.created_at = DateTime.now(),
});
}, alloc) catch return error.DatabaseFailure;
return id;
}
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !Note {
const result = (try db.queryRow(
std.meta.Tuple(&.{ Uuid, []const u8, DateTime }),
\\SELECT author_id, content, created_at
\\FROM note
\\WHERE id = $1
\\LIMIT 1
,
pub const GetError = error{
DatabaseFailure,
NotFound,
};
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Note {
return db.queryRow(
Note,
sql.selectStar(Note, "note") ++
\\WHERE id = $1
\\LIMIT 1
,
.{id},
alloc,
)) orelse return error.NotFound;
return Note{
.id = id,
.author_id = result[0],
.content = result[1],
.created_at = result[2],
) catch |err| switch (err) {
error.NoRows => error.NotFound,
else => error.DatabaseFailure,
};
}

View File

@ -6,31 +6,11 @@ const Uuid = util.Uuid;
const DateTime = util.DateTime;
const getRandom = @import("../api.zig").getRandom;
const UserAuthInfo = struct {
password: []const u8,
email: []const u8,
invite_used: ?Uuid,
};
pub const CreateError = error{
UsernameTaken,
DbError,
};
const DbUser = struct {
id: Uuid,
username: []const u8,
community_id: Uuid,
};
const DbLocalUser = struct {
user_id: Uuid,
invite_id: ?Uuid,
email: ?[]const u8,
};
pub const Role = enum {
user,
admin,
@ -42,46 +22,66 @@ pub const CreateOptions = struct {
role: Role = .user,
};
fn lookupByUsernameInternal(db: anytype, username: []const u8, community_id: Uuid) CreateError!?Uuid {
return if (db.queryRow(
pub const LookupError = error{
DatabaseFailure,
};
pub fn lookupByUsername(
db: anytype,
username: []const u8,
community_id: Uuid,
alloc: std.mem.Allocator,
) LookupError!?Uuid {
const row = db.queryRow(
std.meta.Tuple(&.{Uuid}),
"SELECT user.id FROM user WHERE community_id = $1 AND username = $2",
.{ community_id, username },
null,
) catch return error.DbError) |result|
result[0]
else
null;
}
pub fn lookupByUsername(db: anytype, username: []const u8, community_id: Uuid) CreateError!Uuid {
return (lookupByUsernameInternal(db, username, community_id) catch return error.DbError) orelse error.NotFound;
\\SELECT id
\\FROM account
\\WHERE username = $1 AND community_id = $2
\\LIMIT 1
,
.{ username, community_id },
alloc,
) catch |err| return switch (err) {
error.NoRows => null,
else => error.DatabaseFailure,
};
return row[0];
}
// TODO: This fn sucks.
// auth.passwords.create requires that the user exists, but we shouldn't
// hold onto a transaction for the ~0.5s that it takes to hash the password.
// Should probably change this to be specifically about creating the user,
// and then have something in auth responsible for creating local accounts
pub fn create(
db: anytype,
username: []const u8,
password: []const u8,
community_id: Uuid,
options: CreateOptions,
password_alloc: std.mem.Allocator,
alloc: std.mem.Allocator,
) CreateError!Uuid {
const id = Uuid.randV4(getRandom());
if ((try lookupByUsernameInternal(db, username, community_id)) != null) {
return error.UsernameTaken;
}
const tx = db.begin();
errdefer tx.rollback();
db.insert("user", .{
tx.insert("account", .{
.id = id,
.username = username,
.community_id = community_id,
}) catch return error.DbError;
try auth.passwords.create(db, id, password, password_alloc);
db.insert("local_user", .{
.role = options.role,
}, alloc) catch |err| return switch (err) {
error.UniqueViolation => error.UsernameTaken,
else => error.DatabaseFailure,
};
try auth.passwords.create(tx, id, password, alloc);
tx.insert("local_account", .{
.user_id = id,
.invite_id = options.invite_id,
.email = options.email,
}) catch return error.DbError;
}) catch return error.DatabaseFailure;
try tx.commit();
return id;
}
@ -93,13 +93,14 @@ pub const User = struct {
host: []const u8,
community_id: Uuid,
role: Role,
created_at: DateTime,
};
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !User {
const result = (try db.queryRow(
std.meta.Tuple(&.{ []const u8, []const u8, Uuid, DateTime }),
return db.queryRow(
User,
\\SELECT user.username, community.host, community.id, user.created_at
\\FROM user JOIN community ON user.community_id = community.id
\\WHERE user.id = $1
@ -107,13 +108,8 @@ pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !User {
,
.{id},
alloc,
)) orelse return error.NotFound;
return User{
.id = id,
.username = result[0],
.host = result[1],
.community_id = result[2],
.created_at = result[3],
) catch |err| switch (err) {
error.NoRows => error.NotFound,
else => error.DatabaseFailure,
};
}

View File

@ -1,170 +0,0 @@
const std = @import("std");
const util = @import("util");
const Uuid = util.Uuid;
const models = @import("./models.zig");
// Clones a struct and its fields to a single layer of depth.
// Caller owns memory, can be freed using free below
// TODO: check that this is a struct, etc etc
fn clone(alloc: std.mem.Allocator, val: anytype) !@TypeOf(val) {
var result: @TypeOf(val) = undefined;
errdefer {
@panic("memory leak in deep clone, fix this");
}
inline for (std.meta.fields(@TypeOf(val))) |f| {
// TODO
if (f.field_type == []u8 or f.field_type == []const u8) {
@field(result, f.name) = try cloneString(alloc, @field(val, f.name));
} else if (f.field_type == Uuid) {
@field(result, f.name) = @field(val, f.name);
} else {
@compileError("unsupported field type " ++ @typeName(f.field_type));
}
}
return result;
}
fn cloneString(alloc: std.mem.Allocator, str: []const u8) ![]const u8 {
var result = try alloc.alloc(u8, str.len);
std.mem.copy(u8, result, str);
return result;
}
// Frees a struct and its fields returned by clone
pub fn free(alloc: std.mem.Allocator, val: anytype) void {
inline for (std.meta.fields(@TypeOf(val))) |f| {
// TODO
if (f.field_type == []u8 or f.field_type == []const u8) {
alloc.free(@field(val, f.name));
} else if (f.field_type == Uuid) {
// nothing
} else {
@compileError("unsupported field type " ++ @typeName(f.field_type));
}
}
}
pub fn Table(comptime T: type) type {
return struct {
const Self = @This();
internal_alloc: std.mem.Allocator,
data: std.AutoHashMap(Uuid, T),
pub fn init(alloc: std.mem.Allocator) !Self {
return Self{
.internal_alloc = alloc,
.data = std.AutoHashMap(Uuid, T).init(alloc),
};
}
pub fn deinit(self: *Self) void {
var iter = self.data.iterator();
while (iter.next()) |it| {
free(self.internal_alloc, it.value_ptr.*);
}
self.data.deinit();
}
pub fn contains(self: *Self, id: Uuid) !bool {
return self.data.contains(id);
}
// returns a copy of the note data from storage. memory is allocated with the provided
// allocator. can be freed using free() above
pub fn get(self: *Self, id: Uuid, alloc: std.mem.Allocator) !?T {
const data = self.data.get(id) orelse return null;
return try clone(alloc, data);
}
pub fn put(self: *Self, data: T) !void {
const copy = try clone(self.internal_alloc, data);
errdefer free(self.internal_alloc, copy);
const key = copy.id;
if (self.data.fetchRemove(key)) |e| {
free(self.internal_alloc, e.value);
}
try self.data.put(key, copy);
}
// TODO
pub fn lock(_: *Self) !void {
return;
}
pub fn unlock(_: *Self) void {
return;
}
};
}
pub const Database = struct {
internal_alloc: std.mem.Allocator,
notes: Table(models.Note),
users: Table(models.User),
pub fn init(alloc: std.mem.Allocator) !Database {
var db = Database{
.internal_alloc = alloc,
.notes = try Table(models.Note).init(alloc),
.users = try Table(models.User).init(alloc),
};
return db;
}
pub fn deinit(self: *Database) void {
self.notes.deinit();
self.users.deinit();
}
};
test "clone" {
const T = struct {
name: []const u8,
value: []const u8,
};
const copy = try clone(std.testing.allocator, T{ .name = "myName", .value = "myValue" });
free(std.testing.allocator, copy);
}
test "db" {
var db = try Database.init(std.testing.allocator);
defer db.deinit();
try db.putNote(.{
.id = "100",
.content = "content",
});
const note = (try db.getNote("100", std.testing.allocator)).?;
free(std.testing.allocator, note);
}
test "db" {
var db = try Database.init(std.testing.allocator);
defer db.deinit();
try db.putNote(.{
.id = "100",
.content = "content",
});
try db.putNote(.{
.id = "100",
.content = "content",
});
try db.putNote(.{
.id = "100",
.content = "content",
});
}

View File

@ -19,7 +19,7 @@ fn firstIndexOf(str: []const u8, char: u8) ?usize {
fn execStmt(tx: sql.Tx, stmt: []const u8, alloc: std.mem.Allocator) !void {
const stmt_null = try std.cstr.addNullByte(alloc, stmt);
defer alloc.free(stmt_null);
try tx.exec(stmt_null, .{}, null);
try tx.exec(stmt_null, {}, null);
}
fn execScript(db: *sql.Db, script: []const u8, alloc: std.mem.Allocator) !void {
@ -53,7 +53,7 @@ pub fn up(db: *sql.Db) !void {
if (!was_ran) {
std.log.info("Running migration {s}", .{migration.name});
try execScript(db, migration.up, gpa.allocator());
try db.insert("migration", .{ .name = migration.name });
try db.insert("migration", .{ .name = migration.name }, gpa.allocator());
}
}
}

View File

@ -59,7 +59,7 @@ pub const QueryOptions = struct {
// If true, then it will not return an error on the SQLite backend
// if an argument passed does not map to a parameter in the query.
// Has no effect on the postgres backend.
ignore_unknown_parameters: bool = false,
ignore_unused_arguments: bool = false,
// The allocator to use for query preparation and submission.
// All memory allocated with this allocator will be freed before results

View File

@ -1,3 +1,3 @@
usingnamespace @cImport({
pub usingnamespace @cImport({
@cInclude("libpq-fe.h");
});

View File

@ -31,37 +31,6 @@ pub const Config = union(Engine) {
},
};
pub fn fieldList(comptime RowType: type) []const u8 {
comptime {
const fields = std.meta.fieldNames(RowType);
const separator = ", ";
if (fields.len == 0) return "";
var size: usize = 1; // 1 for null terminator
for (fields) |f| size += f.len + separator.len;
size -= separator.len;
var buf = std.mem.zeroes([size]u8);
// can't use std.mem.join because of problems with comptime allocation
// https://github.com/ziglang/zig/issues/5873#issuecomment-1001778218
//var fba = std.heap.FixedBufferAllocator.init(&buf);
//return (std.mem.join(fba.allocator(), separator, fields) catch unreachable) ++ " ";
var buf_idx = 0;
for (fields) |f, i| {
std.mem.copy(u8, buf[buf_idx..], f);
buf_idx += f.len;
if (i != fields.len - 1) std.mem.copy(u8, buf[buf_idx..], separator);
buf_idx += separator.len;
}
return &buf;
}
}
//pub const OpenError = sqlite.OpenError | postgres.OpenError;
const RawResults = union(Engine) {
postgres: postgres.Results,
sqlite: sqlite.Results,
@ -126,12 +95,6 @@ pub fn Results(comptime T: type) type {
self.underlying.finish();
}
// can be used as an optimization to reduce memory reallocation
// only works on postgres
pub fn rowCount(self: Self) ?usize {
return self.underlying.rowCount();
}
// Returns the next row of results, or null if there are no more rows.
// Caller owns all memory allocated. The entire object can be deallocated with a
// call to util.deepFree
@ -229,13 +192,14 @@ const QueryHelper = union(Engine) {
q: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryRowError!?RowType {
) QueryRowError!RowType {
var results = try self.query(RowType, q, args, alloc);
defer results.finish();
const row = (try results.row(alloc)) orelse return null;
const row = (try results.row(alloc)) orelse return error.NoRows;
errdefer util.deepFree(alloc, row);
// execute query to completion
var more_rows = false;
while (try results.row(alloc)) |r| {
util.deepFree(alloc, r);
@ -251,6 +215,7 @@ const QueryHelper = union(Engine) {
self: QueryHelper,
comptime table: []const u8,
value: anytype,
alloc: ?std.mem.Allocator,
) !void {
const ValueType = comptime @TypeOf(value);
@ -274,7 +239,7 @@ const QueryHelper = union(Engine) {
inline for (fields) |field, i| {
args_tuple[i] = @field(value, field.name);
}
try self.exec(q, args_tuple, null);
try self.exec(q, args_tuple, alloc);
}
};
@ -347,7 +312,7 @@ pub const Db = struct {
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryRowError!?RowType {
) QueryRowError!RowType {
if (self.tx_open) return error.BadTransactionState;
return self.engine.queryRow(RowType, sql, args, alloc);
}
@ -356,9 +321,10 @@ pub const Db = struct {
self: *Db,
comptime table: []const u8,
value: anytype,
alloc: ?std.mem.Allocator,
) !void {
if (self.tx_open) return error.BadTransactionState;
return self.engine.insert(table, value);
return self.engine.insert(table, value, alloc);
}
pub fn sqlEngine(self: *Db) Engine {
@ -369,11 +335,10 @@ pub const Db = struct {
pub fn begin(self: *Db) !Tx {
if (self.tx_open) return error.BadTransactionState;
const tx = Tx{ .db = self };
try tx.exec("BEGIN", {}, null);
try self.exec("BEGIN", {}, null);
self.tx_open = true;
return tx;
return Tx{ .db = self };
}
};
@ -421,7 +386,7 @@ pub const Tx = struct {
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryRowError!?RowType {
) QueryRowError!RowType {
if (!self.db.tx_open) return error.BadTransactionState;
return self.db.engine.queryRow(RowType, sql, args, alloc);
}
@ -431,15 +396,19 @@ pub const Tx = struct {
self: Tx,
comptime table: []const u8,
value: anytype,
alloc: ?std.mem.Allocator,
) !void {
if (!self.db.tx_open) return error.BadTransactionState;
return self.db.engine.insert(table, value);
return self.db.engine.insert(table, value, alloc);
}
pub fn sqlEngine(self: Tx) Engine {
return self.db.engine;
}
// Allows relaxing *some* constraints for the lifetime of the transaction.
// You should generally not do this, but it's useful when bootstrapping
// the initial admin community and cluster operator user.
pub fn setConstraintMode(self: Tx, mode: ConstraintMode) QueryError!void {
if (!self.db.tx_open) return error.BadTransactionState;
switch (self.db.engine) {
@ -472,7 +441,14 @@ pub const Tx = struct {
pub fn commit(self: Tx) CommitError!void {
if (!self.db.tx_open) return error.BadTransactionState;
try self.exec("COMMIT", {}, null);
self.exec("COMMIT", {}, null) catch |err| switch (err) {
error.InvalidArgument,
error.OutOfMemory,
error.UndefinedParameter,
=> return error.Unexpected,
else => return err,
};
self.db.tx_open = false;
}
};

View File

@ -3,6 +3,10 @@ const DateTime = @This();
const std = @import("std");
const epoch = std.time.epoch;
pub const Duration = struct {
seconds: i64 = 0,
};
seconds_since_epoch: i64,
pub fn parse(str: []const u8) !DateTime {
@ -11,6 +15,18 @@ pub fn parse(str: []const u8) !DateTime {
return try parseRfc3339(str);
}
pub fn add(self: DateTime, duration: Duration) DateTime {
return DateTime{
.seconds_since_epoch = self.seconds_since_epoch + duration.seconds,
};
}
pub fn sub(self: DateTime, duration: Duration) DateTime {
return DateTime{
.seconds_since_epoch = self.seconds_since_epoch - duration.seconds,
};
}
// TODO: Validate non-numeric aspects of datetime
// TODO: Don't panic on bad string
// TODO: Make seconds optional (see ActivityStreams 2.0 spec §2.3)

View File

@ -6,29 +6,85 @@ pub const DateTime = @import("./DateTime.zig");
pub const PathIter = @import("./PathIter.zig");
pub const Url = @import("./Url.zig");
fn comptimeJoinSlice(comptime separator: []const u8, comptime slices: []const []const u8) []u8 {
/// Joins an array of strings, prefixing every entry with `prefix`,
/// and putting `separator` in between each pair
pub fn comptimeJoinWithPrefix(
comptime separator: []const u8,
comptime prefix: []const u8,
comptime strs: []const []const u8,
) []const u8 {
comptime {
var size: usize = 1; // 1 for null terminator
for (slices) |s| size += s.len + separator.len;
if (slices.len != 0) size -= separator.len;
if (strs.len == 0) return "";
var size: usize = 0;
for (strs) |str| size += prefix.len + str.len + separator.len;
size -= separator.len;
var buf = std.mem.zeroes([size]u8);
var fba = std.heap.fixedBufferAllocator(&buf);
return std.mem.join(fba.allocator(), separator, slices);
// can't use std.mem.join because of problems with comptime allocation
// https://github.com/ziglang/zig/issues/5873#issuecomment-1001778218
//var fba = std.heap.FixedBufferAllocator.init(&buf);
//return (std.mem.join(fba.allocator(), separator, fields) catch unreachable) ++ " ";
var buf_idx = 0;
for (strs) |str, i| {
std.mem.copy(u8, buf[buf_idx..], prefix);
buf_idx += prefix.len;
std.mem.copy(u8, buf[buf_idx..], str);
buf_idx += str.len;
if (i != strs.len - 1) {
std.mem.copy(u8, buf[buf_idx..], separator);
buf_idx += separator.len;
}
}
return &buf;
}
}
pub fn comptimeJoin(comptime separator: []const u8, comptime slices: []const []const u8) *const [comptimeJoinSlice(separator, slices):0]u8 {
const slice = comptimeJoinSlice(separator, slices);
return slice[0..slice.len];
/// Joins an array of strings, putting `separator` in between each pair
pub fn comptimeJoin(
comptime separator: []const u8,
comptime strs: []const []const u8,
) []const u8 {
return comptimeJoinWithPrefix(separator, "", strs);
}
/// Helper function to serialize a runtime enum value as a string inside JSON.
/// To use, add
/// ```
/// pub const jsonStringify = util.jsonSerializeEnumAsString;
/// ```
/// to your enum type.
pub fn jsonSerializeEnumAsString(
enum_value: anytype,
_: std.json.StringifyOptions,
writer: anytype,
) !void {
switch (@typeInfo(@TypeOf(enum_value))) {
.Enum => |info| if (!info.is_exhaustive) @compileError("Enum must be exhaustive"),
else => @compileError("Must be enum type"),
}
return std.fmt.format(writer, "\"{s}\"", .{@tagName(enum_value)});
}
/// Recursively frees a struct/array/slice/etc using the given allocator
/// by freeing any slices or pointers inside. Assumes that every pointer-like
/// object within points to its own allocation that must be free'd separately.
/// Do *not* use on self-referential types or structs that contain duplicate
/// slices.
/// Meant to be the inverse of `deepClone` below
pub fn deepFree(alloc: ?std.mem.Allocator, val: anytype) void {
const T = @TypeOf(val);
switch (@typeInfo(T)) {
.Pointer => |ptr| switch (ptr.size) {
.One => alloc.?.destroy(val),
.One => {
deepFree(alloc, val.*);
alloc.?.destroy(val);
},
.Slice => {
for (val) |v| deepFree(alloc, v);
alloc.?.free(val);
@ -46,7 +102,9 @@ pub fn deepFree(alloc: ?std.mem.Allocator, val: anytype) void {
}
}
// deepClone assumes the value owns any pointers inside it
/// Clones a struct/array/slice/etc and all its submembers.
/// Assumes that there are no self-refrential pointers within and that
/// every pointer should be followed.
pub fn deepClone(alloc: std.mem.Allocator, val: anytype) !@TypeOf(val) {
const T = @TypeOf(val);
var result: T = undefined;