This commit is contained in:
jaina heartles 2022-07-30 13:40:20 -07:00
parent c1f8b8f0e2
commit c31633cade
2 changed files with 63 additions and 30 deletions

View file

@ -9,20 +9,26 @@ const DateTime = util.DateTime;
const String = []const u8;
const comptimePrint = std.fmt.comptimePrint;
fn baseTypeName(comptime T: type) []const u8 {
comptime {
const name = @typeName(T);
const start = for (name) |_, i| {
if (name[name.len - i] == '.') break name.len - i;
} else 0;
// Stuff in here introduces compile errors in seemingly unused parts of
// the stdlib. Maybe just broken on this version of zig?
const broken = struct {
fn baseTypeName(comptime T: type) []const u8 {
comptime {
const name = @typeName(T);
const start = for (name) |_, i| {
if (name[name.len - i] == '.') break name.len - i;
} else 0;
return name[start..];
return name[start..];
}
}
}
fn tableNameBroken(comptime T: type) String {
return util.case.pascalToSnake(baseTypeName(T));
}
};
fn tableName(comptime T: type) String {
//return util.case.pascalToSnake(baseTypeName(T));
return switch (T) {
models.Note => "note",
models.Actor => "actor",
@ -35,6 +41,8 @@ fn tableName(comptime T: type) String {
};
}
// Combines an array/tuple of strings into a single string, with a copy of
// joiner in between each one
fn join(comptime vals: anytype, comptime joiner: String) String {
comptime {
if (vals.len == 0) return "";
@ -48,10 +56,11 @@ fn join(comptime vals: anytype, comptime joiner: String) String {
}
}
// Select query builder struct
const Query = struct {
select: []const String,
from: String,
where: String = "id = ?",
select: []const String, // the fields to grab
from: String, // what table to query
where: String, // conditions on records to query
order_by: ?[]const String = null,
group_by: ?[]const String = null,
limit: ?usize = null,
@ -71,10 +80,11 @@ const Query = struct {
}
};
// Insert query builder struct
const Insert = struct {
into: String,
columns: []const String,
count: usize = 1,
into: String, // the table to modify
columns: []const String, // the columns to provide
count: usize = 1, // the number of records to insert
pub fn str(comptime self: Insert) String {
comptime {
@ -91,18 +101,20 @@ const Insert = struct {
}
};
fn filterOut(comptime vals: []const String, comptime to_ignore: []const String) []const String {
// treats the inputs as sets and performs set subtraction. Assumes that elements do not appear
// multiple times.
fn setSubtract(comptime lhs: []const String, comptime rhs: []const String) []const String {
comptime {
var result: [vals.len]String = undefined;
var result: [lhs.len]String = undefined;
var count = 0;
for (vals) |v| {
const keep = for (to_ignore) |x| {
if (std.mem.eql(u8, x, v)) break false;
for (lhs) |l| {
const keep = for (rhs) |r| {
if (std.mem.eql(u8, l, r)) break false;
} else true;
if (keep) {
result[count] = v;
result[count] = l;
count += 1;
}
}
@ -111,12 +123,19 @@ fn filterOut(comptime vals: []const String, comptime to_ignore: []const String)
}
}
// returns all fields of T except for those in a specific set
fn fieldsExcept(comptime T: type, comptime to_ignore: []const String) []const String {
comptime {
return filterOut(std.meta.fieldNames(T), to_ignore);
return setSubtract(std.meta.fieldNames(T), to_ignore);
}
}
// Binds a value to a parameter in the query. Use this instead of string
// concatenation to avoid injection attacks;
// If a given type is not supported by this function, you can add support by
// declaring a method with the given signature:
// pub fn bindToSql(val: T, stmt: sql.PreparedStmt, idx: u15) !void
// TODO define what error set this ^ should return
fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
return switch (@TypeOf(val)) {
[]u8, []const u8 => stmt.bindText(idx, val),
@ -132,6 +151,11 @@ fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
};
}
// Gets a value from the row, allocating memory if necessary.
// If a given type is not supported by this function, you can add support by
// declaring a method with the given signature:
// pub fn getFromSql(row: sql.Row, idx: u15, alloc: std.mem.Allocator) !T
// TODO define what error set this ^ should return
fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: std.mem.Allocator) !T {
return switch (T) {
[]u8, []const u8 => row.getTextAlloc(idx, alloc),
@ -165,6 +189,8 @@ pub const Database = struct {
self.db.close();
}
// Returns the first row that satisfies an equality check on the
// field specified
pub fn getBy(
self: *Database,
comptime T: type,
@ -197,6 +223,8 @@ pub const Database = struct {
return result;
}
// Returns an array of all rows that satisfy an equality check
// TODO: paginate this
pub fn getWhereEq(
self: *Database,
comptime T: type,
@ -232,6 +260,8 @@ pub const Database = struct {
return results.toOwnedSlice();
}
// Returns the number of rows that satisfy an equality check on
// one of their fields
pub fn countWhereEq(
self: *Database,
comptime T: type,
@ -254,7 +284,7 @@ pub const Database = struct {
return @intCast(usize, try row.getI64(0));
}
// TODO: don't super like this query
// Returns whether a row with the given value exists.
pub fn existsWhereEq(
self: *Database,
comptime T: type,
@ -262,6 +292,7 @@ pub const Database = struct {
val: std.meta.fieldInfo(T, field).field_type,
) !bool {
const field_name = std.meta.fieldInfo(T, field).name;
// TODO: don't like this query
const q = comptime (Query{
.select = &.{"COUNT(1)"},
.from = tableName(T),
@ -278,6 +309,8 @@ pub const Database = struct {
return (try row.getI64(0)) > 0;
}
// Inserts a row into the database
// TODO: consider making this generic?
pub fn insert(self: *Database, comptime T: type, val: T) !void {
const fields = comptime std.meta.fieldNames(T);
const q = comptime (Insert{

View file

@ -20,7 +20,7 @@ pub const case = struct {
// converts a string from PascalCase to snake_case at comptime.
// only works with ascii characters
pub fn PascalToSnake(comptime str: []const u8) Return: {
pub fn pascalToSnake(comptime str: []const u8) Return: {
break :Return if (str.len == 0)
*const [0:0]u8
else
@ -50,11 +50,11 @@ pub const case = struct {
};
test "pascalToSnake" {
try std.testing.expectEqual("", case.PascalToSnake(""));
try std.testing.expectEqual("abc", case.PascalToSnake("Abc"));
try std.testing.expectEqual("a_bc", case.PascalToSnake("ABc"));
try std.testing.expectEqual("a_b_c", case.PascalToSnake("ABC"));
try std.testing.expectEqual("ab_c", case.PascalToSnake("AbC"));
try std.testing.expectEqual("", case.pascalToSnake(""));
try std.testing.expectEqual("abc", case.pascalToSnake("Abc"));
try std.testing.expectEqual("a_bc", case.pascalToSnake("ABc"));
try std.testing.expectEqual("a_b_c", case.pascalToSnake("ABC"));
try std.testing.expectEqual("ab_c", case.pascalToSnake("AbC"));
}
test {