215 lines
6.9 KiB
Zig
215 lines
6.9 KiB
Zig
const std = @import("std");
|
|
const sql = @import("sql");
|
|
const models = @import("./db/models.zig");
|
|
const migrations = @import("./db/migrations.zig");
|
|
const util = @import("util");
|
|
|
|
const Uuid = util.Uuid;
|
|
const DateTime = util.DateTime;
|
|
const String = []const u8;
|
|
const comptimePrint = std.fmt.comptimePrint;
|
|
|
|
fn readRow(comptime RowTuple: type, row: sql.Row, allocator: ?std.mem.Allocator) !RowTuple {
|
|
var result: RowTuple = undefined;
|
|
// TODO: undo allocations on failure
|
|
inline for (std.meta.fields(RowTuple)) |f, i| {
|
|
@field(result, f.name) = try getAlloc(row, f.field_type, i, allocator);
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
pub fn ResultSet(comptime result_types: []const type) type {
|
|
return struct {
|
|
pub const Row = std.meta.Tuple(result_types);
|
|
|
|
_stmt: sql.PreparedStmt,
|
|
err: ?ExecError = null,
|
|
|
|
pub fn finish(self: *@This()) void {
|
|
self._stmt.finalize();
|
|
}
|
|
|
|
pub fn row(self: *@This(), allocator: ?std.mem.Allocator) ?Row {
|
|
const sql_result = self._stmt.step() catch |err| {
|
|
self.err = err;
|
|
return null;
|
|
};
|
|
|
|
if (sql_result) |sql_row| {
|
|
return readRow(Row, sql_row, allocator) catch |err| {
|
|
self.err = err;
|
|
return null;
|
|
};
|
|
} else return null;
|
|
}
|
|
};
|
|
}
|
|
|
|
// Binds a value to a parameter in the query. Use this instead of string
|
|
// concatenation to avoid injection attacks;
|
|
// If a given type is not supported by this function, you can add support by
|
|
// declaring a method with the given signature:
|
|
// pub fn bindToSql(val: T, stmt: sql.PreparedStmt, idx: u15) !void
|
|
// TODO define what error set this ^ should return
|
|
fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
|
|
if (comptime std.meta.trait.isZigString(@TypeOf(val))) return stmt.bindText(idx, val);
|
|
|
|
return switch (@TypeOf(val)) {
|
|
i64 => stmt.bindI64(idx, val),
|
|
Uuid => stmt.bindUuid(idx, val),
|
|
DateTime => stmt.bindDateTime(idx, val),
|
|
@TypeOf(null) => stmt.bindNull(idx),
|
|
else => |T| switch (@typeInfo(T)) {
|
|
.Optional => if (val) |v| bind(stmt, idx, v) else stmt.bindNull(idx),
|
|
.Enum => stmt.bindText(idx, @tagName(val)),
|
|
.Struct, .Union, .Opaque => if (@hasDecl(T, "bindToSql"))
|
|
val.bindToSql(stmt, idx)
|
|
else
|
|
@compileError("unsupported type " ++ @typeName(T)),
|
|
.Int => stmt.bindI64(idx, @intCast(i64, val)),
|
|
else => @compileError("unsupported type " ++ @typeName(T)),
|
|
},
|
|
};
|
|
}
|
|
|
|
// Gets a value from the row, allocating memory if necessary.
|
|
// If a given type is not supported by this function, you can add support by
|
|
// declaring a method with the given signature:
|
|
// pub fn getFromSql(row: sql.Row, idx: u15, alloc: std.mem.Allocator) !T
|
|
// TODO define what error set this ^ should return
|
|
fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: ?std.mem.Allocator) !T {
|
|
return switch (T) {
|
|
[]u8, []const u8 => row.getTextAlloc(idx, alloc orelse return error.AllocatorRequired),
|
|
i64 => row.getI64(idx),
|
|
Uuid => row.getUuid(idx),
|
|
DateTime => row.getDateTime(idx),
|
|
|
|
else => switch (@typeInfo(T)) {
|
|
.Optional => if (try row.isNull(idx))
|
|
null
|
|
else
|
|
try getAlloc(row, std.meta.Child(T), idx, alloc),
|
|
|
|
.Struct, .Union, .Opaque => if (@hasDecl(T, "getFromSql"))
|
|
T.getFromSql(row, idx, alloc)
|
|
else
|
|
@compileError("unknown type " ++ @typeName(T)),
|
|
|
|
.Enum => try getEnum(row, T, idx),
|
|
|
|
.Int => @intCast(T, try row.getI64(idx)),
|
|
|
|
//else => unreachable,
|
|
else => @compileError("unknown type " ++ @typeName(T)),
|
|
},
|
|
};
|
|
}
|
|
|
|
fn maxTagLen(comptime T: type) usize {
|
|
var max: usize = 0;
|
|
for (std.meta.fields(T)) |f| {
|
|
if (f.name.len > max) {
|
|
max = f.name.len;
|
|
}
|
|
}
|
|
return max;
|
|
}
|
|
|
|
fn getEnum(row: sql.Row, comptime T: type, idx: u15) !T {
|
|
var tag_buf: [maxTagLen(T)]u8 = undefined;
|
|
const tag_name = try row.getText(idx, &tag_buf);
|
|
inline for (std.meta.fields(T)) |tag| {
|
|
if (std.mem.eql(u8, tag_name, tag.name)) return @intToEnum(T, tag.value);
|
|
}
|
|
|
|
return error.UnknownTag;
|
|
}
|
|
|
|
pub const ExecError = sql.PrepareError || sql.RowGetError || sql.BindError || std.mem.Allocator.Error || error{ AllocatorRequired, UnknownTag };
|
|
|
|
pub const Database = struct {
|
|
db: sql.Sqlite,
|
|
|
|
pub fn init(file_path: [:0]const u8) !Database {
|
|
var db = try sql.Sqlite.open(file_path);
|
|
errdefer db.close();
|
|
|
|
try migrations.up(&db);
|
|
|
|
return Database{ .db = db };
|
|
}
|
|
|
|
pub fn deinit(self: *Database) void {
|
|
self.db.close();
|
|
}
|
|
|
|
pub fn exec(
|
|
self: *Database,
|
|
comptime result_types: []const type,
|
|
comptime q: []const u8,
|
|
args: anytype,
|
|
) ExecError!ResultSet(result_types) {
|
|
std.log.debug("executing sql:\n===\n{s}\n===", .{q});
|
|
|
|
const stmt = try self.db.prepare(q);
|
|
errdefer stmt.finalize();
|
|
|
|
inline for (std.meta.fields(@TypeOf(args))) |field, i| {
|
|
try bind(stmt, @intCast(u15, i + 1), @field(args, field.name));
|
|
}
|
|
|
|
return ResultSet(result_types){
|
|
._stmt = stmt,
|
|
};
|
|
}
|
|
|
|
pub fn execRow(
|
|
self: *Database,
|
|
comptime result_types: []const type,
|
|
comptime q: []const u8,
|
|
args: anytype,
|
|
allocator: ?std.mem.Allocator,
|
|
) ExecError!?ResultSet(result_types).Row {
|
|
var results = try self.exec(result_types, q, args);
|
|
defer results.finish();
|
|
|
|
const row = results.row(allocator);
|
|
std.log.debug("done exec", .{});
|
|
if (row) |r| return r;
|
|
if (results.err) |err| {
|
|
std.log.debug("{}", .{err});
|
|
std.log.debug("{?}", .{@errorReturnTrace()});
|
|
return err;
|
|
}
|
|
return null;
|
|
}
|
|
|
|
fn build_field_list(comptime T: type, comptime placeholder: ?[]const u8) []const u8 {
|
|
comptime {
|
|
const joiner = ",";
|
|
var result: []const u8 = "";
|
|
inline for (std.meta.fields(T)) |f| {
|
|
result = result ++ joiner ++ (placeholder orelse f.name);
|
|
}
|
|
|
|
return "(" ++ result[joiner.len..] ++ ")";
|
|
}
|
|
}
|
|
|
|
pub fn insert(
|
|
self: *Database,
|
|
comptime table: []const u8,
|
|
value: anytype,
|
|
) ExecError!void {
|
|
const ValueType = comptime @TypeOf(value);
|
|
const table_spec = comptime table ++ build_field_list(ValueType, null);
|
|
const value_spec = comptime build_field_list(ValueType, "?");
|
|
const q = comptime std.fmt.comptimePrint(
|
|
"INSERT INTO {s} VALUES {s}",
|
|
.{ table_spec, value_spec },
|
|
);
|
|
_ = try self.execRow(&.{}, q, value, null);
|
|
}
|
|
};
|