Provide page links in communities query

This commit is contained in:
jaina heartles 2022-10-11 19:19:34 -07:00
parent 37e0192c4f
commit 46e71d8b44
10 changed files with 180 additions and 50 deletions

View File

@ -53,6 +53,7 @@ pub const NoteResponse = struct {
created_at: DateTime,
};
pub const Community = services.communities.Community;
pub const CommunityQueryArgs = services.communities.QueryArgs;
pub const CommunityQueryResult = services.communities.QueryResult;

View File

@ -46,6 +46,7 @@ pub fn register(
tx.insert("password", .{
.account_id = id,
.hash = hash,
.changed_at = DateTime.now(),
}, alloc) catch return error.DatabaseFailure;
tx.commitOrRelease() catch return error.DatabaseFailure;
@ -126,6 +127,7 @@ pub fn login(
tx.insert("token", .{
.account_id = info.account_id,
.hash = token_hash,
.issued_at = DateTime.now(),
}, alloc) catch return error.DatabaseFailure;
tx.commit() catch return error.DatabaseFailure;

View File

@ -90,6 +90,7 @@ pub fn create(db: anytype, origin: []const u8, options: CreateOptions, alloc: st
.name = options.name orelse host,
.scheme = scheme,
.kind = options.kind,
.created_at = DateTime.now(),
}, alloc) catch return error.DatabaseFailure;
return id;
@ -284,13 +285,16 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul
_ = try builder.array.appendSlice("($5, $6)");
}
const direction_string = switch (args.direction) {
.ascending => " ASC ",
.descending => " DESC ",
};
_ = try builder.array.appendSlice("\nORDER BY ");
_ = try builder.array.appendSlice(@tagName(args.order_by));
_ = try builder.array.appendSlice(direction_string);
_ = try builder.array.appendSlice(", id ");
_ = try builder.array.appendSlice(switch (args.direction) {
.ascending => "ASC",
.descending => "DESC",
});
_ = try builder.array.appendSlice(direction_string);
_ = try builder.array.appendSlice("\nLIMIT $7");
@ -340,9 +344,8 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul
.id = result_buf[count - 1].id,
.order_val = getOrderVal(result_buf[count - 1], args.order_by),
};
} else {
prev_page.prev = null;
}
// TODO: This will give an incorrect previous page link on an empty page
return QueryResult{
.items = result_buf[0..count],

View File

@ -89,6 +89,7 @@ pub fn create(
.username = username,
.community_id = community_id,
.kind = kind,
.created_at = DateTime.now(),
}, alloc) catch |err| return switch (err) {
error.UniqueViolation => error.UsernameTaken,
else => error.DatabaseFailure,

View File

@ -110,17 +110,25 @@ pub fn Context(comptime Route: type) type {
return true;
}
fn errorHandler(response: *Response, status: http.Status) void {
response.status(status) catch unreachable;
fn errorHandler(response: *Response, status: http.Status, err: anytype) void {
std.log.err("Error occured on handler {s} {s}", .{ @tagName(Route.method), Route.path });
std.log.err("{}", .{err});
const result = if (builtin.mode == .Debug)
response.err(status, @errorName(err), {})
else
response.status(status);
_ = result catch |err2| {
std.log.err("Error printing response: {}", .{err2});
};
}
fn prepareAndHandle(self: *Self, api_source: anytype, req: http.Request, response: *Response) void {
self.parseBody(req) catch return errorHandler(response, .bad_request);
self.parseBody(req) catch |err| return errorHandler(response, .bad_request, err);
defer self.freeBody();
self.parseQuery() catch return errorHandler(response, .bad_request);
self.parseQuery() catch |err| return errorHandler(response, .bad_request, err);
var api_conn = self.getApiConn(api_source) catch return errorHandler(response, .internal_server_error); // TODO
var api_conn = self.getApiConn(api_source) catch |err| return errorHandler(response, .internal_server_error, err);
defer api_conn.close();
self.handle(response, &api_conn);

View File

@ -1,3 +1,4 @@
const std = @import("std");
const api = @import("api");
const util = @import("util");
@ -57,6 +58,8 @@ pub const query = struct {
// What direction to scan the page window
page_direction: PageDirection = .forward,
pub const format = formatQueryParams;
};
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
@ -95,6 +98,109 @@ pub const query = struct {
const results = try srv.queryCommunities(query_args);
try res.json(.ok, results);
var link = std.ArrayList(u8).init(req.allocator);
const link_writer = link.writer();
defer link.deinit();
const next_page = queryArgsToControllerQuery(results.next_page);
const prev_page = queryArgsToControllerQuery(results.prev_page);
try writeLink(link_writer, srv.community, path, next_page, "next");
try link_writer.writeByte(',');
try writeLink(link_writer, srv.community, path, prev_page, "prev");
try res.headers.put("Link", link.items);
try res.json(.ok, results.items);
}
};
fn writeLink(
writer: anytype,
community: api.Community,
path: []const u8,
params: anytype,
rel: []const u8,
) !void {
// TODO: percent-encode
try std.fmt.format(
writer,
"<{s}://{s}/{s}?",
.{ @tagName(community.scheme), community.host, path },
);
try std.fmt.format(writer, "{}", .{params});
try std.fmt.format(
writer,
">; rel=\"{s}\"",
.{rel},
);
}
fn formatQueryParams(
params: anytype,
comptime fmt: []const u8,
opt: std.fmt.FormatOptions,
writer: anytype,
) !void {
if (comptime std.meta.trait.is(.Pointer)(@TypeOf(params))) {
return formatQueryParams(params.*, fmt, opt, writer);
}
return formatRecursive("", params, writer);
}
fn formatRecursive(comptime prefix: []const u8, params: anytype, writer: anytype) !void {
inline for (std.meta.fields(@TypeOf(params))) |field| {
const val = @field(params, field.name);
const is_optional = comptime std.meta.trait.is(.Optional)(field.field_type);
const present = if (comptime is_optional) val != null else true;
if (present) {
const unwrapped = if (is_optional) val.? else val;
// TODO: percent-encode this
_ = try switch (@TypeOf(unwrapped)) {
[]const u8 => blk: {
break :blk std.fmt.format(writer, "{s}{s}={s}&", .{ prefix, field.name, unwrapped });
},
else => |U| blk: {
if (comptime std.meta.trait.isContainer(U) and std.meta.trait.hasFn("format")(U)) {
break :blk std.fmt.format(writer, "{s}{s}={}&", .{ prefix, field.name, unwrapped });
}
break :blk switch (@typeInfo(U)) {
.Enum => std.fmt.format(writer, "{s}{s}={s}&", .{ prefix, field.name, @tagName(unwrapped) }),
.Struct => formatRecursive(field.name ++ ".", unwrapped, writer),
else => std.fmt.format(writer, "{s}{s}={}&", .{ prefix, field.name, unwrapped }),
};
},
};
}
}
}
fn queryArgsToControllerQuery(args: QueryArgs) query.Query {
var result = query.Query{
.max_items = args.max_items,
.owner_id = args.owner_id,
.like = args.like,
.created_before = args.created_before,
.created_after = args.created_after,
.order_by = args.order_by,
.direction = args.direction,
.prev = .{},
.page_direction = args.page_direction,
};
if (args.prev) |prev| {
result.prev = .{
.id = prev.id,
.name = if (prev.order_val == .name) prev.order_val.name else null,
.host = if (prev.order_val == .host) prev.order_val.host else null,
.created_at = if (prev.order_val == .created_at) prev.order_val.created_at else null,
};
}
return result;
}

View File

@ -1,6 +1,7 @@
const std = @import("std");
const sql = @import("sql");
const DateTime = @import("util").DateTime;
const util = @import("util");
const DateTime = util.DateTime;
pub const Migration = struct {
name: [:0]const u8,
@ -8,14 +9,6 @@ pub const Migration = struct {
down: []const u8,
};
fn firstIndexOf(str: []const u8, char: u8) ?usize {
for (str) |ch, i| {
if (ch == char) return i;
}
return null;
}
fn execStmt(tx: anytype, stmt: []const u8, alloc: std.mem.Allocator) !void {
const stmt_null = try std.cstr.addNullByte(alloc, stmt);
defer alloc.free(stmt_null);
@ -23,18 +16,10 @@ fn execStmt(tx: anytype, stmt: []const u8, alloc: std.mem.Allocator) !void {
}
fn execScript(db: anytype, script: []const u8, alloc: std.mem.Allocator) !void {
const tx = try db.begin();
errdefer tx.rollback();
var remaining = script;
while (firstIndexOf(remaining, ';')) |last| {
try execStmt(tx, remaining[0 .. last + 1], alloc);
remaining = remaining[last + 1 ..];
var iter = util.SqlStmtIter.from(script);
while (iter.next()) |stmt| {
try execStmt(db, stmt, alloc);
}
if (remaining.len > 1) try execStmt(tx, remaining, alloc);
try tx.commit();
}
fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !bool {
@ -56,12 +41,18 @@ pub fn up(db: anytype) !void {
try execScript(db, create_migration_table, gpa.allocator());
for (migrations) |migration| {
const was_ran = try wasMigrationRan(db, migration.name, gpa.allocator());
const tx = try db.begin();
errdefer tx.rollback();
const was_ran = try wasMigrationRan(tx, migration.name, gpa.allocator());
if (!was_ran) {
std.log.info("Running migration {s}", .{migration.name});
try execScript(db, migration.up, gpa.allocator());
try db.insert("migration", .{ .name = migration.name }, gpa.allocator());
try execScript(tx, migration.up, gpa.allocator());
try tx.insert("migration", .{
.name = migration.name,
.applied_at = DateTime.now(),
}, gpa.allocator());
}
try tx.commit();
}
}
@ -69,7 +60,7 @@ const create_migration_table =
\\CREATE TABLE IF NOT EXISTS
\\migration(
\\ name TEXT NOT NULL PRIMARY KEY,
\\ applied_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\ applied_at TIMESTAMPTZ NOT NULL
\\);
;
@ -84,7 +75,7 @@ const migrations: []const Migration = &.{
\\ username TEXT NOT NULL,
\\
\\ kind TEXT NOT NULL CHECK (kind IN ('admin', 'user')),
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\ created_at TIMESTAMPTZ NOT NULL
\\);
\\
\\CREATE TABLE local_account(
@ -96,7 +87,8 @@ const migrations: []const Migration = &.{
\\CREATE TABLE password(
\\ account_id UUID NOT NULL PRIMARY KEY REFERENCES account(id),
\\
\\ hash BLOB NOT NULL
\\ hash BLOB NOT NULL,
\\ changed_at TIMESTAMPTZ NOT NULL
\\);
,
.down =
@ -114,7 +106,7 @@ const migrations: []const Migration = &.{
\\ content TEXT NOT NULL,
\\ author_id UUID NOT NULL REFERENCES account(id),
\\
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\ created_at TIMESTAMPTZ NOT NULL
\\);
,
.down = "DROP TABLE note;",
@ -128,7 +120,7 @@ const migrations: []const Migration = &.{
\\ account_id UUID NOT NULL REFERENCES account(id),
\\ note_id UUID NOT NULL REFERENCES note(id),
\\
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\ created_at TIMESTAMPTZ NOT NULL
\\);
,
.down = "DROP TABLE reaction;",
@ -140,7 +132,7 @@ const migrations: []const Migration = &.{
\\ hash TEXT NOT NULL PRIMARY KEY,
\\ account_id UUID NOT NULL REFERENCES local_account(id),
\\
\\ issued_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\ issued_at TIMESTAMPTZ NOT NULL
\\);
,
.down = "DROP TABLE token;",
@ -157,7 +149,7 @@ const migrations: []const Migration = &.{
\\
\\ max_uses INTEGER,
\\
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
\\ created_at TIMESTAMPTZ NOT NULL,
\\ expires_at TIMESTAMPTZ,
\\
\\ kind TEXT NOT NULL CHECK (kind in ('system_user', 'community_owner', 'user'))
@ -181,7 +173,7 @@ const migrations: []const Migration = &.{
\\ scheme TEXT NOT NULL CHECK (scheme IN ('http', 'https')),
\\ kind TEXT NOT NULL CHECK (kind in ('admin', 'local')),
\\
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\ created_at TIMESTAMPTZ NOT NULL
\\);
\\ALTER TABLE account ADD COLUMN community_id UUID REFERENCES community(id);
\\ALTER TABLE invite ADD COLUMN community_id UUID REFERENCES community(id);

View File

@ -121,6 +121,7 @@ pub const Db = struct {
// of 0, and we must not bind the argument.
const name = std.fmt.comptimePrint("${}", .{i + 1});
const db_idx = c.sqlite3_bind_parameter_index(stmt.?, name);
std.log.debug("param {s} got index {}", .{ name, db_idx });
if (db_idx != 0)
try self.bindArgument(stmt.?, @intCast(u15, db_idx), arg)
else if (!opts.ignore_unused_arguments)
@ -139,12 +140,24 @@ pub const Db = struct {
const T = @TypeOf(val);
switch (@typeInfo(T)) {
.Union => inline for (std.meta.fields(T)) |field| {
const Tag = std.meta.Tag(T);
const tag = @field(Tag, field.name);
.Union => {
const arr = if (@hasDecl(T, "toCharArray"))
val.toCharArray()
else if (@hasDecl(T, "toCharArrayZ"))
val.toCharArrayZ()
else {
inline for (std.meta.fields(T)) |field| {
const Tag = std.meta.Tag(T);
const tag = @field(Tag, field.name);
if (val == tag) return try self.bindArgument(stmt, idx, @field(val, field.name));
} else unreachable,
if (val == tag) return try self.bindArgument(stmt, idx, @field(val, field.name));
}
unreachable;
};
const len = std.mem.len(&arr);
return self.bindString(stmt, idx, arr[0..len]);
},
.Struct => {
const arr = if (@hasDecl(T, "toCharArray"))
@ -154,6 +167,7 @@ pub const Db = struct {
else
@compileError("SQLite: Could not serialize " ++ @typeName(T) ++ " into staticly sized string");
std.log.debug("binding type {any}: {s}", .{ T, arr });
const len = std.mem.len(&arr);
return self.bindString(stmt, idx, arr[0..len]);
},
@ -180,6 +194,8 @@ pub const Db = struct {
return error.BindException;
};
std.log.debug("binding string {s} to idx {}", .{ str, idx });
switch (c.sqlite3_bind_text(stmt, idx, str.ptr, len, c.SQLITE_TRANSIENT)) {
c.SQLITE_OK => {},
else => |result| {

View File

@ -107,13 +107,13 @@ const array_len = 20;
pub fn toCharArray(value: DateTime) [array_len]u8 {
var buf: [array_len]u8 = undefined;
_ = std.fmt.bufPrintZ(&buf, "{}", value) catch unreachable;
_ = std.fmt.bufPrint(&buf, "{}", .{value}) catch unreachable;
return buf;
}
pub fn toCharArrayZ(value: DateTime) [array_len + 1:0]u8 {
var buf: [array_len + 1:0]u8 = undefined;
_ = std.fmt.bufPrintZ(&buf, "{}", value) catch unreachable;
_ = std.fmt.bufPrintZ(&buf, "{}", .{value}) catch unreachable;
return buf;
}

View File

@ -7,6 +7,7 @@ pub const DateTime = @import("./DateTime.zig");
pub const Url = @import("./Url.zig");
pub const PathIter = iters.PathIter;
pub const QueryIter = iters.QueryIter;
pub const SqlStmtIter = iters.Separator(';');
/// Joins an array of strings, prefixing every entry with `prefix`,
/// and putting `separator` in between each pair