Refactor api calls

This commit is contained in:
jaina heartles 2022-07-25 19:07:05 -07:00
parent 47e157e31b
commit a020199773
7 changed files with 234 additions and 186 deletions

View file

@ -66,15 +66,7 @@ fn reify(comptime T: type, id: Uuid, val: CreateInfo(T)) T {
return result; return result;
} }
pub const ApiContext = struct { pub const NoteCreateInfo = struct {
user_context: struct {
user: models.LocalUser,
},
alloc: std.mem.Allocator,
};
pub const NoteCreate = struct {
content: []const u8, content: []const u8,
}; };
@ -84,24 +76,41 @@ pub const RegistrationInfo = struct {
email: ?[]const u8, email: ?[]const u8,
}; };
pub const LoginResult = struct {
user_id: Uuid,
token: [token_str_len]u8,
issued_at: DateTime,
};
threadlocal var prng: std.rand.DefaultPrng = undefined; threadlocal var prng: std.rand.DefaultPrng = undefined;
pub fn initThreadPrng(seed: u64) void { pub fn initThreadPrng(seed: u64) void {
prng = std.rand.DefaultPrng.init(seed +% std.Thread.getCurrentId()); prng = std.rand.DefaultPrng.init(seed +% std.Thread.getCurrentId());
} }
pub const ApiServer = struct { pub const ApiSource = struct {
db: db.Database, db: db.Database,
internal_alloc: std.mem.Allocator, internal_alloc: std.mem.Allocator,
pub fn init(alloc: std.mem.Allocator) !ApiServer { pub const Conn = ApiConn(db.Database);
return ApiServer{
pub fn init(alloc: std.mem.Allocator) !ApiSource {
return ApiSource{
.db = try db.Database.init(), .db = try db.Database.init(),
.internal_alloc = alloc, .internal_alloc = alloc,
}; };
} }
pub fn makeApiContext(self: *ApiServer, token: []const u8, alloc: std.mem.Allocator) !ApiContext { pub fn connectUnauthorized(self: *ApiSource, alloc: std.mem.Allocator) !Conn {
return Conn{
.db = self.db,
.internal_alloc = self.internal_alloc,
.as_user = null,
.arena = std.heap.ArenaAllocator.init(alloc),
};
}
pub fn connectToken(self: *ApiSource, token: []const u8, alloc: std.mem.Allocator) !Conn {
const decoded_len = std.base64.standard.Decoder.calcSizeForSlice(token) catch return error.InvalidToken; const decoded_len = std.base64.standard.Decoder.calcSizeForSlice(token) catch return error.InvalidToken;
if (decoded_len != token_len) return error.InvalidToken; if (decoded_len != token_len) return error.InvalidToken;
@ -111,146 +120,176 @@ pub const ApiServer = struct {
var hash: models.ByteArray(models.Token.hash_len) = undefined; var hash: models.ByteArray(models.Token.hash_len) = undefined;
models.Token.HashFn.hash(&decoded, &hash.data, .{}); models.Token.HashFn.hash(&decoded, &hash.data, .{});
const db_token = (try self.db.getBy(models.Token, .hash, hash, alloc)) orelse return error.InvalidToken; var arena = std.heap.ArenaAllocator.init(alloc);
const user = (try self.db.getBy(models.LocalUser, .id, db_token.user_id, alloc)) orelse return error.InvalidToken; const db_token = (try self.db.getBy(models.Token, .hash, hash, arena.allocator())) orelse return error.InvalidToken;
return ApiContext{ return Conn{
.user_context = .{ .db = self.db,
.user = user, .internal_alloc = self.internal_alloc,
}, .as_user = db_token.user_id,
.arena = arena,
.alloc = alloc,
}; };
} }
pub fn createNoteUser(self: *ApiServer, info: NoteCreate, ctx: ApiContext) !models.Note {
const id = Uuid.randV4(prng.random());
// TODO: check for dupes
const note = models.Note{
.id = id,
.author_id = ctx.user_context.user.actor_id.?,
.content = info.content,
.created_at = DateTime.now(),
};
try self.db.insert(models.Note, note);
return note;
}
pub fn register(self: *ApiServer, info: RegistrationInfo) !models.Actor {
const actor_id = Uuid.randV4(prng.random());
const user_id = Uuid.randV4(prng.random());
// TODO: transaction?
if (try self.db.existsWhereEq(models.LocalUser, .username, info.username)) {
return error.UsernameUnavailable;
}
if (try self.db.existsWhereEq(models.Actor, .handle, info.username)) {
return error.InconsistentDb;
}
// use internal alloc because necessary buffer is *big*
var buf: [pw_hash_buf_size]u8 = undefined;
const hash = try PwHash.strHash(info.password, .{ .allocator = self.internal_alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, &buf);
const now = DateTime.now();
const actor = models.Actor{
.id = actor_id,
.handle = info.username,
.created_at = now,
};
const user = models.LocalUser{
.id = user_id,
.actor_id = actor_id,
.username = info.username,
.email = info.email,
.hashed_password = hash,
.password_changed_at = now,
.created_at = now,
};
try self.db.insert(models.Actor, actor);
try self.db.insert(models.LocalUser, user);
// TODO: return token instead
return actor;
}
const LoginResult = struct {
user_id: Uuid,
token: [token_str_len]u8,
issued_at: DateTime,
};
pub fn login(self: *ApiServer, username: []const u8, password: []const u8, alloc: std.mem.Allocator) !LoginResult {
// TODO: This gives away the existence of a user through a timing side channel. is that acceptable?
const user_info = (try self.db.getBy(models.LocalUser, .username, username, alloc)) orelse return error.InvalidLogin;
defer free(alloc, user_info);
const Hash = std.crypto.pwhash.scrypt;
Hash.strVerify(user_info.hashed_password, password, .{ .allocator = alloc }) catch |err| switch (err) {
error.PasswordVerificationFailed => return error.InvalidLogin,
else => return err,
};
const token = try self.createToken(user_info);
var token_enc: [token_str_len]u8 = undefined;
_ = std.base64.standard.Encoder.encode(&token_enc, &token.value);
return LoginResult{
.user_id = user_info.id,
.token = token_enc,
.issued_at = token.info.issued_at,
};
//return (try self.db.getBy(models.Actor, .id, user_info.actor_id.?, alloc)) orelse unreachable;
}
const TokenResult = struct {
info: models.Token,
value: [token_len]u8,
};
fn createToken(self: *ApiServer, user: models.LocalUser) !TokenResult {
var token: [token_len]u8 = undefined;
std.crypto.random.bytes(&token);
var hash: [models.Token.hash_len]u8 = undefined;
models.Token.HashFn.hash(&token, &hash, .{});
const db_token = models.Token{
.id = Uuid.randV4(prng.random()),
.hash = .{ .data = hash },
.user_id = user.id,
.issued_at = DateTime.now(),
};
try self.db.insert(models.Token, db_token);
return TokenResult{
.info = db_token,
.value = token,
};
}
pub fn getNote(self: *ApiServer, id: Uuid, alloc: std.mem.Allocator) !?models.Note {
return self.db.getBy(models.Note, .id, id, alloc);
}
pub fn getActor(self: *ApiServer, id: Uuid, alloc: std.mem.Allocator) !?models.Actor {
return self.db.getBy(models.Actor, .id, id, alloc);
}
pub fn getActorByHandle(self: *ApiServer, handle: []const u8, alloc: std.mem.Allocator) !?models.Actor {
return self.db.getBy(models.Actor, .handle, handle, alloc);
}
pub fn react(self: *ApiServer, note_id: Uuid, ctx: ApiContext) !void {
const id = Uuid.randV4(prng.random());
try self.db.insert(models.Reaction, .{ .id = id, .note_id = note_id, .reactor_id = ctx.user_context.user.actor_id.?, .created_at = DateTime.now() });
}
pub fn listReacts(self: *ApiServer, note_id: Uuid, ctx: ApiContext) ![]models.Reaction {
return try self.db.getWhereEq(models.Reaction, .note_id, note_id, ctx.alloc);
}
}; };
fn ApiConn(comptime DbConn: type) type {
return struct {
const Self = @This();
db: DbConn,
internal_alloc: std.mem.Allocator, // used *only* for large, internal buffers
as_user: ?Uuid,
arena: std.heap.ArenaAllocator,
pub fn close(self: *Self) void {
self.arena.deinit();
}
fn getAuthenticatedUser(self: *Self) !models.LocalUser {
if (self.as_user) |user_id| {
const local_user = try self.db.getBy(models.LocalUser, .id, user_id, self.arena.allocator());
if (local_user == null) return error.UserNotFound;
return local_user.?;
} else {
return error.NotAuthenticated;
}
}
fn getAuthenticatedActor(self: *Self) !models.Actor {
const user = try self.getAuthenticatedUser();
if (user.actor_id) |actor_id| {
const actor = try self.db.getBy(models.Actor, .id, actor_id, self.arena);
return actor.?;
} else {
return error.NoActor;
}
}
pub fn createNote(self: *Self, info: NoteCreateInfo) !models.Note {
const id = Uuid.randV4(prng.random());
const user = try self.getAuthenticatedUser();
const note = models.Note{
.id = id,
.author_id = user.actor_id orelse return error.NotAuthorized,
.content = info.content,
.created_at = DateTime.now(),
};
try self.db.insert(models.Note, note);
return note;
}
pub fn getNote(self: *Self, id: Uuid) !?models.Note {
return self.db.getBy(models.Note, .id, id, self.arena.allocator());
}
pub fn getActor(self: *Self, id: Uuid) !?models.Actor {
return self.db.getBy(models.Actor, .id, id, self.arena.allocator());
}
pub fn getActorByHandle(self: *Self, handle: []const u8) !?models.Actor {
return self.db.getBy(models.Actor, .handle, handle, self.arena.allocator());
}
pub fn react(self: *Self, note_id: Uuid) !void {
const id = Uuid.randV4(prng.random());
const user = try self.getAuthenticatedUser();
try self.db.insert(models.Reaction, .{ .id = id, .note_id = note_id, .reactor_id = user.actor_id orelse return error.NotAuthorized, .created_at = DateTime.now() });
}
pub fn listReacts(self: *Self, note_id: Uuid) ![]models.Reaction {
return try self.db.getWhereEq(models.Reaction, .note_id, note_id, self.arena.allocator());
}
pub fn register(self: *Self, info: RegistrationInfo) !models.Actor {
const actor_id = Uuid.randV4(prng.random());
const user_id = Uuid.randV4(prng.random());
// TODO: transaction?
if (try self.db.existsWhereEq(models.LocalUser, .username, info.username)) {
return error.UsernameUnavailable;
}
if (try self.db.existsWhereEq(models.Actor, .handle, info.username)) {
return error.InconsistentDb;
}
// use internal alloc because necessary buffer is *big*
var buf: [pw_hash_buf_size]u8 = undefined;
const hash = try PwHash.strHash(info.password, .{ .allocator = self.internal_alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, &buf);
const now = DateTime.now();
const actor = models.Actor{
.id = actor_id,
.handle = info.username,
.created_at = now,
};
const user = models.LocalUser{
.id = user_id,
.actor_id = actor_id,
.username = info.username,
.email = info.email,
.hashed_password = hash,
.password_changed_at = now,
.created_at = now,
};
try self.db.insert(models.Actor, actor);
try self.db.insert(models.LocalUser, user);
// TODO: return token instead
return actor;
}
pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResult {
// TODO: This gives away the existence of a user through a timing side channel. is that acceptable?
const user_info = (try self.db.getBy(models.LocalUser, .username, username, self.arena.allocator())) orelse return error.InvalidLogin;
//defer free(self.arena.allocator(), user_info);
const Hash = std.crypto.pwhash.scrypt;
Hash.strVerify(user_info.hashed_password, password, .{ .allocator = self.internal_alloc }) catch |err| switch (err) {
error.PasswordVerificationFailed => return error.InvalidLogin,
else => return err,
};
const token = try self.createToken(user_info);
var token_enc: [token_str_len]u8 = undefined;
_ = std.base64.standard.Encoder.encode(&token_enc, &token.value);
return LoginResult{
.user_id = user_info.id,
.token = token_enc,
.issued_at = token.info.issued_at,
};
}
const TokenResult = struct {
info: models.Token,
value: [token_len]u8,
};
fn createToken(self: *Self, user: models.LocalUser) !TokenResult {
var token: [token_len]u8 = undefined;
std.crypto.random.bytes(&token);
var hash: [models.Token.hash_len]u8 = undefined;
models.Token.HashFn.hash(&token, &hash, .{});
const db_token = models.Token{
.id = Uuid.randV4(prng.random()),
.hash = .{ .data = hash },
.user_id = user.id,
.issued_at = DateTime.now(),
};
try self.db.insert(models.Token, db_token);
return TokenResult{
.info = db_token,
.value = token,
};
}
};
}

View file

@ -57,13 +57,21 @@ pub const utils = struct {
std.json.parseFree(@TypeOf(value), value, .{ .allocator = alloc }); std.json.parseFree(@TypeOf(value), value, .{ .allocator = alloc });
} }
pub fn getApiContext(srv: *RequestServer, ctx: *http.server.Context) !api.ApiContext { pub fn getApiConn(srv: *RequestServer, ctx: *http.server.Context) !api.ApiSource.Conn {
const header = ctx.request.headers.get("authorization") orelse "(null)"; return authorizeApiConn(srv, ctx) catch |err| switch (err) {
error.NoToken => srv.api.connectUnauthorized(ctx.alloc),
error.InvalidToken => return error.InvalidToken,
else => @panic("TODO"), // doing this to resolve some sort of compiler analysis dependency issue
};
}
fn authorizeApiConn(srv: *RequestServer, ctx: *http.server.Context) !api.ApiSource.Conn {
const header = ctx.request.headers.get("authorization") orelse return error.NoToken;
if (header.len < ("bearer ").len) return error.InvalidToken;
const token = header[("bearer ").len..]; const token = header[("bearer ").len..];
return try srv.api.makeApiContext(token, ctx.alloc); return try srv.api.connectToken(token, ctx.alloc);
// TODO: defer api.free(ctx.alloc, user_ctx);
} }
}; };

View file

@ -1,6 +1,5 @@
const root = @import("root"); const root = @import("root");
const http = @import("http"); const http = @import("http");
const api = @import("../api.zig");
const Uuid = @import("util").Uuid; const Uuid = @import("util").Uuid;
const utils = @import("../controllers.zig").utils; const utils = @import("../controllers.zig").utils;
@ -11,8 +10,10 @@ const RouteArgs = http.RouteArgs;
pub fn get(srv: *RequestServer, ctx: *http.server.Context, args: RouteArgs) !void { pub fn get(srv: *RequestServer, ctx: *http.server.Context, args: RouteArgs) !void {
const id_str = args.get("id") orelse return error.NotFound; const id_str = args.get("id") orelse return error.NotFound;
const id = Uuid.parse(id_str) catch return utils.respondError(ctx, .bad_request, "Invalid UUID"); const id = Uuid.parse(id_str) catch return utils.respondError(ctx, .bad_request, "Invalid UUID");
const user = (try srv.api.getActor(id, ctx.alloc)) orelse return utils.respondError(ctx, .not_found, "Note not found"); var api = try utils.getApiConn(srv, ctx);
defer api.free(ctx.alloc, user); defer api.close();
const user = (try api.getActor(id)) orelse return utils.respondError(ctx, .not_found, "Note not found");
try utils.respondJson(ctx, .ok, user); try utils.respondJson(ctx, .ok, user);
} }

View file

@ -2,19 +2,22 @@ const std = @import("std");
const root = @import("root"); const root = @import("root");
const builtin = @import("builtin"); const builtin = @import("builtin");
const http = @import("http"); const http = @import("http");
const api = @import("../api.zig");
const Uuid = @import("util").Uuid; const Uuid = @import("util").Uuid;
const RegistrationInfo = @import("../api.zig").RegistrationInfo;
const utils = @import("../controllers.zig").utils; const utils = @import("../controllers.zig").utils;
const RequestServer = root.RequestServer; const RequestServer = root.RequestServer;
const RouteArgs = http.RouteArgs; const RouteArgs = http.RouteArgs;
pub fn register(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { pub fn register(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void {
const info = try utils.parseRequestBody(api.RegistrationInfo, ctx); const info = try utils.parseRequestBody(RegistrationInfo, ctx);
defer utils.freeRequestBody(info, ctx.alloc); defer utils.freeRequestBody(info, ctx.alloc);
const user = srv.api.register(info) catch |err| switch (err) { var api = try utils.getApiConn(srv, ctx);
defer api.close();
const user = api.register(info) catch |err| switch (err) {
error.UsernameUnavailable => return try utils.respondError(ctx, .bad_request, "Username Unavailable"), error.UsernameUnavailable => return try utils.respondError(ctx, .bad_request, "Username Unavailable"),
else => return err, else => return err,
}; };
@ -22,18 +25,14 @@ pub fn register(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !v
try utils.respondJson(ctx, .created, user); try utils.respondJson(ctx, .created, user);
} }
pub fn authenticate(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void {
const user_ctx = try utils.getApiContext(srv, ctx);
// TODO: defer api.free(ctx.alloc, user_ctx);
try utils.respondJson(ctx, .ok, user_ctx.user_context);
}
pub fn login(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { pub fn login(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void {
const credentials = try utils.parseRequestBody(struct { username: []const u8, password: []const u8 }, ctx); const credentials = try utils.parseRequestBody(struct { username: []const u8, password: []const u8 }, ctx);
defer utils.freeRequestBody(credentials, ctx.alloc); defer utils.freeRequestBody(credentials, ctx.alloc);
const token = srv.api.login(credentials.username, credentials.password, ctx.alloc) catch |err| switch (err) { var api = try utils.getApiConn(srv, ctx);
defer api.close();
const token = api.login(credentials.username, credentials.password) catch |err| switch (err) {
error.PasswordVerificationFailed => return utils.respondError(ctx, .bad_request, "Invalid Login"), error.PasswordVerificationFailed => return utils.respondError(ctx, .bad_request, "Invalid Login"),
else => return err, else => return err,
}; };

View file

@ -1,9 +1,9 @@
const root = @import("root"); const root = @import("root");
const http = @import("http"); const http = @import("http");
const api = @import("../api.zig");
const Uuid = @import("util").Uuid; const Uuid = @import("util").Uuid;
const utils = @import("../controllers.zig").utils; const utils = @import("../controllers.zig").utils;
const NoteCreateInfo = @import("../api.zig").NoteCreateInfo;
const RequestServer = root.RequestServer; const RequestServer = root.RequestServer;
const RouteArgs = http.RouteArgs; const RouteArgs = http.RouteArgs;
@ -11,12 +11,13 @@ const RouteArgs = http.RouteArgs;
pub const reacts = @import("./notes/reacts.zig"); pub const reacts = @import("./notes/reacts.zig");
pub fn create(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { pub fn create(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void {
const user_context = try utils.getApiContext(srv, ctx); const info = try utils.parseRequestBody(NoteCreateInfo, ctx);
// TODO: defer free ApiContext
const info = try utils.parseRequestBody(api.NoteCreate, ctx);
defer utils.freeRequestBody(info, ctx.alloc); defer utils.freeRequestBody(info, ctx.alloc);
const note = try srv.api.createNoteUser(info, user_context); var api = try utils.getApiConn(srv, ctx);
defer api.close();
const note = try api.createNote(info);
try utils.respondJson(ctx, .created, note); try utils.respondJson(ctx, .created, note);
} }
@ -24,8 +25,10 @@ pub fn create(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !voi
pub fn get(srv: *RequestServer, ctx: *http.server.Context, args: RouteArgs) !void { pub fn get(srv: *RequestServer, ctx: *http.server.Context, args: RouteArgs) !void {
const id_str = args.get("id") orelse return error.NotFound; const id_str = args.get("id") orelse return error.NotFound;
const id = Uuid.parse(id_str) catch return utils.respondError(ctx, .bad_request, "Invalid UUID"); const id = Uuid.parse(id_str) catch return utils.respondError(ctx, .bad_request, "Invalid UUID");
const note = (try srv.api.getNote(id, ctx.alloc)) orelse return utils.respondError(ctx, .not_found, "Note not found"); var api = try utils.getApiConn(srv, ctx);
defer api.free(ctx.alloc, note); defer api.close();
const note = (try api.getNote(id)) orelse return utils.respondError(ctx, .not_found, "Note not found");
try utils.respondJson(ctx, .ok, note); try utils.respondJson(ctx, .ok, note);
} }

View file

@ -1,6 +1,5 @@
const root = @import("root"); const root = @import("root");
const http = @import("http"); const http = @import("http");
const api = @import("../../api.zig");
const Uuid = @import("util").Uuid; const Uuid = @import("util").Uuid;
const utils = @import("../../controllers.zig").utils; const utils = @import("../../controllers.zig").utils;
@ -9,23 +8,23 @@ const RequestServer = root.RequestServer;
const RouteArgs = http.RouteArgs; const RouteArgs = http.RouteArgs;
pub fn create(srv: *RequestServer, ctx: *http.server.Context, args: RouteArgs) !void { pub fn create(srv: *RequestServer, ctx: *http.server.Context, args: RouteArgs) !void {
const user_context = try utils.getApiContext(srv, ctx); var api = try utils.getApiConn(srv, ctx);
// TODO: defer free ApiContext defer api.close();
const note_id = args.get("id") orelse return error.NotFound; const note_id = args.get("id") orelse return error.NotFound;
const id = Uuid.parse(note_id) catch return utils.respondError(ctx, .bad_request, "Invalid UUID"); const id = Uuid.parse(note_id) catch return utils.respondError(ctx, .bad_request, "Invalid UUID");
try srv.api.react(id, user_context); try api.react(id);
try utils.respondJson(ctx, .created, .{}); try utils.respondJson(ctx, .created, .{});
} }
pub fn list(srv: *RequestServer, ctx: *http.server.Context, args: RouteArgs) !void { pub fn list(srv: *RequestServer, ctx: *http.server.Context, args: RouteArgs) !void {
const user_context = try utils.getApiContext(srv, ctx); var api = try utils.getApiConn(srv, ctx);
// TODO: defer free ApiContext defer api.close();
const note_id = args.get("id") orelse return error.NotFound; const note_id = args.get("id") orelse return error.NotFound;
const id = Uuid.parse(note_id) catch return utils.respondError(ctx, .bad_request, "Invalid UUID"); const id = Uuid.parse(note_id) catch return utils.respondError(ctx, .bad_request, "Invalid UUID");
const reacts = try srv.api.listReacts(id, user_context); const reacts = try api.listReacts(id);
try utils.respondJson(ctx, .ok, .{ .items = reacts }); try utils.respondJson(ctx, .ok, .{ .items = reacts });
} }

View file

@ -17,7 +17,6 @@ const router = Router{
Route.new(.POST, "/auth/register", c.auth.register), Route.new(.POST, "/auth/register", c.auth.register),
Route.new(.POST, "/auth/login", c.auth.login), Route.new(.POST, "/auth/login", c.auth.login),
Route.new(.GET, "/auth/authenticate", c.auth.authenticate),
Route.new(.POST, "/notes", c.notes.create), Route.new(.POST, "/notes", c.notes.create),
Route.new(.GET, "/notes/:id", c.notes.get), Route.new(.GET, "/notes/:id", c.notes.get),
@ -31,12 +30,12 @@ const router = Router{
pub const RequestServer = struct { pub const RequestServer = struct {
alloc: std.mem.Allocator, alloc: std.mem.Allocator,
api: api.ApiServer, api: api.ApiSource,
fn init(alloc: std.mem.Allocator) !RequestServer { fn init(alloc: std.mem.Allocator) !RequestServer {
return RequestServer{ return RequestServer{
.alloc = alloc, .alloc = alloc,
.api = try api.ApiServer.init(alloc), .api = try api.ApiSource.init(alloc),
}; };
} }