This commit is contained in:
jaina heartles 2022-07-30 13:40:20 -07:00
parent c1f8b8f0e2
commit c31633cade
2 changed files with 63 additions and 30 deletions

View file

@ -9,7 +9,10 @@ const DateTime = util.DateTime;
const String = []const u8; const String = []const u8;
const comptimePrint = std.fmt.comptimePrint; const comptimePrint = std.fmt.comptimePrint;
fn baseTypeName(comptime T: type) []const u8 { // Stuff in here introduces compile errors in seemingly unused parts of
// the stdlib. Maybe just broken on this version of zig?
const broken = struct {
fn baseTypeName(comptime T: type) []const u8 {
comptime { comptime {
const name = @typeName(T); const name = @typeName(T);
const start = for (name) |_, i| { const start = for (name) |_, i| {
@ -18,11 +21,14 @@ fn baseTypeName(comptime T: type) []const u8 {
return name[start..]; return name[start..];
} }
} }
fn tableNameBroken(comptime T: type) String {
return util.case.pascalToSnake(baseTypeName(T));
}
};
fn tableName(comptime T: type) String { fn tableName(comptime T: type) String {
//return util.case.pascalToSnake(baseTypeName(T));
return switch (T) { return switch (T) {
models.Note => "note", models.Note => "note",
models.Actor => "actor", models.Actor => "actor",
@ -35,6 +41,8 @@ fn tableName(comptime T: type) String {
}; };
} }
// Combines an array/tuple of strings into a single string, with a copy of
// joiner in between each one
fn join(comptime vals: anytype, comptime joiner: String) String { fn join(comptime vals: anytype, comptime joiner: String) String {
comptime { comptime {
if (vals.len == 0) return ""; if (vals.len == 0) return "";
@ -48,10 +56,11 @@ fn join(comptime vals: anytype, comptime joiner: String) String {
} }
} }
// Select query builder struct
const Query = struct { const Query = struct {
select: []const String, select: []const String, // the fields to grab
from: String, from: String, // what table to query
where: String = "id = ?", where: String, // conditions on records to query
order_by: ?[]const String = null, order_by: ?[]const String = null,
group_by: ?[]const String = null, group_by: ?[]const String = null,
limit: ?usize = null, limit: ?usize = null,
@ -71,10 +80,11 @@ const Query = struct {
} }
}; };
// Insert query builder struct
const Insert = struct { const Insert = struct {
into: String, into: String, // the table to modify
columns: []const String, columns: []const String, // the columns to provide
count: usize = 1, count: usize = 1, // the number of records to insert
pub fn str(comptime self: Insert) String { pub fn str(comptime self: Insert) String {
comptime { comptime {
@ -91,18 +101,20 @@ const Insert = struct {
} }
}; };
fn filterOut(comptime vals: []const String, comptime to_ignore: []const String) []const String { // treats the inputs as sets and performs set subtraction. Assumes that elements do not appear
// multiple times.
fn setSubtract(comptime lhs: []const String, comptime rhs: []const String) []const String {
comptime { comptime {
var result: [vals.len]String = undefined; var result: [lhs.len]String = undefined;
var count = 0; var count = 0;
for (vals) |v| { for (lhs) |l| {
const keep = for (to_ignore) |x| { const keep = for (rhs) |r| {
if (std.mem.eql(u8, x, v)) break false; if (std.mem.eql(u8, l, r)) break false;
} else true; } else true;
if (keep) { if (keep) {
result[count] = v; result[count] = l;
count += 1; count += 1;
} }
} }
@ -111,12 +123,19 @@ fn filterOut(comptime vals: []const String, comptime to_ignore: []const String)
} }
} }
// returns all fields of T except for those in a specific set
fn fieldsExcept(comptime T: type, comptime to_ignore: []const String) []const String { fn fieldsExcept(comptime T: type, comptime to_ignore: []const String) []const String {
comptime { comptime {
return filterOut(std.meta.fieldNames(T), to_ignore); return setSubtract(std.meta.fieldNames(T), to_ignore);
} }
} }
// Binds a value to a parameter in the query. Use this instead of string
// concatenation to avoid injection attacks;
// If a given type is not supported by this function, you can add support by
// declaring a method with the given signature:
// pub fn bindToSql(val: T, stmt: sql.PreparedStmt, idx: u15) !void
// TODO define what error set this ^ should return
fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void { fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
return switch (@TypeOf(val)) { return switch (@TypeOf(val)) {
[]u8, []const u8 => stmt.bindText(idx, val), []u8, []const u8 => stmt.bindText(idx, val),
@ -132,6 +151,11 @@ fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
}; };
} }
// Gets a value from the row, allocating memory if necessary.
// If a given type is not supported by this function, you can add support by
// declaring a method with the given signature:
// pub fn getFromSql(row: sql.Row, idx: u15, alloc: std.mem.Allocator) !T
// TODO define what error set this ^ should return
fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: std.mem.Allocator) !T { fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: std.mem.Allocator) !T {
return switch (T) { return switch (T) {
[]u8, []const u8 => row.getTextAlloc(idx, alloc), []u8, []const u8 => row.getTextAlloc(idx, alloc),
@ -165,6 +189,8 @@ pub const Database = struct {
self.db.close(); self.db.close();
} }
// Returns the first row that satisfies an equality check on the
// field specified
pub fn getBy( pub fn getBy(
self: *Database, self: *Database,
comptime T: type, comptime T: type,
@ -197,6 +223,8 @@ pub const Database = struct {
return result; return result;
} }
// Returns an array of all rows that satisfy an equality check
// TODO: paginate this
pub fn getWhereEq( pub fn getWhereEq(
self: *Database, self: *Database,
comptime T: type, comptime T: type,
@ -232,6 +260,8 @@ pub const Database = struct {
return results.toOwnedSlice(); return results.toOwnedSlice();
} }
// Returns the number of rows that satisfy an equality check on
// one of their fields
pub fn countWhereEq( pub fn countWhereEq(
self: *Database, self: *Database,
comptime T: type, comptime T: type,
@ -254,7 +284,7 @@ pub const Database = struct {
return @intCast(usize, try row.getI64(0)); return @intCast(usize, try row.getI64(0));
} }
// TODO: don't super like this query // Returns whether a row with the given value exists.
pub fn existsWhereEq( pub fn existsWhereEq(
self: *Database, self: *Database,
comptime T: type, comptime T: type,
@ -262,6 +292,7 @@ pub const Database = struct {
val: std.meta.fieldInfo(T, field).field_type, val: std.meta.fieldInfo(T, field).field_type,
) !bool { ) !bool {
const field_name = std.meta.fieldInfo(T, field).name; const field_name = std.meta.fieldInfo(T, field).name;
// TODO: don't like this query
const q = comptime (Query{ const q = comptime (Query{
.select = &.{"COUNT(1)"}, .select = &.{"COUNT(1)"},
.from = tableName(T), .from = tableName(T),
@ -278,6 +309,8 @@ pub const Database = struct {
return (try row.getI64(0)) > 0; return (try row.getI64(0)) > 0;
} }
// Inserts a row into the database
// TODO: consider making this generic?
pub fn insert(self: *Database, comptime T: type, val: T) !void { pub fn insert(self: *Database, comptime T: type, val: T) !void {
const fields = comptime std.meta.fieldNames(T); const fields = comptime std.meta.fieldNames(T);
const q = comptime (Insert{ const q = comptime (Insert{

View file

@ -20,7 +20,7 @@ pub const case = struct {
// converts a string from PascalCase to snake_case at comptime. // converts a string from PascalCase to snake_case at comptime.
// only works with ascii characters // only works with ascii characters
pub fn PascalToSnake(comptime str: []const u8) Return: { pub fn pascalToSnake(comptime str: []const u8) Return: {
break :Return if (str.len == 0) break :Return if (str.len == 0)
*const [0:0]u8 *const [0:0]u8
else else
@ -50,11 +50,11 @@ pub const case = struct {
}; };
test "pascalToSnake" { test "pascalToSnake" {
try std.testing.expectEqual("", case.PascalToSnake("")); try std.testing.expectEqual("", case.pascalToSnake(""));
try std.testing.expectEqual("abc", case.PascalToSnake("Abc")); try std.testing.expectEqual("abc", case.pascalToSnake("Abc"));
try std.testing.expectEqual("a_bc", case.PascalToSnake("ABc")); try std.testing.expectEqual("a_bc", case.pascalToSnake("ABc"));
try std.testing.expectEqual("a_b_c", case.PascalToSnake("ABC")); try std.testing.expectEqual("a_b_c", case.pascalToSnake("ABC"));
try std.testing.expectEqual("ab_c", case.PascalToSnake("AbC")); try std.testing.expectEqual("ab_c", case.pascalToSnake("AbC"));
} }
test { test {