Refactoring
This commit is contained in:
parent
33d1834f19
commit
c42039c559
13 changed files with 369 additions and 180 deletions
37
src/OVERVIEW.md
Normal file
37
src/OVERVIEW.md
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
# Overview
|
||||||
|
|
||||||
|
## Packages
|
||||||
|
- `main`: primary package, has application-specific functionality
|
||||||
|
* TODO: consider moving controllers and api into different packages
|
||||||
|
* `controllers/**.zig`:
|
||||||
|
- Transforms HTTP to/from API calls
|
||||||
|
- Turns error codes into HTTP statuses
|
||||||
|
* `api.zig`:
|
||||||
|
- Makes sure API call is allowed with the given user/host context
|
||||||
|
- Transforms API models into display models
|
||||||
|
- `api/**.zig`: Performs action associated with API call
|
||||||
|
* Transforms DB models into API models
|
||||||
|
* Data validation
|
||||||
|
- TODO: the distinction between what goes in `api.zig` and in its submodules is gross. Refactor?
|
||||||
|
* `migrations.zig`:
|
||||||
|
- Defines database migrations to apply
|
||||||
|
- Should be ran on startup
|
||||||
|
- `util`: utility packages
|
||||||
|
* Components:
|
||||||
|
- `Uuid`: UUID utils (random uuid generation, equality, parsing, printing)
|
||||||
|
* `Uuid.eql`
|
||||||
|
* `Uuid.randV4`
|
||||||
|
* UUID's are serialized to their string representation for JSON, db
|
||||||
|
- `PathIter`: Path segment iterator
|
||||||
|
- `Url`: URL utils (parsing)
|
||||||
|
- `ciutf8`: case-insensitive UTF-8 (TODO: Scrap this, replace with ICU library)
|
||||||
|
- `DateTime`: Time utils
|
||||||
|
- `deepClone(alloc, orig)`/`deepFree(alloc, to_free)`: Utils for cloning and freeing basic data structs
|
||||||
|
* Clones/frees any strings/sub structs within the value
|
||||||
|
- `sql`: SQL library
|
||||||
|
* Supports 2 engines (SQLite, PostgreSQL)
|
||||||
|
* `var my_transaction = try db.begin()`
|
||||||
|
* `const results = try db.query(RowType, "SELECT ...", .{arg_1, ...}, alloc)`
|
||||||
|
- `http`: HTTP Server
|
||||||
|
* The API sucks. Needs a refactor
|
||||||
|
|
119
src/main/api.zig
119
src/main/api.zig
|
@ -94,6 +94,42 @@ pub fn getRandom() std.rand.Random {
|
||||||
return prng.random();
|
return prng.random();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn isAdminSetup(db: *sql.Db) !bool {
|
||||||
|
_ = services.communities.adminCommunityId(db) catch |err| switch (err) {
|
||||||
|
error.NotFound => return false,
|
||||||
|
else => return err,
|
||||||
|
};
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn setupAdmin(db: *sql.Db, origin: []const u8, username: []const u8, password: []const u8, allocator: std.mem.Allocator) !void {
|
||||||
|
const tx = try db.begin();
|
||||||
|
errdefer tx.rollback();
|
||||||
|
var arena = std.heap.ArenaAllocator.init(allocator);
|
||||||
|
defer arena.deinit();
|
||||||
|
|
||||||
|
try tx.setConstraintMode(.deferred);
|
||||||
|
|
||||||
|
const community = try services.communities.create(
|
||||||
|
tx,
|
||||||
|
origin,
|
||||||
|
Uuid.nil,
|
||||||
|
.{ .name = "Cluster Admin", .kind = .admin },
|
||||||
|
);
|
||||||
|
|
||||||
|
const user = try services.users.create(tx, username, password, community.id, .{ .role = .admin }, arena.allocator());
|
||||||
|
|
||||||
|
try services.communities.transferOwnership(tx, community.id, user);
|
||||||
|
|
||||||
|
try tx.commit();
|
||||||
|
|
||||||
|
std.log.info(
|
||||||
|
"Created admin user {s} (id {}) with cluster admin origin {s} (id {})",
|
||||||
|
.{ username, user, origin, community.id },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
pub const ApiSource = struct {
|
pub const ApiSource = struct {
|
||||||
db: *sql.Db,
|
db: *sql.Db,
|
||||||
internal_alloc: std.mem.Allocator,
|
internal_alloc: std.mem.Allocator,
|
||||||
|
@ -103,68 +139,43 @@ pub const ApiSource = struct {
|
||||||
|
|
||||||
const root_username = "root";
|
const root_username = "root";
|
||||||
|
|
||||||
pub fn init(alloc: std.mem.Allocator, cfg: Config, root_password: ?[]const u8, db_conn: *sql.Db) !ApiSource {
|
pub fn init(alloc: std.mem.Allocator, cfg: Config, db_conn: *sql.Db) !ApiSource {
|
||||||
var self = ApiSource{
|
return ApiSource{
|
||||||
.db = db_conn,
|
.db = db_conn,
|
||||||
.internal_alloc = alloc,
|
.internal_alloc = alloc,
|
||||||
.config = cfg,
|
.config = cfg,
|
||||||
};
|
};
|
||||||
|
|
||||||
try migrations.up(db_conn);
|
|
||||||
|
|
||||||
if ((try services.users.lookupByUsername(self.db, root_username, null)) == null) {
|
|
||||||
std.log.info("No cluster root user detected. Creating...", .{});
|
|
||||||
|
|
||||||
// TODO: Fix this
|
|
||||||
const password = root_password orelse return error.NeedRootPassword;
|
|
||||||
var arena = std.heap.ArenaAllocator.init(alloc);
|
|
||||||
defer arena.deinit();
|
|
||||||
const user_id = try services.users.create(self.db, root_username, password, null, .{}, arena.allocator());
|
|
||||||
std.log.debug("Created {s} ID {}", .{ root_username, user_id });
|
|
||||||
}
|
|
||||||
|
|
||||||
return self;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn getCommunityFromHost(self: *ApiSource, host: []const u8) !?Uuid {
|
|
||||||
if (try self.db.queryRow(
|
|
||||||
&.{Uuid},
|
|
||||||
"SELECT id FROM community WHERE host = $1",
|
|
||||||
.{host},
|
|
||||||
null,
|
|
||||||
)) |result| return result[0];
|
|
||||||
|
|
||||||
// Test for cluster admin community
|
|
||||||
if (util.ciutf8.eql(self.config.cluster_host, host)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return error.NoCommunity;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn connectUnauthorized(self: *ApiSource, host: []const u8, alloc: std.mem.Allocator) !Conn {
|
pub fn connectUnauthorized(self: *ApiSource, host: []const u8, alloc: std.mem.Allocator) !Conn {
|
||||||
const community_id = try self.getCommunityFromHost(host);
|
var arena = std.heap.ArenaAllocator.init(alloc);
|
||||||
|
errdefer arena.deinit();
|
||||||
|
|
||||||
|
const community = try services.communities.getByHost(self.db, host, arena.allocator());
|
||||||
|
|
||||||
return Conn{
|
return Conn{
|
||||||
.db = self.db,
|
.db = self.db,
|
||||||
.internal_alloc = self.internal_alloc,
|
.internal_alloc = self.internal_alloc,
|
||||||
.user_id = null,
|
.user_id = null,
|
||||||
.community_id = community_id,
|
.community = community,
|
||||||
.arena = std.heap.ArenaAllocator.init(alloc),
|
.arena = arena,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn connectToken(self: *ApiSource, host: []const u8, token: []const u8, alloc: std.mem.Allocator) !Conn {
|
pub fn connectToken(self: *ApiSource, host: []const u8, token: []const u8, alloc: std.mem.Allocator) !Conn {
|
||||||
const community_id = try self.getCommunityFromHost(host);
|
var arena = std.heap.ArenaAllocator.init(alloc);
|
||||||
|
errdefer arena.deinit();
|
||||||
|
|
||||||
const token_info = try services.auth.tokens.verify(self.db, token, community_id);
|
const community = try services.communities.getByHost(self.db, host, arena.allocator());
|
||||||
|
|
||||||
|
const token_info = try services.auth.tokens.verify(self.db, token, community.id);
|
||||||
|
|
||||||
return Conn{
|
return Conn{
|
||||||
.db = self.db,
|
.db = self.db,
|
||||||
.internal_alloc = self.internal_alloc,
|
.internal_alloc = self.internal_alloc,
|
||||||
.user_id = token_info.user_id,
|
.user_id = token_info.user_id,
|
||||||
.community_id = community_id,
|
.community = community,
|
||||||
.arena = std.heap.ArenaAllocator.init(alloc),
|
.arena = arena,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -176,7 +187,7 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
db: DbConn,
|
db: DbConn,
|
||||||
internal_alloc: std.mem.Allocator, // used *only* for large, internal buffers
|
internal_alloc: std.mem.Allocator, // used *only* for large, internal buffers
|
||||||
user_id: ?Uuid,
|
user_id: ?Uuid,
|
||||||
community_id: ?Uuid,
|
community: services.communities.Community,
|
||||||
arena: std.heap.ArenaAllocator,
|
arena: std.heap.ArenaAllocator,
|
||||||
|
|
||||||
pub fn close(self: *Self) void {
|
pub fn close(self: *Self) void {
|
||||||
|
@ -185,11 +196,11 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
|
|
||||||
fn isAdmin(self: *Self) bool {
|
fn isAdmin(self: *Self) bool {
|
||||||
// TODO
|
// TODO
|
||||||
return self.user_id != null and self.community_id == null;
|
return self.user_id != null and self.community.kind == .admin;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResponse {
|
pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResponse {
|
||||||
const user_id = (try services.users.lookupByUsername(self.db, username, self.community_id)) orelse return error.InvalidLogin;
|
const user_id = (try services.users.lookupByUsername(self.db, username, self.community.id)) orelse return error.InvalidLogin;
|
||||||
try services.auth.passwords.verify(self.db, user_id, password, self.internal_alloc);
|
try services.auth.passwords.verify(self.db, user_id, password, self.internal_alloc);
|
||||||
|
|
||||||
const token = try services.auth.tokens.create(self.db, user_id);
|
const token = try services.auth.tokens.create(self.db, user_id);
|
||||||
|
@ -207,7 +218,7 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
pub fn getTokenInfo(self: *Self) !TokenInfo {
|
pub fn getTokenInfo(self: *Self) !TokenInfo {
|
||||||
if (self.user_id) |user_id| {
|
if (self.user_id) |user_id| {
|
||||||
const result = (try self.db.queryRow(
|
const result = (try self.db.queryRow(
|
||||||
&.{[]const u8},
|
std.meta.Tuple(&.{[]const u8}),
|
||||||
"SELECT username FROM user WHERE id = $1",
|
"SELECT username FROM user WHERE id = $1",
|
||||||
.{user_id},
|
.{user_id},
|
||||||
self.arena.allocator(),
|
self.arena.allocator(),
|
||||||
|
@ -225,7 +236,7 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
return error.PermissionDenied;
|
return error.PermissionDenied;
|
||||||
}
|
}
|
||||||
|
|
||||||
return services.communities.create(self.db, origin, null);
|
return services.communities.create(self.db, origin, self.user_id.?, .{});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn createInvite(self: *Self, options: InviteRequest) !services.invites.Invite {
|
pub fn createInvite(self: *Self, options: InviteRequest) !services.invites.Invite {
|
||||||
|
@ -234,13 +245,13 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
|
|
||||||
const community_id = if (options.to_community) |host| blk: {
|
const community_id = if (options.to_community) |host| blk: {
|
||||||
// You can only specify a different community if you're on the admin domain
|
// You can only specify a different community if you're on the admin domain
|
||||||
if (self.community_id != null) return error.WrongCommunity;
|
if (self.community.kind != .admin) return error.WrongCommunity;
|
||||||
|
|
||||||
// Only admins can invite on the admin domain
|
// Only admins can invite on the admin domain
|
||||||
if (!self.isAdmin()) return error.PermissionDenied;
|
if (!self.isAdmin()) return error.PermissionDenied;
|
||||||
|
|
||||||
break :blk (try services.communities.getByHost(self.db, host, self.arena.allocator())).id;
|
break :blk (try services.communities.getByHost(self.db, host, self.arena.allocator())).id;
|
||||||
} else self.community_id;
|
} else self.community.id;
|
||||||
|
|
||||||
// Users can only make user invites
|
// Users can only make user invites
|
||||||
if (options.invite_type != .user and !self.isAdmin()) return error.PermissionDenied;
|
if (options.invite_type != .user and !self.isAdmin()) return error.PermissionDenied;
|
||||||
|
@ -257,19 +268,19 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
std.log.debug("registering user {s} with code {s}", .{ request.username, request.invite_code });
|
std.log.debug("registering user {s} with code {s}", .{ request.username, request.invite_code });
|
||||||
const invite = try services.invites.getByCode(self.db, request.invite_code, self.arena.allocator());
|
const invite = try services.invites.getByCode(self.db, request.invite_code, self.arena.allocator());
|
||||||
|
|
||||||
if (!Uuid.eql(invite.to_community, self.community_id)) return error.NotFound;
|
if (!Uuid.eql(invite.to_community, self.community.id)) return error.NotFound;
|
||||||
if (invite.max_uses != null and invite.times_used >= invite.max_uses.?) return error.InviteExpired;
|
if (invite.max_uses != null and invite.times_used >= invite.max_uses.?) return error.InviteExpired;
|
||||||
if (invite.expires_at != null and DateTime.now().isAfter(invite.expires_at.?)) return error.InviteExpired;
|
if (invite.expires_at != null and DateTime.now().isAfter(invite.expires_at.?)) return error.InviteExpired;
|
||||||
|
|
||||||
if (self.community_id == null) @panic("Unimplmented");
|
if (self.community.kind == .admin) @panic("Unimplmented");
|
||||||
|
|
||||||
const user_id = try services.users.create(self.db, request.username, request.password, self.community_id, .{ .invite_id = invite.id, .email = request.email }, self.internal_alloc);
|
const user_id = try services.users.create(self.db, request.username, request.password, self.community.id, .{ .invite_id = invite.id, .email = request.email }, self.internal_alloc);
|
||||||
|
|
||||||
switch (invite.invite_type) {
|
switch (invite.invite_type) {
|
||||||
.user => {},
|
.user => {},
|
||||||
.system => @panic("System user invites unimplemented"),
|
.system => @panic("System user invites unimplemented"),
|
||||||
.community_owner => {
|
.community_owner => {
|
||||||
try services.communities.transferOwnership(self.db, self.community_id.?, user_id);
|
try services.communities.transferOwnership(self.db, self.community.id, user_id);
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -283,7 +294,7 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
const user = try services.users.get(self.db, user_id, self.arena.allocator());
|
const user = try services.users.get(self.db, user_id, self.arena.allocator());
|
||||||
|
|
||||||
if (self.user_id == null) {
|
if (self.user_id == null) {
|
||||||
if (!Uuid.eql(self.community_id, user.community_id)) return error.NotFound;
|
if (!Uuid.eql(self.community.id, user.community_id)) return error.NotFound;
|
||||||
}
|
}
|
||||||
|
|
||||||
return UserResponse{
|
return UserResponse{
|
||||||
|
@ -295,7 +306,7 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn createNote(self: *Self, content: []const u8) !NoteResponse {
|
pub fn createNote(self: *Self, content: []const u8) !NoteResponse {
|
||||||
if (self.community_id == null) return error.WrongCommunity;
|
if (self.community.kind == .admin) return error.WrongCommunity;
|
||||||
const user_id = self.user_id orelse return error.TokenRequired;
|
const user_id = self.user_id orelse return error.TokenRequired;
|
||||||
|
|
||||||
const note_id = try services.notes.create(self.db, user_id, content);
|
const note_id = try services.notes.create(self.db, user_id, content);
|
||||||
|
@ -312,7 +323,7 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
|
|
||||||
// Only serve community-specific notes on unauthenticated requests
|
// Only serve community-specific notes on unauthenticated requests
|
||||||
if (self.user_id == null) {
|
if (self.user_id == null) {
|
||||||
if (!Uuid.eql(self.community_id, user.community_id)) return error.NotFound;
|
if (!Uuid.eql(self.community.id, user.community_id)) return error.NotFound;
|
||||||
}
|
}
|
||||||
|
|
||||||
return NoteResponse{
|
return NoteResponse{
|
||||||
|
|
|
@ -24,7 +24,7 @@ pub const passwords = struct {
|
||||||
pub fn verify(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) VerifyError!void {
|
pub fn verify(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) VerifyError!void {
|
||||||
// TODO: This could be done w/o the dynamically allocated hash buf
|
// TODO: This could be done w/o the dynamically allocated hash buf
|
||||||
const hash = (db.queryRow(
|
const hash = (db.queryRow(
|
||||||
&.{[]const u8},
|
std.meta.Tuple(&.{[]const u8}),
|
||||||
"SELECT hashed_password FROM account_password WHERE user_id = $1 LIMIT 1",
|
"SELECT hashed_password FROM account_password WHERE user_id = $1 LIMIT 1",
|
||||||
.{user_id},
|
.{user_id},
|
||||||
alloc,
|
alloc,
|
||||||
|
@ -96,7 +96,7 @@ pub const tokens = struct {
|
||||||
|
|
||||||
fn lookupUserTokenFromHash(db: anytype, hash: []const u8, community_id: Uuid) !?Token.Info {
|
fn lookupUserTokenFromHash(db: anytype, hash: []const u8, community_id: Uuid) !?Token.Info {
|
||||||
return if (try db.queryRow(
|
return if (try db.queryRow(
|
||||||
&.{ Uuid, DateTime },
|
std.meta.Tuple(&.{ Uuid, DateTime }),
|
||||||
\\SELECT user.id, token.issued_at
|
\\SELECT user.id, token.issued_at
|
||||||
\\FROM token JOIN user ON token.user_id = user.id
|
\\FROM token JOIN user ON token.user_id = user.id
|
||||||
\\WHERE user.community_id = $1 AND token.hash = $2
|
\\WHERE user.community_id = $1 AND token.hash = $2
|
||||||
|
@ -115,7 +115,7 @@ pub const tokens = struct {
|
||||||
|
|
||||||
fn lookupSystemTokenFromHash(db: anytype, hash: []const u8) !?Token.Info {
|
fn lookupSystemTokenFromHash(db: anytype, hash: []const u8) !?Token.Info {
|
||||||
return if (try db.queryRow(
|
return if (try db.queryRow(
|
||||||
&.{ Uuid, DateTime },
|
std.meta.Tuple(&.{ Uuid, DateTime }),
|
||||||
\\SELECT user.id, token.issued_at
|
\\SELECT user.id, token.issued_at
|
||||||
\\FROM token JOIN user ON token.user_id = user.id
|
\\FROM token JOIN user ON token.user_id = user.id
|
||||||
\\WHERE user.community_id IS NULL AND token.hash = $1
|
\\WHERE user.community_id IS NULL AND token.hash = $1
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const util = @import("util");
|
const util = @import("util");
|
||||||
const models = @import("../db/models.zig");
|
const sql = @import("sql");
|
||||||
|
|
||||||
const getRandom = @import("../api.zig").getRandom;
|
const getRandom = @import("../api.zig").getRandom;
|
||||||
|
|
||||||
|
@ -26,21 +26,31 @@ pub const Scheme = enum {
|
||||||
pub const Community = struct {
|
pub const Community = struct {
|
||||||
id: Uuid,
|
id: Uuid,
|
||||||
|
|
||||||
owner_id: ?Uuid,
|
owner_id: Uuid,
|
||||||
host: []const u8,
|
host: []const u8,
|
||||||
name: []const u8,
|
name: []const u8,
|
||||||
|
|
||||||
scheme: Scheme,
|
scheme: Scheme,
|
||||||
|
kind: Kind,
|
||||||
created_at: DateTime,
|
created_at: DateTime,
|
||||||
};
|
};
|
||||||
|
|
||||||
fn freeCommunity(alloc: std.mem.Allocator, c: Community) void {
|
pub const Kind = enum {
|
||||||
alloc.free(c.host);
|
admin,
|
||||||
alloc.free(c.name);
|
local,
|
||||||
}
|
|
||||||
|
|
||||||
pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Community {
|
pub fn jsonStringify(val: Kind, _: std.json.StringifyOptions, writer: anytype) !void {
|
||||||
const scheme_len = firstIndexOf(origin, ':') orelse return error.InvalidOrigin;
|
return std.fmt.format(writer, "\"{s}\"", .{@tagName(val)});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const CreateOptions = struct {
|
||||||
|
name: ?[]const u8 = null,
|
||||||
|
kind: Kind = .local,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn create(db: anytype, origin: []const u8, owner: Uuid, options: CreateOptions) CreateError!Community {
|
||||||
|
const scheme_len = std.mem.indexOfScalar(u8, origin, ':') orelse return error.InvalidOrigin;
|
||||||
const scheme_str = origin[0..scheme_len];
|
const scheme_str = origin[0..scheme_len];
|
||||||
const scheme = std.meta.stringToEnum(Scheme, scheme_str) orelse return error.UnsupportedScheme;
|
const scheme = std.meta.stringToEnum(Scheme, scheme_str) orelse return error.UnsupportedScheme;
|
||||||
|
|
||||||
|
@ -55,10 +65,10 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co
|
||||||
// community cannot use non-default ports (except for testing)
|
// community cannot use non-default ports (except for testing)
|
||||||
// NOTE: Do not add, say localhost and localhost:80 or bugs may happen.
|
// NOTE: Do not add, say localhost and localhost:80 or bugs may happen.
|
||||||
// Avoid using non-default ports unless a test can't be conducted without it.
|
// Avoid using non-default ports unless a test can't be conducted without it.
|
||||||
if (firstIndexOf(host, ':') != null and builtin.mode != .Debug) return error.InvalidOrigin;
|
if (std.mem.indexOfScalar(u8, host, ':') != null and builtin.mode != .Debug) return error.InvalidOrigin;
|
||||||
|
|
||||||
// community cannot be hosted on a path
|
// community cannot be hosted on a path
|
||||||
if (firstIndexOf(host, '/') != null) return error.InvalidOrigin;
|
if (std.mem.indexOfScalar(u8, host, '/') != null) return error.InvalidOrigin;
|
||||||
|
|
||||||
// Require TLS on production builds
|
// Require TLS on production builds
|
||||||
if (scheme != .https and builtin.mode != .Debug) return error.UnsupportedScheme;
|
if (scheme != .https and builtin.mode != .Debug) return error.UnsupportedScheme;
|
||||||
|
@ -67,14 +77,15 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co
|
||||||
|
|
||||||
const community = Community{
|
const community = Community{
|
||||||
.id = id,
|
.id = id,
|
||||||
.owner_id = null,
|
.owner_id = owner,
|
||||||
.host = host,
|
.host = host,
|
||||||
.name = name orelse host,
|
.name = options.name orelse host,
|
||||||
.scheme = scheme,
|
.scheme = scheme,
|
||||||
|
.kind = options.kind,
|
||||||
.created_at = DateTime.now(),
|
.created_at = DateTime.now(),
|
||||||
};
|
};
|
||||||
|
|
||||||
if ((try db.queryRow(&.{Uuid}, "SELECT id FROM community WHERE host = $1", .{host}, null)) != null) {
|
if ((try db.queryRow(std.meta.Tuple(&.{Uuid}), "SELECT id FROM community WHERE host = $1", .{host}, null)) != null) {
|
||||||
return error.CommunityExists;
|
return error.CommunityExists;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,25 +94,13 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co
|
||||||
return community;
|
return community;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn firstIndexOf(str: []const u8, ch: u8) ?usize {
|
|
||||||
for (str) |c, i| {
|
|
||||||
if (c == ch) return i;
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) !Community {
|
pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) !Community {
|
||||||
const result = (try db.queryRow(&.{ Uuid, ?Uuid, []const u8, []const u8, Scheme, DateTime }, "SELECT id, owner_id, host, name, scheme, created_at FROM community WHERE host = $1", .{host}, alloc)) orelse return error.NotFound;
|
return (try db.queryRow(
|
||||||
|
Community,
|
||||||
return Community{
|
std.fmt.comptimePrint("SELECT {s} FROM community WHERE host = $1", .{comptime sql.fieldList(Community)}),
|
||||||
.id = result[0],
|
.{host},
|
||||||
.owner_id = result[1],
|
alloc,
|
||||||
.host = result[2],
|
)) orelse return error.NotFound;
|
||||||
.name = result[3],
|
|
||||||
.scheme = result[4],
|
|
||||||
.created_at = result[5],
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn transferOwnership(db: anytype, community_id: Uuid, new_owner: Uuid) !void {
|
pub fn transferOwnership(db: anytype, community_id: Uuid, new_owner: Uuid) !void {
|
||||||
|
@ -247,7 +246,7 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) ![]Communit
|
||||||
errdefer alloc.free(result_buf);
|
errdefer alloc.free(result_buf);
|
||||||
|
|
||||||
var count: usize = 0;
|
var count: usize = 0;
|
||||||
errdefer for (result_buf[0..count]) |c| freeCommunity(alloc, c);
|
errdefer for (result_buf[0..count]) |c| util.deepFree(alloc, c);
|
||||||
|
|
||||||
for (result_buf) |*c| {
|
for (result_buf) |*c| {
|
||||||
const row = results.row(alloc) orelse break;
|
const row = results.row(alloc) orelse break;
|
||||||
|
@ -267,3 +266,14 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) ![]Communit
|
||||||
|
|
||||||
return result_buf[0..count];
|
return result_buf[0..count];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn adminCommunityId(db: anytype) !Uuid {
|
||||||
|
const row = (try db.queryRow(
|
||||||
|
std.meta.Tuple(&.{Uuid}),
|
||||||
|
"SELECT id FROM community WHERE kind = 'admin' LIMIT 1",
|
||||||
|
.{},
|
||||||
|
null,
|
||||||
|
)) orelse return error.NotFound;
|
||||||
|
|
||||||
|
return row[0];
|
||||||
|
}
|
||||||
|
|
|
@ -130,7 +130,7 @@ pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: Invit
|
||||||
|
|
||||||
pub fn getByCode(db: anytype, code: []const u8, alloc: std.mem.Allocator) !Invite {
|
pub fn getByCode(db: anytype, code: []const u8, alloc: std.mem.Allocator) !Invite {
|
||||||
const code_clone = try cloneStr(code, alloc);
|
const code_clone = try cloneStr(code, alloc);
|
||||||
const info = (try db.queryRow(&.{ Uuid, Uuid, Uuid, []const u8, DateTime, ?DateTime, InviteCount, ?InviteCount, InviteType },
|
const info = (try db.queryRow(std.meta.Tuple(&.{ Uuid, Uuid, Uuid, []const u8, DateTime, ?DateTime, InviteCount, ?InviteCount, InviteType }),
|
||||||
\\SELECT
|
\\SELECT
|
||||||
\\ invite.id, invite.created_by, invite.to_community, invite.name,
|
\\ invite.id, invite.created_by, invite.to_community, invite.name,
|
||||||
\\ invite.created_at, invite.expires_at,
|
\\ invite.created_at, invite.expires_at,
|
||||||
|
|
|
@ -41,7 +41,7 @@ pub fn create(
|
||||||
|
|
||||||
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !Note {
|
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !Note {
|
||||||
const result = (try db.queryRow(
|
const result = (try db.queryRow(
|
||||||
&.{ Uuid, []const u8, DateTime },
|
std.meta.Tuple(&.{ Uuid, []const u8, DateTime }),
|
||||||
\\SELECT author_id, content, created_at
|
\\SELECT author_id, content, created_at
|
||||||
\\FROM note
|
\\FROM note
|
||||||
\\WHERE id = $1
|
\\WHERE id = $1
|
||||||
|
|
|
@ -21,7 +21,7 @@ const DbUser = struct {
|
||||||
id: Uuid,
|
id: Uuid,
|
||||||
|
|
||||||
username: []const u8,
|
username: []const u8,
|
||||||
community_id: ?Uuid,
|
community_id: Uuid,
|
||||||
};
|
};
|
||||||
|
|
||||||
const DbLocalUser = struct {
|
const DbLocalUser = struct {
|
||||||
|
@ -31,52 +31,43 @@ const DbLocalUser = struct {
|
||||||
email: ?[]const u8,
|
email: ?[]const u8,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub const Role = enum {
|
||||||
|
user,
|
||||||
|
admin,
|
||||||
|
};
|
||||||
|
|
||||||
pub const CreateOptions = struct {
|
pub const CreateOptions = struct {
|
||||||
invite_id: ?Uuid = null,
|
invite_id: ?Uuid = null,
|
||||||
email: ?[]const u8 = null,
|
email: ?[]const u8 = null,
|
||||||
|
role: Role = .user,
|
||||||
};
|
};
|
||||||
|
|
||||||
fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid {
|
fn lookupByUsernameInternal(db: anytype, username: []const u8, community_id: Uuid) CreateError!?Uuid {
|
||||||
return if (try db.queryRow(
|
return if (db.queryRow(
|
||||||
&.{Uuid},
|
std.meta.Tuple(&.{Uuid}),
|
||||||
"SELECT user.id FROM user WHERE community_id IS NULL AND username = $1",
|
|
||||||
.{username},
|
|
||||||
null,
|
|
||||||
)) |result|
|
|
||||||
result[0]
|
|
||||||
else
|
|
||||||
null;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn lookupUserByUsername(db: anytype, username: []const u8, community_id: Uuid) !?Uuid {
|
|
||||||
return if (try db.queryRow(
|
|
||||||
&.{Uuid},
|
|
||||||
"SELECT user.id FROM user WHERE community_id = $1 AND username = $2",
|
"SELECT user.id FROM user WHERE community_id = $1 AND username = $2",
|
||||||
.{ community_id, username },
|
.{ community_id, username },
|
||||||
null,
|
null,
|
||||||
)) |result|
|
) catch return error.DbError) |result|
|
||||||
result[0]
|
result[0]
|
||||||
else
|
else
|
||||||
null;
|
null;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn lookupByUsername(db: anytype, username: []const u8, community_id: ?Uuid) !?Uuid {
|
pub fn lookupByUsername(db: anytype, username: []const u8, community_id: Uuid) CreateError!Uuid {
|
||||||
return if (community_id) |id|
|
return (lookupByUsernameInternal(db, username, community_id) catch return error.DbError) orelse error.NotFound;
|
||||||
lookupUserByUsername(db, username, id) catch return error.DbError
|
|
||||||
else
|
|
||||||
lookupSystemUserByUsername(db, username) catch return error.DbError;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create(
|
pub fn create(
|
||||||
db: anytype,
|
db: anytype,
|
||||||
username: []const u8,
|
username: []const u8,
|
||||||
password: []const u8,
|
password: []const u8,
|
||||||
community_id: ?Uuid,
|
community_id: Uuid,
|
||||||
options: CreateOptions,
|
options: CreateOptions,
|
||||||
password_alloc: std.mem.Allocator,
|
password_alloc: std.mem.Allocator,
|
||||||
) CreateError!Uuid {
|
) CreateError!Uuid {
|
||||||
const id = Uuid.randV4(getRandom());
|
const id = Uuid.randV4(getRandom());
|
||||||
if ((try lookupByUsername(db, username, community_id)) != null) {
|
if ((try lookupByUsernameInternal(db, username, community_id)) != null) {
|
||||||
return error.UsernameTaken;
|
return error.UsernameTaken;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,7 +99,7 @@ pub const User = struct {
|
||||||
|
|
||||||
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !User {
|
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !User {
|
||||||
const result = (try db.queryRow(
|
const result = (try db.queryRow(
|
||||||
&.{ []const u8, []const u8, Uuid, DateTime },
|
std.meta.Tuple(&.{ []const u8, []const u8, Uuid, DateTime }),
|
||||||
\\SELECT user.username, community.host, community.id, user.created_at
|
\\SELECT user.username, community.host, community.id, user.created_at
|
||||||
\\FROM user JOIN community ON user.community_id = community.id
|
\\FROM user JOIN community ON user.community_id = community.id
|
||||||
\\WHERE user.id = $1
|
\\WHERE user.id = $1
|
||||||
|
|
|
@ -5,7 +5,7 @@ const http = @import("http");
|
||||||
const util = @import("util");
|
const util = @import("util");
|
||||||
|
|
||||||
pub const api = @import("./api.zig");
|
pub const api = @import("./api.zig");
|
||||||
const models = @import("./db/models.zig");
|
const migrations = @import("./migrations.zig");
|
||||||
const Uuid = util.Uuid;
|
const Uuid = util.Uuid;
|
||||||
const c = @import("./controllers.zig");
|
const c = @import("./controllers.zig");
|
||||||
|
|
||||||
|
@ -98,21 +98,49 @@ fn loadConfig(alloc: std.mem.Allocator) !Config {
|
||||||
return std.json.parse(Config, &ts, .{ .allocator = alloc });
|
return std.json.parse(Config, &ts, .{ .allocator = alloc });
|
||||||
}
|
}
|
||||||
|
|
||||||
const root_password_envvar = "CLUSTER_ROOT_PASSWORD";
|
const admin_origin_envvar = "CLUSTER_ADMIN_ORIGIN";
|
||||||
|
const admin_username_envvar = "CLUSTER_ADMIN_USERNAME";
|
||||||
|
const admin_password_envvar = "CLUSTER_ADMIN_PASSWORD";
|
||||||
|
|
||||||
|
fn runAdminSetup(db: *sql.Db, alloc: std.mem.Allocator) !void {
|
||||||
|
const origin = std.os.getenv(admin_origin_envvar) orelse return error.MissingArgument;
|
||||||
|
const username = std.os.getenv(admin_username_envvar) orelse return error.MissingArgument;
|
||||||
|
const password = std.os.getenv(admin_password_envvar) orelse return error.MissingArgument;
|
||||||
|
|
||||||
|
try api.setupAdmin(db, origin, username, password, alloc);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prepareDb(db: *sql.Db, alloc: std.mem.Allocator) !void {
|
||||||
|
try migrations.up(db);
|
||||||
|
|
||||||
|
if (!try api.isAdminSetup(db)) {
|
||||||
|
std.log.info("Performing first-time admin creation...", .{});
|
||||||
|
|
||||||
|
runAdminSetup(db, alloc) catch |err| switch (err) {
|
||||||
|
error.MissingArgument => {
|
||||||
|
std.log.err(
|
||||||
|
\\First time setup required but arguments not provided.
|
||||||
|
\\Please provide the following arguments via environment variable:
|
||||||
|
\\- {s}: The origin to serve the cluster admin panel at (ex: https://admin.example.com)
|
||||||
|
\\- {s}: The username for the initial cluster operator
|
||||||
|
\\- {s}: The password for the initial cluster operator
|
||||||
|
,
|
||||||
|
.{ admin_origin_envvar, admin_username_envvar, admin_password_envvar },
|
||||||
|
);
|
||||||
|
std.os.exit(1);
|
||||||
|
},
|
||||||
|
else => return err,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn main() anyerror!void {
|
pub fn main() anyerror!void {
|
||||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||||
var cfg = try loadConfig(gpa.allocator());
|
var cfg = try loadConfig(gpa.allocator());
|
||||||
var db_conn = try sql.Db.open(cfg.db);
|
var db_conn = try sql.Db.open(cfg.db);
|
||||||
var api_src = api.ApiSource.init(gpa.allocator(), cfg, std.os.getenv(root_password_envvar), &db_conn) catch |err| switch (err) {
|
try prepareDb(&db_conn, gpa.allocator());
|
||||||
error.NeedRootPassword => {
|
|
||||||
std.log.err(
|
var api_src = try api.ApiSource.init(gpa.allocator(), cfg, &db_conn);
|
||||||
"No root user created and no password specified. Please provide the password for the root user by the ${s} environment variable for initial startup. This only needs to be done once",
|
|
||||||
.{root_password_envvar},
|
|
||||||
);
|
|
||||||
return err;
|
|
||||||
},
|
|
||||||
else => return err,
|
|
||||||
};
|
|
||||||
var srv = try RequestServer.init(gpa.allocator(), &api_src, cfg);
|
var srv = try RequestServer.init(gpa.allocator(), &api_src, cfg);
|
||||||
api.initThreadPrng(@bitCast(u64, std.time.milliTimestamp()));
|
api.initThreadPrng(@bitCast(u64, std.time.milliTimestamp()));
|
||||||
return srv.listenAndRun(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable);
|
return srv.listenAndRun(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable);
|
||||||
|
|
|
@ -38,7 +38,7 @@ fn execScript(db: *sql.Db, script: []const u8, alloc: std.mem.Allocator) !void {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !bool {
|
fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !bool {
|
||||||
const row = (try db.queryRow(&.{i32}, "SELECT COUNT(*) FROM migration WHERE name = $1", .{name}, alloc)) orelse return false;
|
const row = (try db.queryRow(std.meta.Tuple(&.{i32}), "SELECT COUNT(*) FROM migration WHERE name = $1", .{name}, alloc)) orelse return false;
|
||||||
return row[0] != 0;
|
return row[0] != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ const migrations: []const Migration = &.{
|
||||||
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
\\ expires_at TIMESTAMPTZ,
|
\\ expires_at TIMESTAMPTZ,
|
||||||
\\
|
\\
|
||||||
\\ type TEXT NOT NULL CHECK (type in ('system', 'community_owner', 'user'))
|
\\ type TEXT NOT NULL CHECK (type in ('system_user', 'community_owner', 'user'))
|
||||||
\\);
|
\\);
|
||||||
\\ALTER TABLE local_user ADD COLUMN invite_id TEXT REFERENCES invite(id);
|
\\ALTER TABLE local_user ADD COLUMN invite_id TEXT REFERENCES invite(id);
|
||||||
,
|
,
|
||||||
|
@ -171,14 +171,15 @@ const migrations: []const Migration = &.{
|
||||||
\\ name TEXT NOT NULL,
|
\\ name TEXT NOT NULL,
|
||||||
\\ host TEXT NOT NULL UNIQUE,
|
\\ host TEXT NOT NULL UNIQUE,
|
||||||
\\ scheme TEXT NOT NULL CHECK (scheme IN ('http', 'https')),
|
\\ scheme TEXT NOT NULL CHECK (scheme IN ('http', 'https')),
|
||||||
|
\\ kind TEXT NOT NULL CHECK (kind in ('admin', 'local')),
|
||||||
\\
|
\\
|
||||||
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
\\);
|
\\);
|
||||||
\\ALTER TABLE user ADD COLUMN community_id TEXT REFERENCES community(id);
|
\\ALTER TABLE user ADD COLUMN community_id TEXT REFERENCES community(id);
|
||||||
\\ALTER TABLE invite ADD COLUMN to_community TEXT REFERENCES community(id);
|
\\ALTER TABLE invite ADD COLUMN community_id TEXT REFERENCES community(id);
|
||||||
,
|
,
|
||||||
.down =
|
.down =
|
||||||
\\ALTER TABLE invite DROP COLUMN to_community;
|
\\ALTER TABLE invite DROP COLUMN community_id;
|
||||||
\\ALTER TABLE user DROP COLUMN community_id;
|
\\ALTER TABLE user DROP COLUMN community_id;
|
||||||
\\DROP TABLE community;
|
\\DROP TABLE community;
|
||||||
,
|
,
|
||||||
|
|
150
src/sql/lib.zig
150
src/sql/lib.zig
|
@ -22,6 +22,41 @@ pub const Config = union(Engine) {
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub const QueryError = error{
|
||||||
|
OutOfMemory,
|
||||||
|
ConnectionLost,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn fieldList(comptime RowType: type) []const u8 {
|
||||||
|
comptime {
|
||||||
|
const fields = std.meta.fieldNames(RowType);
|
||||||
|
const separator = ", ";
|
||||||
|
|
||||||
|
if (fields.len == 0) return "";
|
||||||
|
|
||||||
|
var size: usize = 1; // 1 for null terminator
|
||||||
|
for (fields) |f| size += f.len + separator.len;
|
||||||
|
size -= separator.len;
|
||||||
|
|
||||||
|
var buf = std.mem.zeroes([size]u8);
|
||||||
|
|
||||||
|
// can't use std.mem.join because of problems with comptime allocation
|
||||||
|
// https://github.com/ziglang/zig/issues/5873#issuecomment-1001778218
|
||||||
|
//var fba = std.heap.FixedBufferAllocator.init(&buf);
|
||||||
|
//return (std.mem.join(fba.allocator(), separator, fields) catch unreachable) ++ " ";
|
||||||
|
|
||||||
|
var buf_idx = 0;
|
||||||
|
for (fields) |f, i| {
|
||||||
|
std.mem.copy(u8, buf[buf_idx..], f);
|
||||||
|
buf_idx += f.len;
|
||||||
|
if (i != fields.len - 1) std.mem.copy(u8, buf[buf_idx..], separator);
|
||||||
|
buf_idx += separator.len;
|
||||||
|
}
|
||||||
|
|
||||||
|
return &buf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//pub const OpenError = sqlite.OpenError | postgres.OpenError;
|
//pub const OpenError = sqlite.OpenError | postgres.OpenError;
|
||||||
const RawResults = union(Engine) {
|
const RawResults = union(Engine) {
|
||||||
postgres: postgres.Results,
|
postgres: postgres.Results,
|
||||||
|
@ -33,17 +68,50 @@ const RawResults = union(Engine) {
|
||||||
.sqlite => |lite| lite.finish(),
|
.sqlite => |lite| lite.finish(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn columnCount(self: RawResults) u15 {
|
||||||
|
return switch (self) {
|
||||||
|
.postgres => |pg| pg.columnCount(),
|
||||||
|
.sqlite => |lite| lite.columnCount(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
fn columnNameToIndex(self: RawResults, name: []const u8) !u15 {
|
||||||
|
return try switch (self) {
|
||||||
|
.postgres => |pg| pg.columnNameToIndex(name),
|
||||||
|
.sqlite => |lite| lite.columnNameToIndex(name),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
fn row(self: *RawResults) !?Row {
|
||||||
|
return switch (self.*) {
|
||||||
|
.postgres => |*pg| if (try pg.row()) |r| Row{ .postgres = r } else null,
|
||||||
|
.sqlite => |*lite| if (try lite.row()) |r| Row{ .sqlite = r } else null,
|
||||||
|
};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Represents a set of results.
|
// Represents a set of results.
|
||||||
// row() must be called until it returns null, or the query may not complete
|
// row() must be called until it returns null, or the query may not complete
|
||||||
// Must be deallocated by a call to finish()
|
// Must be deallocated by a call to finish()
|
||||||
pub fn Results(comptime result_types: []const type) type {
|
pub fn Results(comptime T: type) type {
|
||||||
|
// would normally make this a declaration of the struct, but it causes the compiler to crash
|
||||||
|
const fields = std.meta.fields(T);
|
||||||
return struct {
|
return struct {
|
||||||
const Self = @This();
|
const Self = @This();
|
||||||
const RowTuple = std.meta.Tuple(result_types);
|
|
||||||
|
|
||||||
underlying: RawResults,
|
underlying: RawResults,
|
||||||
|
column_indices: [fields.len]u15,
|
||||||
|
|
||||||
|
fn from(underlying: RawResults) !Self {
|
||||||
|
return Self{ .underlying = underlying, .column_indices = blk: {
|
||||||
|
var indices: [fields.len]u15 = undefined;
|
||||||
|
inline for (fields) |f, i| {
|
||||||
|
indices[i] = if (!std.meta.trait.isTuple(T)) try underlying.columnNameToIndex(f.name) else i;
|
||||||
|
}
|
||||||
|
break :blk indices;
|
||||||
|
} };
|
||||||
|
}
|
||||||
|
|
||||||
pub fn finish(self: Self) void {
|
pub fn finish(self: Self) void {
|
||||||
self.underlying.finish();
|
self.underlying.finish();
|
||||||
|
@ -52,31 +120,30 @@ pub fn Results(comptime result_types: []const type) type {
|
||||||
// can be used as an optimization to reduce memory reallocation
|
// can be used as an optimization to reduce memory reallocation
|
||||||
// only works on postgres
|
// only works on postgres
|
||||||
pub fn rowCount(self: Self) ?usize {
|
pub fn rowCount(self: Self) ?usize {
|
||||||
return switch (self.underlying) {
|
return self.underlying.rowCount();
|
||||||
.postgres => |pg| pg.rowCount(),
|
|
||||||
.sqlite => null, // not possible without repeating the query
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn row(self: *Self, alloc: ?Allocator) !?RowTuple {
|
// Returns the next row of results, or null if there are no more rows.
|
||||||
const row_val = switch (self.underlying) {
|
// Caller owns all memory allocated. The entire object can be deallocated with a
|
||||||
.postgres => |*pg| if (try pg.row()) |r| Row{ .postgres = r } else return null,
|
// call to util.deepFree
|
||||||
.sqlite => |*lite| if (try lite.row()) |r| Row{ .sqlite = r } else return null,
|
pub fn row(self: *Self, alloc: ?Allocator) !?T {
|
||||||
};
|
if (try self.underlying.row()) |row_val| {
|
||||||
|
var result: T = undefined;
|
||||||
|
var fields_allocated: usize = 0;
|
||||||
|
errdefer inline for (fields) |f, i| {
|
||||||
|
// Iteration bounds must be defined at comptime (inline for) but the number of fields we could
|
||||||
|
// successfully allocate is defined at runtime. So we iterate over the entire field array and
|
||||||
|
// conditionally deallocate fields in the loop.
|
||||||
|
if (i < fields_allocated) util.deepFree(alloc, @field(result, f.name));
|
||||||
|
};
|
||||||
|
|
||||||
var result: RowTuple = undefined;
|
inline for (fields) |f, i| {
|
||||||
var fields_allocated = [_]bool{false} ** result.len;
|
@field(result, f.name) = try row_val.get(f.field_type, self.column_indices[i], alloc);
|
||||||
errdefer {
|
fields_allocated += 1;
|
||||||
inline for (result_types) |_, i| {
|
|
||||||
if (fields_allocated[i]) util.deepFree(alloc, result[i]);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
inline for (result_types) |T, i| {
|
|
||||||
result[i] = try row_val.get(T, i, alloc);
|
|
||||||
fields_allocated[i] = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
} else return null;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -136,26 +203,26 @@ pub const Db = struct {
|
||||||
|
|
||||||
pub fn queryWithOptions(
|
pub fn queryWithOptions(
|
||||||
self: *Db,
|
self: *Db,
|
||||||
comptime result_types: []const type,
|
comptime RowType: type,
|
||||||
sql: [:0]const u8,
|
sql: [:0]const u8,
|
||||||
args: anytype,
|
args: anytype,
|
||||||
opt: QueryOptions,
|
opt: QueryOptions,
|
||||||
) !Results(result_types) {
|
) !Results(RowType) {
|
||||||
if (self.tx_open) return error.TransactionOpen;
|
if (self.tx_open) return error.TransactionOpen;
|
||||||
// Create fake transaction to use its functions
|
// Create fake transaction to use its functions
|
||||||
return (Tx{ .db = self }).queryWithOptions(result_types, sql, args, opt);
|
return (Tx{ .db = self }).queryWithOptions(RowType, sql, args, opt);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn query(
|
pub fn query(
|
||||||
self: *Db,
|
self: *Db,
|
||||||
comptime result_types: []const type,
|
comptime RowType: type,
|
||||||
sql: [:0]const u8,
|
sql: [:0]const u8,
|
||||||
args: anytype,
|
args: anytype,
|
||||||
alloc: ?Allocator,
|
alloc: ?Allocator,
|
||||||
) !Results(result_types) {
|
) !Results(RowType) {
|
||||||
if (self.tx_open) return error.TransactionOpen;
|
if (self.tx_open) return error.TransactionOpen;
|
||||||
// Create fake transaction to use its functions
|
// Create fake transaction to use its functions
|
||||||
return (Tx{ .db = self }).query(result_types, sql, args, alloc);
|
return (Tx{ .db = self }).query(RowType, sql, args, alloc);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn exec(
|
pub fn exec(
|
||||||
|
@ -171,14 +238,14 @@ pub const Db = struct {
|
||||||
|
|
||||||
pub fn queryRow(
|
pub fn queryRow(
|
||||||
self: *Db,
|
self: *Db,
|
||||||
comptime result_types: []const type,
|
comptime RowType: type,
|
||||||
sql: [:0]const u8,
|
sql: [:0]const u8,
|
||||||
args: anytype,
|
args: anytype,
|
||||||
alloc: ?Allocator,
|
alloc: ?Allocator,
|
||||||
) !?Results(result_types).RowTuple {
|
) !?RowType {
|
||||||
if (self.tx_open) return error.TransactionOpen;
|
if (self.tx_open) return error.TransactionOpen;
|
||||||
// Create fake transaction to use its functions
|
// Create fake transaction to use its functions
|
||||||
return (Tx{ .db = self }).queryRow(result_types, sql, args, alloc);
|
return (Tx{ .db = self }).queryRow(RowType, sql, args, alloc);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn insert(
|
pub fn insert(
|
||||||
|
@ -222,23 +289,23 @@ pub const Tx = struct {
|
||||||
|
|
||||||
pub fn queryWithOptions(
|
pub fn queryWithOptions(
|
||||||
self: Tx,
|
self: Tx,
|
||||||
comptime result_types: []const type,
|
comptime RowType: type,
|
||||||
sql: [:0]const u8,
|
sql: [:0]const u8,
|
||||||
args: anytype,
|
args: anytype,
|
||||||
options: QueryOptions,
|
options: QueryOptions,
|
||||||
) !Results(result_types) {
|
) !Results(RowType) {
|
||||||
return Results(result_types){ .underlying = try self.queryInternal(sql, args, options) };
|
return Results(RowType).from(try self.queryInternal(sql, args, options));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Executes a query and returns the result set
|
// Executes a query and returns the result set
|
||||||
pub fn query(
|
pub fn query(
|
||||||
self: Tx,
|
self: Tx,
|
||||||
comptime result_types: []const type,
|
comptime RowType: type,
|
||||||
sql: [:0]const u8,
|
sql: [:0]const u8,
|
||||||
args: anytype,
|
args: anytype,
|
||||||
alloc: ?Allocator,
|
alloc: ?Allocator,
|
||||||
) !Results(result_types) {
|
) !Results(RowType) {
|
||||||
return self.queryWithOptions(result_types, sql, args, .{ .prep_allocator = alloc });
|
return self.queryWithOptions(RowType, sql, args, .{ .prep_allocator = alloc });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Executes a query without returning results
|
// Executes a query without returning results
|
||||||
|
@ -248,19 +315,20 @@ pub const Tx = struct {
|
||||||
args: anytype,
|
args: anytype,
|
||||||
alloc: ?Allocator,
|
alloc: ?Allocator,
|
||||||
) !void {
|
) !void {
|
||||||
_ = try self.queryRow(&.{}, sql, args, alloc);
|
_ = try self.queryRow(std.meta.Tuple(&.{}), sql, args, alloc);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Runs a query and returns a single row
|
// Runs a query and returns a single row
|
||||||
pub fn queryRow(
|
pub fn queryRow(
|
||||||
self: Tx,
|
self: Tx,
|
||||||
comptime result_types: []const type,
|
comptime RowType: type,
|
||||||
q: [:0]const u8,
|
q: [:0]const u8,
|
||||||
args: anytype,
|
args: anytype,
|
||||||
alloc: ?Allocator,
|
alloc: ?Allocator,
|
||||||
) !?Results(result_types).RowTuple {
|
) !?RowType {
|
||||||
var results = try self.query(result_types, q, args, alloc);
|
var results = try self.query(RowType, q, args, alloc);
|
||||||
defer results.finish();
|
defer results.finish();
|
||||||
|
@compileLog(args);
|
||||||
|
|
||||||
const row = (try results.row(alloc)) orelse return null;
|
const row = (try results.row(alloc)) orelse return null;
|
||||||
errdefer util.deepFree(alloc, row);
|
errdefer util.deepFree(alloc, row);
|
||||||
|
|
|
@ -26,6 +26,16 @@ pub const Results = struct {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn columnCount(self: Results) u15 {
|
||||||
|
return @intCast(u15, c.PQnfields(self.result));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn columnNameToIndex(self: Results, name: []const u8) !u15 {
|
||||||
|
const idx = c.PQfnumber(self.result, name.ptr);
|
||||||
|
if (idx == -1) return error.ColumnNotFound;
|
||||||
|
return @intCast(u15, idx);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn finish(self: Results) void {
|
pub fn finish(self: Results) void {
|
||||||
c.PQclear(self.result);
|
c.PQclear(self.result);
|
||||||
}
|
}
|
||||||
|
@ -89,8 +99,8 @@ pub const Db = struct {
|
||||||
if (comptime args.len > 0) {
|
if (comptime args.len > 0) {
|
||||||
var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired);
|
var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired);
|
||||||
defer arena.deinit();
|
defer arena.deinit();
|
||||||
const params = try arena.allocator().alloc(?[*]const u8, args.len);
|
const params = try arena.allocator().alloc(?[*:0]const u8, args.len);
|
||||||
inline for (args) |a, i| params[i] = if (try common.prepareParamText(&arena, a)) |slice| slice.ptr else null;
|
inline for (args) |arg, i| params[i] = if (try common.prepareParamText(&arena, arg)) |slice| slice.ptr else null;
|
||||||
|
|
||||||
break :blk c.PQexecParams(self.conn, sql.ptr, @intCast(c_int, params.len), null, params.ptr, null, null, format_text);
|
break :blk c.PQexecParams(self.conn, sql.ptr, @intCast(c_int, params.len), null, params.ptr, null, null, format_text);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -184,7 +184,6 @@ pub const Results = struct {
|
||||||
return switch (c.sqlite3_step(self.stmt)) {
|
return switch (c.sqlite3_step(self.stmt)) {
|
||||||
c.SQLITE_ROW => Row{ .stmt = self.stmt, .db = self.db },
|
c.SQLITE_ROW => Row{ .stmt = self.stmt, .db = self.db },
|
||||||
c.SQLITE_DONE => null,
|
c.SQLITE_DONE => null,
|
||||||
|
|
||||||
else => |err| handleUnexpectedError(self.db, err, self.getGeneratingSql()),
|
else => |err| handleUnexpectedError(self.db, err, self.getGeneratingSql()),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -193,6 +192,28 @@ pub const Results = struct {
|
||||||
const ptr = c.sqlite3_sql(self.stmt) orelse return null;
|
const ptr = c.sqlite3_sql(self.stmt) orelse return null;
|
||||||
return ptr[0..std.mem.len(ptr)];
|
return ptr[0..std.mem.len(ptr)];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn columnCount(self: Results) u15 {
|
||||||
|
return @intCast(u15, c.sqlite3_column_count(self.stmt));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn columnName(self: Results, idx: u15) ![]const u8 {
|
||||||
|
return if (c.sqlite3_column_name(self.stmt, idx)) |ptr|
|
||||||
|
ptr[0..std.mem.len(ptr)]
|
||||||
|
else
|
||||||
|
return error.OutOfMemory;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn columnNameToIndex(self: Results, name: []const u8) !u15 {
|
||||||
|
var i: u15 = 0;
|
||||||
|
const count = self.columnCount();
|
||||||
|
while (i < count) : (i += 1) {
|
||||||
|
const column = try self.columnName(i);
|
||||||
|
if (std.mem.eql(u8, name, column)) return i;
|
||||||
|
}
|
||||||
|
|
||||||
|
return error.ColumnNotFound;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Row = struct {
|
pub const Row = struct {
|
||||||
|
|
|
@ -6,10 +6,22 @@ pub const DateTime = @import("./DateTime.zig");
|
||||||
pub const PathIter = @import("./PathIter.zig");
|
pub const PathIter = @import("./PathIter.zig");
|
||||||
pub const Url = @import("./Url.zig");
|
pub const Url = @import("./Url.zig");
|
||||||
|
|
||||||
pub fn cloneStr(str: []const u8, alloc: std.mem.Allocator) ![]const u8 {
|
fn comptimeJoinSlice(comptime separator: []const u8, comptime slices: []const []const u8) []u8 {
|
||||||
var new = try alloc.alloc(u8, str.len);
|
comptime {
|
||||||
std.mem.copy(u8, new, str);
|
var size: usize = 1; // 1 for null terminator
|
||||||
return new;
|
for (slices) |s| size += s.len + separator.len;
|
||||||
|
if (slices.len != 0) size -= separator.len;
|
||||||
|
|
||||||
|
var buf = std.mem.zeroes([size]u8);
|
||||||
|
var fba = std.heap.fixedBufferAllocator(&buf);
|
||||||
|
|
||||||
|
return std.mem.join(fba.allocator(), separator, slices);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn comptimeJoin(comptime separator: []const u8, comptime slices: []const []const u8) *const [comptimeJoinSlice(separator, slices):0]u8 {
|
||||||
|
const slice = comptimeJoinSlice(separator, slices);
|
||||||
|
return slice[0..slice.len];
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn deepFree(alloc: ?std.mem.Allocator, val: anytype) void {
|
pub fn deepFree(alloc: ?std.mem.Allocator, val: anytype) void {
|
||||||
|
|
Loading…
Reference in a new issue