Move to new DB api

This commit is contained in:
jaina heartles 2022-09-14 18:12:07 -07:00
parent db225b6689
commit 33cf0ff87a
18 changed files with 258 additions and 339 deletions

View file

@ -1,9 +1,10 @@
const std = @import("std");
const util = @import("util");
const builtin = @import("builtin");
const sql = @import("sql");
const db = @import("./db.zig");
const models = @import("./db/models.zig");
const migrations = @import("./migrations.zig");
pub const DateTime = util.DateTime;
pub const Uuid = util.Uuid;
const Config = @import("./main.zig").Config;
@ -28,7 +29,7 @@ pub const InviteRequest = struct {
name: ?[]const u8 = null,
expires_at: ?DateTime = null, // TODO: Change this to lifespan
max_uses: ?usize = null,
max_uses: ?u16 = null,
invite_type: Type = .user, // must be user unless the creator is an admin
to_community: ?[]const u8 = null, // only valid on admin community
@ -94,40 +95,41 @@ pub fn getRandom() std.rand.Random {
}
pub const ApiSource = struct {
db: db.Database,
db: sql.Db,
internal_alloc: std.mem.Allocator,
config: Config,
pub const Conn = ApiConn(db.Database);
pub const Conn = ApiConn(sql.Db);
const root_username = "root";
pub fn init(alloc: std.mem.Allocator, cfg: Config, root_password: ?[]const u8) !ApiSource {
pub fn init(alloc: std.mem.Allocator, cfg: Config, root_password: ?[]const u8, db_conn: sql.Db) !ApiSource {
var self = ApiSource{
.db = try db.Database.init(cfg.db.sqlite.db_file),
.db = db_conn,
.internal_alloc = alloc,
.config = cfg,
};
try migrations.up(db_conn);
if ((try services.users.lookupByUsername(&self.db, root_username, null)) == null) {
std.log.info("No cluster root user detected. Creating...", .{});
// TODO: Fix this
const password = root_password orelse return error.NeedRootPassword;
std.debug.print("\npassword: {s}\n", .{password});
var arena = std.heap.ArenaAllocator.init(alloc);
defer arena.deinit();
const user_id = try services.users.create(&self.db, root_username, password, null, .{}, arena.allocator());
std.debug.print("Created {s} ID {}", .{ root_username, user_id });
std.log.debug("Created {s} ID {}", .{ root_username, user_id });
}
return self;
}
fn getCommunityFromHost(self: *ApiSource, host: []const u8) !?Uuid {
if (try self.db.execRow(
if (try self.db.queryRow(
&.{Uuid},
"SELECT id FROM community WHERE host = ?",
"SELECT id FROM community WHERE host = $1",
.{host},
null,
)) |result| return result[0];
@ -204,9 +206,9 @@ fn ApiConn(comptime DbConn: type) type {
};
pub fn getTokenInfo(self: *Self) !TokenInfo {
if (self.user_id) |user_id| {
const result = (try self.db.execRow(
const result = (try self.db.queryRow(
&.{[]const u8},
"SELECT username FROM user WHERE id = ?",
"SELECT username FROM user WHERE id = $1",
.{user_id},
self.arena.allocator(),
)) orelse {

View file

@ -23,9 +23,9 @@ pub const passwords = struct {
};
pub fn verify(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) VerifyError!void {
// TODO: This could be done w/o the dynamically allocated hash buf
const hash = (db.execRow(
const hash = (db.queryRow(
&.{[]const u8},
"SELECT hashed_password FROM account_password WHERE user_id = ? LIMIT 1",
"SELECT hashed_password FROM account_password WHERE user_id = $1 LIMIT 1",
.{user_id},
alloc,
) catch return error.DbError) orelse return error.InvalidLogin;
@ -95,11 +95,11 @@ pub const tokens = struct {
}
fn lookupUserTokenFromHash(db: anytype, hash: []const u8, community_id: Uuid) !?Token.Info {
return if (try db.execRow(
return if (try db.queryRow(
&.{ Uuid, DateTime },
\\SELECT user.id, token.issued_at
\\FROM token JOIN user ON token.user_id = user.id
\\WHERE user.community_id = ? AND token.hash = ?
\\WHERE user.community_id = $1 AND token.hash = $2
\\LIMIT 1
,
.{ community_id, hash },
@ -114,11 +114,11 @@ pub const tokens = struct {
}
fn lookupSystemTokenFromHash(db: anytype, hash: []const u8) !?Token.Info {
return if (try db.execRow(
return if (try db.queryRow(
&.{ Uuid, DateTime },
\\SELECT user.id, token.issued_at
\\FROM token JOIN user ON token.user_id = user.id
\\WHERE user.community_id IS NULL AND token.hash = ?
\\WHERE user.community_id IS NULL AND token.hash = $1
\\LIMIT 1
,
.{hash},

View file

@ -3,8 +3,6 @@ const builtin = @import("builtin");
const util = @import("util");
const models = @import("../db/models.zig");
const DbError = @import("../db.zig").ExecError;
const getRandom = @import("../api.zig").getRandom;
const Uuid = util.Uuid;
@ -14,7 +12,7 @@ const CreateError = error{
InvalidOrigin,
UnsupportedScheme,
CommunityExists,
} || DbError;
} || anyerror; // TODO
pub const Scheme = enum {
https,
@ -76,7 +74,7 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co
.created_at = DateTime.now(),
};
if ((try db.execRow(&.{Uuid}, "SELECT id FROM community WHERE host = ?", .{host}, null)) != null) {
if ((try db.queryRow(&.{Uuid}, "SELECT id FROM community WHERE host = $1", .{host}, null)) != null) {
return error.CommunityExists;
}
@ -94,7 +92,7 @@ fn firstIndexOf(str: []const u8, ch: u8) ?usize {
}
pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) !Community {
const result = (try db.execRow(&.{ Uuid, ?Uuid, []const u8, []const u8, Scheme, DateTime }, "SELECT id, owner_id, host, name, scheme, created_at FROM community WHERE host = ?", .{host}, alloc)) orelse return error.NotFound;
const result = (try db.queryRow(&.{ Uuid, ?Uuid, []const u8, []const u8, Scheme, DateTime }, "SELECT id, owner_id, host, name, scheme, created_at FROM community WHERE host = $1", .{host}, alloc)) orelse return error.NotFound;
return Community{
.id = result[0],
@ -107,7 +105,7 @@ pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) !Commu
}
pub fn transferOwnership(db: anytype, community_id: Uuid, new_owner: Uuid) !void {
_ = try db.execRow(&.{i64}, "UPDATE community SET owner_id = ? WHERE id = ?", .{ new_owner, community_id }, null);
try db.exec("UPDATE community SET owner_id = $1 WHERE id = $2", .{ new_owner, community_id }, null);
}
pub const QueryArgs = struct {
@ -238,7 +236,7 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) ![]Communit
max_items,
};
var results = try db.exec(
var results = try db.query(
&.{ Uuid, ?Uuid, []const u8, []const u8, Scheme, DateTime },
builder.array.items,
query_args,

View file

@ -2,7 +2,6 @@ const std = @import("std");
const builtin = @import("builtin");
const util = @import("util");
const models = @import("../db/models.zig");
const DbError = @import("../db.zig").ExecError;
const getRandom = @import("../api.zig").getRandom;
const Uuid = util.Uuid;
@ -31,6 +30,7 @@ fn defaultJsonStringify(comptime T: type) fn (T, std.json.StringifyOptions, anyt
}.jsonStringify;
}
const InviteCount = u16;
pub const Invite = struct {
id: Uuid,
@ -40,10 +40,10 @@ pub const Invite = struct {
code: []const u8,
created_at: DateTime,
times_used: usize,
times_used: InviteCount,
expires_at: ?DateTime,
max_uses: ?usize,
max_uses: ?InviteCount,
invite_type: InviteType,
};
@ -59,7 +59,7 @@ const DbModel = struct {
created_at: DateTime,
expires_at: ?DateTime,
max_uses: ?usize,
max_uses: ?InviteCount,
@"type": InviteType,
};
@ -72,7 +72,7 @@ fn cloneStr(str: []const u8, alloc: std.mem.Allocator) ![]const u8 {
pub const InviteOptions = struct {
name: ?[]const u8 = null,
max_uses: ?usize = null,
max_uses: ?InviteCount = null,
expires_at: ?DateTime = null,
invite_type: InviteType = .user,
};
@ -130,14 +130,14 @@ pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: Invit
pub fn getByCode(db: anytype, code: []const u8, alloc: std.mem.Allocator) !Invite {
const code_clone = try cloneStr(code, alloc);
const info = (try db.execRow(&.{ Uuid, Uuid, Uuid, []const u8, DateTime, ?DateTime, usize, ?usize, InviteType },
const info = (try db.queryRow(&.{ Uuid, Uuid, Uuid, []const u8, DateTime, ?DateTime, InviteCount, ?InviteCount, InviteType },
\\SELECT
\\ invite.id, invite.created_by, invite.to_community, invite.name,
\\ invite.created_at, invite.expires_at,
\\ COUNT(local_user.user_id) as uses, invite.max_uses,
\\ invite.type
\\FROM invite LEFT OUTER JOIN local_user ON invite.id = local_user.invite_id
\\WHERE invite.code = ?
\\WHERE invite.code = $1
\\GROUP BY invite.id
, .{code}, alloc)) orelse return error.NotFound;

View file

@ -40,11 +40,11 @@ pub fn create(
}
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !Note {
const result = (try db.execRow(
const result = (try db.queryRow(
&.{ Uuid, []const u8, DateTime },
\\SELECT author_id, content, created_at
\\FROM note
\\WHERE id = ?
\\WHERE id = $1
\\LIMIT 1
,
.{id},

View file

@ -37,9 +37,9 @@ pub const CreateOptions = struct {
};
fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid {
return if (try db.execRow(
return if (try db.queryRow(
&.{Uuid},
"SELECT user.id FROM user WHERE community_id IS NULL AND username = ?",
"SELECT user.id FROM user WHERE community_id IS NULL AND username = $1",
.{username},
null,
)) |result|
@ -49,9 +49,9 @@ fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid {
}
fn lookupUserByUsername(db: anytype, username: []const u8, community_id: Uuid) !?Uuid {
return if (try db.execRow(
return if (try db.queryRow(
&.{Uuid},
"SELECT user.id FROM user WHERE community_id = ? AND username = ?",
"SELECT user.id FROM user WHERE community_id = $1 AND username = $2",
.{ community_id, username },
null,
)) |result|
@ -107,11 +107,11 @@ pub const User = struct {
};
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !User {
const result = (try db.execRow(
const result = (try db.queryRow(
&.{ []const u8, []const u8, Uuid, DateTime },
\\SELECT user.username, community.host, community.id, user.created_at
\\FROM user JOIN community ON user.community_id = community.id
\\WHERE user.id = ?
\\WHERE user.id = $1
\\LIMIT 1
,
.{id},

View file

@ -14,6 +14,7 @@ pub const login = struct {
pub const path = "/auth/login";
pub const method = .POST;
pub fn handler(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void {
std.debug.print("{s}", .{ctx.request.body.?});
const credentials = try utils.parseRequestBody(struct { username: []const u8, password: []const u8 }, ctx);
defer utils.freeRequestBody(credentials, ctx.alloc);

View file

@ -1,214 +0,0 @@
const std = @import("std");
const sql = @import("sql");
const models = @import("./db/models.zig");
const migrations = @import("./db/migrations.zig");
const util = @import("util");
const Uuid = util.Uuid;
const DateTime = util.DateTime;
const String = []const u8;
const comptimePrint = std.fmt.comptimePrint;
fn readRow(comptime RowTuple: type, row: sql.Row, allocator: ?std.mem.Allocator) !RowTuple {
var result: RowTuple = undefined;
// TODO: undo allocations on failure
inline for (std.meta.fields(RowTuple)) |f, i| {
@field(result, f.name) = try getAlloc(row, f.field_type, i, allocator);
}
return result;
}
pub fn ResultSet(comptime result_types: []const type) type {
return struct {
pub const Row = std.meta.Tuple(result_types);
_stmt: sql.PreparedStmt,
err: ?ExecError = null,
pub fn finish(self: *@This()) void {
self._stmt.finalize();
}
pub fn row(self: *@This(), allocator: ?std.mem.Allocator) ?Row {
const sql_result = self._stmt.step() catch |err| {
self.err = err;
return null;
};
if (sql_result) |sql_row| {
return readRow(Row, sql_row, allocator) catch |err| {
self.err = err;
return null;
};
} else return null;
}
};
}
// Binds a value to a parameter in the query. Use this instead of string
// concatenation to avoid injection attacks;
// If a given type is not supported by this function, you can add support by
// declaring a method with the given signature:
// pub fn bindToSql(val: T, stmt: sql.PreparedStmt, idx: u15) !void
// TODO define what error set this ^ should return
fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
if (comptime std.meta.trait.isZigString(@TypeOf(val))) return stmt.bindText(idx, val);
return switch (@TypeOf(val)) {
i64 => stmt.bindI64(idx, val),
Uuid => stmt.bindUuid(idx, val),
DateTime => stmt.bindDateTime(idx, val),
@TypeOf(null) => stmt.bindNull(idx),
else => |T| switch (@typeInfo(T)) {
.Optional => if (val) |v| bind(stmt, idx, v) else stmt.bindNull(idx),
.Enum => stmt.bindText(idx, @tagName(val)),
.Struct, .Union, .Opaque => if (@hasDecl(T, "bindToSql"))
val.bindToSql(stmt, idx)
else
@compileError("unsupported type " ++ @typeName(T)),
.Int => stmt.bindI64(idx, @intCast(i64, val)),
else => @compileError("unsupported type " ++ @typeName(T)),
},
};
}
// Gets a value from the row, allocating memory if necessary.
// If a given type is not supported by this function, you can add support by
// declaring a method with the given signature:
// pub fn getFromSql(row: sql.Row, idx: u15, alloc: std.mem.Allocator) !T
// TODO define what error set this ^ should return
fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: ?std.mem.Allocator) !T {
return switch (T) {
[]u8, []const u8 => row.getTextAlloc(idx, alloc orelse return error.AllocatorRequired),
i64 => row.getI64(idx),
Uuid => row.getUuid(idx),
DateTime => row.getDateTime(idx),
else => switch (@typeInfo(T)) {
.Optional => if (try row.isNull(idx))
null
else
try getAlloc(row, std.meta.Child(T), idx, alloc),
.Struct, .Union, .Opaque => if (@hasDecl(T, "getFromSql"))
T.getFromSql(row, idx, alloc)
else
@compileError("unknown type " ++ @typeName(T)),
.Enum => try getEnum(row, T, idx),
.Int => @intCast(T, try row.getI64(idx)),
//else => unreachable,
else => @compileError("unknown type " ++ @typeName(T)),
},
};
}
fn maxTagLen(comptime T: type) usize {
var max: usize = 0;
for (std.meta.fields(T)) |f| {
if (f.name.len > max) {
max = f.name.len;
}
}
return max;
}
fn getEnum(row: sql.Row, comptime T: type, idx: u15) !T {
var tag_buf: [maxTagLen(T)]u8 = undefined;
const tag_name = try row.getText(idx, &tag_buf);
inline for (std.meta.fields(T)) |tag| {
if (std.mem.eql(u8, tag_name, tag.name)) return @intToEnum(T, tag.value);
}
return error.UnknownTag;
}
pub const ExecError = sql.PrepareError || sql.RowGetError || sql.BindError || std.mem.Allocator.Error || error{ AllocatorRequired, UnknownTag };
pub const Database = struct {
db: sql.Sqlite,
pub fn init(file_path: [:0]const u8) !Database {
var db = try sql.Sqlite.open(file_path);
errdefer db.close();
try migrations.up(&db);
return Database{ .db = db };
}
pub fn deinit(self: *Database) void {
self.db.close();
}
pub fn exec(
self: *Database,
comptime result_types: []const type,
comptime q: []const u8,
args: anytype,
) ExecError!ResultSet(result_types) {
std.log.debug("executing sql:\n===\n{s}\n===", .{q});
const stmt = try self.db.prepare(q);
errdefer stmt.finalize();
inline for (std.meta.fields(@TypeOf(args))) |field, i| {
try bind(stmt, @intCast(u15, i + 1), @field(args, field.name));
}
return ResultSet(result_types){
._stmt = stmt,
};
}
pub fn execRow(
self: *Database,
comptime result_types: []const type,
comptime q: []const u8,
args: anytype,
allocator: ?std.mem.Allocator,
) ExecError!?ResultSet(result_types).Row {
var results = try self.exec(result_types, q, args);
defer results.finish();
const row = results.row(allocator);
std.log.debug("done exec", .{});
if (row) |r| return r;
if (results.err) |err| {
std.log.debug("{}", .{err});
std.log.debug("{?}", .{@errorReturnTrace()});
return err;
}
return null;
}
fn build_field_list(comptime T: type, comptime placeholder: ?[]const u8) []const u8 {
comptime {
const joiner = ",";
var result: []const u8 = "";
inline for (std.meta.fields(T)) |f| {
result = result ++ joiner ++ (placeholder orelse f.name);
}
return "(" ++ result[joiner.len..] ++ ")";
}
}
pub fn insert(
self: *Database,
comptime table: []const u8,
value: anytype,
) ExecError!void {
const ValueType = comptime @TypeOf(value);
const table_spec = comptime table ++ build_field_list(ValueType, null);
const value_spec = comptime build_field_list(ValueType, "?");
const q = comptime std.fmt.comptimePrint(
"INSERT INTO {s} VALUES {s}",
.{ table_spec, value_spec },
);
_ = try self.execRow(&.{}, q, value, null);
}
};

View file

@ -1,5 +1,6 @@
const std = @import("std");
const builtin = @import("builtin");
const sql = @import("sql");
const http = @import("http");
const util = @import("util");
@ -82,11 +83,7 @@ pub const RequestServer = struct {
pub const Config = struct {
cluster_host: []const u8,
db: struct {
sqlite: struct {
db_file: [:0]const u8,
},
},
db: sql.Config,
root_password: ?[]const u8 = null,
};
@ -105,7 +102,8 @@ const root_password_envvar = "CLUSTER_ROOT_PASSWORD";
pub fn main() anyerror!void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
var cfg = try loadConfig(gpa.allocator());
var api_src = api.ApiSource.init(gpa.allocator(), cfg, std.os.getenv(root_password_envvar)) catch |err| switch (err) {
var db_conn = try sql.Db.open(cfg.db);
var api_src = api.ApiSource.init(gpa.allocator(), cfg, std.os.getenv(root_password_envvar), db_conn) catch |err| switch (err) {
error.NeedRootPassword => {
std.log.err(
"No root user created and no password specified. Please provide the password for the root user by the ${s} environment variable for initial startup. This only needs to be done once",

View file

@ -1,8 +1,9 @@
const std = @import("std");
const sql = @import("sql");
const DateTime = @import("util").DateTime;
pub const Migration = struct {
name: []const u8,
name: [:0]const u8,
up: []const u8,
down: []const u8,
};
@ -15,54 +16,44 @@ fn firstIndexOf(str: []const u8, char: u8) ?usize {
return null;
}
fn execStmt(db: *sql.Sqlite, stmt_sql: []const u8) !void {
const stmt = try db.prepare(stmt_sql);
defer stmt.finalize();
while (try stmt.step()) |_| {}
fn execStmt(tx: sql.Tx, stmt: []const u8, alloc: std.mem.Allocator) !void {
const stmt_null = try std.cstr.addNullByte(alloc, stmt);
defer alloc.free(stmt_null);
try tx.exec(stmt_null, .{}, null);
}
fn execScript(db: *sql.Sqlite, script: []const u8) !void {
try execStmt(db, "BEGIN;");
errdefer {
_ = execStmt(db, "ROLLBACK;") catch unreachable;
}
fn execScript(db: sql.Db, script: []const u8, alloc: std.mem.Allocator) !void {
const tx = try db.begin();
errdefer tx.rollback();
var remaining = script;
while (firstIndexOf(remaining, ';')) |last| {
try execStmt(db, remaining[0 .. last + 1]);
try execStmt(tx, remaining[0 .. last + 1], alloc);
remaining = remaining[last + 1 ..];
}
if (remaining.len > 1) try execStmt(tx, remaining, alloc);
try execStmt(db, "COMMIT;");
try tx.commit();
}
fn wasMigrationRan(db: *sql.Sqlite, name: []const u8) !bool {
const stmt = try db.prepare("SELECT COUNT(*) FROM migration WHERE name = ?;");
defer stmt.finalize();
try stmt.bindText(1, name);
const result = (try stmt.step()).?;
const count = try result.getI64(0);
return count != 0;
fn wasMigrationRan(db: sql.Db, name: []const u8, alloc: std.mem.Allocator) !bool {
const row = (try db.queryRow(&.{i32}, "SELECT COUNT(*) FROM migration WHERE name = $1", .{name}, alloc)) orelse return false;
return row[0] != 0;
}
fn markMigrationAsRan(db: *sql.Sqlite, name: []const u8) !void {
const stmt = try db.prepare("INSERT INTO migration(name) VALUES(?);");
defer stmt.finalize();
try stmt.bindText(1, name);
_ = try stmt.step();
}
pub fn up(db: *sql.Sqlite) !void {
try execScript(db, create_migration_table);
pub fn up(db: sql.Db) !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
std.log.info("Running migrations...", .{});
try execScript(db, create_migration_table, gpa.allocator());
for (migrations) |migration| {
if (!try wasMigrationRan(db, migration.name)) {
try execScript(db, migration.up);
try markMigrationAsRan(db, migration.name);
const was_ran = try wasMigrationRan(db, migration.name, gpa.allocator());
if (!was_ran) {
std.log.info("Running migration {s}", .{migration.name});
try execScript(db, migration.up, gpa.allocator());
try db.insert("migration", .{ .name = migration.name });
}
}
}
@ -71,7 +62,7 @@ const create_migration_table =
\\CREATE TABLE IF NOT EXISTS
\\migration(
\\ name TEXT NOT NULL PRIMARY KEY,
\\ applied_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
\\ applied_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\);
;
@ -85,7 +76,7 @@ const migrations: []const Migration = &.{
\\ id TEXT NOT NULL PRIMARY KEY,
\\ username TEXT NOT NULL,
\\
\\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\);
\\
\\CREATE TABLE local_user(
@ -115,7 +106,7 @@ const migrations: []const Migration = &.{
\\ content TEXT NOT NULL,
\\ author_id TEXT NOT NULL REFERENCES user(id),
\\
\\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\);
,
.down = "DROP TABLE note;",
@ -129,7 +120,7 @@ const migrations: []const Migration = &.{
\\ user_id TEXT NOT NULL REFERENCES user(id),
\\ note_id TEXT NOT NULL REFERENCES note(id),
\\
\\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\);
,
.down = "DROP TABLE reaction;",
@ -141,7 +132,7 @@ const migrations: []const Migration = &.{
\\ hash TEXT NOT NULL PRIMARY KEY,
\\ user_id TEXT NOT NULL REFERENCES local_user(id),
\\
\\ issued_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
\\ issued_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\);
,
.down = "DROP TABLE token;",
@ -158,8 +149,8 @@ const migrations: []const Migration = &.{
\\
\\ max_uses INTEGER,
\\
\\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
\\ expires_at DATETIME,
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
\\ expires_at TIMESTAMPTZ,
\\
\\ type TEXT NOT NULL CHECK (type in ('system', 'community_owner', 'user'))
\\);
@ -181,7 +172,7 @@ const migrations: []const Migration = &.{
\\ host TEXT NOT NULL UNIQUE,
\\ scheme TEXT NOT NULL CHECK (scheme IN ('http', 'https')),
\\
\\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\);
\\ALTER TABLE user ADD COLUMN community_id TEXT REFERENCES community(id);
\\ALTER TABLE invite ADD COLUMN to_community TEXT REFERENCES community(id);

View file

@ -5,10 +5,30 @@ const Uuid = util.Uuid;
const DateTime = util.DateTime;
const Allocator = std.mem.Allocator;
pub const QueryOptions = struct {
// If true, then it will not return an error on the SQLite backend
// if an argument passed does not map to a parameter in the query.
// Has no effect on the postgres backend.
ignore_unknown_parameters: bool = false,
// The allocator to use for query preparation and submission.
// All memory allocated with this allocator will be freed before results
// are retrieved.
// Some types (enums with constant representation, null terminated strings)
// do not require allocators for prep. If an allocator is needed but not
// provided, `error.AllocatorRequired` will be returned.
// Only used with the postgres backend.
prep_allocator: ?Allocator = null,
};
// Turns a value into its appropriate textual value (or null)
// as appropriate using the given arena allocator
pub fn prepareParamText(arena: std.heap.ArenaAllocator, val: anytype) !?[:0]const u8 {
if (comptime std.meta.trait.isZigString(@TypeOf(val))) return val;
pub fn prepareParamText(arena: *std.heap.ArenaAllocator, val: anytype) !?[:0]const u8 {
if (comptime std.meta.trait.isZigString(@TypeOf(val))) {
if (comptime std.meta.sentinel(@TypeOf(val))) |s| if (comptime s == 0) return val;
return try std.cstr.addNullByte(arena.allocator(), val);
}
return switch (@TypeOf(val)) {
[:0]u8, [:0]const u8 => val,
@ -36,7 +56,9 @@ pub fn parseValueNotNull(alloc: ?Allocator, comptime T: type, str: []const u8) !
[]u8, []const u8 => if (alloc) |a| util.deepClone(a, str) else return error.AllocatorRequired,
else => switch (@typeInfo(T)) {
.Enum => parseEnum(T, str),
.Int => std.fmt.parseInt(T, str, 0),
.Enum => std.meta.stringToEnum(T, str) orelse return error.InvalidValue,
.Optional => try parseValueNotNull(alloc, std.meta.Child(T), str),
else => @compileError("Type " ++ @typeName(T) ++ " not supported"),
},

View file

@ -3,8 +3,11 @@ const util = @import("util");
const postgres = @import("./postgres.zig");
const sqlite = @import("./sqlite.zig");
const common = @import("./common.zig");
const Allocator = std.mem.Allocator;
pub const QueryOptions = common.QueryOptions;
pub const Type = enum {
postgres,
sqlite,
@ -12,10 +15,10 @@ pub const Type = enum {
pub const Config = union(Type) {
postgres: struct {
conn_str: [:0]const u8,
pg_conn_str: [:0]const u8,
},
sqlite: struct {
file_path: [:0]const u8,
sqlite_file_path: [:0]const u8,
},
};
@ -106,12 +109,12 @@ pub const Db = struct {
return switch (cfg) {
.postgres => |postgres_cfg| Db{
.underlying = .{
.postgres = try postgres.Db.open(postgres_cfg.conn_str),
.postgres = try postgres.Db.open(postgres_cfg.pg_conn_str),
},
},
.sqlite => |lite_cfg| Db{
.underlying = .{
.sqlite = try sqlite.Db.open(lite_cfg.file_path),
.sqlite = try sqlite.Db.open(lite_cfg.sqlite_file_path),
},
},
};
@ -124,6 +127,17 @@ pub const Db = struct {
}
}
pub fn queryWithOptions(
self: Db,
comptime result_types: []const type,
sql: [:0]const u8,
args: anytype,
opt: QueryOptions,
) !Results(result_types) {
// Create fake transaction to use its functions
return (Tx{ .underlying = self.underlying }).queryWithOptions(result_types, sql, args, opt);
}
pub fn query(
self: Db,
comptime result_types: []const type,
@ -153,7 +167,7 @@ pub const Db = struct {
alloc: ?Allocator,
) !?Results(result_types).RowTuple {
// Create fake transaction to use its functions
return (Tx{ .underlying = self.underlying }).exec(sql, args, alloc);
return (Tx{ .underlying = self.underlying }).queryRow(result_types, sql, args, alloc);
}
pub fn insert(
@ -182,14 +196,24 @@ pub const Tx = struct {
self: Tx,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
opt: QueryOptions,
) !RawResults {
return switch (self.underlying) {
.postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, alloc) },
.sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args) },
.postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt.prep_allocator) },
.sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args, opt) },
};
}
pub fn queryWithOptions(
self: Tx,
comptime result_types: []const type,
sql: [:0]const u8,
args: anytype,
options: QueryOptions,
) !Results(result_types) {
return Results(result_types){ .underlying = try self.queryInternal(sql, args, options) };
}
// Executes a query and returns the result set
pub fn query(
self: Tx,
@ -198,7 +222,7 @@ pub const Tx = struct {
args: anytype,
alloc: ?Allocator,
) !Results(result_types) {
return Results(result_types){ .unerlying = try self.queryInternal(sql, args, alloc) };
return self.queryWithOptions(result_types, sql, args, .{ .prep_allocator = alloc });
}
// Executes a query without returning results
@ -208,7 +232,7 @@ pub const Tx = struct {
args: anytype,
alloc: ?Allocator,
) !void {
(try self.queryInternal(sql, args, alloc)).finish();
_ = try self.queryRow(&.{}, sql, args, alloc);
}
// Runs a query and returns a single row
@ -242,13 +266,28 @@ pub const Tx = struct {
value: anytype,
) !void {
const ValueType = comptime @TypeOf(value);
const table_spec = comptime table ++ build_field_list(ValueType, null);
const value_spec = comptime build_field_list(ValueType, "?");
const fields = std.meta.fields(ValueType);
comptime var types: [fields.len]type = undefined;
comptime var table_spec: []const u8 = table ++ "(";
comptime var value_spec: []const u8 = "(";
inline for (fields) |field, i| {
types[i] = field.field_type;
table_spec = comptime (table_spec ++ field.name ++ ",");
value_spec = comptime value_spec ++ std.fmt.comptimePrint("${},", .{i + 1});
}
table_spec = comptime table_spec[0 .. table_spec.len - 1] ++ ")";
value_spec = comptime value_spec[0 .. value_spec.len - 1] ++ ")";
const q = comptime std.fmt.comptimePrint(
"INSERT INTO {s} VALUES {s}",
.{ table_spec, value_spec },
);
try self.exec(q, value, null);
var args_tuple: std.meta.Tuple(&types) = undefined;
inline for (fields) |field, i| {
args_tuple[i] = @field(value, field.name);
}
try self.exec(q, args_tuple, null);
}
pub fn rollback(self: Tx) void {
@ -261,15 +300,3 @@ pub const Tx = struct {
try self.exec("COMMIT", .{}, null);
}
};
fn build_field_list(comptime T: type, comptime placeholder: ?[]const u8) []const u8 {
comptime {
const joiner = ",";
var result: []const u8 = "";
inline for (std.meta.fields(T)) |f| {
result = result ++ joiner ++ (placeholder orelse f.name);
}
return "(" ++ result[joiner.len..] ++ ")";
}
}

View file

@ -90,7 +90,7 @@ pub const Db = struct {
var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired);
defer arena.deinit();
const params = try arena.allocator().alloc(?[*]const u8, args.len);
inline for (args) |a, i| params[i] = if (try common.prepareParamText(arena, a)) |slice| slice.ptr else null;
inline for (args) |a, i| params[i] = if (try common.prepareParamText(&arena, a)) |slice| slice.ptr else null;
break :blk c.PQexecParams(self.conn, sql.ptr, @intCast(c_int, params.len), null, params.ptr, null, null, format_text);
} else {

View file

@ -103,7 +103,7 @@ pub const Db = struct {
}
}
pub fn exec(self: Db, sql: []const u8, args: anytype) !Results {
pub fn exec(self: Db, sql: []const u8, args: anytype, opts: common.QueryOptions) !Results {
var stmt: ?*c.sqlite3_stmt = undefined;
switch (c.sqlite3_prepare_v2(self.db, sql.ptr, @intCast(c_int, sql.len), &stmt, null)) {
c.SQLITE_OK => {},
@ -134,7 +134,7 @@ pub const Db = struct {
return handleUnexpectedError(self.db, err, sql);
},
}
} else unreachable;
} else if (!opts.ignore_unknown_parameters) return error.UnknownParameter;
}
return Results{ .stmt = stmt.?, .db = self.db };
@ -216,7 +216,7 @@ pub const Row = struct {
@intCast(T, c.sqlite3_column_int64(self.stmt, idx))
else
self.getFromString(T, idx, alloc),
.Optional => self.getNotNull(std.meta.Child(T), idx, alloc),
.Optional => try self.getNotNull(std.meta.Child(T), idx, alloc),
else => self.getFromString(T, idx, alloc),
},
};

41
src/sql/test.zig Normal file
View file

@ -0,0 +1,41 @@
const sql = @import("./lib.zig");
const std = @import("std");
const Uuid = @import("util").Uuid;
const alloc = std.testing.allocator;
pub fn main() !void {
const db = try sql.Db.open(.{
//.postgres = .{
//.conn_str = "postgresql://localhost",
//},
.sqlite = .{
.file_path = "./test.db",
},
});
defer db.close();
const tx = try db.begin();
try tx.commit();
tx.rollback();
try tx.commit();
}
test {
const db = try sql.Db.open(.{
.sqlite = .{
.file_path = "./test.db",
},
});
defer db.close();
var results = try db.query(&.{[]const u8}, "SELECT $1 as id", .{"abcdefg"}, alloc);
defer results.finish();
const row = (try results.row(alloc)) orelse unreachable;
defer alloc.free(row[0]);
try std.testing.expectEqualStrings("abcdefg", row[0]);
std.log.info("value: {s}", .{row[0]});
}

View file

@ -1,9 +1,48 @@
const DateTime = @This();
const std = @import("std");
const epoch = std.time.epoch;
seconds_since_epoch: i64,
pub fn parse(str: []const u8) !DateTime {
// TODO: Try other formats
return try parseRfc3339(str);
}
// TODO: Validate non-numeric aspects of datetime
// TODO: Don't panic on bad string
// TODO: Make seconds optional (see ActivityStreams 2.0 spec §2.3)
// TODO: Handle times before 1970
pub fn parseRfc3339(str: []const u8) !DateTime {
const year_num = try std.fmt.parseInt(u16, str[0..4], 10);
const month_num = try std.fmt.parseInt(std.meta.Tag(epoch.Month), str[5..7], 10);
const day_num = @as(i64, try std.fmt.parseInt(u9, str[8..10], 10));
const hour_num = @as(i64, try std.fmt.parseInt(u5, str[11..13], 10));
const minute_num = @as(i64, try std.fmt.parseInt(u6, str[14..15], 10));
const second_num = @as(i64, try std.fmt.parseInt(u6, str[16..17], 10));
const is_leap_year = epoch.isLeapYear(year_num);
const leap_days_preceding_epoch = comptime epoch.epoch_year / 4 - epoch.epoch_year / 100 + epoch.epoch_year / 400;
const leap_days_preceding_year = year_num / 4 - year_num / 100 + year_num / 400 - leap_days_preceding_epoch - if (is_leap_year) @as(i64, 1) else 0;
const epoch_day = (year_num - epoch.epoch_year) * 365 + leap_days_preceding_year + year_day: {
var days_preceding_month: i64 = 0;
var month_i: i64 = 1;
while (month_i < month_num) : (month_i += 1) {
days_preceding_month += epoch.getDaysInMonth(if (is_leap_year) .leap else .not_leap, @intToEnum(epoch.Month, month_i));
}
break :year_day days_preceding_month + day_num;
};
const day_second = (hour_num * 60 + minute_num) * 60 + second_num;
return DateTime{
.seconds_since_epoch = epoch_day * epoch.secs_per_day + day_second,
};
}
pub fn now() DateTime {
return .{ .seconds_since_epoch = std.time.timestamp() };
}
@ -40,10 +79,24 @@ pub fn second(value: DateTime) u6 {
return value.epochSeconds().getDaySeconds().getSecondsIntoMinute();
}
const array_len = 20;
pub fn toCharArray(value: DateTime) [array_len + 1]u8 {
var buf: [array_len]u8 = undefined;
_ = std.fmt.bufPrintZ(&buf, "{}", value) catch unreachable;
return buf;
}
pub fn toCharArrayZ(value: DateTime) [array_len + 1:0]u8 {
var buf: [array_len + 1:0]u8 = undefined;
_ = std.fmt.bufPrintZ(&buf, "{}", value) catch unreachable;
return buf;
}
pub fn format(value: DateTime, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
return std.fmt.format(
writer,
"{}-{}-{} {}:{}:{}",
"{:0>4}-{:0>2}-{:0>2}T{:0>2}:{:0>2}:{:0>2}Z",
.{ value.year(), value.month().numeric(), value.day(), value.hour(), value.minute(), value.second() },
);
}

View file

@ -58,7 +58,7 @@ pub const ParseError = error{
};
pub fn parse(str: []const u8) ParseError!Uuid {
if (str.len != string_len) return error.InvalidLength;
if (str.len != string_len and (str.len != string_len + 1 or str[str.len - 1] != 0)) return error.InvalidLength;
var data: [16]u8 = undefined;
var str_i: usize = 0;

View file

@ -24,11 +24,11 @@ pub fn deepFree(alloc: ?std.mem.Allocator, val: anytype) void {
else => @compileError("Many and C-style pointers not supported by deepfree"),
},
.Optional => if (val) |v| deepFree(alloc, v) else {},
.Struct => |struct_info| for (struct_info.fields) |field| deepFree(alloc, @field(val, field.name)),
.Struct => |struct_info| inline for (struct_info.fields) |field| deepFree(alloc, @field(val, field.name)),
.Union, .ErrorUnion => @compileError("TODO: Unions not yet supported by deepFree"),
.Array => for (val) |v| deepFree(alloc, v),
.Int, .Float, .Bool, .Void, .Type => {},
.Enum, .Int, .Float, .Bool, .Void, .Type => {},
else => @compileError("Type " ++ @typeName(T) ++ " not supported by deepFree"),
}
@ -87,7 +87,7 @@ pub fn deepClone(alloc: std.mem.Allocator, val: anytype) !@TypeOf(val) {
count += 1;
}
},
.Int, .Float, .Bool, .Void, .Type => {
.Enum, .Int, .Float, .Bool, .Void, .Type => {
result = val;
},
else => @compileError("Type " ++ @typeName(T) ++ " not supported"),