Refactoring

This commit is contained in:
jaina heartles 2022-09-29 14:52:01 -07:00
parent 33d1834f19
commit c42039c559
13 changed files with 369 additions and 180 deletions

37
src/OVERVIEW.md Normal file
View file

@ -0,0 +1,37 @@
# Overview
## Packages
- `main`: primary package, has application-specific functionality
* TODO: consider moving controllers and api into different packages
* `controllers/**.zig`:
- Transforms HTTP to/from API calls
- Turns error codes into HTTP statuses
* `api.zig`:
- Makes sure API call is allowed with the given user/host context
- Transforms API models into display models
- `api/**.zig`: Performs action associated with API call
* Transforms DB models into API models
* Data validation
- TODO: the distinction between what goes in `api.zig` and in its submodules is gross. Refactor?
* `migrations.zig`:
- Defines database migrations to apply
- Should be ran on startup
- `util`: utility packages
* Components:
- `Uuid`: UUID utils (random uuid generation, equality, parsing, printing)
* `Uuid.eql`
* `Uuid.randV4`
* UUID's are serialized to their string representation for JSON, db
- `PathIter`: Path segment iterator
- `Url`: URL utils (parsing)
- `ciutf8`: case-insensitive UTF-8 (TODO: Scrap this, replace with ICU library)
- `DateTime`: Time utils
- `deepClone(alloc, orig)`/`deepFree(alloc, to_free)`: Utils for cloning and freeing basic data structs
* Clones/frees any strings/sub structs within the value
- `sql`: SQL library
* Supports 2 engines (SQLite, PostgreSQL)
* `var my_transaction = try db.begin()`
* `const results = try db.query(RowType, "SELECT ...", .{arg_1, ...}, alloc)`
- `http`: HTTP Server
* The API sucks. Needs a refactor

View file

@ -94,6 +94,42 @@ pub fn getRandom() std.rand.Random {
return prng.random();
}
pub fn isAdminSetup(db: *sql.Db) !bool {
_ = services.communities.adminCommunityId(db) catch |err| switch (err) {
error.NotFound => return false,
else => return err,
};
return true;
}
pub fn setupAdmin(db: *sql.Db, origin: []const u8, username: []const u8, password: []const u8, allocator: std.mem.Allocator) !void {
const tx = try db.begin();
errdefer tx.rollback();
var arena = std.heap.ArenaAllocator.init(allocator);
defer arena.deinit();
try tx.setConstraintMode(.deferred);
const community = try services.communities.create(
tx,
origin,
Uuid.nil,
.{ .name = "Cluster Admin", .kind = .admin },
);
const user = try services.users.create(tx, username, password, community.id, .{ .role = .admin }, arena.allocator());
try services.communities.transferOwnership(tx, community.id, user);
try tx.commit();
std.log.info(
"Created admin user {s} (id {}) with cluster admin origin {s} (id {})",
.{ username, user, origin, community.id },
);
}
pub const ApiSource = struct {
db: *sql.Db,
internal_alloc: std.mem.Allocator,
@ -103,68 +139,43 @@ pub const ApiSource = struct {
const root_username = "root";
pub fn init(alloc: std.mem.Allocator, cfg: Config, root_password: ?[]const u8, db_conn: *sql.Db) !ApiSource {
var self = ApiSource{
pub fn init(alloc: std.mem.Allocator, cfg: Config, db_conn: *sql.Db) !ApiSource {
return ApiSource{
.db = db_conn,
.internal_alloc = alloc,
.config = cfg,
};
try migrations.up(db_conn);
if ((try services.users.lookupByUsername(self.db, root_username, null)) == null) {
std.log.info("No cluster root user detected. Creating...", .{});
// TODO: Fix this
const password = root_password orelse return error.NeedRootPassword;
var arena = std.heap.ArenaAllocator.init(alloc);
defer arena.deinit();
const user_id = try services.users.create(self.db, root_username, password, null, .{}, arena.allocator());
std.log.debug("Created {s} ID {}", .{ root_username, user_id });
}
return self;
}
fn getCommunityFromHost(self: *ApiSource, host: []const u8) !?Uuid {
if (try self.db.queryRow(
&.{Uuid},
"SELECT id FROM community WHERE host = $1",
.{host},
null,
)) |result| return result[0];
// Test for cluster admin community
if (util.ciutf8.eql(self.config.cluster_host, host)) {
return null;
}
return error.NoCommunity;
}
pub fn connectUnauthorized(self: *ApiSource, host: []const u8, alloc: std.mem.Allocator) !Conn {
const community_id = try self.getCommunityFromHost(host);
var arena = std.heap.ArenaAllocator.init(alloc);
errdefer arena.deinit();
const community = try services.communities.getByHost(self.db, host, arena.allocator());
return Conn{
.db = self.db,
.internal_alloc = self.internal_alloc,
.user_id = null,
.community_id = community_id,
.arena = std.heap.ArenaAllocator.init(alloc),
.community = community,
.arena = arena,
};
}
pub fn connectToken(self: *ApiSource, host: []const u8, token: []const u8, alloc: std.mem.Allocator) !Conn {
const community_id = try self.getCommunityFromHost(host);
var arena = std.heap.ArenaAllocator.init(alloc);
errdefer arena.deinit();
const token_info = try services.auth.tokens.verify(self.db, token, community_id);
const community = try services.communities.getByHost(self.db, host, arena.allocator());
const token_info = try services.auth.tokens.verify(self.db, token, community.id);
return Conn{
.db = self.db,
.internal_alloc = self.internal_alloc,
.user_id = token_info.user_id,
.community_id = community_id,
.arena = std.heap.ArenaAllocator.init(alloc),
.community = community,
.arena = arena,
};
}
};
@ -176,7 +187,7 @@ fn ApiConn(comptime DbConn: type) type {
db: DbConn,
internal_alloc: std.mem.Allocator, // used *only* for large, internal buffers
user_id: ?Uuid,
community_id: ?Uuid,
community: services.communities.Community,
arena: std.heap.ArenaAllocator,
pub fn close(self: *Self) void {
@ -185,11 +196,11 @@ fn ApiConn(comptime DbConn: type) type {
fn isAdmin(self: *Self) bool {
// TODO
return self.user_id != null and self.community_id == null;
return self.user_id != null and self.community.kind == .admin;
}
pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResponse {
const user_id = (try services.users.lookupByUsername(self.db, username, self.community_id)) orelse return error.InvalidLogin;
const user_id = (try services.users.lookupByUsername(self.db, username, self.community.id)) orelse return error.InvalidLogin;
try services.auth.passwords.verify(self.db, user_id, password, self.internal_alloc);
const token = try services.auth.tokens.create(self.db, user_id);
@ -207,7 +218,7 @@ fn ApiConn(comptime DbConn: type) type {
pub fn getTokenInfo(self: *Self) !TokenInfo {
if (self.user_id) |user_id| {
const result = (try self.db.queryRow(
&.{[]const u8},
std.meta.Tuple(&.{[]const u8}),
"SELECT username FROM user WHERE id = $1",
.{user_id},
self.arena.allocator(),
@ -225,7 +236,7 @@ fn ApiConn(comptime DbConn: type) type {
return error.PermissionDenied;
}
return services.communities.create(self.db, origin, null);
return services.communities.create(self.db, origin, self.user_id.?, .{});
}
pub fn createInvite(self: *Self, options: InviteRequest) !services.invites.Invite {
@ -234,13 +245,13 @@ fn ApiConn(comptime DbConn: type) type {
const community_id = if (options.to_community) |host| blk: {
// You can only specify a different community if you're on the admin domain
if (self.community_id != null) return error.WrongCommunity;
if (self.community.kind != .admin) return error.WrongCommunity;
// Only admins can invite on the admin domain
if (!self.isAdmin()) return error.PermissionDenied;
break :blk (try services.communities.getByHost(self.db, host, self.arena.allocator())).id;
} else self.community_id;
} else self.community.id;
// Users can only make user invites
if (options.invite_type != .user and !self.isAdmin()) return error.PermissionDenied;
@ -257,19 +268,19 @@ fn ApiConn(comptime DbConn: type) type {
std.log.debug("registering user {s} with code {s}", .{ request.username, request.invite_code });
const invite = try services.invites.getByCode(self.db, request.invite_code, self.arena.allocator());
if (!Uuid.eql(invite.to_community, self.community_id)) return error.NotFound;
if (!Uuid.eql(invite.to_community, self.community.id)) return error.NotFound;
if (invite.max_uses != null and invite.times_used >= invite.max_uses.?) return error.InviteExpired;
if (invite.expires_at != null and DateTime.now().isAfter(invite.expires_at.?)) return error.InviteExpired;
if (self.community_id == null) @panic("Unimplmented");
if (self.community.kind == .admin) @panic("Unimplmented");
const user_id = try services.users.create(self.db, request.username, request.password, self.community_id, .{ .invite_id = invite.id, .email = request.email }, self.internal_alloc);
const user_id = try services.users.create(self.db, request.username, request.password, self.community.id, .{ .invite_id = invite.id, .email = request.email }, self.internal_alloc);
switch (invite.invite_type) {
.user => {},
.system => @panic("System user invites unimplemented"),
.community_owner => {
try services.communities.transferOwnership(self.db, self.community_id.?, user_id);
try services.communities.transferOwnership(self.db, self.community.id, user_id);
},
}
@ -283,7 +294,7 @@ fn ApiConn(comptime DbConn: type) type {
const user = try services.users.get(self.db, user_id, self.arena.allocator());
if (self.user_id == null) {
if (!Uuid.eql(self.community_id, user.community_id)) return error.NotFound;
if (!Uuid.eql(self.community.id, user.community_id)) return error.NotFound;
}
return UserResponse{
@ -295,7 +306,7 @@ fn ApiConn(comptime DbConn: type) type {
}
pub fn createNote(self: *Self, content: []const u8) !NoteResponse {
if (self.community_id == null) return error.WrongCommunity;
if (self.community.kind == .admin) return error.WrongCommunity;
const user_id = self.user_id orelse return error.TokenRequired;
const note_id = try services.notes.create(self.db, user_id, content);
@ -312,7 +323,7 @@ fn ApiConn(comptime DbConn: type) type {
// Only serve community-specific notes on unauthenticated requests
if (self.user_id == null) {
if (!Uuid.eql(self.community_id, user.community_id)) return error.NotFound;
if (!Uuid.eql(self.community.id, user.community_id)) return error.NotFound;
}
return NoteResponse{

View file

@ -24,7 +24,7 @@ pub const passwords = struct {
pub fn verify(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) VerifyError!void {
// TODO: This could be done w/o the dynamically allocated hash buf
const hash = (db.queryRow(
&.{[]const u8},
std.meta.Tuple(&.{[]const u8}),
"SELECT hashed_password FROM account_password WHERE user_id = $1 LIMIT 1",
.{user_id},
alloc,
@ -96,7 +96,7 @@ pub const tokens = struct {
fn lookupUserTokenFromHash(db: anytype, hash: []const u8, community_id: Uuid) !?Token.Info {
return if (try db.queryRow(
&.{ Uuid, DateTime },
std.meta.Tuple(&.{ Uuid, DateTime }),
\\SELECT user.id, token.issued_at
\\FROM token JOIN user ON token.user_id = user.id
\\WHERE user.community_id = $1 AND token.hash = $2
@ -115,7 +115,7 @@ pub const tokens = struct {
fn lookupSystemTokenFromHash(db: anytype, hash: []const u8) !?Token.Info {
return if (try db.queryRow(
&.{ Uuid, DateTime },
std.meta.Tuple(&.{ Uuid, DateTime }),
\\SELECT user.id, token.issued_at
\\FROM token JOIN user ON token.user_id = user.id
\\WHERE user.community_id IS NULL AND token.hash = $1

View file

@ -1,7 +1,7 @@
const std = @import("std");
const builtin = @import("builtin");
const util = @import("util");
const models = @import("../db/models.zig");
const sql = @import("sql");
const getRandom = @import("../api.zig").getRandom;
@ -26,21 +26,31 @@ pub const Scheme = enum {
pub const Community = struct {
id: Uuid,
owner_id: ?Uuid,
owner_id: Uuid,
host: []const u8,
name: []const u8,
scheme: Scheme,
kind: Kind,
created_at: DateTime,
};
fn freeCommunity(alloc: std.mem.Allocator, c: Community) void {
alloc.free(c.host);
alloc.free(c.name);
}
pub const Kind = enum {
admin,
local,
pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Community {
const scheme_len = firstIndexOf(origin, ':') orelse return error.InvalidOrigin;
pub fn jsonStringify(val: Kind, _: std.json.StringifyOptions, writer: anytype) !void {
return std.fmt.format(writer, "\"{s}\"", .{@tagName(val)});
}
};
pub const CreateOptions = struct {
name: ?[]const u8 = null,
kind: Kind = .local,
};
pub fn create(db: anytype, origin: []const u8, owner: Uuid, options: CreateOptions) CreateError!Community {
const scheme_len = std.mem.indexOfScalar(u8, origin, ':') orelse return error.InvalidOrigin;
const scheme_str = origin[0..scheme_len];
const scheme = std.meta.stringToEnum(Scheme, scheme_str) orelse return error.UnsupportedScheme;
@ -55,10 +65,10 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co
// community cannot use non-default ports (except for testing)
// NOTE: Do not add, say localhost and localhost:80 or bugs may happen.
// Avoid using non-default ports unless a test can't be conducted without it.
if (firstIndexOf(host, ':') != null and builtin.mode != .Debug) return error.InvalidOrigin;
if (std.mem.indexOfScalar(u8, host, ':') != null and builtin.mode != .Debug) return error.InvalidOrigin;
// community cannot be hosted on a path
if (firstIndexOf(host, '/') != null) return error.InvalidOrigin;
if (std.mem.indexOfScalar(u8, host, '/') != null) return error.InvalidOrigin;
// Require TLS on production builds
if (scheme != .https and builtin.mode != .Debug) return error.UnsupportedScheme;
@ -67,14 +77,15 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co
const community = Community{
.id = id,
.owner_id = null,
.owner_id = owner,
.host = host,
.name = name orelse host,
.name = options.name orelse host,
.scheme = scheme,
.kind = options.kind,
.created_at = DateTime.now(),
};
if ((try db.queryRow(&.{Uuid}, "SELECT id FROM community WHERE host = $1", .{host}, null)) != null) {
if ((try db.queryRow(std.meta.Tuple(&.{Uuid}), "SELECT id FROM community WHERE host = $1", .{host}, null)) != null) {
return error.CommunityExists;
}
@ -83,25 +94,13 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co
return community;
}
fn firstIndexOf(str: []const u8, ch: u8) ?usize {
for (str) |c, i| {
if (c == ch) return i;
}
return null;
}
pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) !Community {
const result = (try db.queryRow(&.{ Uuid, ?Uuid, []const u8, []const u8, Scheme, DateTime }, "SELECT id, owner_id, host, name, scheme, created_at FROM community WHERE host = $1", .{host}, alloc)) orelse return error.NotFound;
return Community{
.id = result[0],
.owner_id = result[1],
.host = result[2],
.name = result[3],
.scheme = result[4],
.created_at = result[5],
};
return (try db.queryRow(
Community,
std.fmt.comptimePrint("SELECT {s} FROM community WHERE host = $1", .{comptime sql.fieldList(Community)}),
.{host},
alloc,
)) orelse return error.NotFound;
}
pub fn transferOwnership(db: anytype, community_id: Uuid, new_owner: Uuid) !void {
@ -247,7 +246,7 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) ![]Communit
errdefer alloc.free(result_buf);
var count: usize = 0;
errdefer for (result_buf[0..count]) |c| freeCommunity(alloc, c);
errdefer for (result_buf[0..count]) |c| util.deepFree(alloc, c);
for (result_buf) |*c| {
const row = results.row(alloc) orelse break;
@ -267,3 +266,14 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) ![]Communit
return result_buf[0..count];
}
pub fn adminCommunityId(db: anytype) !Uuid {
const row = (try db.queryRow(
std.meta.Tuple(&.{Uuid}),
"SELECT id FROM community WHERE kind = 'admin' LIMIT 1",
.{},
null,
)) orelse return error.NotFound;
return row[0];
}

View file

@ -130,7 +130,7 @@ pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: Invit
pub fn getByCode(db: anytype, code: []const u8, alloc: std.mem.Allocator) !Invite {
const code_clone = try cloneStr(code, alloc);
const info = (try db.queryRow(&.{ Uuid, Uuid, Uuid, []const u8, DateTime, ?DateTime, InviteCount, ?InviteCount, InviteType },
const info = (try db.queryRow(std.meta.Tuple(&.{ Uuid, Uuid, Uuid, []const u8, DateTime, ?DateTime, InviteCount, ?InviteCount, InviteType }),
\\SELECT
\\ invite.id, invite.created_by, invite.to_community, invite.name,
\\ invite.created_at, invite.expires_at,

View file

@ -41,7 +41,7 @@ pub fn create(
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !Note {
const result = (try db.queryRow(
&.{ Uuid, []const u8, DateTime },
std.meta.Tuple(&.{ Uuid, []const u8, DateTime }),
\\SELECT author_id, content, created_at
\\FROM note
\\WHERE id = $1

View file

@ -21,7 +21,7 @@ const DbUser = struct {
id: Uuid,
username: []const u8,
community_id: ?Uuid,
community_id: Uuid,
};
const DbLocalUser = struct {
@ -31,52 +31,43 @@ const DbLocalUser = struct {
email: ?[]const u8,
};
pub const Role = enum {
user,
admin,
};
pub const CreateOptions = struct {
invite_id: ?Uuid = null,
email: ?[]const u8 = null,
role: Role = .user,
};
fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid {
return if (try db.queryRow(
&.{Uuid},
"SELECT user.id FROM user WHERE community_id IS NULL AND username = $1",
.{username},
null,
)) |result|
result[0]
else
null;
}
fn lookupUserByUsername(db: anytype, username: []const u8, community_id: Uuid) !?Uuid {
return if (try db.queryRow(
&.{Uuid},
fn lookupByUsernameInternal(db: anytype, username: []const u8, community_id: Uuid) CreateError!?Uuid {
return if (db.queryRow(
std.meta.Tuple(&.{Uuid}),
"SELECT user.id FROM user WHERE community_id = $1 AND username = $2",
.{ community_id, username },
null,
)) |result|
) catch return error.DbError) |result|
result[0]
else
null;
}
pub fn lookupByUsername(db: anytype, username: []const u8, community_id: ?Uuid) !?Uuid {
return if (community_id) |id|
lookupUserByUsername(db, username, id) catch return error.DbError
else
lookupSystemUserByUsername(db, username) catch return error.DbError;
pub fn lookupByUsername(db: anytype, username: []const u8, community_id: Uuid) CreateError!Uuid {
return (lookupByUsernameInternal(db, username, community_id) catch return error.DbError) orelse error.NotFound;
}
pub fn create(
db: anytype,
username: []const u8,
password: []const u8,
community_id: ?Uuid,
community_id: Uuid,
options: CreateOptions,
password_alloc: std.mem.Allocator,
) CreateError!Uuid {
const id = Uuid.randV4(getRandom());
if ((try lookupByUsername(db, username, community_id)) != null) {
if ((try lookupByUsernameInternal(db, username, community_id)) != null) {
return error.UsernameTaken;
}
@ -108,7 +99,7 @@ pub const User = struct {
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !User {
const result = (try db.queryRow(
&.{ []const u8, []const u8, Uuid, DateTime },
std.meta.Tuple(&.{ []const u8, []const u8, Uuid, DateTime }),
\\SELECT user.username, community.host, community.id, user.created_at
\\FROM user JOIN community ON user.community_id = community.id
\\WHERE user.id = $1

View file

@ -5,7 +5,7 @@ const http = @import("http");
const util = @import("util");
pub const api = @import("./api.zig");
const models = @import("./db/models.zig");
const migrations = @import("./migrations.zig");
const Uuid = util.Uuid;
const c = @import("./controllers.zig");
@ -98,21 +98,49 @@ fn loadConfig(alloc: std.mem.Allocator) !Config {
return std.json.parse(Config, &ts, .{ .allocator = alloc });
}
const root_password_envvar = "CLUSTER_ROOT_PASSWORD";
const admin_origin_envvar = "CLUSTER_ADMIN_ORIGIN";
const admin_username_envvar = "CLUSTER_ADMIN_USERNAME";
const admin_password_envvar = "CLUSTER_ADMIN_PASSWORD";
fn runAdminSetup(db: *sql.Db, alloc: std.mem.Allocator) !void {
const origin = std.os.getenv(admin_origin_envvar) orelse return error.MissingArgument;
const username = std.os.getenv(admin_username_envvar) orelse return error.MissingArgument;
const password = std.os.getenv(admin_password_envvar) orelse return error.MissingArgument;
try api.setupAdmin(db, origin, username, password, alloc);
}
fn prepareDb(db: *sql.Db, alloc: std.mem.Allocator) !void {
try migrations.up(db);
if (!try api.isAdminSetup(db)) {
std.log.info("Performing first-time admin creation...", .{});
runAdminSetup(db, alloc) catch |err| switch (err) {
error.MissingArgument => {
std.log.err(
\\First time setup required but arguments not provided.
\\Please provide the following arguments via environment variable:
\\- {s}: The origin to serve the cluster admin panel at (ex: https://admin.example.com)
\\- {s}: The username for the initial cluster operator
\\- {s}: The password for the initial cluster operator
,
.{ admin_origin_envvar, admin_username_envvar, admin_password_envvar },
);
std.os.exit(1);
},
else => return err,
};
}
}
pub fn main() anyerror!void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
var cfg = try loadConfig(gpa.allocator());
var db_conn = try sql.Db.open(cfg.db);
var api_src = api.ApiSource.init(gpa.allocator(), cfg, std.os.getenv(root_password_envvar), &db_conn) catch |err| switch (err) {
error.NeedRootPassword => {
std.log.err(
"No root user created and no password specified. Please provide the password for the root user by the ${s} environment variable for initial startup. This only needs to be done once",
.{root_password_envvar},
);
return err;
},
else => return err,
};
try prepareDb(&db_conn, gpa.allocator());
var api_src = try api.ApiSource.init(gpa.allocator(), cfg, &db_conn);
var srv = try RequestServer.init(gpa.allocator(), &api_src, cfg);
api.initThreadPrng(@bitCast(u64, std.time.milliTimestamp()));
return srv.listenAndRun(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable);

View file

@ -38,7 +38,7 @@ fn execScript(db: *sql.Db, script: []const u8, alloc: std.mem.Allocator) !void {
}
fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !bool {
const row = (try db.queryRow(&.{i32}, "SELECT COUNT(*) FROM migration WHERE name = $1", .{name}, alloc)) orelse return false;
const row = (try db.queryRow(std.meta.Tuple(&.{i32}), "SELECT COUNT(*) FROM migration WHERE name = $1", .{name}, alloc)) orelse return false;
return row[0] != 0;
}
@ -152,7 +152,7 @@ const migrations: []const Migration = &.{
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
\\ expires_at TIMESTAMPTZ,
\\
\\ type TEXT NOT NULL CHECK (type in ('system', 'community_owner', 'user'))
\\ type TEXT NOT NULL CHECK (type in ('system_user', 'community_owner', 'user'))
\\);
\\ALTER TABLE local_user ADD COLUMN invite_id TEXT REFERENCES invite(id);
,
@ -171,14 +171,15 @@ const migrations: []const Migration = &.{
\\ name TEXT NOT NULL,
\\ host TEXT NOT NULL UNIQUE,
\\ scheme TEXT NOT NULL CHECK (scheme IN ('http', 'https')),
\\ kind TEXT NOT NULL CHECK (kind in ('admin', 'local')),
\\
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\);
\\ALTER TABLE user ADD COLUMN community_id TEXT REFERENCES community(id);
\\ALTER TABLE invite ADD COLUMN to_community TEXT REFERENCES community(id);
\\ALTER TABLE invite ADD COLUMN community_id TEXT REFERENCES community(id);
,
.down =
\\ALTER TABLE invite DROP COLUMN to_community;
\\ALTER TABLE invite DROP COLUMN community_id;
\\ALTER TABLE user DROP COLUMN community_id;
\\DROP TABLE community;
,

View file

@ -22,6 +22,41 @@ pub const Config = union(Engine) {
},
};
pub const QueryError = error{
OutOfMemory,
ConnectionLost,
};
pub fn fieldList(comptime RowType: type) []const u8 {
comptime {
const fields = std.meta.fieldNames(RowType);
const separator = ", ";
if (fields.len == 0) return "";
var size: usize = 1; // 1 for null terminator
for (fields) |f| size += f.len + separator.len;
size -= separator.len;
var buf = std.mem.zeroes([size]u8);
// can't use std.mem.join because of problems with comptime allocation
// https://github.com/ziglang/zig/issues/5873#issuecomment-1001778218
//var fba = std.heap.FixedBufferAllocator.init(&buf);
//return (std.mem.join(fba.allocator(), separator, fields) catch unreachable) ++ " ";
var buf_idx = 0;
for (fields) |f, i| {
std.mem.copy(u8, buf[buf_idx..], f);
buf_idx += f.len;
if (i != fields.len - 1) std.mem.copy(u8, buf[buf_idx..], separator);
buf_idx += separator.len;
}
return &buf;
}
}
//pub const OpenError = sqlite.OpenError | postgres.OpenError;
const RawResults = union(Engine) {
postgres: postgres.Results,
@ -33,17 +68,50 @@ const RawResults = union(Engine) {
.sqlite => |lite| lite.finish(),
}
}
fn columnCount(self: RawResults) u15 {
return switch (self) {
.postgres => |pg| pg.columnCount(),
.sqlite => |lite| lite.columnCount(),
};
}
fn columnNameToIndex(self: RawResults, name: []const u8) !u15 {
return try switch (self) {
.postgres => |pg| pg.columnNameToIndex(name),
.sqlite => |lite| lite.columnNameToIndex(name),
};
}
fn row(self: *RawResults) !?Row {
return switch (self.*) {
.postgres => |*pg| if (try pg.row()) |r| Row{ .postgres = r } else null,
.sqlite => |*lite| if (try lite.row()) |r| Row{ .sqlite = r } else null,
};
}
};
// Represents a set of results.
// row() must be called until it returns null, or the query may not complete
// Must be deallocated by a call to finish()
pub fn Results(comptime result_types: []const type) type {
pub fn Results(comptime T: type) type {
// would normally make this a declaration of the struct, but it causes the compiler to crash
const fields = std.meta.fields(T);
return struct {
const Self = @This();
const RowTuple = std.meta.Tuple(result_types);
underlying: RawResults,
column_indices: [fields.len]u15,
fn from(underlying: RawResults) !Self {
return Self{ .underlying = underlying, .column_indices = blk: {
var indices: [fields.len]u15 = undefined;
inline for (fields) |f, i| {
indices[i] = if (!std.meta.trait.isTuple(T)) try underlying.columnNameToIndex(f.name) else i;
}
break :blk indices;
} };
}
pub fn finish(self: Self) void {
self.underlying.finish();
@ -52,31 +120,30 @@ pub fn Results(comptime result_types: []const type) type {
// can be used as an optimization to reduce memory reallocation
// only works on postgres
pub fn rowCount(self: Self) ?usize {
return switch (self.underlying) {
.postgres => |pg| pg.rowCount(),
.sqlite => null, // not possible without repeating the query
};
return self.underlying.rowCount();
}
pub fn row(self: *Self, alloc: ?Allocator) !?RowTuple {
const row_val = switch (self.underlying) {
.postgres => |*pg| if (try pg.row()) |r| Row{ .postgres = r } else return null,
.sqlite => |*lite| if (try lite.row()) |r| Row{ .sqlite = r } else return null,
};
// Returns the next row of results, or null if there are no more rows.
// Caller owns all memory allocated. The entire object can be deallocated with a
// call to util.deepFree
pub fn row(self: *Self, alloc: ?Allocator) !?T {
if (try self.underlying.row()) |row_val| {
var result: T = undefined;
var fields_allocated: usize = 0;
errdefer inline for (fields) |f, i| {
// Iteration bounds must be defined at comptime (inline for) but the number of fields we could
// successfully allocate is defined at runtime. So we iterate over the entire field array and
// conditionally deallocate fields in the loop.
if (i < fields_allocated) util.deepFree(alloc, @field(result, f.name));
};
var result: RowTuple = undefined;
var fields_allocated = [_]bool{false} ** result.len;
errdefer {
inline for (result_types) |_, i| {
if (fields_allocated[i]) util.deepFree(alloc, result[i]);
inline for (fields) |f, i| {
@field(result, f.name) = try row_val.get(f.field_type, self.column_indices[i], alloc);
fields_allocated += 1;
}
}
inline for (result_types) |T, i| {
result[i] = try row_val.get(T, i, alloc);
fields_allocated[i] = true;
}
return result;
return result;
} else return null;
}
};
}
@ -136,26 +203,26 @@ pub const Db = struct {
pub fn queryWithOptions(
self: *Db,
comptime result_types: []const type,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
opt: QueryOptions,
) !Results(result_types) {
) !Results(RowType) {
if (self.tx_open) return error.TransactionOpen;
// Create fake transaction to use its functions
return (Tx{ .db = self }).queryWithOptions(result_types, sql, args, opt);
return (Tx{ .db = self }).queryWithOptions(RowType, sql, args, opt);
}
pub fn query(
self: *Db,
comptime result_types: []const type,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) !Results(result_types) {
) !Results(RowType) {
if (self.tx_open) return error.TransactionOpen;
// Create fake transaction to use its functions
return (Tx{ .db = self }).query(result_types, sql, args, alloc);
return (Tx{ .db = self }).query(RowType, sql, args, alloc);
}
pub fn exec(
@ -171,14 +238,14 @@ pub const Db = struct {
pub fn queryRow(
self: *Db,
comptime result_types: []const type,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) !?Results(result_types).RowTuple {
) !?RowType {
if (self.tx_open) return error.TransactionOpen;
// Create fake transaction to use its functions
return (Tx{ .db = self }).queryRow(result_types, sql, args, alloc);
return (Tx{ .db = self }).queryRow(RowType, sql, args, alloc);
}
pub fn insert(
@ -222,23 +289,23 @@ pub const Tx = struct {
pub fn queryWithOptions(
self: Tx,
comptime result_types: []const type,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
options: QueryOptions,
) !Results(result_types) {
return Results(result_types){ .underlying = try self.queryInternal(sql, args, options) };
) !Results(RowType) {
return Results(RowType).from(try self.queryInternal(sql, args, options));
}
// Executes a query and returns the result set
pub fn query(
self: Tx,
comptime result_types: []const type,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) !Results(result_types) {
return self.queryWithOptions(result_types, sql, args, .{ .prep_allocator = alloc });
) !Results(RowType) {
return self.queryWithOptions(RowType, sql, args, .{ .prep_allocator = alloc });
}
// Executes a query without returning results
@ -248,19 +315,20 @@ pub const Tx = struct {
args: anytype,
alloc: ?Allocator,
) !void {
_ = try self.queryRow(&.{}, sql, args, alloc);
_ = try self.queryRow(std.meta.Tuple(&.{}), sql, args, alloc);
}
// Runs a query and returns a single row
pub fn queryRow(
self: Tx,
comptime result_types: []const type,
comptime RowType: type,
q: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) !?Results(result_types).RowTuple {
var results = try self.query(result_types, q, args, alloc);
) !?RowType {
var results = try self.query(RowType, q, args, alloc);
defer results.finish();
@compileLog(args);
const row = (try results.row(alloc)) orelse return null;
errdefer util.deepFree(alloc, row);

View file

@ -26,6 +26,16 @@ pub const Results = struct {
};
}
pub fn columnCount(self: Results) u15 {
return @intCast(u15, c.PQnfields(self.result));
}
pub fn columnNameToIndex(self: Results, name: []const u8) !u15 {
const idx = c.PQfnumber(self.result, name.ptr);
if (idx == -1) return error.ColumnNotFound;
return @intCast(u15, idx);
}
pub fn finish(self: Results) void {
c.PQclear(self.result);
}
@ -89,8 +99,8 @@ pub const Db = struct {
if (comptime args.len > 0) {
var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired);
defer arena.deinit();
const params = try arena.allocator().alloc(?[*]const u8, args.len);
inline for (args) |a, i| params[i] = if (try common.prepareParamText(&arena, a)) |slice| slice.ptr else null;
const params = try arena.allocator().alloc(?[*:0]const u8, args.len);
inline for (args) |arg, i| params[i] = if (try common.prepareParamText(&arena, arg)) |slice| slice.ptr else null;
break :blk c.PQexecParams(self.conn, sql.ptr, @intCast(c_int, params.len), null, params.ptr, null, null, format_text);
} else {

View file

@ -184,7 +184,6 @@ pub const Results = struct {
return switch (c.sqlite3_step(self.stmt)) {
c.SQLITE_ROW => Row{ .stmt = self.stmt, .db = self.db },
c.SQLITE_DONE => null,
else => |err| handleUnexpectedError(self.db, err, self.getGeneratingSql()),
};
}
@ -193,6 +192,28 @@ pub const Results = struct {
const ptr = c.sqlite3_sql(self.stmt) orelse return null;
return ptr[0..std.mem.len(ptr)];
}
pub fn columnCount(self: Results) u15 {
return @intCast(u15, c.sqlite3_column_count(self.stmt));
}
pub fn columnName(self: Results, idx: u15) ![]const u8 {
return if (c.sqlite3_column_name(self.stmt, idx)) |ptr|
ptr[0..std.mem.len(ptr)]
else
return error.OutOfMemory;
}
pub fn columnNameToIndex(self: Results, name: []const u8) !u15 {
var i: u15 = 0;
const count = self.columnCount();
while (i < count) : (i += 1) {
const column = try self.columnName(i);
if (std.mem.eql(u8, name, column)) return i;
}
return error.ColumnNotFound;
}
};
pub const Row = struct {

View file

@ -6,10 +6,22 @@ pub const DateTime = @import("./DateTime.zig");
pub const PathIter = @import("./PathIter.zig");
pub const Url = @import("./Url.zig");
pub fn cloneStr(str: []const u8, alloc: std.mem.Allocator) ![]const u8 {
var new = try alloc.alloc(u8, str.len);
std.mem.copy(u8, new, str);
return new;
fn comptimeJoinSlice(comptime separator: []const u8, comptime slices: []const []const u8) []u8 {
comptime {
var size: usize = 1; // 1 for null terminator
for (slices) |s| size += s.len + separator.len;
if (slices.len != 0) size -= separator.len;
var buf = std.mem.zeroes([size]u8);
var fba = std.heap.fixedBufferAllocator(&buf);
return std.mem.join(fba.allocator(), separator, slices);
}
}
pub fn comptimeJoin(comptime separator: []const u8, comptime slices: []const []const u8) *const [comptimeJoinSlice(separator, slices):0]u8 {
const slice = comptimeJoinSlice(separator, slices);
return slice[0..slice.len];
}
pub fn deepFree(alloc: ?std.mem.Allocator, val: anytype) void {