diff --git a/src/main/api.zig b/src/main/api.zig index 6259994..47a5def 100644 --- a/src/main/api.zig +++ b/src/main/api.zig @@ -73,13 +73,9 @@ pub const ApiServer = struct { pub fn createUser(self: *ApiServer, info: CreateInfo(models.User)) !models.User { const id = Uuid.randV4(self.prng.random()); - // check for handle dupes - //while (iter.next()) |it| { - //if (std.mem.eql(u8, it.value_ptr.handle, info.handle)) { - //return error.DuplicateHandle; - //} - //} - // TODO: check for id dupes + if (try self.db.existsWhereEq(models.User, .handle, info.handle)) { + return error.HandleNotAvailable; + } const user = reify(models.User, id, info); try self.db.insert(models.User, user); @@ -88,10 +84,10 @@ pub const ApiServer = struct { } pub fn getNote(self: *ApiServer, id: Uuid, alloc: std.mem.Allocator) !?models.Note { - return self.db.getById(models.Note, id, alloc); + return self.db.getBy(models.Note, .id, id, alloc); } pub fn getUser(self: *ApiServer, id: Uuid, alloc: std.mem.Allocator) !?models.User { - return self.db.getById(models.User, id, alloc); + return self.db.getBy(models.User, .id, id, alloc); } }; diff --git a/src/main/controllers.zig b/src/main/controllers.zig index cab0584..cbc94fe 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -65,7 +65,10 @@ pub fn createUser(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) const info = try utils.parseRequestBody(api.CreateInfo(models.User), ctx, srv.alloc); defer utils.freeRequestBody(info, srv.alloc); - const user = try srv.api.createUser(info); + const user = srv.api.createUser(info) catch |err| switch (err) { + error.HandleNotAvailable => return try utils.respondJson(ctx, .bad_request, .{ .@"error" = "handle not available" }, srv.alloc), + else => return err, + }; try utils.respondJson(ctx, .created, user, srv.alloc); } diff --git a/src/main/db.zig b/src/main/db.zig index e303c54..54d55d4 100644 --- a/src/main/db.zig +++ b/src/main/db.zig @@ -30,13 +30,14 @@ const Query = struct { select: []const String, from: String, where: String = "id = ?", - limit: usize = 1, + limit: ?usize = null, pub fn str(comptime self: Query) String { comptime { + const limit_expr = if (self.limit == null) "" else std.fmt.comptimePrint(" LIMIT {}", .{self.limit}); return std.fmt.comptimePrint( - "SELECT {s} FROM {s} WHERE {s} LIMIT {};", - .{ join(self.select, ", "), self.from, self.where, self.limit }, + "SELECT {s} FROM {s} WHERE {s}{s};", + .{ join(self.select, ", "), self.from, self.where, limit_expr }, ); } } @@ -123,36 +124,6 @@ pub const Database = struct { self.db.close(); } - pub fn getByIdOld(self: *Database, comptime T: type, id: Uuid, alloc: std.mem.Allocator) !?T { - const fields = comptime fieldsExcept(T, &.{"id"}); - const q = comptime (Query{ - .select = fields, - .from = tableName(T), - .where = "id = ?", - .limit = 1, - }).str(); - - var stmt = try self.db.prepare(q); - defer stmt.finalize(); - - const id_str = id.toCharArray(); - try stmt.bindText(1, &id_str); - - const row = (try stmt.step()) orelse return null; - var val: T = undefined; - val.id = id; - - inline for (fields) |f, i| { - @field(val, f) = switch (@TypeOf(@field(val, f))) { - // TODO: Handle allocation failures gracefully - []const u8 => row.getTextAlloc(i, alloc) catch unreachable, - else => @compileError("unknown type"), - }; - } - - return val; - } - pub fn getById(self: *Database, comptime T: type, id: Uuid, alloc: std.mem.Allocator) !?T { return self.getBy(T, .id, id, alloc); } @@ -193,6 +164,52 @@ pub const Database = struct { return result; } + pub fn countWhereEq( + self: *Database, + comptime T: type, + comptime field: std.meta.FieldEnum(T), + val: std.meta.fieldInfo(T, field).field_type, + ) !usize { + const field_name = std.meta.fieldInfo(T, field).name; + const q = comptime (Query{ + .select = &.{"COUNT()"}, + .from = tableName(T), + .where = field_name ++ " = ?", + }).str(); + + var stmt = try self.db.prepare(q); + defer stmt.finalize(); + + try stmt.bind(1, val); + + const row = (try stmt.step()) orelse unreachable; + return @intCast(usize, try row.getI64(0)); + } + + // TODO: don't super like this query + pub fn existsWhereEq( + self: *Database, + comptime T: type, + comptime field: std.meta.FieldEnum(T), + val: std.meta.fieldInfo(T, field).field_type, + ) !bool { + const field_name = std.meta.fieldInfo(T, field).name; + const q = comptime (Query{ + .select = &.{"COUNT(1)"}, + .from = tableName(T), + .where = field_name ++ " = ?", + .limit = 1, + }).str(); + + var stmt = try self.db.prepare(q); + defer stmt.finalize(); + + try stmt.bind(1, val); + + const row = (try stmt.step()) orelse unreachable; + return (try row.getI64(0)) > 0; + } + pub fn getNoteById(self: *Database, id: Uuid, alloc: std.mem.Allocator) !?models.Note { return self.getById(models.Note, id, alloc); }