Customize allocation strategy per endpoint

This commit is contained in:
jaina heartles 2022-11-20 15:42:34 -08:00
parent 252c12403a
commit e1c0d2942c
2 changed files with 52 additions and 41 deletions

View file

@ -186,34 +186,28 @@ pub const ApiSource = struct {
} }
pub fn connectUnauthorized(self: *ApiSource, host: []const u8, alloc: std.mem.Allocator) !Conn { pub fn connectUnauthorized(self: *ApiSource, host: []const u8, alloc: std.mem.Allocator) !Conn {
var arena = std.heap.ArenaAllocator.init(alloc);
errdefer arena.deinit();
const db = try self.db_conn_pool.acquire(); const db = try self.db_conn_pool.acquire();
errdefer db.releaseConnection(); errdefer db.releaseConnection();
const community = try services.communities.getByHost(db, host, arena.allocator()); const community = try services.communities.getByHost(db, host, alloc);
return Conn{ return Conn{
.db = db, .db = db,
.user_id = null, .user_id = null,
.community = community, .community = community,
.arena = arena, .allocator = alloc,
}; };
} }
pub fn connectToken(self: *ApiSource, host: []const u8, token: []const u8, alloc: std.mem.Allocator) !Conn { pub fn connectToken(self: *ApiSource, host: []const u8, token: []const u8, alloc: std.mem.Allocator) !Conn {
var arena = std.heap.ArenaAllocator.init(alloc);
errdefer arena.deinit();
const db = try self.db_conn_pool.acquire(); const db = try self.db_conn_pool.acquire();
errdefer db.releaseConnection(); errdefer db.releaseConnection();
const community = try services.communities.getByHost(db, host, arena.allocator()); const community = try services.communities.getByHost(db, host, alloc);
const token_info = try services.auth.verifyToken( const token_info = try services.auth.verifyToken(
db, db,
token, token,
community.id, community.id,
arena.allocator(), alloc,
); );
return Conn{ return Conn{
@ -221,7 +215,7 @@ pub const ApiSource = struct {
.token_info = token_info, .token_info = token_info,
.user_id = token_info.user_id, .user_id = token_info.user_id,
.community = community, .community = community,
.arena = arena, .allocator = alloc,
}; };
} }
}; };
@ -234,10 +228,10 @@ fn ApiConn(comptime DbConn: type) type {
token_info: ?services.auth.TokenInfo = null, token_info: ?services.auth.TokenInfo = null,
user_id: ?Uuid = null, user_id: ?Uuid = null,
community: services.communities.Community, community: services.communities.Community,
arena: std.heap.ArenaAllocator, allocator: std.mem.Allocator,
pub fn close(self: *Self) void { pub fn close(self: *Self) void {
self.arena.deinit(); util.deepFree(self.allocator, self.community);
self.db.releaseConnection(); self.db.releaseConnection();
} }
@ -252,7 +246,7 @@ fn ApiConn(comptime DbConn: type) type {
username, username,
self.community.id, self.community.id,
password, password,
self.arena.allocator(), self.allocator,
); );
} }
@ -266,11 +260,12 @@ fn ApiConn(comptime DbConn: type) type {
}; };
pub fn verifyAuthorization(self: *Self) !AuthorizationInfo { pub fn verifyAuthorization(self: *Self) !AuthorizationInfo {
if (self.token_info) |info| { if (self.token_info) |info| {
const user = try services.actors.get(self.db, info.user_id, self.arena.allocator()); const user = try services.actors.get(self.db, info.user_id, self.allocator);
defer util.deepFree(self.allocator, user);
return AuthorizationInfo{ return AuthorizationInfo{
.id = user.id, .id = user.id,
.username = user.username, .username = try util.deepClone(self.allocator, user.username),
.community_id = self.community.id, .community_id = self.community.id,
.host = self.community.host, .host = self.community.host,
@ -292,13 +287,13 @@ fn ApiConn(comptime DbConn: type) type {
tx, tx,
origin, origin,
.{}, .{},
self.arena.allocator(), self.allocator,
); );
const community = services.communities.get( const community = services.communities.get(
tx, tx,
community_id, community_id,
self.arena.allocator(), self.allocator,
) catch |err| return switch (err) { ) catch |err| return switch (err) {
error.NotFound => error.DatabaseError, error.NotFound => error.DatabaseError,
else => |err2| err2, else => |err2| err2,
@ -328,17 +323,18 @@ fn ApiConn(comptime DbConn: type) type {
.lifespan = options.lifespan, .lifespan = options.lifespan,
.max_uses = options.max_uses, .max_uses = options.max_uses,
.kind = options.kind, .kind = options.kind,
}, self.arena.allocator()); }, self.allocator);
return try services.invites.get(self.db, invite_id, self.arena.allocator()); return try services.invites.get(self.db, invite_id, self.allocator);
} }
pub fn register(self: *Self, username: []const u8, password: []const u8, opt: RegistrationOptions) !UserResponse { pub fn register(self: *Self, username: []const u8, password: []const u8, opt: RegistrationOptions) !UserResponse {
std.log.debug("registering user {s} with code {?s}", .{ username, opt.invite_code }); const tx = try self.db.beginOrSavepoint();
const maybe_invite = if (opt.invite_code) |code| const maybe_invite = if (opt.invite_code) |code|
try services.invites.getByCode(self.db, code, self.community.id, self.arena.allocator()) try services.invites.getByCode(tx, code, self.community.id, self.allocator)
else else
null; null;
defer if (maybe_invite) |inv| util.deepFree(self.allocator, inv);
if (maybe_invite) |invite| { if (maybe_invite) |invite| {
if (!Uuid.eql(invite.community_id, self.community.id)) return error.WrongCommunity; if (!Uuid.eql(invite.community_id, self.community.id)) return error.WrongCommunity;
@ -351,7 +347,7 @@ fn ApiConn(comptime DbConn: type) type {
if (self.community.kind == .admin) @panic("Unimplmented"); if (self.community.kind == .admin) @panic("Unimplmented");
const user_id = try services.auth.register( const user_id = try services.auth.register(
self.db, tx,
username, username,
password, password,
self.community.id, self.community.id,
@ -359,14 +355,14 @@ fn ApiConn(comptime DbConn: type) type {
.invite_id = if (maybe_invite) |inv| @as(?Uuid, inv.id) else null, .invite_id = if (maybe_invite) |inv| @as(?Uuid, inv.id) else null,
.email = opt.email, .email = opt.email,
}, },
self.arena.allocator(), self.allocator,
); );
switch (invite_kind) { switch (invite_kind) {
.user => {}, .user => {},
.system => @panic("System user invites unimplemented"), .system => @panic("System user invites unimplemented"),
.community_owner => { .community_owner => {
try services.communities.transferOwnership(self.db, self.community.id, user_id); try services.communities.transferOwnership(tx, self.community.id, user_id);
}, },
} }
@ -377,7 +373,7 @@ fn ApiConn(comptime DbConn: type) type {
} }
pub fn getUser(self: *Self, user_id: Uuid) !UserResponse { pub fn getUser(self: *Self, user_id: Uuid) !UserResponse {
const user = try services.actors.get(self.db, user_id, self.arena.allocator()); const user = try services.actors.get(self.db, user_id, self.allocator);
if (self.user_id == null) { if (self.user_id == null) {
if (!Uuid.eql(self.community.id, user.community_id)) return error.NotFound; if (!Uuid.eql(self.community.id, user.community_id)) return error.NotFound;
@ -397,7 +393,7 @@ fn ApiConn(comptime DbConn: type) type {
// Only authenticated users can post // Only authenticated users can post
const user_id = self.user_id orelse return error.TokenRequired; const user_id = self.user_id orelse return error.TokenRequired;
const note_id = try services.notes.create(self.db, user_id, content, self.arena.allocator()); const note_id = try services.notes.create(self.db, user_id, content, self.allocator);
return self.getNote(note_id) catch |err| switch (err) { return self.getNote(note_id) catch |err| switch (err) {
error.NotFound => error.Unexpected, error.NotFound => error.Unexpected,
@ -406,8 +402,8 @@ fn ApiConn(comptime DbConn: type) type {
} }
pub fn getNote(self: *Self, note_id: Uuid) !NoteResponse { pub fn getNote(self: *Self, note_id: Uuid) !NoteResponse {
const note = try services.notes.get(self.db, note_id, self.arena.allocator()); const note = try services.notes.get(self.db, note_id, self.allocator);
const user = try services.actors.get(self.db, note.author_id, self.arena.allocator()); const user = try services.actors.get(self.db, note.author_id, self.allocator);
// Only serve community-specific notes on unauthenticated requests // Only serve community-specific notes on unauthenticated requests
if (self.user_id == null) { if (self.user_id == null) {
@ -428,12 +424,12 @@ fn ApiConn(comptime DbConn: type) type {
pub fn queryCommunities(self: *Self, args: services.communities.QueryArgs) !CommunityQueryResult { pub fn queryCommunities(self: *Self, args: services.communities.QueryArgs) !CommunityQueryResult {
if (!self.isAdmin()) return error.PermissionDenied; if (!self.isAdmin()) return error.PermissionDenied;
return try services.communities.query(self.db, args, self.arena.allocator()); return try services.communities.query(self.db, args, self.allocator);
} }
pub fn globalTimeline(self: *Self, args: TimelineArgs) !TimelineResult { pub fn globalTimeline(self: *Self, args: TimelineArgs) !TimelineResult {
const all_args = std.mem.zeroInit(NoteQueryArgs, args); const all_args = std.mem.zeroInit(NoteQueryArgs, args);
const result = try services.notes.query(self.db, all_args, self.arena.allocator()); const result = try services.notes.query(self.db, all_args, self.allocator);
return TimelineResult{ return TimelineResult{
.items = result.items, .items = result.items,
.prev_page = TimelineArgs.from(result.prev_page), .prev_page = TimelineArgs.from(result.prev_page),
@ -444,7 +440,7 @@ fn ApiConn(comptime DbConn: type) type {
pub fn localTimeline(self: *Self, args: TimelineArgs) !TimelineResult { pub fn localTimeline(self: *Self, args: TimelineArgs) !TimelineResult {
var all_args = std.mem.zeroInit(NoteQueryArgs, args); var all_args = std.mem.zeroInit(NoteQueryArgs, args);
all_args.community_id = self.community.id; all_args.community_id = self.community.id;
const result = try services.notes.query(self.db, all_args, self.arena.allocator()); const result = try services.notes.query(self.db, all_args, self.allocator);
return TimelineResult{ return TimelineResult{
.items = result.items, .items = result.items,
.prev_page = TimelineArgs.from(result.prev_page), .prev_page = TimelineArgs.from(result.prev_page),
@ -457,7 +453,7 @@ fn ApiConn(comptime DbConn: type) type {
var all_args = std.mem.zeroInit(services.notes.QueryArgs, args); var all_args = std.mem.zeroInit(services.notes.QueryArgs, args);
all_args.followed_by = self.user_id; all_args.followed_by = self.user_id;
const result = try services.notes.query(self.db, all_args, self.arena.allocator()); const result = try services.notes.query(self.db, all_args, self.allocator);
return TimelineResult{ return TimelineResult{
.items = result.items, .items = result.items,
.prev_page = TimelineArgs.from(result.prev_page), .prev_page = TimelineArgs.from(result.prev_page),
@ -468,7 +464,7 @@ fn ApiConn(comptime DbConn: type) type {
pub fn queryFollowers(self: *Self, user_id: Uuid, args: FollowerQueryArgs) !FollowerQueryResult { pub fn queryFollowers(self: *Self, user_id: Uuid, args: FollowerQueryArgs) !FollowerQueryResult {
var all_args = std.mem.zeroInit(services.follows.QueryArgs, args); var all_args = std.mem.zeroInit(services.follows.QueryArgs, args);
all_args.followee_id = user_id; all_args.followee_id = user_id;
const result = try services.follows.query(self.db, all_args, self.arena.allocator()); const result = try services.follows.query(self.db, all_args, self.allocator);
return FollowerQueryResult{ return FollowerQueryResult{
.items = result.items, .items = result.items,
.prev_page = FollowQueryArgs.from(result.prev_page), .prev_page = FollowQueryArgs.from(result.prev_page),
@ -479,7 +475,7 @@ fn ApiConn(comptime DbConn: type) type {
pub fn queryFollowing(self: *Self, user_id: Uuid, args: FollowingQueryArgs) !FollowingQueryResult { pub fn queryFollowing(self: *Self, user_id: Uuid, args: FollowingQueryArgs) !FollowingQueryResult {
var all_args = std.mem.zeroInit(services.follows.QueryArgs, args); var all_args = std.mem.zeroInit(services.follows.QueryArgs, args);
all_args.followed_by_id = user_id; all_args.followed_by_id = user_id;
const result = try services.follows.query(self.db, all_args, self.arena.allocator()); const result = try services.follows.query(self.db, all_args, self.allocator);
return FollowingQueryResult{ return FollowingQueryResult{
.items = result.items, .items = result.items,
.prev_page = FollowQueryArgs.from(result.prev_page), .prev_page = FollowQueryArgs.from(result.prev_page),
@ -488,13 +484,13 @@ fn ApiConn(comptime DbConn: type) type {
} }
pub fn follow(self: *Self, followee: Uuid) !void { pub fn follow(self: *Self, followee: Uuid) !void {
const result = try services.follows.create(self.db, self.user_id orelse return error.NoToken, followee, self.arena.allocator()); const result = try services.follows.create(self.db, self.user_id orelse return error.NoToken, followee, self.allocator);
defer util.deepFree(self.arena.allocator(), result); defer util.deepFree(self.allocator, result);
} }
pub fn unfollow(self: *Self, followee: Uuid) !void { pub fn unfollow(self: *Self, followee: Uuid) !void {
const result = try services.follows.delete(self.db, self.user_id orelse return error.NoToken, followee, self.arena.allocator()); const result = try services.follows.delete(self.db, self.user_id orelse return error.NoToken, followee, self.allocator);
defer util.deepFree(self.arena.allocator(), result); defer util.deepFree(self.allocator, result);
} }
pub fn getClusterMeta(self: *Self) !ClusterMeta { pub fn getClusterMeta(self: *Self) !ClusterMeta {
@ -510,7 +506,7 @@ fn ApiConn(comptime DbConn: type) type {
\\ community.kind != 'admin' \\ community.kind != 'admin'
, ,
.{}, .{},
self.arena.allocator(), self.allocator,
); );
} }
}; };

View file

@ -120,6 +120,11 @@ fn matchContentType(hdr: ?[]const u8) ?BaseContentType {
return null; return null;
} }
pub const AllocationStrategy = enum {
arena,
normal,
};
pub fn Context(comptime Route: type) type { pub fn Context(comptime Route: type) type {
return struct { return struct {
const Self = @This(); const Self = @This();
@ -134,6 +139,11 @@ pub fn Context(comptime Route: type) type {
// leave it as a simple string instead of void // leave it as a simple string instead of void
pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void; pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void;
const allocation_strategy: AllocationStrategy = if (@hasDecl(Route, "allocation_strategy"))
Route.AllocationStrategy
else
.arena;
base_request: *http.Request, base_request: *http.Request,
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
@ -168,11 +178,16 @@ pub fn Context(comptime Route: type) type {
api_source: *api.ApiSource, api_source: *api.ApiSource,
req: *http.Request, req: *http.Request,
res: *Response, res: *Response,
alloc: std.mem.Allocator, base_allocator: std.mem.Allocator,
args: Args, args: Args,
) !void { ) !void {
const base_content_type = matchContentType(req.headers.get("Content-Type")); const base_content_type = matchContentType(req.headers.get("Content-Type"));
var arena = if (allocation_strategy == .arena)
std.heap.ArenaAllocator.init(base_allocator)
else {};
const alloc = if (allocation_strategy == .arena) arena.allocator() else base_allocator;
const body = if (Body != void) blk: { const body = if (Body != void) blk: {
var stream = req.body orelse return error.NoBody; var stream = req.body orelse return error.NoBody;
break :blk try parseBody(Body, base_content_type orelse .json, stream.reader(), alloc); break :blk try parseBody(Body, base_content_type orelse .json, stream.reader(), alloc);