From c42039c559ff43029705fa6c61911507f0777321 Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Thu, 29 Sep 2022 14:52:01 -0700 Subject: [PATCH] Refactoring --- src/OVERVIEW.md | 37 +++++++++ src/main/api.zig | 119 ++++++++++++++------------- src/main/api/auth.zig | 6 +- src/main/api/communities.zig | 74 +++++++++-------- src/main/api/invites.zig | 2 +- src/main/api/notes.zig | 2 +- src/main/api/users.zig | 41 ++++------ src/main/main.zig | 52 +++++++++--- src/main/migrations.zig | 9 ++- src/sql/lib.zig | 150 +++++++++++++++++++++++++---------- src/sql/postgres.zig | 14 +++- src/sql/sqlite.zig | 23 +++++- src/util/lib.zig | 20 ++++- 13 files changed, 369 insertions(+), 180 deletions(-) create mode 100644 src/OVERVIEW.md diff --git a/src/OVERVIEW.md b/src/OVERVIEW.md new file mode 100644 index 0000000..5c2fe42 --- /dev/null +++ b/src/OVERVIEW.md @@ -0,0 +1,37 @@ +# Overview + +## Packages +- `main`: primary package, has application-specific functionality + * TODO: consider moving controllers and api into different packages + * `controllers/**.zig`: + - Transforms HTTP to/from API calls + - Turns error codes into HTTP statuses + * `api.zig`: + - Makes sure API call is allowed with the given user/host context + - Transforms API models into display models + - `api/**.zig`: Performs action associated with API call + * Transforms DB models into API models + * Data validation + - TODO: the distinction between what goes in `api.zig` and in its submodules is gross. Refactor? + * `migrations.zig`: + - Defines database migrations to apply + - Should be ran on startup +- `util`: utility packages + * Components: + - `Uuid`: UUID utils (random uuid generation, equality, parsing, printing) + * `Uuid.eql` + * `Uuid.randV4` + * UUID's are serialized to their string representation for JSON, db + - `PathIter`: Path segment iterator + - `Url`: URL utils (parsing) + - `ciutf8`: case-insensitive UTF-8 (TODO: Scrap this, replace with ICU library) + - `DateTime`: Time utils + - `deepClone(alloc, orig)`/`deepFree(alloc, to_free)`: Utils for cloning and freeing basic data structs + * Clones/frees any strings/sub structs within the value +- `sql`: SQL library + * Supports 2 engines (SQLite, PostgreSQL) + * `var my_transaction = try db.begin()` + * `const results = try db.query(RowType, "SELECT ...", .{arg_1, ...}, alloc)` +- `http`: HTTP Server + * The API sucks. Needs a refactor + diff --git a/src/main/api.zig b/src/main/api.zig index ca9b119..2166683 100644 --- a/src/main/api.zig +++ b/src/main/api.zig @@ -94,6 +94,42 @@ pub fn getRandom() std.rand.Random { return prng.random(); } +pub fn isAdminSetup(db: *sql.Db) !bool { + _ = services.communities.adminCommunityId(db) catch |err| switch (err) { + error.NotFound => return false, + else => return err, + }; + + return true; +} + +pub fn setupAdmin(db: *sql.Db, origin: []const u8, username: []const u8, password: []const u8, allocator: std.mem.Allocator) !void { + const tx = try db.begin(); + errdefer tx.rollback(); + var arena = std.heap.ArenaAllocator.init(allocator); + defer arena.deinit(); + + try tx.setConstraintMode(.deferred); + + const community = try services.communities.create( + tx, + origin, + Uuid.nil, + .{ .name = "Cluster Admin", .kind = .admin }, + ); + + const user = try services.users.create(tx, username, password, community.id, .{ .role = .admin }, arena.allocator()); + + try services.communities.transferOwnership(tx, community.id, user); + + try tx.commit(); + + std.log.info( + "Created admin user {s} (id {}) with cluster admin origin {s} (id {})", + .{ username, user, origin, community.id }, + ); +} + pub const ApiSource = struct { db: *sql.Db, internal_alloc: std.mem.Allocator, @@ -103,68 +139,43 @@ pub const ApiSource = struct { const root_username = "root"; - pub fn init(alloc: std.mem.Allocator, cfg: Config, root_password: ?[]const u8, db_conn: *sql.Db) !ApiSource { - var self = ApiSource{ + pub fn init(alloc: std.mem.Allocator, cfg: Config, db_conn: *sql.Db) !ApiSource { + return ApiSource{ .db = db_conn, .internal_alloc = alloc, .config = cfg, }; - - try migrations.up(db_conn); - - if ((try services.users.lookupByUsername(self.db, root_username, null)) == null) { - std.log.info("No cluster root user detected. Creating...", .{}); - - // TODO: Fix this - const password = root_password orelse return error.NeedRootPassword; - var arena = std.heap.ArenaAllocator.init(alloc); - defer arena.deinit(); - const user_id = try services.users.create(self.db, root_username, password, null, .{}, arena.allocator()); - std.log.debug("Created {s} ID {}", .{ root_username, user_id }); - } - - return self; - } - - fn getCommunityFromHost(self: *ApiSource, host: []const u8) !?Uuid { - if (try self.db.queryRow( - &.{Uuid}, - "SELECT id FROM community WHERE host = $1", - .{host}, - null, - )) |result| return result[0]; - - // Test for cluster admin community - if (util.ciutf8.eql(self.config.cluster_host, host)) { - return null; - } - - return error.NoCommunity; } pub fn connectUnauthorized(self: *ApiSource, host: []const u8, alloc: std.mem.Allocator) !Conn { - const community_id = try self.getCommunityFromHost(host); + var arena = std.heap.ArenaAllocator.init(alloc); + errdefer arena.deinit(); + + const community = try services.communities.getByHost(self.db, host, arena.allocator()); return Conn{ .db = self.db, .internal_alloc = self.internal_alloc, .user_id = null, - .community_id = community_id, - .arena = std.heap.ArenaAllocator.init(alloc), + .community = community, + .arena = arena, }; } pub fn connectToken(self: *ApiSource, host: []const u8, token: []const u8, alloc: std.mem.Allocator) !Conn { - const community_id = try self.getCommunityFromHost(host); + var arena = std.heap.ArenaAllocator.init(alloc); + errdefer arena.deinit(); - const token_info = try services.auth.tokens.verify(self.db, token, community_id); + const community = try services.communities.getByHost(self.db, host, arena.allocator()); + + const token_info = try services.auth.tokens.verify(self.db, token, community.id); return Conn{ .db = self.db, .internal_alloc = self.internal_alloc, .user_id = token_info.user_id, - .community_id = community_id, - .arena = std.heap.ArenaAllocator.init(alloc), + .community = community, + .arena = arena, }; } }; @@ -176,7 +187,7 @@ fn ApiConn(comptime DbConn: type) type { db: DbConn, internal_alloc: std.mem.Allocator, // used *only* for large, internal buffers user_id: ?Uuid, - community_id: ?Uuid, + community: services.communities.Community, arena: std.heap.ArenaAllocator, pub fn close(self: *Self) void { @@ -185,11 +196,11 @@ fn ApiConn(comptime DbConn: type) type { fn isAdmin(self: *Self) bool { // TODO - return self.user_id != null and self.community_id == null; + return self.user_id != null and self.community.kind == .admin; } pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResponse { - const user_id = (try services.users.lookupByUsername(self.db, username, self.community_id)) orelse return error.InvalidLogin; + const user_id = (try services.users.lookupByUsername(self.db, username, self.community.id)) orelse return error.InvalidLogin; try services.auth.passwords.verify(self.db, user_id, password, self.internal_alloc); const token = try services.auth.tokens.create(self.db, user_id); @@ -207,7 +218,7 @@ fn ApiConn(comptime DbConn: type) type { pub fn getTokenInfo(self: *Self) !TokenInfo { if (self.user_id) |user_id| { const result = (try self.db.queryRow( - &.{[]const u8}, + std.meta.Tuple(&.{[]const u8}), "SELECT username FROM user WHERE id = $1", .{user_id}, self.arena.allocator(), @@ -225,7 +236,7 @@ fn ApiConn(comptime DbConn: type) type { return error.PermissionDenied; } - return services.communities.create(self.db, origin, null); + return services.communities.create(self.db, origin, self.user_id.?, .{}); } pub fn createInvite(self: *Self, options: InviteRequest) !services.invites.Invite { @@ -234,13 +245,13 @@ fn ApiConn(comptime DbConn: type) type { const community_id = if (options.to_community) |host| blk: { // You can only specify a different community if you're on the admin domain - if (self.community_id != null) return error.WrongCommunity; + if (self.community.kind != .admin) return error.WrongCommunity; // Only admins can invite on the admin domain if (!self.isAdmin()) return error.PermissionDenied; break :blk (try services.communities.getByHost(self.db, host, self.arena.allocator())).id; - } else self.community_id; + } else self.community.id; // Users can only make user invites if (options.invite_type != .user and !self.isAdmin()) return error.PermissionDenied; @@ -257,19 +268,19 @@ fn ApiConn(comptime DbConn: type) type { std.log.debug("registering user {s} with code {s}", .{ request.username, request.invite_code }); const invite = try services.invites.getByCode(self.db, request.invite_code, self.arena.allocator()); - if (!Uuid.eql(invite.to_community, self.community_id)) return error.NotFound; + if (!Uuid.eql(invite.to_community, self.community.id)) return error.NotFound; if (invite.max_uses != null and invite.times_used >= invite.max_uses.?) return error.InviteExpired; if (invite.expires_at != null and DateTime.now().isAfter(invite.expires_at.?)) return error.InviteExpired; - if (self.community_id == null) @panic("Unimplmented"); + if (self.community.kind == .admin) @panic("Unimplmented"); - const user_id = try services.users.create(self.db, request.username, request.password, self.community_id, .{ .invite_id = invite.id, .email = request.email }, self.internal_alloc); + const user_id = try services.users.create(self.db, request.username, request.password, self.community.id, .{ .invite_id = invite.id, .email = request.email }, self.internal_alloc); switch (invite.invite_type) { .user => {}, .system => @panic("System user invites unimplemented"), .community_owner => { - try services.communities.transferOwnership(self.db, self.community_id.?, user_id); + try services.communities.transferOwnership(self.db, self.community.id, user_id); }, } @@ -283,7 +294,7 @@ fn ApiConn(comptime DbConn: type) type { const user = try services.users.get(self.db, user_id, self.arena.allocator()); if (self.user_id == null) { - if (!Uuid.eql(self.community_id, user.community_id)) return error.NotFound; + if (!Uuid.eql(self.community.id, user.community_id)) return error.NotFound; } return UserResponse{ @@ -295,7 +306,7 @@ fn ApiConn(comptime DbConn: type) type { } pub fn createNote(self: *Self, content: []const u8) !NoteResponse { - if (self.community_id == null) return error.WrongCommunity; + if (self.community.kind == .admin) return error.WrongCommunity; const user_id = self.user_id orelse return error.TokenRequired; const note_id = try services.notes.create(self.db, user_id, content); @@ -312,7 +323,7 @@ fn ApiConn(comptime DbConn: type) type { // Only serve community-specific notes on unauthenticated requests if (self.user_id == null) { - if (!Uuid.eql(self.community_id, user.community_id)) return error.NotFound; + if (!Uuid.eql(self.community.id, user.community_id)) return error.NotFound; } return NoteResponse{ diff --git a/src/main/api/auth.zig b/src/main/api/auth.zig index b6722af..50a809d 100644 --- a/src/main/api/auth.zig +++ b/src/main/api/auth.zig @@ -24,7 +24,7 @@ pub const passwords = struct { pub fn verify(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) VerifyError!void { // TODO: This could be done w/o the dynamically allocated hash buf const hash = (db.queryRow( - &.{[]const u8}, + std.meta.Tuple(&.{[]const u8}), "SELECT hashed_password FROM account_password WHERE user_id = $1 LIMIT 1", .{user_id}, alloc, @@ -96,7 +96,7 @@ pub const tokens = struct { fn lookupUserTokenFromHash(db: anytype, hash: []const u8, community_id: Uuid) !?Token.Info { return if (try db.queryRow( - &.{ Uuid, DateTime }, + std.meta.Tuple(&.{ Uuid, DateTime }), \\SELECT user.id, token.issued_at \\FROM token JOIN user ON token.user_id = user.id \\WHERE user.community_id = $1 AND token.hash = $2 @@ -115,7 +115,7 @@ pub const tokens = struct { fn lookupSystemTokenFromHash(db: anytype, hash: []const u8) !?Token.Info { return if (try db.queryRow( - &.{ Uuid, DateTime }, + std.meta.Tuple(&.{ Uuid, DateTime }), \\SELECT user.id, token.issued_at \\FROM token JOIN user ON token.user_id = user.id \\WHERE user.community_id IS NULL AND token.hash = $1 diff --git a/src/main/api/communities.zig b/src/main/api/communities.zig index e7cbbdb..16a14fb 100644 --- a/src/main/api/communities.zig +++ b/src/main/api/communities.zig @@ -1,7 +1,7 @@ const std = @import("std"); const builtin = @import("builtin"); const util = @import("util"); -const models = @import("../db/models.zig"); +const sql = @import("sql"); const getRandom = @import("../api.zig").getRandom; @@ -26,21 +26,31 @@ pub const Scheme = enum { pub const Community = struct { id: Uuid, - owner_id: ?Uuid, + owner_id: Uuid, host: []const u8, name: []const u8, scheme: Scheme, + kind: Kind, created_at: DateTime, }; -fn freeCommunity(alloc: std.mem.Allocator, c: Community) void { - alloc.free(c.host); - alloc.free(c.name); -} +pub const Kind = enum { + admin, + local, -pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Community { - const scheme_len = firstIndexOf(origin, ':') orelse return error.InvalidOrigin; + pub fn jsonStringify(val: Kind, _: std.json.StringifyOptions, writer: anytype) !void { + return std.fmt.format(writer, "\"{s}\"", .{@tagName(val)}); + } +}; + +pub const CreateOptions = struct { + name: ?[]const u8 = null, + kind: Kind = .local, +}; + +pub fn create(db: anytype, origin: []const u8, owner: Uuid, options: CreateOptions) CreateError!Community { + const scheme_len = std.mem.indexOfScalar(u8, origin, ':') orelse return error.InvalidOrigin; const scheme_str = origin[0..scheme_len]; const scheme = std.meta.stringToEnum(Scheme, scheme_str) orelse return error.UnsupportedScheme; @@ -55,10 +65,10 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co // community cannot use non-default ports (except for testing) // NOTE: Do not add, say localhost and localhost:80 or bugs may happen. // Avoid using non-default ports unless a test can't be conducted without it. - if (firstIndexOf(host, ':') != null and builtin.mode != .Debug) return error.InvalidOrigin; + if (std.mem.indexOfScalar(u8, host, ':') != null and builtin.mode != .Debug) return error.InvalidOrigin; // community cannot be hosted on a path - if (firstIndexOf(host, '/') != null) return error.InvalidOrigin; + if (std.mem.indexOfScalar(u8, host, '/') != null) return error.InvalidOrigin; // Require TLS on production builds if (scheme != .https and builtin.mode != .Debug) return error.UnsupportedScheme; @@ -67,14 +77,15 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co const community = Community{ .id = id, - .owner_id = null, + .owner_id = owner, .host = host, - .name = name orelse host, + .name = options.name orelse host, .scheme = scheme, + .kind = options.kind, .created_at = DateTime.now(), }; - if ((try db.queryRow(&.{Uuid}, "SELECT id FROM community WHERE host = $1", .{host}, null)) != null) { + if ((try db.queryRow(std.meta.Tuple(&.{Uuid}), "SELECT id FROM community WHERE host = $1", .{host}, null)) != null) { return error.CommunityExists; } @@ -83,25 +94,13 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co return community; } -fn firstIndexOf(str: []const u8, ch: u8) ?usize { - for (str) |c, i| { - if (c == ch) return i; - } - - return null; -} - pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) !Community { - const result = (try db.queryRow(&.{ Uuid, ?Uuid, []const u8, []const u8, Scheme, DateTime }, "SELECT id, owner_id, host, name, scheme, created_at FROM community WHERE host = $1", .{host}, alloc)) orelse return error.NotFound; - - return Community{ - .id = result[0], - .owner_id = result[1], - .host = result[2], - .name = result[3], - .scheme = result[4], - .created_at = result[5], - }; + return (try db.queryRow( + Community, + std.fmt.comptimePrint("SELECT {s} FROM community WHERE host = $1", .{comptime sql.fieldList(Community)}), + .{host}, + alloc, + )) orelse return error.NotFound; } pub fn transferOwnership(db: anytype, community_id: Uuid, new_owner: Uuid) !void { @@ -247,7 +246,7 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) ![]Communit errdefer alloc.free(result_buf); var count: usize = 0; - errdefer for (result_buf[0..count]) |c| freeCommunity(alloc, c); + errdefer for (result_buf[0..count]) |c| util.deepFree(alloc, c); for (result_buf) |*c| { const row = results.row(alloc) orelse break; @@ -267,3 +266,14 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) ![]Communit return result_buf[0..count]; } + +pub fn adminCommunityId(db: anytype) !Uuid { + const row = (try db.queryRow( + std.meta.Tuple(&.{Uuid}), + "SELECT id FROM community WHERE kind = 'admin' LIMIT 1", + .{}, + null, + )) orelse return error.NotFound; + + return row[0]; +} diff --git a/src/main/api/invites.zig b/src/main/api/invites.zig index f07d286..d991e45 100644 --- a/src/main/api/invites.zig +++ b/src/main/api/invites.zig @@ -130,7 +130,7 @@ pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: Invit pub fn getByCode(db: anytype, code: []const u8, alloc: std.mem.Allocator) !Invite { const code_clone = try cloneStr(code, alloc); - const info = (try db.queryRow(&.{ Uuid, Uuid, Uuid, []const u8, DateTime, ?DateTime, InviteCount, ?InviteCount, InviteType }, + const info = (try db.queryRow(std.meta.Tuple(&.{ Uuid, Uuid, Uuid, []const u8, DateTime, ?DateTime, InviteCount, ?InviteCount, InviteType }), \\SELECT \\ invite.id, invite.created_by, invite.to_community, invite.name, \\ invite.created_at, invite.expires_at, diff --git a/src/main/api/notes.zig b/src/main/api/notes.zig index 8379e1e..ba10cc1 100644 --- a/src/main/api/notes.zig +++ b/src/main/api/notes.zig @@ -41,7 +41,7 @@ pub fn create( pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !Note { const result = (try db.queryRow( - &.{ Uuid, []const u8, DateTime }, + std.meta.Tuple(&.{ Uuid, []const u8, DateTime }), \\SELECT author_id, content, created_at \\FROM note \\WHERE id = $1 diff --git a/src/main/api/users.zig b/src/main/api/users.zig index 3f8c440..4b8532a 100644 --- a/src/main/api/users.zig +++ b/src/main/api/users.zig @@ -21,7 +21,7 @@ const DbUser = struct { id: Uuid, username: []const u8, - community_id: ?Uuid, + community_id: Uuid, }; const DbLocalUser = struct { @@ -31,52 +31,43 @@ const DbLocalUser = struct { email: ?[]const u8, }; +pub const Role = enum { + user, + admin, +}; + pub const CreateOptions = struct { invite_id: ?Uuid = null, email: ?[]const u8 = null, + role: Role = .user, }; -fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid { - return if (try db.queryRow( - &.{Uuid}, - "SELECT user.id FROM user WHERE community_id IS NULL AND username = $1", - .{username}, - null, - )) |result| - result[0] - else - null; -} - -fn lookupUserByUsername(db: anytype, username: []const u8, community_id: Uuid) !?Uuid { - return if (try db.queryRow( - &.{Uuid}, +fn lookupByUsernameInternal(db: anytype, username: []const u8, community_id: Uuid) CreateError!?Uuid { + return if (db.queryRow( + std.meta.Tuple(&.{Uuid}), "SELECT user.id FROM user WHERE community_id = $1 AND username = $2", .{ community_id, username }, null, - )) |result| + ) catch return error.DbError) |result| result[0] else null; } -pub fn lookupByUsername(db: anytype, username: []const u8, community_id: ?Uuid) !?Uuid { - return if (community_id) |id| - lookupUserByUsername(db, username, id) catch return error.DbError - else - lookupSystemUserByUsername(db, username) catch return error.DbError; +pub fn lookupByUsername(db: anytype, username: []const u8, community_id: Uuid) CreateError!Uuid { + return (lookupByUsernameInternal(db, username, community_id) catch return error.DbError) orelse error.NotFound; } pub fn create( db: anytype, username: []const u8, password: []const u8, - community_id: ?Uuid, + community_id: Uuid, options: CreateOptions, password_alloc: std.mem.Allocator, ) CreateError!Uuid { const id = Uuid.randV4(getRandom()); - if ((try lookupByUsername(db, username, community_id)) != null) { + if ((try lookupByUsernameInternal(db, username, community_id)) != null) { return error.UsernameTaken; } @@ -108,7 +99,7 @@ pub const User = struct { pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !User { const result = (try db.queryRow( - &.{ []const u8, []const u8, Uuid, DateTime }, + std.meta.Tuple(&.{ []const u8, []const u8, Uuid, DateTime }), \\SELECT user.username, community.host, community.id, user.created_at \\FROM user JOIN community ON user.community_id = community.id \\WHERE user.id = $1 diff --git a/src/main/main.zig b/src/main/main.zig index e03d54c..85d6ce1 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -5,7 +5,7 @@ const http = @import("http"); const util = @import("util"); pub const api = @import("./api.zig"); -const models = @import("./db/models.zig"); +const migrations = @import("./migrations.zig"); const Uuid = util.Uuid; const c = @import("./controllers.zig"); @@ -98,21 +98,49 @@ fn loadConfig(alloc: std.mem.Allocator) !Config { return std.json.parse(Config, &ts, .{ .allocator = alloc }); } -const root_password_envvar = "CLUSTER_ROOT_PASSWORD"; +const admin_origin_envvar = "CLUSTER_ADMIN_ORIGIN"; +const admin_username_envvar = "CLUSTER_ADMIN_USERNAME"; +const admin_password_envvar = "CLUSTER_ADMIN_PASSWORD"; + +fn runAdminSetup(db: *sql.Db, alloc: std.mem.Allocator) !void { + const origin = std.os.getenv(admin_origin_envvar) orelse return error.MissingArgument; + const username = std.os.getenv(admin_username_envvar) orelse return error.MissingArgument; + const password = std.os.getenv(admin_password_envvar) orelse return error.MissingArgument; + + try api.setupAdmin(db, origin, username, password, alloc); +} + +fn prepareDb(db: *sql.Db, alloc: std.mem.Allocator) !void { + try migrations.up(db); + + if (!try api.isAdminSetup(db)) { + std.log.info("Performing first-time admin creation...", .{}); + + runAdminSetup(db, alloc) catch |err| switch (err) { + error.MissingArgument => { + std.log.err( + \\First time setup required but arguments not provided. + \\Please provide the following arguments via environment variable: + \\- {s}: The origin to serve the cluster admin panel at (ex: https://admin.example.com) + \\- {s}: The username for the initial cluster operator + \\- {s}: The password for the initial cluster operator + , + .{ admin_origin_envvar, admin_username_envvar, admin_password_envvar }, + ); + std.os.exit(1); + }, + else => return err, + }; + } +} + pub fn main() anyerror!void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; var cfg = try loadConfig(gpa.allocator()); var db_conn = try sql.Db.open(cfg.db); - var api_src = api.ApiSource.init(gpa.allocator(), cfg, std.os.getenv(root_password_envvar), &db_conn) catch |err| switch (err) { - error.NeedRootPassword => { - std.log.err( - "No root user created and no password specified. Please provide the password for the root user by the ${s} environment variable for initial startup. This only needs to be done once", - .{root_password_envvar}, - ); - return err; - }, - else => return err, - }; + try prepareDb(&db_conn, gpa.allocator()); + + var api_src = try api.ApiSource.init(gpa.allocator(), cfg, &db_conn); var srv = try RequestServer.init(gpa.allocator(), &api_src, cfg); api.initThreadPrng(@bitCast(u64, std.time.milliTimestamp())); return srv.listenAndRun(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); diff --git a/src/main/migrations.zig b/src/main/migrations.zig index 83fba1d..a649fa3 100644 --- a/src/main/migrations.zig +++ b/src/main/migrations.zig @@ -38,7 +38,7 @@ fn execScript(db: *sql.Db, script: []const u8, alloc: std.mem.Allocator) !void { } fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !bool { - const row = (try db.queryRow(&.{i32}, "SELECT COUNT(*) FROM migration WHERE name = $1", .{name}, alloc)) orelse return false; + const row = (try db.queryRow(std.meta.Tuple(&.{i32}), "SELECT COUNT(*) FROM migration WHERE name = $1", .{name}, alloc)) orelse return false; return row[0] != 0; } @@ -152,7 +152,7 @@ const migrations: []const Migration = &.{ \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, \\ expires_at TIMESTAMPTZ, \\ - \\ type TEXT NOT NULL CHECK (type in ('system', 'community_owner', 'user')) + \\ type TEXT NOT NULL CHECK (type in ('system_user', 'community_owner', 'user')) \\); \\ALTER TABLE local_user ADD COLUMN invite_id TEXT REFERENCES invite(id); , @@ -171,14 +171,15 @@ const migrations: []const Migration = &.{ \\ name TEXT NOT NULL, \\ host TEXT NOT NULL UNIQUE, \\ scheme TEXT NOT NULL CHECK (scheme IN ('http', 'https')), + \\ kind TEXT NOT NULL CHECK (kind in ('admin', 'local')), \\ \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\); \\ALTER TABLE user ADD COLUMN community_id TEXT REFERENCES community(id); - \\ALTER TABLE invite ADD COLUMN to_community TEXT REFERENCES community(id); + \\ALTER TABLE invite ADD COLUMN community_id TEXT REFERENCES community(id); , .down = - \\ALTER TABLE invite DROP COLUMN to_community; + \\ALTER TABLE invite DROP COLUMN community_id; \\ALTER TABLE user DROP COLUMN community_id; \\DROP TABLE community; , diff --git a/src/sql/lib.zig b/src/sql/lib.zig index e122f8c..8108c44 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -22,6 +22,41 @@ pub const Config = union(Engine) { }, }; +pub const QueryError = error{ + OutOfMemory, + ConnectionLost, +}; + +pub fn fieldList(comptime RowType: type) []const u8 { + comptime { + const fields = std.meta.fieldNames(RowType); + const separator = ", "; + + if (fields.len == 0) return ""; + + var size: usize = 1; // 1 for null terminator + for (fields) |f| size += f.len + separator.len; + size -= separator.len; + + var buf = std.mem.zeroes([size]u8); + + // can't use std.mem.join because of problems with comptime allocation + // https://github.com/ziglang/zig/issues/5873#issuecomment-1001778218 + //var fba = std.heap.FixedBufferAllocator.init(&buf); + //return (std.mem.join(fba.allocator(), separator, fields) catch unreachable) ++ " "; + + var buf_idx = 0; + for (fields) |f, i| { + std.mem.copy(u8, buf[buf_idx..], f); + buf_idx += f.len; + if (i != fields.len - 1) std.mem.copy(u8, buf[buf_idx..], separator); + buf_idx += separator.len; + } + + return &buf; + } +} + //pub const OpenError = sqlite.OpenError | postgres.OpenError; const RawResults = union(Engine) { postgres: postgres.Results, @@ -33,17 +68,50 @@ const RawResults = union(Engine) { .sqlite => |lite| lite.finish(), } } + + fn columnCount(self: RawResults) u15 { + return switch (self) { + .postgres => |pg| pg.columnCount(), + .sqlite => |lite| lite.columnCount(), + }; + } + + fn columnNameToIndex(self: RawResults, name: []const u8) !u15 { + return try switch (self) { + .postgres => |pg| pg.columnNameToIndex(name), + .sqlite => |lite| lite.columnNameToIndex(name), + }; + } + + fn row(self: *RawResults) !?Row { + return switch (self.*) { + .postgres => |*pg| if (try pg.row()) |r| Row{ .postgres = r } else null, + .sqlite => |*lite| if (try lite.row()) |r| Row{ .sqlite = r } else null, + }; + } }; // Represents a set of results. // row() must be called until it returns null, or the query may not complete // Must be deallocated by a call to finish() -pub fn Results(comptime result_types: []const type) type { +pub fn Results(comptime T: type) type { + // would normally make this a declaration of the struct, but it causes the compiler to crash + const fields = std.meta.fields(T); return struct { const Self = @This(); - const RowTuple = std.meta.Tuple(result_types); underlying: RawResults, + column_indices: [fields.len]u15, + + fn from(underlying: RawResults) !Self { + return Self{ .underlying = underlying, .column_indices = blk: { + var indices: [fields.len]u15 = undefined; + inline for (fields) |f, i| { + indices[i] = if (!std.meta.trait.isTuple(T)) try underlying.columnNameToIndex(f.name) else i; + } + break :blk indices; + } }; + } pub fn finish(self: Self) void { self.underlying.finish(); @@ -52,31 +120,30 @@ pub fn Results(comptime result_types: []const type) type { // can be used as an optimization to reduce memory reallocation // only works on postgres pub fn rowCount(self: Self) ?usize { - return switch (self.underlying) { - .postgres => |pg| pg.rowCount(), - .sqlite => null, // not possible without repeating the query - }; + return self.underlying.rowCount(); } - pub fn row(self: *Self, alloc: ?Allocator) !?RowTuple { - const row_val = switch (self.underlying) { - .postgres => |*pg| if (try pg.row()) |r| Row{ .postgres = r } else return null, - .sqlite => |*lite| if (try lite.row()) |r| Row{ .sqlite = r } else return null, - }; + // Returns the next row of results, or null if there are no more rows. + // Caller owns all memory allocated. The entire object can be deallocated with a + // call to util.deepFree + pub fn row(self: *Self, alloc: ?Allocator) !?T { + if (try self.underlying.row()) |row_val| { + var result: T = undefined; + var fields_allocated: usize = 0; + errdefer inline for (fields) |f, i| { + // Iteration bounds must be defined at comptime (inline for) but the number of fields we could + // successfully allocate is defined at runtime. So we iterate over the entire field array and + // conditionally deallocate fields in the loop. + if (i < fields_allocated) util.deepFree(alloc, @field(result, f.name)); + }; - var result: RowTuple = undefined; - var fields_allocated = [_]bool{false} ** result.len; - errdefer { - inline for (result_types) |_, i| { - if (fields_allocated[i]) util.deepFree(alloc, result[i]); + inline for (fields) |f, i| { + @field(result, f.name) = try row_val.get(f.field_type, self.column_indices[i], alloc); + fields_allocated += 1; } - } - inline for (result_types) |T, i| { - result[i] = try row_val.get(T, i, alloc); - fields_allocated[i] = true; - } - return result; + return result; + } else return null; } }; } @@ -136,26 +203,26 @@ pub const Db = struct { pub fn queryWithOptions( self: *Db, - comptime result_types: []const type, + comptime RowType: type, sql: [:0]const u8, args: anytype, opt: QueryOptions, - ) !Results(result_types) { + ) !Results(RowType) { if (self.tx_open) return error.TransactionOpen; // Create fake transaction to use its functions - return (Tx{ .db = self }).queryWithOptions(result_types, sql, args, opt); + return (Tx{ .db = self }).queryWithOptions(RowType, sql, args, opt); } pub fn query( self: *Db, - comptime result_types: []const type, + comptime RowType: type, sql: [:0]const u8, args: anytype, alloc: ?Allocator, - ) !Results(result_types) { + ) !Results(RowType) { if (self.tx_open) return error.TransactionOpen; // Create fake transaction to use its functions - return (Tx{ .db = self }).query(result_types, sql, args, alloc); + return (Tx{ .db = self }).query(RowType, sql, args, alloc); } pub fn exec( @@ -171,14 +238,14 @@ pub const Db = struct { pub fn queryRow( self: *Db, - comptime result_types: []const type, + comptime RowType: type, sql: [:0]const u8, args: anytype, alloc: ?Allocator, - ) !?Results(result_types).RowTuple { + ) !?RowType { if (self.tx_open) return error.TransactionOpen; // Create fake transaction to use its functions - return (Tx{ .db = self }).queryRow(result_types, sql, args, alloc); + return (Tx{ .db = self }).queryRow(RowType, sql, args, alloc); } pub fn insert( @@ -222,23 +289,23 @@ pub const Tx = struct { pub fn queryWithOptions( self: Tx, - comptime result_types: []const type, + comptime RowType: type, sql: [:0]const u8, args: anytype, options: QueryOptions, - ) !Results(result_types) { - return Results(result_types){ .underlying = try self.queryInternal(sql, args, options) }; + ) !Results(RowType) { + return Results(RowType).from(try self.queryInternal(sql, args, options)); } // Executes a query and returns the result set pub fn query( self: Tx, - comptime result_types: []const type, + comptime RowType: type, sql: [:0]const u8, args: anytype, alloc: ?Allocator, - ) !Results(result_types) { - return self.queryWithOptions(result_types, sql, args, .{ .prep_allocator = alloc }); + ) !Results(RowType) { + return self.queryWithOptions(RowType, sql, args, .{ .prep_allocator = alloc }); } // Executes a query without returning results @@ -248,19 +315,20 @@ pub const Tx = struct { args: anytype, alloc: ?Allocator, ) !void { - _ = try self.queryRow(&.{}, sql, args, alloc); + _ = try self.queryRow(std.meta.Tuple(&.{}), sql, args, alloc); } // Runs a query and returns a single row pub fn queryRow( self: Tx, - comptime result_types: []const type, + comptime RowType: type, q: [:0]const u8, args: anytype, alloc: ?Allocator, - ) !?Results(result_types).RowTuple { - var results = try self.query(result_types, q, args, alloc); + ) !?RowType { + var results = try self.query(RowType, q, args, alloc); defer results.finish(); + @compileLog(args); const row = (try results.row(alloc)) orelse return null; errdefer util.deepFree(alloc, row); diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index a7a2fd1..5fe217f 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -26,6 +26,16 @@ pub const Results = struct { }; } + pub fn columnCount(self: Results) u15 { + return @intCast(u15, c.PQnfields(self.result)); + } + + pub fn columnNameToIndex(self: Results, name: []const u8) !u15 { + const idx = c.PQfnumber(self.result, name.ptr); + if (idx == -1) return error.ColumnNotFound; + return @intCast(u15, idx); + } + pub fn finish(self: Results) void { c.PQclear(self.result); } @@ -89,8 +99,8 @@ pub const Db = struct { if (comptime args.len > 0) { var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired); defer arena.deinit(); - const params = try arena.allocator().alloc(?[*]const u8, args.len); - inline for (args) |a, i| params[i] = if (try common.prepareParamText(&arena, a)) |slice| slice.ptr else null; + const params = try arena.allocator().alloc(?[*:0]const u8, args.len); + inline for (args) |arg, i| params[i] = if (try common.prepareParamText(&arena, arg)) |slice| slice.ptr else null; break :blk c.PQexecParams(self.conn, sql.ptr, @intCast(c_int, params.len), null, params.ptr, null, null, format_text); } else { diff --git a/src/sql/sqlite.zig b/src/sql/sqlite.zig index 9a8de0e..77fc457 100644 --- a/src/sql/sqlite.zig +++ b/src/sql/sqlite.zig @@ -184,7 +184,6 @@ pub const Results = struct { return switch (c.sqlite3_step(self.stmt)) { c.SQLITE_ROW => Row{ .stmt = self.stmt, .db = self.db }, c.SQLITE_DONE => null, - else => |err| handleUnexpectedError(self.db, err, self.getGeneratingSql()), }; } @@ -193,6 +192,28 @@ pub const Results = struct { const ptr = c.sqlite3_sql(self.stmt) orelse return null; return ptr[0..std.mem.len(ptr)]; } + + pub fn columnCount(self: Results) u15 { + return @intCast(u15, c.sqlite3_column_count(self.stmt)); + } + + pub fn columnName(self: Results, idx: u15) ![]const u8 { + return if (c.sqlite3_column_name(self.stmt, idx)) |ptr| + ptr[0..std.mem.len(ptr)] + else + return error.OutOfMemory; + } + + pub fn columnNameToIndex(self: Results, name: []const u8) !u15 { + var i: u15 = 0; + const count = self.columnCount(); + while (i < count) : (i += 1) { + const column = try self.columnName(i); + if (std.mem.eql(u8, name, column)) return i; + } + + return error.ColumnNotFound; + } }; pub const Row = struct { diff --git a/src/util/lib.zig b/src/util/lib.zig index cd2eddc..a87f791 100644 --- a/src/util/lib.zig +++ b/src/util/lib.zig @@ -6,10 +6,22 @@ pub const DateTime = @import("./DateTime.zig"); pub const PathIter = @import("./PathIter.zig"); pub const Url = @import("./Url.zig"); -pub fn cloneStr(str: []const u8, alloc: std.mem.Allocator) ![]const u8 { - var new = try alloc.alloc(u8, str.len); - std.mem.copy(u8, new, str); - return new; +fn comptimeJoinSlice(comptime separator: []const u8, comptime slices: []const []const u8) []u8 { + comptime { + var size: usize = 1; // 1 for null terminator + for (slices) |s| size += s.len + separator.len; + if (slices.len != 0) size -= separator.len; + + var buf = std.mem.zeroes([size]u8); + var fba = std.heap.fixedBufferAllocator(&buf); + + return std.mem.join(fba.allocator(), separator, slices); + } +} + +pub fn comptimeJoin(comptime separator: []const u8, comptime slices: []const []const u8) *const [comptimeJoinSlice(separator, slices):0]u8 { + const slice = comptimeJoinSlice(separator, slices); + return slice[0..slice.len]; } pub fn deepFree(alloc: ?std.mem.Allocator, val: anytype) void {