From a0201997732601f477b696d9623ee8805202895d Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Mon, 25 Jul 2022 19:07:05 -0700 Subject: [PATCH] Refactor api calls --- src/main/api.zig | 341 ++++++++++++++------------ src/main/controllers.zig | 16 +- src/main/controllers/actors.zig | 7 +- src/main/controllers/auth.zig | 21 +- src/main/controllers/notes.zig | 17 +- src/main/controllers/notes/reacts.zig | 13 +- src/main/main.zig | 5 +- 7 files changed, 234 insertions(+), 186 deletions(-) diff --git a/src/main/api.zig b/src/main/api.zig index b2b81b0..a846b97 100644 --- a/src/main/api.zig +++ b/src/main/api.zig @@ -66,15 +66,7 @@ fn reify(comptime T: type, id: Uuid, val: CreateInfo(T)) T { return result; } -pub const ApiContext = struct { - user_context: struct { - user: models.LocalUser, - }, - - alloc: std.mem.Allocator, -}; - -pub const NoteCreate = struct { +pub const NoteCreateInfo = struct { content: []const u8, }; @@ -84,24 +76,41 @@ pub const RegistrationInfo = struct { email: ?[]const u8, }; +pub const LoginResult = struct { + user_id: Uuid, + token: [token_str_len]u8, + issued_at: DateTime, +}; + threadlocal var prng: std.rand.DefaultPrng = undefined; pub fn initThreadPrng(seed: u64) void { prng = std.rand.DefaultPrng.init(seed +% std.Thread.getCurrentId()); } -pub const ApiServer = struct { +pub const ApiSource = struct { db: db.Database, internal_alloc: std.mem.Allocator, - pub fn init(alloc: std.mem.Allocator) !ApiServer { - return ApiServer{ + pub const Conn = ApiConn(db.Database); + + pub fn init(alloc: std.mem.Allocator) !ApiSource { + return ApiSource{ .db = try db.Database.init(), .internal_alloc = alloc, }; } - pub fn makeApiContext(self: *ApiServer, token: []const u8, alloc: std.mem.Allocator) !ApiContext { + pub fn connectUnauthorized(self: *ApiSource, alloc: std.mem.Allocator) !Conn { + return Conn{ + .db = self.db, + .internal_alloc = self.internal_alloc, + .as_user = null, + .arena = std.heap.ArenaAllocator.init(alloc), + }; + } + + pub fn connectToken(self: *ApiSource, token: []const u8, alloc: std.mem.Allocator) !Conn { const decoded_len = std.base64.standard.Decoder.calcSizeForSlice(token) catch return error.InvalidToken; if (decoded_len != token_len) return error.InvalidToken; @@ -111,146 +120,176 @@ pub const ApiServer = struct { var hash: models.ByteArray(models.Token.hash_len) = undefined; models.Token.HashFn.hash(&decoded, &hash.data, .{}); - const db_token = (try self.db.getBy(models.Token, .hash, hash, alloc)) orelse return error.InvalidToken; + var arena = std.heap.ArenaAllocator.init(alloc); - const user = (try self.db.getBy(models.LocalUser, .id, db_token.user_id, alloc)) orelse return error.InvalidToken; + const db_token = (try self.db.getBy(models.Token, .hash, hash, arena.allocator())) orelse return error.InvalidToken; - return ApiContext{ - .user_context = .{ - .user = user, - }, - - .alloc = alloc, + return Conn{ + .db = self.db, + .internal_alloc = self.internal_alloc, + .as_user = db_token.user_id, + .arena = arena, }; } - - pub fn createNoteUser(self: *ApiServer, info: NoteCreate, ctx: ApiContext) !models.Note { - const id = Uuid.randV4(prng.random()); - // TODO: check for dupes - - const note = models.Note{ - .id = id, - .author_id = ctx.user_context.user.actor_id.?, - .content = info.content, - - .created_at = DateTime.now(), - }; - try self.db.insert(models.Note, note); - - return note; - } - - pub fn register(self: *ApiServer, info: RegistrationInfo) !models.Actor { - const actor_id = Uuid.randV4(prng.random()); - const user_id = Uuid.randV4(prng.random()); - // TODO: transaction? - - if (try self.db.existsWhereEq(models.LocalUser, .username, info.username)) { - return error.UsernameUnavailable; - } - - if (try self.db.existsWhereEq(models.Actor, .handle, info.username)) { - return error.InconsistentDb; - } - - // use internal alloc because necessary buffer is *big* - var buf: [pw_hash_buf_size]u8 = undefined; - const hash = try PwHash.strHash(info.password, .{ .allocator = self.internal_alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, &buf); - const now = DateTime.now(); - - const actor = models.Actor{ - .id = actor_id, - .handle = info.username, - .created_at = now, - }; - const user = models.LocalUser{ - .id = user_id, - .actor_id = actor_id, - .username = info.username, - .email = info.email, - .hashed_password = hash, - .password_changed_at = now, - .created_at = now, - }; - try self.db.insert(models.Actor, actor); - try self.db.insert(models.LocalUser, user); - - // TODO: return token instead - return actor; - } - - const LoginResult = struct { - user_id: Uuid, - token: [token_str_len]u8, - issued_at: DateTime, - }; - pub fn login(self: *ApiServer, username: []const u8, password: []const u8, alloc: std.mem.Allocator) !LoginResult { - // TODO: This gives away the existence of a user through a timing side channel. is that acceptable? - const user_info = (try self.db.getBy(models.LocalUser, .username, username, alloc)) orelse return error.InvalidLogin; - defer free(alloc, user_info); - - const Hash = std.crypto.pwhash.scrypt; - Hash.strVerify(user_info.hashed_password, password, .{ .allocator = alloc }) catch |err| switch (err) { - error.PasswordVerificationFailed => return error.InvalidLogin, - else => return err, - }; - - const token = try self.createToken(user_info); - - var token_enc: [token_str_len]u8 = undefined; - _ = std.base64.standard.Encoder.encode(&token_enc, &token.value); - - return LoginResult{ - .user_id = user_info.id, - .token = token_enc, - .issued_at = token.info.issued_at, - }; - //return (try self.db.getBy(models.Actor, .id, user_info.actor_id.?, alloc)) orelse unreachable; - } - - const TokenResult = struct { - info: models.Token, - value: [token_len]u8, - }; - fn createToken(self: *ApiServer, user: models.LocalUser) !TokenResult { - var token: [token_len]u8 = undefined; - std.crypto.random.bytes(&token); - - var hash: [models.Token.hash_len]u8 = undefined; - models.Token.HashFn.hash(&token, &hash, .{}); - - const db_token = models.Token{ - .id = Uuid.randV4(prng.random()), - .hash = .{ .data = hash }, - .user_id = user.id, - .issued_at = DateTime.now(), - }; - - try self.db.insert(models.Token, db_token); - return TokenResult{ - .info = db_token, - .value = token, - }; - } - - pub fn getNote(self: *ApiServer, id: Uuid, alloc: std.mem.Allocator) !?models.Note { - return self.db.getBy(models.Note, .id, id, alloc); - } - - pub fn getActor(self: *ApiServer, id: Uuid, alloc: std.mem.Allocator) !?models.Actor { - return self.db.getBy(models.Actor, .id, id, alloc); - } - - pub fn getActorByHandle(self: *ApiServer, handle: []const u8, alloc: std.mem.Allocator) !?models.Actor { - return self.db.getBy(models.Actor, .handle, handle, alloc); - } - - pub fn react(self: *ApiServer, note_id: Uuid, ctx: ApiContext) !void { - const id = Uuid.randV4(prng.random()); - try self.db.insert(models.Reaction, .{ .id = id, .note_id = note_id, .reactor_id = ctx.user_context.user.actor_id.?, .created_at = DateTime.now() }); - } - - pub fn listReacts(self: *ApiServer, note_id: Uuid, ctx: ApiContext) ![]models.Reaction { - return try self.db.getWhereEq(models.Reaction, .note_id, note_id, ctx.alloc); - } }; + +fn ApiConn(comptime DbConn: type) type { + return struct { + const Self = @This(); + + db: DbConn, + internal_alloc: std.mem.Allocator, // used *only* for large, internal buffers + as_user: ?Uuid, + arena: std.heap.ArenaAllocator, + + pub fn close(self: *Self) void { + self.arena.deinit(); + } + + fn getAuthenticatedUser(self: *Self) !models.LocalUser { + if (self.as_user) |user_id| { + const local_user = try self.db.getBy(models.LocalUser, .id, user_id, self.arena.allocator()); + if (local_user == null) return error.UserNotFound; + + return local_user.?; + } else { + return error.NotAuthenticated; + } + } + + fn getAuthenticatedActor(self: *Self) !models.Actor { + const user = try self.getAuthenticatedUser(); + if (user.actor_id) |actor_id| { + const actor = try self.db.getBy(models.Actor, .id, actor_id, self.arena); + return actor.?; + } else { + return error.NoActor; + } + } + + pub fn createNote(self: *Self, info: NoteCreateInfo) !models.Note { + const id = Uuid.randV4(prng.random()); + const user = try self.getAuthenticatedUser(); + + const note = models.Note{ + .id = id, + .author_id = user.actor_id orelse return error.NotAuthorized, + .content = info.content, + + .created_at = DateTime.now(), + }; + try self.db.insert(models.Note, note); + + return note; + } + + pub fn getNote(self: *Self, id: Uuid) !?models.Note { + return self.db.getBy(models.Note, .id, id, self.arena.allocator()); + } + + pub fn getActor(self: *Self, id: Uuid) !?models.Actor { + return self.db.getBy(models.Actor, .id, id, self.arena.allocator()); + } + + pub fn getActorByHandle(self: *Self, handle: []const u8) !?models.Actor { + return self.db.getBy(models.Actor, .handle, handle, self.arena.allocator()); + } + + pub fn react(self: *Self, note_id: Uuid) !void { + const id = Uuid.randV4(prng.random()); + const user = try self.getAuthenticatedUser(); + try self.db.insert(models.Reaction, .{ .id = id, .note_id = note_id, .reactor_id = user.actor_id orelse return error.NotAuthorized, .created_at = DateTime.now() }); + } + + pub fn listReacts(self: *Self, note_id: Uuid) ![]models.Reaction { + return try self.db.getWhereEq(models.Reaction, .note_id, note_id, self.arena.allocator()); + } + + pub fn register(self: *Self, info: RegistrationInfo) !models.Actor { + const actor_id = Uuid.randV4(prng.random()); + const user_id = Uuid.randV4(prng.random()); + // TODO: transaction? + + if (try self.db.existsWhereEq(models.LocalUser, .username, info.username)) { + return error.UsernameUnavailable; + } + + if (try self.db.existsWhereEq(models.Actor, .handle, info.username)) { + return error.InconsistentDb; + } + + // use internal alloc because necessary buffer is *big* + var buf: [pw_hash_buf_size]u8 = undefined; + const hash = try PwHash.strHash(info.password, .{ .allocator = self.internal_alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, &buf); + const now = DateTime.now(); + + const actor = models.Actor{ + .id = actor_id, + .handle = info.username, + .created_at = now, + }; + const user = models.LocalUser{ + .id = user_id, + .actor_id = actor_id, + .username = info.username, + .email = info.email, + .hashed_password = hash, + .password_changed_at = now, + .created_at = now, + }; + try self.db.insert(models.Actor, actor); + try self.db.insert(models.LocalUser, user); + + // TODO: return token instead + return actor; + } + + pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResult { + // TODO: This gives away the existence of a user through a timing side channel. is that acceptable? + const user_info = (try self.db.getBy(models.LocalUser, .username, username, self.arena.allocator())) orelse return error.InvalidLogin; + //defer free(self.arena.allocator(), user_info); + + const Hash = std.crypto.pwhash.scrypt; + Hash.strVerify(user_info.hashed_password, password, .{ .allocator = self.internal_alloc }) catch |err| switch (err) { + error.PasswordVerificationFailed => return error.InvalidLogin, + else => return err, + }; + + const token = try self.createToken(user_info); + + var token_enc: [token_str_len]u8 = undefined; + _ = std.base64.standard.Encoder.encode(&token_enc, &token.value); + + return LoginResult{ + .user_id = user_info.id, + .token = token_enc, + .issued_at = token.info.issued_at, + }; + } + + const TokenResult = struct { + info: models.Token, + value: [token_len]u8, + }; + fn createToken(self: *Self, user: models.LocalUser) !TokenResult { + var token: [token_len]u8 = undefined; + std.crypto.random.bytes(&token); + + var hash: [models.Token.hash_len]u8 = undefined; + models.Token.HashFn.hash(&token, &hash, .{}); + + const db_token = models.Token{ + .id = Uuid.randV4(prng.random()), + .hash = .{ .data = hash }, + .user_id = user.id, + .issued_at = DateTime.now(), + }; + + try self.db.insert(models.Token, db_token); + return TokenResult{ + .info = db_token, + .value = token, + }; + } + }; +} diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 7a4e714..9411cc7 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -57,13 +57,21 @@ pub const utils = struct { std.json.parseFree(@TypeOf(value), value, .{ .allocator = alloc }); } - pub fn getApiContext(srv: *RequestServer, ctx: *http.server.Context) !api.ApiContext { - const header = ctx.request.headers.get("authorization") orelse "(null)"; + pub fn getApiConn(srv: *RequestServer, ctx: *http.server.Context) !api.ApiSource.Conn { + return authorizeApiConn(srv, ctx) catch |err| switch (err) { + error.NoToken => srv.api.connectUnauthorized(ctx.alloc), + error.InvalidToken => return error.InvalidToken, + else => @panic("TODO"), // doing this to resolve some sort of compiler analysis dependency issue + }; + } + fn authorizeApiConn(srv: *RequestServer, ctx: *http.server.Context) !api.ApiSource.Conn { + const header = ctx.request.headers.get("authorization") orelse return error.NoToken; + + if (header.len < ("bearer ").len) return error.InvalidToken; const token = header[("bearer ").len..]; - return try srv.api.makeApiContext(token, ctx.alloc); - // TODO: defer api.free(ctx.alloc, user_ctx); + return try srv.api.connectToken(token, ctx.alloc); } }; diff --git a/src/main/controllers/actors.zig b/src/main/controllers/actors.zig index 6bfa292..3663972 100644 --- a/src/main/controllers/actors.zig +++ b/src/main/controllers/actors.zig @@ -1,6 +1,5 @@ const root = @import("root"); const http = @import("http"); -const api = @import("../api.zig"); const Uuid = @import("util").Uuid; const utils = @import("../controllers.zig").utils; @@ -11,8 +10,10 @@ const RouteArgs = http.RouteArgs; pub fn get(srv: *RequestServer, ctx: *http.server.Context, args: RouteArgs) !void { const id_str = args.get("id") orelse return error.NotFound; const id = Uuid.parse(id_str) catch return utils.respondError(ctx, .bad_request, "Invalid UUID"); - const user = (try srv.api.getActor(id, ctx.alloc)) orelse return utils.respondError(ctx, .not_found, "Note not found"); - defer api.free(ctx.alloc, user); + var api = try utils.getApiConn(srv, ctx); + defer api.close(); + + const user = (try api.getActor(id)) orelse return utils.respondError(ctx, .not_found, "Note not found"); try utils.respondJson(ctx, .ok, user); } diff --git a/src/main/controllers/auth.zig b/src/main/controllers/auth.zig index 8976ae2..5c78715 100644 --- a/src/main/controllers/auth.zig +++ b/src/main/controllers/auth.zig @@ -2,19 +2,22 @@ const std = @import("std"); const root = @import("root"); const builtin = @import("builtin"); const http = @import("http"); -const api = @import("../api.zig"); const Uuid = @import("util").Uuid; +const RegistrationInfo = @import("../api.zig").RegistrationInfo; const utils = @import("../controllers.zig").utils; const RequestServer = root.RequestServer; const RouteArgs = http.RouteArgs; pub fn register(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { - const info = try utils.parseRequestBody(api.RegistrationInfo, ctx); + const info = try utils.parseRequestBody(RegistrationInfo, ctx); defer utils.freeRequestBody(info, ctx.alloc); - const user = srv.api.register(info) catch |err| switch (err) { + var api = try utils.getApiConn(srv, ctx); + defer api.close(); + + const user = api.register(info) catch |err| switch (err) { error.UsernameUnavailable => return try utils.respondError(ctx, .bad_request, "Username Unavailable"), else => return err, }; @@ -22,18 +25,14 @@ pub fn register(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !v try utils.respondJson(ctx, .created, user); } -pub fn authenticate(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { - const user_ctx = try utils.getApiContext(srv, ctx); - // TODO: defer api.free(ctx.alloc, user_ctx); - - try utils.respondJson(ctx, .ok, user_ctx.user_context); -} - pub fn login(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { const credentials = try utils.parseRequestBody(struct { username: []const u8, password: []const u8 }, ctx); defer utils.freeRequestBody(credentials, ctx.alloc); - const token = srv.api.login(credentials.username, credentials.password, ctx.alloc) catch |err| switch (err) { + var api = try utils.getApiConn(srv, ctx); + defer api.close(); + + const token = api.login(credentials.username, credentials.password) catch |err| switch (err) { error.PasswordVerificationFailed => return utils.respondError(ctx, .bad_request, "Invalid Login"), else => return err, }; diff --git a/src/main/controllers/notes.zig b/src/main/controllers/notes.zig index dce6637..534c910 100644 --- a/src/main/controllers/notes.zig +++ b/src/main/controllers/notes.zig @@ -1,9 +1,9 @@ const root = @import("root"); const http = @import("http"); -const api = @import("../api.zig"); const Uuid = @import("util").Uuid; const utils = @import("../controllers.zig").utils; +const NoteCreateInfo = @import("../api.zig").NoteCreateInfo; const RequestServer = root.RequestServer; const RouteArgs = http.RouteArgs; @@ -11,12 +11,13 @@ const RouteArgs = http.RouteArgs; pub const reacts = @import("./notes/reacts.zig"); pub fn create(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { - const user_context = try utils.getApiContext(srv, ctx); - // TODO: defer free ApiContext - const info = try utils.parseRequestBody(api.NoteCreate, ctx); + const info = try utils.parseRequestBody(NoteCreateInfo, ctx); defer utils.freeRequestBody(info, ctx.alloc); - const note = try srv.api.createNoteUser(info, user_context); + var api = try utils.getApiConn(srv, ctx); + defer api.close(); + + const note = try api.createNote(info); try utils.respondJson(ctx, .created, note); } @@ -24,8 +25,10 @@ pub fn create(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !voi pub fn get(srv: *RequestServer, ctx: *http.server.Context, args: RouteArgs) !void { const id_str = args.get("id") orelse return error.NotFound; const id = Uuid.parse(id_str) catch return utils.respondError(ctx, .bad_request, "Invalid UUID"); - const note = (try srv.api.getNote(id, ctx.alloc)) orelse return utils.respondError(ctx, .not_found, "Note not found"); - defer api.free(ctx.alloc, note); + var api = try utils.getApiConn(srv, ctx); + defer api.close(); + + const note = (try api.getNote(id)) orelse return utils.respondError(ctx, .not_found, "Note not found"); try utils.respondJson(ctx, .ok, note); } diff --git a/src/main/controllers/notes/reacts.zig b/src/main/controllers/notes/reacts.zig index 37c36c9..0d13b8a 100644 --- a/src/main/controllers/notes/reacts.zig +++ b/src/main/controllers/notes/reacts.zig @@ -1,6 +1,5 @@ const root = @import("root"); const http = @import("http"); -const api = @import("../../api.zig"); const Uuid = @import("util").Uuid; const utils = @import("../../controllers.zig").utils; @@ -9,23 +8,23 @@ const RequestServer = root.RequestServer; const RouteArgs = http.RouteArgs; pub fn create(srv: *RequestServer, ctx: *http.server.Context, args: RouteArgs) !void { - const user_context = try utils.getApiContext(srv, ctx); - // TODO: defer free ApiContext + var api = try utils.getApiConn(srv, ctx); + defer api.close(); const note_id = args.get("id") orelse return error.NotFound; const id = Uuid.parse(note_id) catch return utils.respondError(ctx, .bad_request, "Invalid UUID"); - try srv.api.react(id, user_context); + try api.react(id); try utils.respondJson(ctx, .created, .{}); } pub fn list(srv: *RequestServer, ctx: *http.server.Context, args: RouteArgs) !void { - const user_context = try utils.getApiContext(srv, ctx); - // TODO: defer free ApiContext + var api = try utils.getApiConn(srv, ctx); + defer api.close(); const note_id = args.get("id") orelse return error.NotFound; const id = Uuid.parse(note_id) catch return utils.respondError(ctx, .bad_request, "Invalid UUID"); - const reacts = try srv.api.listReacts(id, user_context); + const reacts = try api.listReacts(id); try utils.respondJson(ctx, .ok, .{ .items = reacts }); } diff --git a/src/main/main.zig b/src/main/main.zig index 32eb90d..ecfb6e8 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -17,7 +17,6 @@ const router = Router{ Route.new(.POST, "/auth/register", c.auth.register), Route.new(.POST, "/auth/login", c.auth.login), - Route.new(.GET, "/auth/authenticate", c.auth.authenticate), Route.new(.POST, "/notes", c.notes.create), Route.new(.GET, "/notes/:id", c.notes.get), @@ -31,12 +30,12 @@ const router = Router{ pub const RequestServer = struct { alloc: std.mem.Allocator, - api: api.ApiServer, + api: api.ApiSource, fn init(alloc: std.mem.Allocator) !RequestServer { return RequestServer{ .alloc = alloc, - .api = try api.ApiServer.init(alloc), + .api = try api.ApiSource.init(alloc), }; }