diff --git a/src/main/api.zig b/src/main/api.zig index 129fff2..72aef3d 100644 --- a/src/main/api.zig +++ b/src/main/api.zig @@ -176,22 +176,18 @@ pub const ApiSource = struct { } pub fn connectUnauthorized(self: *ApiSource, host: []const u8, alloc: std.mem.Allocator) !Conn { - const community_id = blk: { - const result = try self.db.execRow2(&.{Uuid}, "SELECT id FROM community WHERE host = ?", .{host}, null); - if (result) |r| break :blk r[0]; - - break :blk null; - }; - - if (community_id == null and !util.ciutf8.eql(self.config.cluster_host, host)) { - return error.NoCommunity; - } + const community_id = (try self.db.execRow2( + &.{Uuid}, + "SELECT id FROM community WHERE host = ?", + .{host}, + null, + )) orelse return error.NoCommunity; return Conn{ .db = self.db, .internal_alloc = self.internal_alloc, .as_user = null, - .on_community = community_id, + .on_community = community_id[0], .arena = std.heap.ArenaAllocator.init(alloc), }; } @@ -209,23 +205,21 @@ pub const ApiSource = struct { var hash: models.ByteArray(models.Token.hash_len) = undefined; models.Token.HashFn.hash(&decoded, &hash.data, .{}); - const db_token = (try self.db.getBy(models.Token, .hash, hash, conn.arena.allocator())) orelse return error.InvalidToken; - //const token_result = try self.db.execRow2( - //&.{Uuid}, - //\\SELECT user.id - //\\FROM token - //\\ JOIN user ON token.user_id = user.id - //\\ JOIN community ON - //); - //const token_result = (try self.db.execRow2( - //&.{Uuid}, - //"SELECT id FROM token WHERE hash = ?", - //.{hash}, - //null, - //)) orelse return error.InvalidToken; + const token_result = (try self.db.execRow2( + &.{Uuid}, + \\SELECT user.id + \\FROM token + \\ JOIN user ON token.user_id = user.id + \\ JOIN community ON user.community_id = community.id + \\ JOIN local_user ON local_user.user_id = user.id + \\WHERE token.hash = ? + \\LIMIT 1 + , + .{hash}, + null, + )) orelse return error.InvalidToken; - //conn.as_user = token_result[0]; - conn.as_user = db_token.user_id; + conn.as_user = token_result[0]; return conn; } @@ -238,7 +232,7 @@ fn ApiConn(comptime DbConn: type) type { db: DbConn, internal_alloc: std.mem.Allocator, // used *only* for large, internal buffers as_user: ?Uuid, - on_community: ?Uuid, + on_community: Uuid, arena: std.heap.ArenaAllocator, pub fn close(self: *Self) void { @@ -312,6 +306,25 @@ fn ApiConn(comptime DbConn: type) type { }; } + const TokenInfo = struct { + username: []const u8, + }; + pub fn getTokenInfo(self: *Self) !TokenInfo { + if (self.as_user) |user_id| { + const result = (try self.db.execRow2( + &.{[]const u8}, + "SELECT username FROM user WHERE id = ?", + .{user_id}, + self.arena.allocator(), + )) orelse { + return error.UserNotFound; + }; + return TokenInfo{ .username = result[0] }; + } + + return error.Unauthorized; + } + const TokenResult = struct { info: models.Token, value: [token_len]u8, diff --git a/src/main/controllers/auth.zig b/src/main/controllers/auth.zig index bba4229..ffe7a0b 100644 --- a/src/main/controllers/auth.zig +++ b/src/main/controllers/auth.zig @@ -10,21 +10,6 @@ const utils = @import("../controllers.zig").utils; const RequestServer = root.RequestServer; const RouteArgs = http.RouteArgs; -pub fn register(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { - const info = try utils.parseRequestBody(RegistrationInfo, ctx); - defer utils.freeRequestBody(info, ctx.alloc); - - var api = try utils.getApiConn(srv, ctx); - defer api.close(); - - const user = api.register(info) catch |err| switch (err) { - error.UsernameUnavailable => return utils.respondError(ctx, .bad_request, "Username Unavailable"), - else => return err, - }; - - try utils.respondJson(ctx, .created, user); -} - pub fn login(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { const credentials = try utils.parseRequestBody(struct { username: []const u8, password: []const u8 }, ctx); defer utils.freeRequestBody(credentials, ctx.alloc); @@ -39,3 +24,14 @@ pub fn login(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void try utils.respondJson(ctx, .ok, token); } + +pub fn verifyLogin(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { + var api = try utils.getApiConn(srv, ctx); + defer api.close(); + + // The self-hosted compiler doesn't like inferring this error set. + // do this for now + const info = api.getTokenInfo() catch unreachable; + + try utils.respondJson(ctx, .ok, info); +} diff --git a/src/main/db.zig b/src/main/db.zig index 1dcc521..c102dce 100644 --- a/src/main/db.zig +++ b/src/main/db.zig @@ -263,7 +263,9 @@ pub const Database = struct { defer results.finish(); const row = results.row(allocator); - return row orelse (results.err orelse null); + if (row) |r| return r; + if (results.err) |err| return err; + return null; } fn build_field_list(comptime T: type, comptime placeholder: ?[]const u8) []const u8 { diff --git a/src/main/main.zig b/src/main/main.zig index b776bcf..bfaa7c5 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -14,13 +14,13 @@ const Route = Router.Route; const RouteArgs = http.RouteArgs; const router = Router{ .routes = &[_]Route{ - Route.new(.GET, "/healthcheck", c.healthcheck), + Route.new(.GET, "/healthcheck", &c.healthcheck), //Route.new(.POST, "/users", c.users.create), //Route.new(.POST, "/auth/register", &c.auth.register), Route.new(.POST, "/login", &c.auth.login), - //Route.new(.GET, "/current-login", &c.auth.verifyLogin), + Route.new(.GET, "/login", &c.auth.verifyLogin), //Route.new(.POST, "/notes", &c.notes.create), //Route.new(.GET, "/notes/:id", &c.notes.get),