diff --git a/src/api/lib.zig b/src/api/lib.zig index 98f0af2..ab19097 100644 --- a/src/api/lib.zig +++ b/src/api/lib.zig @@ -11,6 +11,7 @@ const services = struct { const auth = @import("./services/auth.zig"); const invites = @import("./services/invites.zig"); const notes = @import("./services/notes.zig"); + const follows = @import("./services/follows.zig"); }; pub const RegistrationOptions = struct { @@ -90,6 +91,45 @@ pub const TimelineResult = struct { next_page: TimelineArgs, }; +const FollowQueryArgs = struct { + pub const OrderBy = services.follows.QueryArgs.OrderBy; + pub const Direction = services.follows.QueryArgs.Direction; + pub const PageDirection = services.follows.QueryArgs.PageDirection; + pub const Prev = services.follows.QueryArgs.Prev; + + max_items: usize = 20, + + order_by: OrderBy = .created_at, + + direction: Direction = .descending, + + prev: ?Prev = null, + + page_direction: PageDirection = .forward, + + fn from(args: services.follows.QueryArgs) FollowQueryArgs { + return .{ + .max_items = args.max_items, + .order_by = args.order_by, + .direction = args.direction, + .prev = args.prev, + .page_direction = args.page_direction, + }; + } +}; + +const FollowQueryResult = struct { + items: []services.follows.Follow, + + prev_page: FollowQueryArgs, + next_page: FollowQueryArgs, +}; + +pub const FollowerQueryArgs = FollowQueryArgs; +pub const FollowerQueryResult = FollowQueryResult; +pub const FollowingQueryArgs = FollowQueryArgs; +pub const FollowingQueryResult = FollowQueryResult; + pub fn isAdminSetup(db: sql.Db) !bool { _ = services.communities.adminCommunityId(db) catch |err| switch (err) { error.NotFound => return false, @@ -405,5 +445,45 @@ fn ApiConn(comptime DbConn: type) type { .next_page = TimelineArgs.from(result.next_page), }; } + + pub fn homeTimeline(self: *Self, args: TimelineArgs) !TimelineResult { + if (self.user_id == null) return error.NoToken; + + var all_args = std.mem.zeroInit(services.notes.QueryArgs, args); + all_args.followed_by = self.user_id; + const result = try services.notes.query(self.db, all_args, self.arena.allocator()); + return TimelineResult{ + .items = result.items, + .prev_page = TimelineArgs.from(result.prev_page), + .next_page = TimelineArgs.from(result.next_page), + }; + } + + pub fn queryFollowers(self: *Self, user_id: Uuid, args: FollowerQueryArgs) !FollowerQueryResult { + var all_args = std.mem.zeroInit(services.follows.QueryArgs, args); + all_args.followee_id = user_id; + const result = try services.follows.query(self.db, all_args, self.arena.allocator()); + return FollowerQueryResult{ + .items = result.items, + .prev_page = FollowQueryArgs.from(result.prev_page), + .next_page = FollowQueryArgs.from(result.next_page), + }; + } + + pub fn queryFollowing(self: *Self, user_id: Uuid, args: FollowingQueryArgs) !FollowingQueryResult { + var all_args = std.mem.zeroInit(services.follows.QueryArgs, args); + all_args.followed_by_id = user_id; + const result = try services.follows.query(self.db, all_args, self.arena.allocator()); + return FollowingQueryResult{ + .items = result.items, + .prev_page = FollowQueryArgs.from(result.prev_page), + .next_page = FollowQueryArgs.from(result.next_page), + }; + } + + pub fn follow(self: *Self, followee: Uuid) !void { + const result = try services.follows.create(self.db, self.user_id orelse return error.NoToken, followee, self.arena.allocator()); + defer util.deepFree(self.arena.allocator(), result); + } }; } diff --git a/src/api/services/follows.zig b/src/api/services/follows.zig new file mode 100644 index 0000000..dc4f5c7 --- /dev/null +++ b/src/api/services/follows.zig @@ -0,0 +1,142 @@ +const std = @import("std"); +const util = @import("util"); +const sql = @import("sql"); + +const common = @import("./common.zig"); + +const Uuid = util.Uuid; +const DateTime = util.DateTime; + +pub const Follow = struct { + id: Uuid, + + followed_by_id: Uuid, + followee_id: Uuid, + + created_at: DateTime, +}; + +pub fn create(db: anytype, followed_by_id: Uuid, followee_id: Uuid, alloc: std.mem.Allocator) !void { + if (Uuid.eql(followed_by_id, followee_id)) return error.SelfFollow; + const now = DateTime.now(); + const id = Uuid.randV4(util.getThreadPrng()); + + db.insert("follow", .{ + .id = id, + .followed_by_id = followed_by_id, + .followee_id = followee_id, + .created_at = now, + }, alloc) catch |err| return switch (err) { + error.ForeignKeyViolation => error.NotFound, + error.UniqueViolation => error.NotUnique, + else => error.DatabaseFailure, + }; +} + +const max_max_items = 100; + +pub const QueryArgs = struct { + pub const Direction = common.Direction; + pub const PageDirection = common.PageDirection; + pub const Prev = std.meta.Child(std.meta.fieldInfo(@This(), .prev).field_type); + + pub const OrderBy = enum { + created_at, + }; + + max_items: usize = 20, + + followed_by_id: ?Uuid = null, + followee_id: ?Uuid = null, + + order_by: OrderBy = .created_at, + + direction: Direction = .descending, + + prev: ?struct { + id: Uuid, + order_val: union(OrderBy) { + created_at: DateTime, + }, + } = null, + + page_direction: PageDirection = .forward, +}; + +pub const QueryResult = struct { + items: []Follow, + + prev_page: QueryArgs, + next_page: QueryArgs, +}; + +pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResult { + var builder = sql.QueryBuilder.init(alloc); + defer builder.deinit(); + + try builder.appendSlice( + \\SELECT follow.id, follow.followed_by_id, follow.followee_id, follow.created_at + \\FROM follow + \\ + ); + + if (args.followed_by_id != null) try builder.andWhere("follow.followed_by_id = $1"); + if (args.followee_id != null) try builder.andWhere("follow.followee_id = $2"); + + if (args.prev != null) { + try builder.andWhere("(follow.id, follow.created_at)"); + switch (args.page_direction) { + .forward => try builder.appendSlice(" < "), + .backward => try builder.appendSlice(" > "), + } + try builder.appendSlice("($3, $4)"); + } + + try builder.appendSlice( + \\ + \\ORDER BY follow.created_at DESC + \\LIMIT $5 + \\ + ); + + const max_items = if (args.max_items > max_max_items) max_max_items else args.max_items; + const query_args = .{ + args.followed_by_id, + args.followee_id, + if (args.prev) |p| p.id else null, + if (args.prev) |p| p.order_val else null, + max_items, + }; + + const results = try db.queryRowsWithOptions( + Follow, + try builder.terminate(), + query_args, + max_items, + .{ .allocator = alloc, .ignore_unused_arguments = true }, + ); + errdefer util.deepFree(alloc, results); + + var next_page = args; + var prev_page = args; + prev_page.page_direction = .backward; + next_page.page_direction = .forward; + if (results.len != 0) { + prev_page.prev = .{ + .id = results[0].id, + .order_val = .{ .created_at = results[0].created_at }, + }; + + next_page.prev = .{ + .id = results[results.len - 1].id, + .order_val = .{ .created_at = results[results.len - 1].created_at }, + }; + } + // TODO: this will give incorrect links on an empty page + + return QueryResult{ + .items = results, + .next_page = next_page, + .prev_page = prev_page, + }; +} diff --git a/src/api/services/notes.zig b/src/api/services/notes.zig index 2a2d7f0..086a85a 100644 --- a/src/api/services/notes.zig +++ b/src/api/services/notes.zig @@ -70,6 +70,7 @@ pub const QueryArgs = struct { created_before: ?DateTime = null, created_after: ?DateTime = null, community_id: ?Uuid = null, + followed_by: ?Uuid = null, prev: ?struct { id: Uuid, @@ -95,6 +96,12 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul \\ ); + if (args.followed_by != null) try builder.appendSlice( + \\ JOIN follow ON + \\ follow.followed_by_id = $7 AND follow.followee_id = note.author_id + \\ + ); + if (args.created_before != null) try builder.andWhere("note.created_at < $1"); if (args.created_after != null) try builder.andWhere("note.created_at > $2"); if (args.prev != null) { @@ -128,6 +135,7 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul prev_id, args.community_id, max_items, + args.followed_by, }; }; @@ -138,7 +146,7 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul max_items, .{ .allocator = alloc, .ignore_unused_arguments = true }, ); - errdefer util.deepFree(results); + errdefer util.deepFree(alloc, results); var next_page = args; var prev_page = args; diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 902711e..4470d75 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -11,6 +11,7 @@ pub const auth = @import("./controllers/auth.zig"); pub const communities = @import("./controllers/communities.zig"); pub const invites = @import("./controllers/invites.zig"); pub const users = @import("./controllers/users.zig"); +pub const follows = @import("./controllers/users/follows.zig"); pub const notes = @import("./controllers/notes.zig"); pub const streaming = @import("./controllers/streaming.zig"); pub const timelines = @import("./controllers/timelines.zig"); @@ -45,6 +46,10 @@ const routes = .{ streaming.streaming, timelines.global, timelines.local, + timelines.home, + follows.create, + follows.query_followers, + follows.query_following, }; pub fn Context(comptime Route: type) type { @@ -90,6 +95,8 @@ pub fn Context(comptime Route: type) type { } } + if (path_iter.next() != null) return null; + return args; } @@ -263,3 +270,42 @@ const json_options = if (builtin.mode == .Debug) }, .string = .{ .String = .{} }, }; + +pub const helpers = struct { + pub fn paginate(community: api.Community, path: []const u8, results: anytype, res: *Response, alloc: std.mem.Allocator) !void { + var link = std.ArrayList(u8).init(alloc); + const link_writer = link.writer(); + defer link.deinit(); + + try writeLink(link_writer, community, path, results.next_page, "next"); + try link_writer.writeByte(','); + try writeLink(link_writer, community, path, results.prev_page, "prev"); + + try res.headers.put("Link", link.items); + + try res.json(.ok, results.items); + } + + fn writeLink( + writer: anytype, + community: api.Community, + path: []const u8, + params: anytype, + rel: []const u8, + ) !void { + // TODO: percent-encode + try std.fmt.format( + writer, + "<{s}://{s}/{s}?", + .{ @tagName(community.scheme), community.host, path }, + ); + + try query_utils.formatQuery(params, writer); + + try std.fmt.format( + writer, + ">; rel=\"{s}\"", + .{rel}, + ); + } +}; diff --git a/src/main/controllers/communities.zig b/src/main/controllers/communities.zig index 3aeec1f..89ed8a5 100644 --- a/src/main/controllers/communities.zig +++ b/src/main/controllers/communities.zig @@ -2,6 +2,7 @@ const std = @import("std"); const api = @import("api"); const util = @import("util"); const query_utils = @import("../query.zig"); +const controller_utils = @import("../controllers.zig").helpers; const QueryArgs = api.CommunityQueryArgs; const Uuid = util.Uuid; @@ -31,39 +32,6 @@ pub const query = struct { pub fn handler(req: anytype, res: anytype, srv: anytype) !void { const results = try srv.queryCommunities(req.query); - var link = std.ArrayList(u8).init(req.allocator); - const link_writer = link.writer(); - defer link.deinit(); - - try writeLink(link_writer, srv.community, path, results.next_page, "next"); - try link_writer.writeByte(','); - try writeLink(link_writer, srv.community, path, results.prev_page, "prev"); - - try res.headers.put("Link", link.items); - - try res.json(.ok, results.items); + try controller_utils.paginate(srv.community, path, results, res, req.allocator); } }; - -fn writeLink( - writer: anytype, - community: api.Community, - path: []const u8, - params: anytype, - rel: []const u8, -) !void { - // TODO: percent-encode - try std.fmt.format( - writer, - "<{s}://{s}/{s}?", - .{ @tagName(community.scheme), community.host, path }, - ); - - try query_utils.formatQuery(params, writer); - - try std.fmt.format( - writer, - ">; rel=\"{s}\"", - .{rel}, - ); -} diff --git a/src/main/controllers/timelines.zig b/src/main/controllers/timelines.zig index 9a68da8..2d5ace6 100644 --- a/src/main/controllers/timelines.zig +++ b/src/main/controllers/timelines.zig @@ -1,6 +1,7 @@ const std = @import("std"); const api = @import("api"); const query_utils = @import("../query.zig"); +const controller_utils = @import("../controllers.zig").helpers; pub const global = struct { pub const method = .GET; @@ -10,18 +11,7 @@ pub const global = struct { pub fn handler(req: anytype, res: anytype, srv: anytype) !void { const results = try srv.globalTimeline(req.query); - - var link = std.ArrayList(u8).init(req.allocator); - const link_writer = link.writer(); - defer link.deinit(); - - try writeLink(link_writer, srv.community, path, results.next_page, "next"); - try link_writer.writeByte(','); - try writeLink(link_writer, srv.community, path, results.prev_page, "prev"); - - try res.headers.put("Link", link.items); - - try res.json(.ok, results.items); + try controller_utils.paginate(srv.community, path, results, res, req.allocator); } }; @@ -33,41 +23,18 @@ pub const local = struct { pub fn handler(req: anytype, res: anytype, srv: anytype) !void { const results = try srv.localTimeline(req.query); - - var link = std.ArrayList(u8).init(req.allocator); - const link_writer = link.writer(); - defer link.deinit(); - - try writeLink(link_writer, srv.community, path, results.next_page, "next"); - try link_writer.writeByte(','); - try writeLink(link_writer, srv.community, path, results.prev_page, "prev"); - - try res.headers.put("Link", link.items); - - try res.json(.ok, results.items); + try controller_utils.paginate(srv.community, path, results, res, req.allocator); } }; -// TOOD: unify with communities.zig -fn writeLink( - writer: anytype, - community: api.Community, - path: []const u8, - params: anytype, - rel: []const u8, -) !void { - // TODO: percent-encode - try std.fmt.format( - writer, - "<{s}://{s}/{s}?", - .{ @tagName(community.scheme), community.host, path }, - ); +pub const home = struct { + pub const method = .GET; + pub const path = "/timelines/home"; - try query_utils.formatQuery(params, writer); + pub const Query = api.TimelineArgs; - try std.fmt.format( - writer, - ">; rel=\"{s}\"", - .{rel}, - ); -} + pub fn handler(req: anytype, res: anytype, srv: anytype) !void { + const results = try srv.homeTimeline(req.query); + try controller_utils.paginate(srv.community, path, results, res, req.allocator); + } +}; diff --git a/src/main/controllers/users/follows.zig b/src/main/controllers/users/follows.zig new file mode 100644 index 0000000..dcb36b9 --- /dev/null +++ b/src/main/controllers/users/follows.zig @@ -0,0 +1,54 @@ +const api = @import("api"); +const util = @import("util"); +const controller_utils = @import("../../controllers.zig").helpers; + +const Uuid = util.Uuid; + +pub const create = struct { + pub const method = .POST; + pub const path = "/users/:id/follow"; + + pub const Args = struct { + id: Uuid, + }; + + pub fn handler(req: anytype, res: anytype, srv: anytype) !void { + try srv.follow(req.args.id); + + try res.json(.created, .{}); + } +}; + +pub const query_followers = struct { + pub const method = .GET; + pub const path = "/users/:id/followers"; + + pub const Args = struct { + id: Uuid, + }; + + pub const Query = api.FollowingQueryArgs; + + pub fn handler(req: anytype, res: anytype, srv: anytype) !void { + const results = try srv.queryFollowers(req.args.id, req.query); + + try controller_utils.paginate(srv.community, path, results, res, req.allocator); + } +}; + +pub const query_following = struct { + pub const method = .GET; + pub const path = "/users/:id/following"; + + pub const Args = struct { + id: Uuid, + }; + + pub const Query = api.FollowerQueryArgs; + + pub fn handler(req: anytype, res: anytype, srv: anytype) !void { + const results = try srv.queryFollowing(req.args.id, req.query); + + try controller_utils.paginate(srv.community, path, results, res, req.allocator); + } +}; diff --git a/src/main/migrations.zig b/src/main/migrations.zig index b0abc02..89a63d3 100644 --- a/src/main/migrations.zig +++ b/src/main/migrations.zig @@ -189,4 +189,20 @@ const migrations: []const Migration = &.{ \\DROP TABLE community; , }, + .{ + .name = "follows", + .up = + \\CREATE TABLE follow( + \\ id UUID NOT NULL PRIMARY KEY, + \\ + \\ followed_by_id UUID NOT NULL, + \\ followee_id UUID NOT NULL, + \\ + \\ created_at TIMESTAMPTZ NOT NULL, + \\ + \\ UNIQUE(followed_by_id, followee_id) + \\); + , + .down = "DROP TABLE follow", + }, }; diff --git a/src/main/query.zig b/src/main/query.zig index 857eff0..a51a7ab 100644 --- a/src/main/query.zig +++ b/src/main/query.zig @@ -127,7 +127,11 @@ fn parse(comptime T: type, comptime prefix: []const u8, comptime name: []const u if (try parse(F, prefix ++ "." ++ name, field.name, fields)) |v| { maybe_value = v; } else if (field.default_value) |default| { - maybe_value = @ptrCast(*const F, @alignCast(@alignOf(F), default)).*; + if (comptime @sizeOf(F) != 0) { + maybe_value = @ptrCast(*const F, @alignCast(@alignOf(F), default)).*; + } else { + maybe_value = std.mem.zeroes(F); + } } if (maybe_value) |v| { @@ -139,6 +143,7 @@ fn parse(comptime T: type, comptime prefix: []const u8, comptime name: []const u if (fields_specified == 0) { return null; } else if (fields_specified != info.fields.len) { + std.log.debug("{} {s} {s}", .{ T, prefix, name }); return error.PartiallySpecifiedStruct; } else { return result;