fediglam/src/main/db.zig

215 lines
6.9 KiB
Zig

const std = @import("std");
const sql = @import("sql");
const models = @import("./db/models.zig");
const migrations = @import("./db/migrations.zig");
const util = @import("util");
const Uuid = util.Uuid;
const DateTime = util.DateTime;
const String = []const u8;
const comptimePrint = std.fmt.comptimePrint;
fn readRow(comptime RowTuple: type, row: sql.Row, allocator: ?std.mem.Allocator) !RowTuple {
var result: RowTuple = undefined;
// TODO: undo allocations on failure
inline for (std.meta.fields(RowTuple)) |f, i| {
@field(result, f.name) = try getAlloc(row, f.field_type, i, allocator);
}
return result;
}
pub fn ResultSet(comptime result_types: []const type) type {
return struct {
pub const Row = std.meta.Tuple(result_types);
_stmt: sql.PreparedStmt,
err: ?ExecError = null,
pub fn finish(self: *@This()) void {
self._stmt.finalize();
}
pub fn row(self: *@This(), allocator: ?std.mem.Allocator) ?Row {
const sql_result = self._stmt.step() catch |err| {
self.err = err;
return null;
};
if (sql_result) |sql_row| {
return readRow(Row, sql_row, allocator) catch |err| {
self.err = err;
return null;
};
} else return null;
}
};
}
// Binds a value to a parameter in the query. Use this instead of string
// concatenation to avoid injection attacks;
// If a given type is not supported by this function, you can add support by
// declaring a method with the given signature:
// pub fn bindToSql(val: T, stmt: sql.PreparedStmt, idx: u15) !void
// TODO define what error set this ^ should return
fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
if (comptime std.meta.trait.isZigString(@TypeOf(val))) return stmt.bindText(idx, val);
return switch (@TypeOf(val)) {
i64 => stmt.bindI64(idx, val),
Uuid => stmt.bindUuid(idx, val),
DateTime => stmt.bindDateTime(idx, val),
@TypeOf(null) => stmt.bindNull(idx),
else => |T| switch (@typeInfo(T)) {
.Optional => if (val) |v| bind(stmt, idx, v) else stmt.bindNull(idx),
.Enum => stmt.bindText(idx, @tagName(val)),
.Struct, .Union, .Opaque => if (@hasDecl(T, "bindToSql"))
val.bindToSql(stmt, idx)
else
@compileError("unsupported type " ++ @typeName(T)),
.Int => stmt.bindI64(idx, @intCast(i64, val)),
else => @compileError("unsupported type " ++ @typeName(T)),
},
};
}
// Gets a value from the row, allocating memory if necessary.
// If a given type is not supported by this function, you can add support by
// declaring a method with the given signature:
// pub fn getFromSql(row: sql.Row, idx: u15, alloc: std.mem.Allocator) !T
// TODO define what error set this ^ should return
fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: ?std.mem.Allocator) !T {
return switch (T) {
[]u8, []const u8 => row.getTextAlloc(idx, alloc orelse return error.AllocatorRequired),
i64 => row.getI64(idx),
Uuid => row.getUuid(idx),
DateTime => row.getDateTime(idx),
else => switch (@typeInfo(T)) {
.Optional => if (try row.isNull(idx))
null
else
try getAlloc(row, std.meta.Child(T), idx, alloc),
.Struct, .Union, .Opaque => if (@hasDecl(T, "getFromSql"))
T.getFromSql(row, idx, alloc)
else
@compileError("unknown type " ++ @typeName(T)),
.Enum => try getEnum(row, T, idx),
.Int => @intCast(T, try row.getI64(idx)),
//else => unreachable,
else => @compileError("unknown type " ++ @typeName(T)),
},
};
}
fn maxTagLen(comptime T: type) usize {
var max: usize = 0;
for (std.meta.fields(T)) |f| {
if (f.name.len > max) {
max = f.name.len;
}
}
return max;
}
fn getEnum(row: sql.Row, comptime T: type, idx: u15) !T {
var tag_buf: [maxTagLen(T)]u8 = undefined;
const tag_name = try row.getText(idx, &tag_buf);
inline for (std.meta.fields(T)) |tag| {
if (std.mem.eql(u8, tag_name, tag.name)) return @intToEnum(T, tag.value);
}
return error.UnknownTag;
}
pub const ExecError = sql.PrepareError || sql.RowGetError || sql.BindError || std.mem.Allocator.Error || error{ AllocatorRequired, UnknownTag };
pub const Database = struct {
db: sql.Sqlite,
pub fn init(file_path: [:0]const u8) !Database {
var db = try sql.Sqlite.open(file_path);
errdefer db.close();
try migrations.up(&db);
return Database{ .db = db };
}
pub fn deinit(self: *Database) void {
self.db.close();
}
pub fn exec(
self: *Database,
comptime result_types: []const type,
comptime q: []const u8,
args: anytype,
) ExecError!ResultSet(result_types) {
std.log.debug("executing sql:\n===\n{s}\n===", .{q});
const stmt = try self.db.prepare(q);
errdefer stmt.finalize();
inline for (std.meta.fields(@TypeOf(args))) |field, i| {
try bind(stmt, @intCast(u15, i + 1), @field(args, field.name));
}
return ResultSet(result_types){
._stmt = stmt,
};
}
pub fn execRow(
self: *Database,
comptime result_types: []const type,
comptime q: []const u8,
args: anytype,
allocator: ?std.mem.Allocator,
) ExecError!?ResultSet(result_types).Row {
var results = try self.exec(result_types, q, args);
defer results.finish();
const row = results.row(allocator);
std.log.debug("done exec", .{});
if (row) |r| return r;
if (results.err) |err| {
std.log.debug("{}", .{err});
std.log.debug("{?}", .{@errorReturnTrace()});
return err;
}
return null;
}
fn build_field_list(comptime T: type, comptime placeholder: ?[]const u8) []const u8 {
comptime {
const joiner = ",";
var result: []const u8 = "";
inline for (std.meta.fields(T)) |f| {
result = result ++ joiner ++ (placeholder orelse f.name);
}
return "(" ++ result[joiner.len..] ++ ")";
}
}
pub fn insert(
self: *Database,
comptime table: []const u8,
value: anytype,
) ExecError!void {
const ValueType = comptime @TypeOf(value);
const table_spec = comptime table ++ build_field_list(ValueType, null);
const value_spec = comptime build_field_list(ValueType, "?");
const q = comptime std.fmt.comptimePrint(
"INSERT INTO {s} VALUES {s}",
.{ table_spec, value_spec },
);
_ = try self.execRow(&.{}, q, value, null);
}
};