Compare commits
17 commits
e90d9daf77
...
4e81441a0d
Author | SHA1 | Date | |
---|---|---|---|
4e81441a0d | |||
bfd73b7a1f | |||
8694516180 | |||
5630c6160f | |||
139fc92854 | |||
83da3c914b | |||
7d13b2546b | |||
181e57a631 | |||
cfdc8c5761 | |||
cbf98c1cf3 | |||
d2151ae326 | |||
2d464f0820 | |||
438c72b7e9 | |||
b46e82746c | |||
4cb574bc91 | |||
5d94f6874d | |||
1ba1b18c39 |
28 changed files with 1101 additions and 673 deletions
12
build.zig
12
build.zig
|
@ -53,16 +53,16 @@ pub fn build(b: *std.build.Builder) void {
|
|||
exe.linkSystemLibrary("pq");
|
||||
exe.linkLibC();
|
||||
|
||||
const util_tests = b.addTest("src/util/lib.zig");
|
||||
const http_tests = b.addTest("src/http/lib.zig");
|
||||
const sql_tests = b.addTest("src/sql/lib.zig");
|
||||
//const util_tests = b.addTest("src/util/lib.zig");
|
||||
const http_tests = b.addTest("src/http/test.zig");
|
||||
//const sql_tests = b.addTest("src/sql/lib.zig");
|
||||
http_tests.addPackage(util_pkg);
|
||||
sql_tests.addPackage(util_pkg);
|
||||
//sql_tests.addPackage(util_pkg);
|
||||
|
||||
const unit_tests = b.step("unit-tests", "Run tests");
|
||||
unit_tests.dependOn(&util_tests.step);
|
||||
//unit_tests.dependOn(&util_tests.step);
|
||||
unit_tests.dependOn(&http_tests.step);
|
||||
unit_tests.dependOn(&sql_tests.step);
|
||||
//unit_tests.dependOn(&sql_tests.step);
|
||||
|
||||
const api_integration = b.addTest("./tests/api_integration/lib.zig");
|
||||
api_integration.addPackage(sql_pkg);
|
||||
|
|
|
@ -276,7 +276,10 @@ fn ApiConn(comptime DbConn: type) type {
|
|||
username,
|
||||
password,
|
||||
self.community.id,
|
||||
.{ .invite_id = if (maybe_invite) |inv| inv.id else null, .email = opt.email },
|
||||
.{
|
||||
.invite_id = if (maybe_invite) |inv| @as(?Uuid, inv.id) else null,
|
||||
.email = opt.email,
|
||||
},
|
||||
self.arena.allocator(),
|
||||
);
|
||||
|
||||
|
@ -348,5 +351,19 @@ fn ApiConn(comptime DbConn: type) type {
|
|||
if (!self.isAdmin()) return error.PermissionDenied;
|
||||
return try services.communities.query(self.db, args, self.arena.allocator());
|
||||
}
|
||||
|
||||
pub fn globalTimeline(self: *Self) ![]services.notes.Note {
|
||||
const result = try services.notes.query(self.db, .{}, self.arena.allocator());
|
||||
return result.items;
|
||||
}
|
||||
|
||||
pub fn localTimeline(self: *Self) ![]services.notes.Note {
|
||||
const result = try services.notes.query(
|
||||
self.db,
|
||||
.{ .community_id = self.community.id },
|
||||
self.arena.allocator(),
|
||||
);
|
||||
return result.items;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -94,7 +94,8 @@ pub const Actor = struct {
|
|||
created_at: DateTime,
|
||||
};
|
||||
|
||||
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !Actor {
|
||||
pub const GetError = error{ NotFound, DatabaseFailure };
|
||||
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Actor {
|
||||
return db.queryRow(
|
||||
Actor,
|
||||
\\SELECT
|
||||
|
|
16
src/api/services/common.zig
Normal file
16
src/api/services/common.zig
Normal file
|
@ -0,0 +1,16 @@
|
|||
const std = @import("std");
|
||||
const util = @import("util");
|
||||
|
||||
pub const Direction = enum {
|
||||
ascending,
|
||||
descending,
|
||||
|
||||
pub const jsonStringify = util.jsonSerializeEnumAsString;
|
||||
};
|
||||
|
||||
pub const PageDirection = enum {
|
||||
forward,
|
||||
backward,
|
||||
|
||||
pub const jsonStringify = util.jsonSerializeEnumAsString;
|
||||
};
|
|
@ -2,6 +2,7 @@ const std = @import("std");
|
|||
const builtin = @import("builtin");
|
||||
const util = @import("util");
|
||||
const sql = @import("sql");
|
||||
const common = @import("./common.zig");
|
||||
|
||||
const Uuid = util.Uuid;
|
||||
const DateTime = util.DateTime;
|
||||
|
@ -82,11 +83,12 @@ pub fn create(db: anytype, origin: []const u8, options: CreateOptions, alloc: st
|
|||
else => return error.DatabaseFailure,
|
||||
}
|
||||
|
||||
const name = options.name orelse host;
|
||||
db.insert("community", .{
|
||||
.id = id,
|
||||
.owner_id = null,
|
||||
.host = host,
|
||||
.name = options.name orelse host,
|
||||
.name = name,
|
||||
.scheme = scheme,
|
||||
.kind = options.kind,
|
||||
.created_at = DateTime.now(),
|
||||
|
@ -153,20 +155,8 @@ pub const QueryArgs = struct {
|
|||
pub const jsonStringify = util.jsonSerializeEnumAsString;
|
||||
};
|
||||
|
||||
pub const Direction = enum {
|
||||
ascending,
|
||||
descending,
|
||||
|
||||
pub const jsonStringify = util.jsonSerializeEnumAsString;
|
||||
};
|
||||
|
||||
pub const PageDirection = enum {
|
||||
forward,
|
||||
backward,
|
||||
|
||||
pub const jsonStringify = util.jsonSerializeEnumAsString;
|
||||
};
|
||||
|
||||
pub const Direction = common.Direction;
|
||||
pub const PageDirection = common.PageDirection;
|
||||
pub const Prev = std.meta.Child(std.meta.fieldInfo(QueryArgs, .prev).field_type);
|
||||
pub const OrderVal = std.meta.fieldInfo(Prev, .order_val).field_type;
|
||||
|
||||
|
@ -211,30 +201,6 @@ pub const QueryResult = struct {
|
|||
next_page: QueryArgs,
|
||||
};
|
||||
|
||||
const QueryBuilder = struct {
|
||||
array: std.ArrayList(u8),
|
||||
where_clauses_appended: usize = 0,
|
||||
|
||||
pub fn init(alloc: std.mem.Allocator) QueryBuilder {
|
||||
return QueryBuilder{ .array = std.ArrayList(u8).init(alloc) };
|
||||
}
|
||||
|
||||
pub fn deinit(self: *const QueryBuilder) void {
|
||||
self.array.deinit();
|
||||
}
|
||||
|
||||
pub fn andWhere(self: *QueryBuilder, clause: []const u8) !void {
|
||||
if (self.where_clauses_appended == 0) {
|
||||
try self.array.appendSlice("WHERE ");
|
||||
} else {
|
||||
try self.array.appendSlice(" AND ");
|
||||
}
|
||||
|
||||
try self.array.appendSlice(clause);
|
||||
self.where_clauses_appended += 1;
|
||||
}
|
||||
};
|
||||
|
||||
const max_max_items = 100;
|
||||
|
||||
pub const QueryError = error{
|
||||
|
@ -246,7 +212,7 @@ pub const QueryError = error{
|
|||
// arguments.
|
||||
// `args.max_items` is only a request, and fewer entries may be returned.
|
||||
pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResult {
|
||||
var builder = QueryBuilder.init(alloc);
|
||||
var builder = sql.QueryBuilder.init(alloc);
|
||||
defer builder.deinit();
|
||||
|
||||
try builder.array.appendSlice(
|
||||
|
@ -266,21 +232,21 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul
|
|||
if (args.prev) |prev| {
|
||||
if (prev.order_val != args.order_by) return error.PageArgMismatch;
|
||||
|
||||
try builder.andWhere(switch (args.order_by) {
|
||||
.name => "(name, id)",
|
||||
.host => "(host, id)",
|
||||
.created_at => "(created_at, id)",
|
||||
});
|
||||
_ = try builder.array.appendSlice(switch (args.direction) {
|
||||
switch (args.order_by) {
|
||||
.name => try builder.andWhere("(name, id)"),
|
||||
.host => try builder.andWhere("(host, id)"),
|
||||
.created_at => try builder.andWhere("(created_at, id)"),
|
||||
}
|
||||
switch (args.direction) {
|
||||
.ascending => switch (args.page_direction) {
|
||||
.forward => " > ",
|
||||
.backward => " < ",
|
||||
.forward => try builder.appendSlice(" > "),
|
||||
.backward => try builder.appendSlice(" < "),
|
||||
},
|
||||
.descending => switch (args.page_direction) {
|
||||
.forward => " < ",
|
||||
.backward => " > ",
|
||||
.forward => try builder.appendSlice(" < "),
|
||||
.backward => try builder.appendSlice(" > "),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
_ = try builder.array.appendSlice("($5, $6)");
|
||||
}
|
||||
|
@ -297,57 +263,52 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul
|
|||
|
||||
_ = try builder.array.appendSlice("\nLIMIT $7");
|
||||
|
||||
const query_args = .{
|
||||
args.owner_id,
|
||||
args.like,
|
||||
args.created_before,
|
||||
args.created_after,
|
||||
if (args.prev) |prev| prev.order_val else null,
|
||||
if (args.prev) |prev| prev.id else null,
|
||||
max_items,
|
||||
const query_args = blk: {
|
||||
const ord_val =
|
||||
if (args.prev) |prev| @as(?QueryArgs.OrderVal, prev.order_val) else null;
|
||||
const id =
|
||||
if (args.prev) |prev| @as(?Uuid, prev.id) else null;
|
||||
break :blk .{
|
||||
args.owner_id,
|
||||
args.like,
|
||||
args.created_before,
|
||||
args.created_after,
|
||||
ord_val,
|
||||
id,
|
||||
max_items,
|
||||
};
|
||||
};
|
||||
|
||||
try builder.array.append(0);
|
||||
|
||||
var results = try db.queryWithOptions(
|
||||
var results = try db.queryRowsWithOptions(
|
||||
Community,
|
||||
std.meta.assumeSentinel(builder.array.items, 0),
|
||||
query_args,
|
||||
.{ .prep_allocator = alloc, .ignore_unused_arguments = true },
|
||||
max_items,
|
||||
.{ .allocator = alloc, .ignore_unused_arguments = true },
|
||||
);
|
||||
defer results.finish();
|
||||
|
||||
const result_buf = try alloc.alloc(Community, args.max_items);
|
||||
errdefer alloc.free(result_buf);
|
||||
|
||||
var count: usize = 0;
|
||||
errdefer for (result_buf[0..count]) |c| util.deepFree(alloc, c);
|
||||
|
||||
for (result_buf) |*c| {
|
||||
c.* = (try results.row(alloc)) orelse break;
|
||||
|
||||
count += 1;
|
||||
}
|
||||
errdefer util.deepFree(alloc, results);
|
||||
|
||||
var next_page = args;
|
||||
var prev_page = args;
|
||||
prev_page.page_direction = .backward;
|
||||
next_page.page_direction = .forward;
|
||||
if (count != 0) {
|
||||
if (results.len != 0) {
|
||||
prev_page.prev = .{
|
||||
.id = result_buf[0].id,
|
||||
.order_val = getOrderVal(result_buf[0], args.order_by),
|
||||
.id = results[0].id,
|
||||
.order_val = getOrderVal(results[0], args.order_by),
|
||||
};
|
||||
|
||||
next_page.prev = .{
|
||||
.id = result_buf[count - 1].id,
|
||||
.order_val = getOrderVal(result_buf[count - 1], args.order_by),
|
||||
.id = results[results.len - 1].id,
|
||||
.order_val = getOrderVal(results[results.len - 1], args.order_by),
|
||||
};
|
||||
}
|
||||
// TODO: This will give incorrect links on an empty page
|
||||
|
||||
return QueryResult{
|
||||
.items = result_buf[0..count],
|
||||
.items = results,
|
||||
|
||||
.next_page = next_page,
|
||||
.prev_page = prev_page,
|
||||
|
|
|
@ -71,7 +71,7 @@ pub fn create(db: anytype, created_by: Uuid, community_id: ?Uuid, options: Invit
|
|||
.max_uses = options.max_uses,
|
||||
.created_at = created_at,
|
||||
.expires_at = if (options.lifespan) |lifespan|
|
||||
created_at.add(lifespan)
|
||||
@as(?DateTime, created_at.add(lifespan))
|
||||
else
|
||||
null,
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
const std = @import("std");
|
||||
const util = @import("util");
|
||||
const sql = @import("sql");
|
||||
const common = @import("./common.zig");
|
||||
|
||||
const Uuid = util.Uuid;
|
||||
const DateTime = util.DateTime;
|
||||
|
@ -42,7 +43,7 @@ const selectStarFromNote = std.fmt.comptimePrint(
|
|||
\\SELECT {s}
|
||||
\\FROM note
|
||||
\\
|
||||
, .{util.comptimeJoin(",", std.meta.fieldNames(Note))});
|
||||
, .{util.comptimeJoinWithPrefix(",", "note.", std.meta.fieldNames(Note))});
|
||||
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Note {
|
||||
return db.queryRow(
|
||||
Note,
|
||||
|
@ -57,3 +58,108 @@ pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Note {
|
|||
else => error.DatabaseFailure,
|
||||
};
|
||||
}
|
||||
|
||||
const max_max_items = 100;
|
||||
|
||||
pub const QueryArgs = struct {
|
||||
pub const PageDirection = common.PageDirection;
|
||||
pub const Prev = std.meta.Child(std.meta.field(@This(), .prev).field_type);
|
||||
|
||||
max_items: usize = 20,
|
||||
|
||||
created_before: ?DateTime = null,
|
||||
created_after: ?DateTime = null,
|
||||
community_id: ?Uuid = null,
|
||||
|
||||
prev: ?struct {
|
||||
id: Uuid,
|
||||
created_at: DateTime,
|
||||
} = null,
|
||||
|
||||
page_direction: PageDirection = .forward,
|
||||
};
|
||||
|
||||
pub const QueryResult = struct {
|
||||
items: []Note,
|
||||
|
||||
prev_page: QueryArgs,
|
||||
next_page: QueryArgs,
|
||||
};
|
||||
|
||||
pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResult {
|
||||
var builder = sql.QueryBuilder.init(alloc);
|
||||
defer builder.deinit();
|
||||
|
||||
try builder.appendSlice(selectStarFromNote ++
|
||||
\\ JOIN actor ON actor.id = note.author_id
|
||||
\\
|
||||
);
|
||||
|
||||
if (args.created_before != null) try builder.andWhere("note.created_at < $1");
|
||||
if (args.created_after != null) try builder.andWhere("note.created_at > $2");
|
||||
if (args.prev != null) {
|
||||
try builder.andWhere("(note.created_at, note.id)");
|
||||
|
||||
switch (args.page_direction) {
|
||||
.forward => try builder.appendSlice(" < "),
|
||||
.backward => try builder.appendSlice(" > "),
|
||||
}
|
||||
try builder.appendSlice("($3, $4)");
|
||||
}
|
||||
if (args.community_id != null) try builder.andWhere("actor.community_id = $5");
|
||||
|
||||
try builder.appendSlice(
|
||||
\\
|
||||
\\ORDER BY note.created_at DESC
|
||||
\\LIMIT $6
|
||||
\\
|
||||
);
|
||||
|
||||
const max_items = if (args.max_items > max_max_items) max_max_items else args.max_items;
|
||||
|
||||
const query_args = blk: {
|
||||
const prev_created_at = if (args.prev) |prev| @as(?DateTime, prev.created_at) else null;
|
||||
const prev_id = if (args.prev) |prev| @as(?Uuid, prev.id) else null;
|
||||
|
||||
break :blk .{
|
||||
args.created_before,
|
||||
args.created_after,
|
||||
prev_created_at,
|
||||
prev_id,
|
||||
args.community_id,
|
||||
max_items,
|
||||
};
|
||||
};
|
||||
|
||||
const results = try db.queryRowsWithOptions(
|
||||
Note,
|
||||
try builder.terminate(),
|
||||
query_args,
|
||||
max_items,
|
||||
.{ .allocator = alloc, .ignore_unused_arguments = true },
|
||||
);
|
||||
errdefer util.deepFree(results);
|
||||
|
||||
var next_page = args;
|
||||
var prev_page = args;
|
||||
prev_page.page_direction = .backward;
|
||||
next_page.page_direction = .forward;
|
||||
if (results.len != 0) {
|
||||
prev_page.prev = .{
|
||||
.id = results[0].id,
|
||||
.created_at = results[0].created_at,
|
||||
};
|
||||
|
||||
next_page.prev = .{
|
||||
.id = results[results.len - 1].id,
|
||||
.created_at = results[results.len - 1].created_at,
|
||||
};
|
||||
}
|
||||
// TODO: this will give incorrect links on an empty page
|
||||
|
||||
return QueryResult{
|
||||
.items = results,
|
||||
.next_page = next_page,
|
||||
.prev_page = prev_page,
|
||||
};
|
||||
}
|
||||
|
|
124
src/http/headers.zig
Normal file
124
src/http/headers.zig
Normal file
|
@ -0,0 +1,124 @@
|
|||
const std = @import("std");
|
||||
|
||||
pub const Fields = struct {
|
||||
const HashContext = struct {
|
||||
const hash_seed = 1;
|
||||
pub fn eql(_: @This(), lhs: []const u8, rhs: []const u8, _: usize) bool {
|
||||
return std.ascii.eqlIgnoreCase(lhs, rhs);
|
||||
}
|
||||
pub fn hash(_: @This(), s: []const u8) u32 {
|
||||
var h = std.hash.Wyhash.init(hash_seed);
|
||||
for (s) |ch| {
|
||||
const c = [1]u8{std.ascii.toLower(ch)};
|
||||
h.update(&c);
|
||||
}
|
||||
return @truncate(u32, h.final());
|
||||
}
|
||||
};
|
||||
|
||||
const HashMap = std.ArrayHashMapUnmanaged(
|
||||
[]const u8,
|
||||
[]const u8,
|
||||
HashContext,
|
||||
true,
|
||||
);
|
||||
|
||||
unmanaged: HashMap,
|
||||
allocator: std.mem.Allocator,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator) Fields {
|
||||
return Fields{
|
||||
.unmanaged = .{},
|
||||
.allocator = allocator,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *Fields) void {
|
||||
var hash_iter = self.unmanaged.iterator();
|
||||
while (hash_iter.next()) |entry| {
|
||||
self.allocator.free(entry.key_ptr.*);
|
||||
self.allocator.free(entry.value_ptr.*);
|
||||
}
|
||||
|
||||
self.unmanaged.deinit(self.allocator);
|
||||
}
|
||||
|
||||
pub fn iterator(self: Fields) HashMap.Iterator {
|
||||
return self.unmanaged.iterator();
|
||||
}
|
||||
|
||||
pub fn get(self: Fields, key: []const u8) ?[]const u8 {
|
||||
return self.unmanaged.get(key);
|
||||
}
|
||||
|
||||
pub const ListIterator = struct {
|
||||
remaining: []const u8,
|
||||
|
||||
fn extractElement(self: *ListIterator) ?[]const u8 {
|
||||
if (self.remaining.len == 0) return null;
|
||||
|
||||
var start: usize = 0;
|
||||
var is_quoted = false;
|
||||
const end = for (self.remaining) |ch, i| {
|
||||
if (start == i and std.ascii.isWhitespace(ch)) {
|
||||
start += 1;
|
||||
} else if (ch == '"') {
|
||||
is_quoted = !is_quoted;
|
||||
}
|
||||
if (ch == ',' and !is_quoted) {
|
||||
break i;
|
||||
}
|
||||
} else self.remaining.len;
|
||||
|
||||
const str = self.remaining[start..end];
|
||||
if (end == self.remaining.len) {
|
||||
self.remaining = "";
|
||||
} else {
|
||||
self.remaining = self.remaining[end + 1 ..];
|
||||
}
|
||||
|
||||
return std.mem.trim(u8, str, " \t");
|
||||
}
|
||||
|
||||
pub fn next(self: *ListIterator) ?[]const u8 {
|
||||
while (self.extractElement()) |elem| {
|
||||
if (elem.len != 0) return elem;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
pub fn getList(self: Fields, key: []const u8) ?ListIterator {
|
||||
return if (self.unmanaged.get(key)) |hdr| ListIterator{ .remaining = hdr } else null;
|
||||
}
|
||||
|
||||
pub fn put(self: *Fields, key: []const u8, val: []const u8) !void {
|
||||
const key_clone = try self.allocator.alloc(u8, key.len);
|
||||
std.mem.copy(u8, key_clone, key);
|
||||
errdefer self.allocator.free(key_clone);
|
||||
|
||||
const val_clone = try self.allocator.alloc(u8, val.len);
|
||||
std.mem.copy(u8, val_clone, val);
|
||||
errdefer self.allocator.free(val_clone);
|
||||
|
||||
if (try self.unmanaged.fetchPut(self.allocator, key_clone, val_clone)) |entry| {
|
||||
self.allocator.free(entry.key);
|
||||
self.allocator.free(entry.value);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn append(self: *Fields, key: []const u8, val: []const u8) !void {
|
||||
if (self.unmanaged.getEntry(key)) |entry| {
|
||||
const new_val = try std.mem.join(self.allocator, ", ", &.{ entry.value_ptr.*, val });
|
||||
self.allocator.free(entry.value_ptr.*);
|
||||
entry.value_ptr.* = new_val;
|
||||
} else {
|
||||
try self.put(key, val);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count(self: Fields) usize {
|
||||
return self.unmanaged.count();
|
||||
}
|
||||
};
|
|
@ -10,22 +10,15 @@ pub const socket = @import("./socket.zig");
|
|||
pub const Method = std.http.Method;
|
||||
pub const Status = std.http.Status;
|
||||
|
||||
pub const Request = request.Request;
|
||||
pub const Request = request.Request(std.net.Stream.Reader);
|
||||
pub const serveConn = server.serveConn;
|
||||
pub const Response = server.Response;
|
||||
pub const Handler = server.Handler;
|
||||
|
||||
pub const Headers = std.HashMap([]const u8, []const u8, struct {
|
||||
pub fn eql(_: @This(), a: []const u8, b: []const u8) bool {
|
||||
return ciutf8.eql(a, b);
|
||||
}
|
||||
pub const Fields = @import("./headers.zig").Fields;
|
||||
|
||||
pub fn hash(_: @This(), str: []const u8) u64 {
|
||||
return ciutf8.hash(str);
|
||||
}
|
||||
}, std.hash_map.default_max_load_percentage);
|
||||
|
||||
test {
|
||||
_ = server;
|
||||
_ = request;
|
||||
}
|
||||
pub const Protocol = enum {
|
||||
http_1_0,
|
||||
http_1_1,
|
||||
http_1_x,
|
||||
};
|
||||
|
|
|
@ -3,29 +3,23 @@ const http = @import("./lib.zig");
|
|||
|
||||
const parser = @import("./request/parser.zig");
|
||||
|
||||
pub const Request = struct {
|
||||
pub const Protocol = enum {
|
||||
http_1_0,
|
||||
http_1_1,
|
||||
pub fn Request(comptime Reader: type) type {
|
||||
return struct {
|
||||
protocol: http.Protocol,
|
||||
|
||||
method: http.Method,
|
||||
uri: []const u8,
|
||||
headers: http.Fields,
|
||||
|
||||
body: ?parser.TransferStream(Reader),
|
||||
|
||||
pub fn parseFree(self: *@This(), allocator: std.mem.Allocator) void {
|
||||
allocator.free(self.uri);
|
||||
self.headers.deinit();
|
||||
}
|
||||
};
|
||||
|
||||
protocol: Protocol,
|
||||
source_address: ?std.net.Address,
|
||||
|
||||
method: http.Method,
|
||||
uri: []const u8,
|
||||
headers: http.Headers,
|
||||
body: ?[]const u8 = null,
|
||||
|
||||
pub fn parse(alloc: std.mem.Allocator, reader: anytype, addr: std.net.Address) !Request {
|
||||
return parser.parse(alloc, reader, addr);
|
||||
}
|
||||
|
||||
pub fn parseFree(self: Request, alloc: std.mem.Allocator) void {
|
||||
parser.parseFree(alloc, self);
|
||||
}
|
||||
};
|
||||
|
||||
test {
|
||||
_ = parser;
|
||||
}
|
||||
|
||||
pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request(@TypeOf(reader)) {
|
||||
return parser.parse(alloc, reader);
|
||||
}
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
const std = @import("std");
|
||||
const util = @import("util");
|
||||
const http = @import("../lib.zig");
|
||||
|
||||
const Method = http.Method;
|
||||
const Headers = http.Headers;
|
||||
const Fields = http.Fields;
|
||||
|
||||
const Request = @import("../request.zig").Request;
|
||||
|
||||
const request_buf_size = 1 << 16;
|
||||
const max_path_len = 1 << 10;
|
||||
const max_body_len = 1 << 12;
|
||||
|
||||
fn ParseError(comptime Reader: type) type {
|
||||
return error{
|
||||
|
@ -22,7 +20,7 @@ const Encoding = enum {
|
|||
chunked,
|
||||
};
|
||||
|
||||
pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address) !Request {
|
||||
pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request(@TypeOf(reader)) {
|
||||
const method = try parseMethod(reader);
|
||||
const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) {
|
||||
error.StreamTooLong => return error.RequestUriTooLong,
|
||||
|
@ -33,28 +31,20 @@ pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address
|
|||
const proto = try parseProto(reader);
|
||||
|
||||
// discard \r\n
|
||||
_ = try reader.readByte();
|
||||
_ = try reader.readByte();
|
||||
switch (try reader.readByte()) {
|
||||
'\r' => if ((try reader.readByte()) != '\n') return error.BadRequest,
|
||||
'\n' => {},
|
||||
else => return error.BadRequest,
|
||||
}
|
||||
|
||||
var headers = try parseHeaders(alloc, reader);
|
||||
errdefer freeHeaders(alloc, &headers);
|
||||
errdefer headers.deinit();
|
||||
|
||||
const body = if (method.requestHasBody())
|
||||
try readBody(alloc, headers, reader)
|
||||
else
|
||||
null;
|
||||
errdefer if (body) |b| alloc.free(b);
|
||||
const body = try prepareBody(headers, reader);
|
||||
if (body != null and !method.requestHasBody()) return error.BadRequest;
|
||||
|
||||
const eff_addr = if (headers.get("X-Real-IP")) |ip|
|
||||
std.net.Address.parseIp(ip, address.getPort()) catch {
|
||||
return error.BadRequest;
|
||||
}
|
||||
else
|
||||
address;
|
||||
|
||||
return Request{
|
||||
return Request(@TypeOf(reader)){
|
||||
.protocol = proto,
|
||||
.source_address = eff_addr,
|
||||
|
||||
.method = method,
|
||||
.uri = uri,
|
||||
|
@ -79,7 +69,7 @@ fn parseMethod(reader: anytype) !Method {
|
|||
return error.MethodNotImplemented;
|
||||
}
|
||||
|
||||
fn parseProto(reader: anytype) !Request.Protocol {
|
||||
fn parseProto(reader: anytype) !http.Protocol {
|
||||
var buf: [8]u8 = undefined;
|
||||
const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) {
|
||||
error.StreamTooLong => return error.UnknownProtocol,
|
||||
|
@ -99,85 +89,145 @@ fn parseProto(reader: anytype) !Request.Protocol {
|
|||
return switch (buf[2]) {
|
||||
'0' => .http_1_0,
|
||||
'1' => .http_1_1,
|
||||
else => error.HttpVersionNotSupported,
|
||||
else => .http_1_x,
|
||||
};
|
||||
}
|
||||
|
||||
fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers {
|
||||
var map = Headers.init(allocator);
|
||||
errdefer map.deinit();
|
||||
errdefer {
|
||||
var iter = map.iterator();
|
||||
while (iter.next()) |it| {
|
||||
allocator.free(it.key_ptr.*);
|
||||
allocator.free(it.value_ptr.*);
|
||||
}
|
||||
}
|
||||
// todo:
|
||||
//errdefer {
|
||||
//var iter = map.iterator();
|
||||
//while (iter.next()) |it| {
|
||||
//allocator.free(it.key_ptr);
|
||||
//allocator.free(it.value_ptr);
|
||||
//}
|
||||
//}
|
||||
|
||||
var buf: [1024]u8 = undefined;
|
||||
fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields {
|
||||
var headers = Fields.init(allocator);
|
||||
|
||||
var buf: [4096]u8 = undefined;
|
||||
while (true) {
|
||||
const line = try reader.readUntilDelimiter(&buf, '\n');
|
||||
if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break;
|
||||
const full_line = reader.readUntilDelimiter(&buf, '\n') catch |err| switch (err) {
|
||||
error.StreamTooLong => return error.HeaderLineTooLong,
|
||||
else => return err,
|
||||
};
|
||||
const line = std.mem.trimRight(u8, full_line, "\r");
|
||||
if (line.len == 0) break;
|
||||
|
||||
// TODO: handle multi-line headers
|
||||
const name = extractHeaderName(line) orelse continue;
|
||||
const value_end = if (line[line.len - 1] == '\r') line.len - 1 else line.len;
|
||||
const value = line[name.len + 1 + 1 .. value_end];
|
||||
const name = std.mem.sliceTo(line, ':');
|
||||
if (!isTokenValid(name)) return error.BadRequest;
|
||||
if (name.len == line.len) return error.BadRequest;
|
||||
|
||||
if (name.len == 0 or value.len == 0) return error.BadRequest;
|
||||
const value = std.mem.trim(u8, line[name.len + 1 ..], " \t");
|
||||
|
||||
const name_alloc = try allocator.alloc(u8, name.len);
|
||||
errdefer allocator.free(name_alloc);
|
||||
const value_alloc = try allocator.alloc(u8, value.len);
|
||||
errdefer allocator.free(value_alloc);
|
||||
|
||||
@memcpy(name_alloc.ptr, name.ptr, name.len);
|
||||
@memcpy(value_alloc.ptr, value.ptr, value.len);
|
||||
|
||||
try map.put(name_alloc, value_alloc);
|
||||
try headers.append(name, value);
|
||||
}
|
||||
|
||||
return map;
|
||||
return headers;
|
||||
}
|
||||
|
||||
fn extractHeaderName(line: []const u8) ?[]const u8 {
|
||||
var index: usize = 0;
|
||||
fn isTokenValid(token: []const u8) bool {
|
||||
if (token.len == 0) return false;
|
||||
for (token) |ch| {
|
||||
switch (ch) {
|
||||
'"', '(', ')', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '{', '}' => return false,
|
||||
|
||||
// TODO: handle whitespace
|
||||
while (index < line.len) : (index += 1) {
|
||||
if (line[index] == ':') {
|
||||
if (index == 0) return null;
|
||||
return line[0..index];
|
||||
'!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~' => {},
|
||||
else => if (!std.ascii.isAlphanumeric(ch)) return false,
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
return true;
|
||||
}
|
||||
|
||||
fn readBody(alloc: std.mem.Allocator, headers: Headers, reader: anytype) !?[]const u8 {
|
||||
const xfer_encoding = try parseEncoding(headers.get("Transfer-Encoding"));
|
||||
if (xfer_encoding != .identity) return error.UnsupportedMediaType;
|
||||
fn prepareBody(headers: Fields, reader: anytype) !?TransferStream(@TypeOf(reader)) {
|
||||
const hdr = headers.get("Transfer-Encoding");
|
||||
// TODO:
|
||||
// if (hder != null and protocol == .http_1_0) return error.BadRequest;
|
||||
const xfer_encoding = try parseEncoding(hdr);
|
||||
const content_encoding = try parseEncoding(headers.get("Content-Encoding"));
|
||||
if (content_encoding != .identity) return error.UnsupportedMediaType;
|
||||
|
||||
const len_str = headers.get("Content-Length") orelse return null;
|
||||
const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest;
|
||||
if (len > max_body_len) return error.RequestEntityTooLarge;
|
||||
const body = try alloc.alloc(u8, len);
|
||||
errdefer alloc.free(body);
|
||||
switch (xfer_encoding) {
|
||||
.identity => {
|
||||
const len_str = headers.get("Content-Length") orelse return null;
|
||||
const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest;
|
||||
|
||||
reader.readNoEof(body) catch return error.BadRequest;
|
||||
return TransferStream(@TypeOf(reader)){ .underlying = .{ .identity = std.io.limitedReader(reader, len) } };
|
||||
},
|
||||
.chunked => {
|
||||
if (headers.get("Content-Length") != null) return error.BadRequest;
|
||||
return TransferStream(@TypeOf(reader)){
|
||||
.underlying = .{
|
||||
.chunked = try ChunkedStream(@TypeOf(reader)).init(reader),
|
||||
},
|
||||
};
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return body;
|
||||
fn ChunkedStream(comptime R: type) type {
|
||||
return struct {
|
||||
const Self = @This();
|
||||
|
||||
remaining: ?usize = 0,
|
||||
underlying: R,
|
||||
|
||||
const Error = R.Error || error{ Unexpected, InvalidChunkHeader, StreamTooLong, EndOfStream };
|
||||
fn init(reader: R) !Self {
|
||||
var self: Self = .{ .underlying = reader };
|
||||
return self;
|
||||
}
|
||||
|
||||
fn read(self: *Self, buf: []u8) !usize {
|
||||
var count: usize = 0;
|
||||
while (true) {
|
||||
if (count == buf.len) return count;
|
||||
if (self.remaining == null) return count;
|
||||
if (self.remaining.? == 0) self.remaining = try self.readChunkHeader();
|
||||
|
||||
const max_read = std.math.min(buf.len, self.remaining.?);
|
||||
const amt = try self.underlying.read(buf[count .. count + max_read]);
|
||||
if (amt != max_read) return error.EndOfStream;
|
||||
count += amt;
|
||||
self.remaining.? -= amt;
|
||||
if (self.remaining.? == 0) {
|
||||
var crlf: [2]u8 = undefined;
|
||||
_ = try self.underlying.readUntilDelimiter(&crlf, '\n');
|
||||
self.remaining = try self.readChunkHeader();
|
||||
}
|
||||
|
||||
if (count == buf.len) return count;
|
||||
}
|
||||
}
|
||||
|
||||
fn readChunkHeader(self: *Self) !?usize {
|
||||
// TODO: Pick a reasonable limit for this
|
||||
var buf = std.mem.zeroes([10]u8);
|
||||
const line = self.underlying.readUntilDelimiter(&buf, '\n') catch |err| {
|
||||
return if (err == error.StreamTooLong) error.InvalidChunkHeader else err;
|
||||
};
|
||||
if (line.len < 2 or line[line.len - 1] != '\r') return error.InvalidChunkHeader;
|
||||
|
||||
const size = std.fmt.parseInt(usize, line[0 .. line.len - 1], 16) catch return error.InvalidChunkHeader;
|
||||
|
||||
return if (size != 0) size else null;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn TransferStream(comptime R: type) type {
|
||||
return struct {
|
||||
const Error = R.Error || ChunkedStream(R).Error;
|
||||
const Reader = std.io.Reader(*@This(), Error, read);
|
||||
|
||||
underlying: union(enum) {
|
||||
identity: std.io.LimitedReader(R),
|
||||
chunked: ChunkedStream(R),
|
||||
},
|
||||
|
||||
pub fn read(self: *@This(), buf: []u8) Error!usize {
|
||||
return switch (self.underlying) {
|
||||
.identity => |*r| try r.read(buf),
|
||||
.chunked => |*r| try r.read(buf),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn reader(self: *@This()) Reader {
|
||||
return .{ .context = self };
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: assumes that there's only one encoding, not layered encodings
|
||||
|
@ -187,257 +237,3 @@ fn parseEncoding(encoding: ?[]const u8) !Encoding {
|
|||
if (std.mem.eql(u8, encoding.?, "chunked")) return .chunked;
|
||||
return error.UnsupportedMediaType;
|
||||
}
|
||||
|
||||
pub fn parseFree(allocator: std.mem.Allocator, request: *Request) void {
|
||||
allocator.free(request.uri);
|
||||
freeHeaders(allocator, &request.headers);
|
||||
if (request.body) |body| allocator.free(body);
|
||||
}
|
||||
|
||||
fn freeHeaders(allocator: std.mem.Allocator, headers: *http.Headers) void {
|
||||
var iter = headers.iterator();
|
||||
while (iter.next()) |it| {
|
||||
allocator.free(it.key_ptr.*);
|
||||
allocator.free(it.value_ptr.*);
|
||||
}
|
||||
headers.deinit();
|
||||
}
|
||||
|
||||
const _test = struct {
|
||||
const expectEqual = std.testing.expectEqual;
|
||||
const expectEqualStrings = std.testing.expectEqualStrings;
|
||||
|
||||
fn toCrlf(comptime str: []const u8) []const u8 {
|
||||
comptime {
|
||||
var buf: [str.len * 2]u8 = undefined;
|
||||
|
||||
@setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2
|
||||
|
||||
var buf_len: usize = 0;
|
||||
for (str) |ch| {
|
||||
if (ch == '\n') {
|
||||
buf[buf_len] = '\r';
|
||||
buf_len += 1;
|
||||
}
|
||||
|
||||
buf[buf_len] = ch;
|
||||
buf_len += 1;
|
||||
}
|
||||
|
||||
return buf[0..buf_len];
|
||||
}
|
||||
}
|
||||
|
||||
fn makeHeaders(alloc: std.mem.Allocator, headers: anytype) !Headers {
|
||||
var result = Headers.init(alloc);
|
||||
inline for (headers) |tup| {
|
||||
try result.put(tup[0], tup[1]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
fn areEqualHeaders(lhs: Headers, rhs: Headers) bool {
|
||||
if (lhs.count() != rhs.count()) return false;
|
||||
var iter = lhs.iterator();
|
||||
while (iter.next()) |it| {
|
||||
const rhs_val = rhs.get(it.key_ptr.*) orelse return false;
|
||||
if (!std.mem.eql(u8, it.value_ptr.*, rhs_val)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
fn printHeaders(headers: Headers) void {
|
||||
var iter = headers.iterator();
|
||||
while (iter.next()) |it| {
|
||||
std.debug.print("{s}: {s}\n", .{ it.key_ptr.*, it.value_ptr.* });
|
||||
}
|
||||
}
|
||||
|
||||
fn expectEqualHeaders(expected: Headers, actual: Headers) !void {
|
||||
if (!areEqualHeaders(expected, actual)) {
|
||||
std.debug.print("\nexpected: \n", .{});
|
||||
printHeaders(expected);
|
||||
std.debug.print("\n\nfound: \n", .{});
|
||||
printHeaders(actual);
|
||||
std.debug.print("\n\n", .{});
|
||||
return error.TestExpectedEqual;
|
||||
}
|
||||
}
|
||||
|
||||
fn parseTestCase(alloc: std.mem.Allocator, comptime request: []const u8, expected: http.Request) !void {
|
||||
var stream = std.io.fixedBufferStream(toCrlf(request));
|
||||
|
||||
const result = try parse(alloc, stream.reader());
|
||||
|
||||
try expectEqual(expected.method, result.method);
|
||||
try expectEqualStrings(expected.path, result.path);
|
||||
try expectEqualHeaders(expected.headers, result.headers);
|
||||
if ((expected.body == null) != (result.body == null)) {
|
||||
const null_str: []const u8 = "(null)";
|
||||
const exp = expected.body orelse null_str;
|
||||
const act = result.body orelse null_str;
|
||||
std.debug.print("\nexpected:\n{s}\n\nfound:\n{s}\n\n", .{ exp, act });
|
||||
return error.TestExpectedEqual;
|
||||
}
|
||||
if (expected.body != null) {
|
||||
try expectEqualStrings(expected.body.?, result.body.?);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// TOOD: failure test cases
|
||||
test "parse" {
|
||||
const testCase = _test.parseTestCase;
|
||||
var buf = [_]u8{0} ** (1 << 16);
|
||||
var fba = std.heap.FixedBufferAllocator.init(&buf);
|
||||
const alloc = fba.allocator();
|
||||
try testCase(alloc, (
|
||||
\\GET / HTTP/1.1
|
||||
\\
|
||||
\\
|
||||
), .{
|
||||
.method = .GET,
|
||||
.headers = try _test.makeHeaders(alloc, .{}),
|
||||
.path = "/",
|
||||
});
|
||||
|
||||
fba.reset();
|
||||
try testCase(alloc, (
|
||||
\\POST / HTTP/1.1
|
||||
\\
|
||||
\\
|
||||
), .{
|
||||
.method = .POST,
|
||||
.headers = try _test.makeHeaders(alloc, .{}),
|
||||
.path = "/",
|
||||
});
|
||||
|
||||
fba.reset();
|
||||
try testCase(alloc, (
|
||||
\\HEAD / HTTP/1.1
|
||||
\\Authorization: bearer <token>
|
||||
\\
|
||||
\\
|
||||
), .{
|
||||
.method = .HEAD,
|
||||
.headers = try _test.makeHeaders(alloc, .{
|
||||
.{ "Authorization", "bearer <token>" },
|
||||
}),
|
||||
.path = "/",
|
||||
});
|
||||
|
||||
fba.reset();
|
||||
try testCase(alloc, (
|
||||
\\POST /nonsense HTTP/1.1
|
||||
\\Authorization: bearer <token>
|
||||
\\Content-Length: 5
|
||||
\\
|
||||
\\12345
|
||||
), .{
|
||||
.method = .POST,
|
||||
.headers = try _test.makeHeaders(alloc, .{
|
||||
.{ "Authorization", "bearer <token>" },
|
||||
.{ "Content-Length", "5" },
|
||||
}),
|
||||
.path = "/nonsense",
|
||||
.body = "12345",
|
||||
});
|
||||
|
||||
fba.reset();
|
||||
try std.testing.expectError(
|
||||
error.MethodNotImplemented,
|
||||
testCase(alloc, (
|
||||
\\FOO /nonsense HTTP/1.1
|
||||
\\
|
||||
\\
|
||||
), undefined),
|
||||
);
|
||||
|
||||
fba.reset();
|
||||
try std.testing.expectError(
|
||||
error.MethodNotImplemented,
|
||||
testCase(alloc, (
|
||||
\\FOOBARBAZ /nonsense HTTP/1.1
|
||||
\\
|
||||
\\
|
||||
), undefined),
|
||||
);
|
||||
|
||||
fba.reset();
|
||||
try std.testing.expectError(
|
||||
error.RequestUriTooLong,
|
||||
testCase(alloc, (
|
||||
\\GET /
|
||||
++ ("a" ** 2048)), undefined),
|
||||
);
|
||||
|
||||
fba.reset();
|
||||
try std.testing.expectError(
|
||||
error.UnknownProtocol,
|
||||
testCase(alloc, (
|
||||
\\GET /nonsense SPECIALHTTP/1.1
|
||||
\\
|
||||
\\
|
||||
), undefined),
|
||||
);
|
||||
|
||||
fba.reset();
|
||||
try std.testing.expectError(
|
||||
error.UnknownProtocol,
|
||||
testCase(alloc, (
|
||||
\\GET /nonsense JSON/1.1
|
||||
\\
|
||||
\\
|
||||
), undefined),
|
||||
);
|
||||
|
||||
fba.reset();
|
||||
try std.testing.expectError(
|
||||
error.HttpVersionNotSupported,
|
||||
testCase(alloc, (
|
||||
\\GET /nonsense HTTP/1.9
|
||||
\\
|
||||
\\
|
||||
), undefined),
|
||||
);
|
||||
|
||||
fba.reset();
|
||||
try std.testing.expectError(
|
||||
error.HttpVersionNotSupported,
|
||||
testCase(alloc, (
|
||||
\\GET /nonsense HTTP/8.1
|
||||
\\
|
||||
\\
|
||||
), undefined),
|
||||
);
|
||||
|
||||
fba.reset();
|
||||
try std.testing.expectError(
|
||||
error.BadRequest,
|
||||
testCase(alloc, (
|
||||
\\GET /nonsense HTTP/blah blah blah
|
||||
\\
|
||||
\\
|
||||
), undefined),
|
||||
);
|
||||
|
||||
fba.reset();
|
||||
try std.testing.expectError(
|
||||
error.BadRequest,
|
||||
testCase(alloc, (
|
||||
\\GET /nonsense HTTP/1/1
|
||||
\\
|
||||
\\
|
||||
), undefined),
|
||||
);
|
||||
|
||||
fba.reset();
|
||||
try std.testing.expectError(
|
||||
error.BadRequest,
|
||||
testCase(alloc, (
|
||||
\\GET /nonsense HTTP/1/1
|
||||
\\
|
||||
\\
|
||||
), undefined),
|
||||
);
|
||||
}
|
||||
|
|
282
src/http/request/test_parser.zig
Normal file
282
src/http/request/test_parser.zig
Normal file
|
@ -0,0 +1,282 @@
|
|||
const std = @import("std");
|
||||
const parser = @import("./parser.zig");
|
||||
const http = @import("../lib.zig");
|
||||
const t = std.testing;
|
||||
|
||||
const test_case = struct {
|
||||
fn parse(text: []const u8, expected: struct {
|
||||
protocol: http.Protocol = .http_1_1,
|
||||
method: http.Method = .GET,
|
||||
headers: []const std.meta.Tuple(&.{ []const u8, []const u8 }) = &.{},
|
||||
uri: []const u8 = "",
|
||||
}) !void {
|
||||
var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(text) };
|
||||
var actual = try parser.parse(t.allocator, stream.reader());
|
||||
defer actual.parseFree(t.allocator);
|
||||
|
||||
try t.expectEqual(expected.protocol, actual.protocol);
|
||||
try t.expectEqual(expected.method, actual.method);
|
||||
try t.expectEqualStrings(expected.uri, actual.uri);
|
||||
|
||||
try t.expectEqual(expected.headers.len, actual.headers.count());
|
||||
for (expected.headers) |hdr| {
|
||||
if (actual.headers.get(hdr[0])) |val| {
|
||||
try t.expectEqualStrings(hdr[1], val);
|
||||
} else {
|
||||
std.debug.print("Error: Header {s} expected to be present, was not.\n", .{hdr[0]});
|
||||
try t.expect(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
fn toCrlf(comptime str: []const u8) []const u8 {
|
||||
comptime {
|
||||
var buf: [str.len * 2]u8 = undefined;
|
||||
|
||||
@setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2
|
||||
|
||||
var buf_len: usize = 0;
|
||||
for (str) |ch| {
|
||||
if (ch == '\n') {
|
||||
buf[buf_len] = '\r';
|
||||
buf_len += 1;
|
||||
}
|
||||
|
||||
buf[buf_len] = ch;
|
||||
buf_len += 1;
|
||||
}
|
||||
|
||||
return buf[0..buf_len];
|
||||
}
|
||||
}
|
||||
|
||||
test "HTTP/1.x parse - No body" {
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
\\GET / HTTP/1.1
|
||||
\\
|
||||
\\
|
||||
),
|
||||
.{
|
||||
.protocol = .http_1_1,
|
||||
.method = .GET,
|
||||
.uri = "/",
|
||||
},
|
||||
);
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
\\POST / HTTP/1.1
|
||||
\\
|
||||
\\
|
||||
),
|
||||
.{
|
||||
.protocol = .http_1_1,
|
||||
.method = .POST,
|
||||
.uri = "/",
|
||||
},
|
||||
);
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
\\GET /url/abcd HTTP/1.1
|
||||
\\
|
||||
\\
|
||||
),
|
||||
.{
|
||||
.protocol = .http_1_1,
|
||||
.method = .GET,
|
||||
.uri = "/url/abcd",
|
||||
},
|
||||
);
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
\\GET / HTTP/1.0
|
||||
\\
|
||||
\\
|
||||
),
|
||||
.{
|
||||
.protocol = .http_1_0,
|
||||
.method = .GET,
|
||||
.uri = "/",
|
||||
},
|
||||
);
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
\\GET /url/abcd HTTP/1.1
|
||||
\\Content-Type: application/json
|
||||
\\
|
||||
\\
|
||||
),
|
||||
.{
|
||||
.protocol = .http_1_1,
|
||||
.method = .GET,
|
||||
.uri = "/url/abcd",
|
||||
.headers = &.{.{ "Content-Type", "application/json" }},
|
||||
},
|
||||
);
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
\\GET /url/abcd HTTP/1.1
|
||||
\\Content-Type: application/json
|
||||
\\Authorization: bearer <token>
|
||||
\\
|
||||
\\
|
||||
),
|
||||
.{
|
||||
.protocol = .http_1_1,
|
||||
.method = .GET,
|
||||
.uri = "/url/abcd",
|
||||
.headers = &.{
|
||||
.{ "Content-Type", "application/json" },
|
||||
.{ "Authorization", "bearer <token>" },
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// Test without CRLF
|
||||
try test_case.parse(
|
||||
\\GET /url/abcd HTTP/1.1
|
||||
\\Content-Type: application/json
|
||||
\\Authorization: bearer <token>
|
||||
\\
|
||||
\\
|
||||
,
|
||||
.{
|
||||
.protocol = .http_1_1,
|
||||
.method = .GET,
|
||||
.uri = "/url/abcd",
|
||||
.headers = &.{
|
||||
.{ "Content-Type", "application/json" },
|
||||
.{ "Authorization", "bearer <token>" },
|
||||
},
|
||||
},
|
||||
);
|
||||
try test_case.parse(
|
||||
\\POST / HTTP/1.1
|
||||
\\
|
||||
\\
|
||||
,
|
||||
.{
|
||||
.protocol = .http_1_1,
|
||||
.method = .POST,
|
||||
.uri = "/",
|
||||
},
|
||||
);
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
\\GET / HTTP/1.2
|
||||
\\
|
||||
\\
|
||||
),
|
||||
.{
|
||||
.protocol = .http_1_x,
|
||||
.method = .GET,
|
||||
.uri = "/",
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
test "HTTP/1.x parse - unsupported protocol" {
|
||||
try t.expectError(error.UnknownProtocol, test_case.parse(
|
||||
\\GET / JSON/1.1
|
||||
\\
|
||||
\\
|
||||
,
|
||||
.{},
|
||||
));
|
||||
try t.expectError(error.UnknownProtocol, test_case.parse(
|
||||
\\GET / SOMETHINGELSE/3.5
|
||||
\\
|
||||
\\
|
||||
,
|
||||
.{},
|
||||
));
|
||||
try t.expectError(error.UnknownProtocol, test_case.parse(
|
||||
\\GET / /1.1
|
||||
\\
|
||||
\\
|
||||
,
|
||||
.{},
|
||||
));
|
||||
try t.expectError(error.HttpVersionNotSupported, test_case.parse(
|
||||
\\GET / HTTP/2.1
|
||||
\\
|
||||
\\
|
||||
,
|
||||
.{},
|
||||
));
|
||||
}
|
||||
|
||||
test "HTTP/1.x parse - Unknown method" {
|
||||
try t.expectError(error.MethodNotImplemented, test_case.parse(
|
||||
\\ABCD / HTTP/1.1
|
||||
\\
|
||||
\\
|
||||
,
|
||||
.{},
|
||||
));
|
||||
try t.expectError(error.MethodNotImplemented, test_case.parse(
|
||||
\\PATCHPATCHPATCH / HTTP/1.1
|
||||
\\
|
||||
\\
|
||||
,
|
||||
.{},
|
||||
));
|
||||
}
|
||||
|
||||
test "HTTP/1.x parse - Too long" {
|
||||
try t.expectError(error.RequestUriTooLong, test_case.parse(
|
||||
std.fmt.comptimePrint("GET {s} HTTP/1.1\n\n", .{"a" ** 8192}),
|
||||
.{},
|
||||
));
|
||||
try t.expectError(error.HeaderLineTooLong, test_case.parse(
|
||||
std.fmt.comptimePrint("GET / HTTP/1.1\r\n{s}: abcd", .{"a" ** 8192}),
|
||||
.{},
|
||||
));
|
||||
try t.expectError(error.HeaderLineTooLong, test_case.parse(
|
||||
std.fmt.comptimePrint("GET / HTTP/1.1\r\nabcd: {s}", .{"a" ** 8192}),
|
||||
.{},
|
||||
));
|
||||
}
|
||||
|
||||
test "HTTP/1.x parse - bad requests" {
|
||||
try t.expectError(error.BadRequest, test_case.parse(
|
||||
\\GET / HTTP/1.1 blah blah
|
||||
\\
|
||||
\\
|
||||
,
|
||||
.{},
|
||||
));
|
||||
try t.expectError(error.BadRequest, test_case.parse(
|
||||
\\GET / HTTP/1.1
|
||||
\\abcd : lksjdfkl
|
||||
\\
|
||||
,
|
||||
.{},
|
||||
));
|
||||
try t.expectError(error.BadRequest, test_case.parse(
|
||||
\\GET / HTTP/1.1
|
||||
\\ lksjfklsjdfklj
|
||||
\\
|
||||
,
|
||||
.{},
|
||||
));
|
||||
}
|
||||
|
||||
test "HTTP/1.x parse - Headers" {
|
||||
try test_case.parse(
|
||||
toCrlf(
|
||||
\\GET /url/abcd HTTP/1.1
|
||||
\\Content-Type: application/json
|
||||
\\Content-Type: application/xml
|
||||
\\
|
||||
\\
|
||||
),
|
||||
.{
|
||||
.protocol = .http_1_1,
|
||||
.method = .GET,
|
||||
.uri = "/url/abcd",
|
||||
.headers = &.{.{ "Content-Type", "application/json, application/xml" }},
|
||||
},
|
||||
);
|
||||
}
|
|
@ -3,13 +3,14 @@ const util = @import("util");
|
|||
const http = @import("./lib.zig");
|
||||
|
||||
const response = @import("./server/response.zig");
|
||||
const request = @import("./request.zig");
|
||||
|
||||
pub const Response = struct {
|
||||
alloc: std.mem.Allocator,
|
||||
stream: std.net.Stream,
|
||||
should_close: bool = false,
|
||||
pub const Stream = response.ResponseStream(std.net.Stream.Writer);
|
||||
pub fn open(self: *Response, status: http.Status, headers: *const http.Headers) !Stream {
|
||||
pub fn open(self: *Response, status: http.Status, headers: *const http.Fields) !Stream {
|
||||
if (headers.get("Connection")) |hdr| {
|
||||
if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true;
|
||||
}
|
||||
|
@ -17,7 +18,7 @@ pub const Response = struct {
|
|||
return response.open(self.alloc, self.stream.writer(), headers, status);
|
||||
}
|
||||
|
||||
pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Headers) !std.net.Stream {
|
||||
pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Fields) !std.net.Stream {
|
||||
try response.writeRequestHeader(self.stream.writer(), headers, status);
|
||||
return self.stream;
|
||||
}
|
||||
|
@ -26,10 +27,6 @@ pub const Response = struct {
|
|||
const Request = http.Request;
|
||||
const request_buf_size = 1 << 16;
|
||||
|
||||
pub fn Handler(comptime Ctx: type) type {
|
||||
return fn (Ctx, Request, *Response) void;
|
||||
}
|
||||
|
||||
pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: anytype, alloc: std.mem.Allocator) !void {
|
||||
// TODO: Timeouts
|
||||
while (true) {
|
||||
|
@ -37,7 +34,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a
|
|||
var arena = std.heap.ArenaAllocator.init(alloc);
|
||||
defer arena.deinit();
|
||||
|
||||
const req = Request.parse(arena.allocator(), conn.stream.reader(), conn.address) catch |err| {
|
||||
var req = request.parse(arena.allocator(), conn.stream.reader()) catch |err| {
|
||||
return handleError(conn.stream.writer(), err) catch {};
|
||||
};
|
||||
std.log.debug("done parsing", .{});
|
||||
|
@ -47,7 +44,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a
|
|||
.stream = conn.stream,
|
||||
};
|
||||
|
||||
handler(ctx, req, &res);
|
||||
handler(ctx, &req, &res);
|
||||
std.log.debug("done handling", .{});
|
||||
|
||||
if (req.headers.get("Connection")) |hdr| {
|
||||
|
|
|
@ -2,20 +2,20 @@ const std = @import("std");
|
|||
const http = @import("../lib.zig");
|
||||
|
||||
const Status = http.Status;
|
||||
const Headers = http.Headers;
|
||||
const Fields = http.Fields;
|
||||
|
||||
const chunk_size = 16 * 1024;
|
||||
pub fn open(
|
||||
alloc: std.mem.Allocator,
|
||||
writer: anytype,
|
||||
headers: *const Headers,
|
||||
headers: *const Fields,
|
||||
status: Status,
|
||||
) !ResponseStream(@TypeOf(writer)) {
|
||||
const buf = try alloc.alloc(u8, chunk_size);
|
||||
errdefer alloc.free(buf);
|
||||
|
||||
try writeStatusLine(writer, status);
|
||||
try writeHeaders(writer, headers);
|
||||
try writeFields(writer, headers);
|
||||
|
||||
return ResponseStream(@TypeOf(writer)){
|
||||
.allocator = alloc,
|
||||
|
@ -25,9 +25,9 @@ pub fn open(
|
|||
};
|
||||
}
|
||||
|
||||
pub fn writeRequestHeader(writer: anytype, headers: *const Headers, status: Status) !void {
|
||||
pub fn writeRequestHeader(writer: anytype, headers: *const Fields, status: Status) !void {
|
||||
try writeStatusLine(writer, status);
|
||||
try writeHeaders(writer, headers);
|
||||
try writeFields(writer, headers);
|
||||
try writer.writeAll("\r\n");
|
||||
}
|
||||
|
||||
|
@ -36,7 +36,7 @@ fn writeStatusLine(writer: anytype, status: Status) !void {
|
|||
try writer.print("HTTP/1.1 {} {s}\r\n", .{ @enumToInt(status), status_text });
|
||||
}
|
||||
|
||||
fn writeHeaders(writer: anytype, headers: *const Headers) !void {
|
||||
fn writeFields(writer: anytype, headers: *const Fields) !void {
|
||||
var iter = headers.iterator();
|
||||
while (iter.next()) |header| {
|
||||
for (header.value_ptr.*) |ch| {
|
||||
|
@ -65,7 +65,7 @@ pub fn ResponseStream(comptime BaseWriter: type) type {
|
|||
|
||||
allocator: std.mem.Allocator,
|
||||
base_writer: BaseWriter,
|
||||
headers: *const Headers,
|
||||
headers: *const Fields,
|
||||
buffer: []u8,
|
||||
buffer_pos: usize = 0,
|
||||
chunked: bool = false,
|
||||
|
@ -95,7 +95,6 @@ pub fn ResponseStream(comptime BaseWriter: type) type {
|
|||
return;
|
||||
}
|
||||
|
||||
std.debug.print("{}\n", .{cursor});
|
||||
self.writeToBuffer(bytes[cursor .. cursor + remaining_in_chunk]);
|
||||
cursor += remaining_in_chunk;
|
||||
try self.flushChunk();
|
||||
|
@ -177,7 +176,7 @@ const _tests = struct {
|
|||
test "ResponseStream no headers empty body" {
|
||||
var buffer: [test_buffer_size]u8 = undefined;
|
||||
var test_stream = std.io.fixedBufferStream(&buffer);
|
||||
var headers = Headers.init(std.testing.allocator);
|
||||
var headers = Fields.init(std.testing.allocator);
|
||||
defer headers.deinit();
|
||||
|
||||
{
|
||||
|
@ -205,7 +204,7 @@ const _tests = struct {
|
|||
test "ResponseStream empty body" {
|
||||
var buffer: [test_buffer_size]u8 = undefined;
|
||||
var test_stream = std.io.fixedBufferStream(&buffer);
|
||||
var headers = Headers.init(std.testing.allocator);
|
||||
var headers = Fields.init(std.testing.allocator);
|
||||
defer headers.deinit();
|
||||
|
||||
try headers.put("Content-Type", "text/plain");
|
||||
|
@ -236,7 +235,7 @@ const _tests = struct {
|
|||
test "ResponseStream not 200 OK" {
|
||||
var buffer: [test_buffer_size]u8 = undefined;
|
||||
var test_stream = std.io.fixedBufferStream(&buffer);
|
||||
var headers = Headers.init(std.testing.allocator);
|
||||
var headers = Fields.init(std.testing.allocator);
|
||||
defer headers.deinit();
|
||||
|
||||
try headers.put("Content-Type", "text/plain");
|
||||
|
@ -266,7 +265,7 @@ const _tests = struct {
|
|||
test "ResponseStream small body" {
|
||||
var buffer: [test_buffer_size]u8 = undefined;
|
||||
var test_stream = std.io.fixedBufferStream(&buffer);
|
||||
var headers = Headers.init(std.testing.allocator);
|
||||
var headers = Fields.init(std.testing.allocator);
|
||||
defer headers.deinit();
|
||||
|
||||
try headers.put("Content-Type", "text/plain");
|
||||
|
@ -300,7 +299,7 @@ const _tests = struct {
|
|||
test "ResponseStream large body" {
|
||||
var buffer: [test_buffer_size]u8 = undefined;
|
||||
var test_stream = std.io.fixedBufferStream(&buffer);
|
||||
var headers = Headers.init(std.testing.allocator);
|
||||
var headers = Fields.init(std.testing.allocator);
|
||||
defer headers.deinit();
|
||||
|
||||
try headers.put("Content-Type", "text/plain");
|
||||
|
@ -341,7 +340,7 @@ const _tests = struct {
|
|||
test "ResponseStream large body ending on chunk boundary" {
|
||||
var buffer: [test_buffer_size]u8 = undefined;
|
||||
var test_stream = std.io.fixedBufferStream(&buffer);
|
||||
var headers = Headers.init(std.testing.allocator);
|
||||
var headers = Fields.init(std.testing.allocator);
|
||||
defer headers.deinit();
|
||||
|
||||
try headers.put("Content-Type", "text/plain");
|
||||
|
|
|
@ -23,21 +23,21 @@ const Opcode = enum(u4) {
|
|||
}
|
||||
};
|
||||
|
||||
pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Response) !Socket {
|
||||
pub fn handshake(alloc: std.mem.Allocator, req: *http.Request, res: *http.Response) !Socket {
|
||||
const upgrade = req.headers.get("Upgrade") orelse return error.BadHandshake;
|
||||
const connection = req.headers.get("Connection") orelse return error.BadHandshake;
|
||||
if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake;
|
||||
if (std.ascii.indexOfIgnoreCase(connection, "Upgrade") == null) return error.BadHandshake;
|
||||
|
||||
const key_hdr = req.headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake;
|
||||
if (try std.base64.standard.Decoder.calcSizeForSlice(key_hdr) != 16) return error.BadHandshake;
|
||||
if ((try std.base64.standard.Decoder.calcSizeForSlice(key_hdr)) != 16) return error.BadHandshake;
|
||||
var key: [16]u8 = undefined;
|
||||
std.base64.standard.Decoder.decode(&key, key_hdr) catch return error.BadHandshake;
|
||||
|
||||
const version = req.headers.get("Sec-WebSocket-version") orelse return error.BadHandshake;
|
||||
if (!std.mem.eql(u8, "13", version)) return error.BadHandshake;
|
||||
|
||||
var headers = http.Headers.init(alloc);
|
||||
var headers = http.Fields.init(alloc);
|
||||
defer headers.deinit();
|
||||
|
||||
try headers.put("Upgrade", "websocket");
|
||||
|
@ -51,7 +51,7 @@ pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Respons
|
|||
var hash_encoded: [std.base64.standard.Encoder.calcSize(Sha1.digest_length)]u8 = undefined;
|
||||
_ = std.base64.standard.Encoder.encode(&hash_encoded, &hash);
|
||||
try headers.put("Sec-WebSocket-Accept", &hash_encoded);
|
||||
const stream = try res.upgrade(.switching_protcols, &headers);
|
||||
const stream = try res.upgrade(.switching_protocols, &headers);
|
||||
|
||||
return Socket{ .stream = stream };
|
||||
}
|
||||
|
@ -164,15 +164,15 @@ fn writeFrame(writer: anytype, header: FrameInfo, buf: []const u8) !void {
|
|||
const initial_len: u7 = if (header.len < 126)
|
||||
@intCast(u7, header.len)
|
||||
else if (std.math.cast(u16, header.len)) |_|
|
||||
126
|
||||
@as(u7, 126)
|
||||
else
|
||||
127;
|
||||
@as(u7, 127);
|
||||
|
||||
var hdr_buf = [2]u8{ 0, 0 };
|
||||
hdr_buf[0] |= if (header.is_final) 0b1000_0000 else 0;
|
||||
hdr_buf[0] |= if (header.is_final) @as(u8, 0b1000_0000) else 0;
|
||||
hdr_buf[0] |= @as(u8, header.rsv) << 4;
|
||||
hdr_buf[0] |= @enumToInt(header.opcode);
|
||||
hdr_buf[1] |= if (header.masking_key) |_| 0b1000_0000 else 0;
|
||||
hdr_buf[1] |= if (header.masking_key) |_| @as(u8, 0b1000_0000) else 0;
|
||||
hdr_buf[1] |= initial_len;
|
||||
try writer.writeAll(&hdr_buf);
|
||||
if (initial_len == 126)
|
||||
|
|
3
src/http/test.zig
Normal file
3
src/http/test.zig
Normal file
|
@ -0,0 +1,3 @@
|
|||
test {
|
||||
_ = @import("./request/test_parser.zig");
|
||||
}
|
|
@ -13,10 +13,11 @@ pub const invites = @import("./controllers/invites.zig");
|
|||
pub const users = @import("./controllers/users.zig");
|
||||
pub const notes = @import("./controllers/notes.zig");
|
||||
pub const streaming = @import("./controllers/streaming.zig");
|
||||
pub const timelines = @import("./controllers/timelines.zig");
|
||||
|
||||
pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void {
|
||||
pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response, alloc: std.mem.Allocator) void {
|
||||
// TODO: hashmaps?
|
||||
var response = Response{ .headers = http.Headers.init(alloc), .res = res };
|
||||
var response = Response{ .headers = http.Fields.init(alloc), .res = res };
|
||||
defer response.headers.deinit();
|
||||
|
||||
const found = routeRequestInternal(api_source, req, &response, alloc);
|
||||
|
@ -24,7 +25,7 @@ pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response,
|
|||
if (!found) response.status(.not_found) catch {};
|
||||
}
|
||||
|
||||
fn routeRequestInternal(api_source: anytype, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool {
|
||||
fn routeRequestInternal(api_source: anytype, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool {
|
||||
inline for (routes) |route| {
|
||||
if (Context(route).matchAndHandle(api_source, req, res, alloc)) return true;
|
||||
}
|
||||
|
@ -42,6 +43,8 @@ const routes = .{
|
|||
notes.create,
|
||||
notes.get,
|
||||
streaming.streaming,
|
||||
timelines.global,
|
||||
timelines.local,
|
||||
};
|
||||
|
||||
pub fn Context(comptime Route: type) type {
|
||||
|
@ -58,18 +61,21 @@ pub fn Context(comptime Route: type) type {
|
|||
// leave it as a simple string instead of void
|
||||
pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void;
|
||||
|
||||
base_request: http.Request,
|
||||
base_request: *http.Request,
|
||||
|
||||
allocator: std.mem.Allocator,
|
||||
|
||||
method: http.Method,
|
||||
uri: []const u8,
|
||||
headers: http.Headers,
|
||||
headers: http.Fields,
|
||||
|
||||
args: Args,
|
||||
body: Body,
|
||||
query: Query,
|
||||
|
||||
// TODO
|
||||
body_buf: ?[]const u8 = null,
|
||||
|
||||
fn parseArgs(path: []const u8) ?Args {
|
||||
var args: Args = undefined;
|
||||
var path_iter = util.PathIter.from(path);
|
||||
|
@ -94,7 +100,7 @@ pub fn Context(comptime Route: type) type {
|
|||
@compileError("Unsupported Type " ++ @typeName(T));
|
||||
}
|
||||
|
||||
pub fn matchAndHandle(api_source: *api.ApiSource, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool {
|
||||
pub fn matchAndHandle(api_source: *api.ApiSource, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool {
|
||||
if (req.method != Route.method) return false;
|
||||
var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?');
|
||||
var args: Args = parseArgs(path) orelse return false;
|
||||
|
@ -112,6 +118,8 @@ pub fn Context(comptime Route: type) type {
|
|||
.query = undefined,
|
||||
};
|
||||
|
||||
std.log.debug("Matched route {s}", .{path});
|
||||
|
||||
self.prepareAndHandle(api_source, req, res);
|
||||
|
||||
return true;
|
||||
|
@ -129,7 +137,7 @@ pub fn Context(comptime Route: type) type {
|
|||
};
|
||||
}
|
||||
|
||||
fn prepareAndHandle(self: *Self, api_source: anytype, req: http.Request, response: *Response) void {
|
||||
fn prepareAndHandle(self: *Self, api_source: anytype, req: *http.Request, response: *Response) void {
|
||||
self.parseBody(req) catch |err| return errorHandler(response, .bad_request, err);
|
||||
defer self.freeBody();
|
||||
|
||||
|
@ -141,16 +149,20 @@ pub fn Context(comptime Route: type) type {
|
|||
self.handle(response, &api_conn);
|
||||
}
|
||||
|
||||
fn parseBody(self: *Self, req: http.Request) !void {
|
||||
fn parseBody(self: *Self, req: *http.Request) !void {
|
||||
if (Body != void) {
|
||||
const body = req.body orelse return error.NoBody;
|
||||
var stream = req.body orelse return error.NoBody;
|
||||
const body = try stream.reader().readAllAlloc(self.allocator, 1 << 16);
|
||||
errdefer self.allocator.free(body);
|
||||
self.body = try json_utils.parse(Body, body, self.allocator);
|
||||
self.body_buf = body;
|
||||
}
|
||||
}
|
||||
|
||||
fn freeBody(self: *Self) void {
|
||||
if (Body != void) {
|
||||
json_utils.parseFree(self.body, self.allocator);
|
||||
self.allocator.free(self.body_buf.?);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -191,7 +203,7 @@ pub fn Context(comptime Route: type) type {
|
|||
|
||||
pub const Response = struct {
|
||||
const Self = @This();
|
||||
headers: http.Headers,
|
||||
headers: http.Fields,
|
||||
res: *http.Response,
|
||||
opened: bool = false,
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
const api = @import("api");
|
||||
const std = @import("std");
|
||||
|
||||
pub const login = struct {
|
||||
pub const method = .POST;
|
||||
|
@ -12,6 +13,8 @@ pub const login = struct {
|
|||
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
|
||||
const token = try srv.login(req.body.username, req.body.password);
|
||||
|
||||
std.log.debug("{any}", .{res.headers});
|
||||
|
||||
try res.json(.ok, token);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
const std = @import("std");
|
||||
const api = @import("api");
|
||||
const util = @import("util");
|
||||
const query_utils = @import("../query.zig");
|
||||
|
||||
const QueryArgs = api.CommunityQueryArgs;
|
||||
const Uuid = util.Uuid;
|
||||
|
@ -25,89 +26,18 @@ pub const query = struct {
|
|||
pub const method = .GET;
|
||||
pub const path = "/communities";
|
||||
|
||||
// NOTE: This has to match QueryArgs
|
||||
// TODO: Support union fields in query strings natively, so we don't
|
||||
// have to keep these in sync
|
||||
pub const Query = struct {
|
||||
const OrderBy = QueryArgs.OrderBy;
|
||||
const Direction = QueryArgs.Direction;
|
||||
const PageDirection = QueryArgs.PageDirection;
|
||||
|
||||
// Max items to fetch
|
||||
max_items: usize = 20,
|
||||
|
||||
// Selection filters
|
||||
owner_id: ?Uuid = null,
|
||||
like: ?[]const u8 = null,
|
||||
created_before: ?DateTime = null,
|
||||
created_after: ?DateTime = null,
|
||||
|
||||
// Ordering parameter
|
||||
order_by: OrderBy = .created_at,
|
||||
direction: Direction = .ascending,
|
||||
|
||||
// the `prev` struct has a slightly different format to QueryArgs
|
||||
prev: struct {
|
||||
id: ?Uuid = null,
|
||||
|
||||
// Only one of these can be present, and must match order_by above
|
||||
name: ?[]const u8 = null,
|
||||
host: ?[]const u8 = null,
|
||||
created_at: ?DateTime = null,
|
||||
} = .{},
|
||||
|
||||
// What direction to scan the page window
|
||||
page_direction: PageDirection = .forward,
|
||||
|
||||
pub const format = formatQueryParams;
|
||||
};
|
||||
pub const Query = QueryArgs;
|
||||
|
||||
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
|
||||
const q = req.query;
|
||||
const query_matches = if (q.prev.id) |_| switch (q.order_by) {
|
||||
.name => q.prev.name != null and q.prev.host == null and q.prev.created_at == null,
|
||||
.host => q.prev.name == null and q.prev.host != null and q.prev.created_at == null,
|
||||
.created_at => q.prev.name == null and q.prev.host == null and q.prev.created_at != null,
|
||||
} else (q.prev.name == null and q.prev.host == null and q.prev.created_at == null);
|
||||
|
||||
if (!query_matches) return res.err(.bad_request, "prev.* parameters do not match", {});
|
||||
|
||||
const prev_arg: ?QueryArgs.Prev = if (q.prev.id) |id| .{
|
||||
.id = id,
|
||||
.order_val = switch (q.order_by) {
|
||||
.name => .{ .name = q.prev.name.? },
|
||||
.host => .{ .host = q.prev.host.? },
|
||||
.created_at => .{ .created_at = q.prev.created_at.? },
|
||||
},
|
||||
} else null;
|
||||
|
||||
const query_args = QueryArgs{
|
||||
.max_items = q.max_items,
|
||||
.owner_id = q.owner_id,
|
||||
.like = q.like,
|
||||
.created_before = q.created_before,
|
||||
.created_after = q.created_after,
|
||||
|
||||
.order_by = q.order_by,
|
||||
.direction = q.direction,
|
||||
|
||||
.prev = prev_arg,
|
||||
|
||||
.page_direction = q.page_direction,
|
||||
};
|
||||
|
||||
const results = try srv.queryCommunities(query_args);
|
||||
const results = try srv.queryCommunities(req.query);
|
||||
|
||||
var link = std.ArrayList(u8).init(req.allocator);
|
||||
const link_writer = link.writer();
|
||||
defer link.deinit();
|
||||
|
||||
const next_page = queryArgsToControllerQuery(results.next_page);
|
||||
const prev_page = queryArgsToControllerQuery(results.prev_page);
|
||||
|
||||
try writeLink(link_writer, srv.community, path, next_page, "next");
|
||||
try writeLink(link_writer, srv.community, path, results.next_page, "next");
|
||||
try link_writer.writeByte(',');
|
||||
try writeLink(link_writer, srv.community, path, prev_page, "prev");
|
||||
try writeLink(link_writer, srv.community, path, results.prev_page, "prev");
|
||||
|
||||
try res.headers.put("Link", link.items);
|
||||
|
||||
|
@ -129,7 +59,7 @@ fn writeLink(
|
|||
.{ @tagName(community.scheme), community.host, path },
|
||||
);
|
||||
|
||||
try std.fmt.format(writer, "{}", .{params});
|
||||
try query_utils.formatQuery(params, writer);
|
||||
|
||||
try std.fmt.format(
|
||||
writer,
|
||||
|
@ -137,70 +67,3 @@ fn writeLink(
|
|||
.{rel},
|
||||
);
|
||||
}
|
||||
|
||||
fn formatQueryParams(
|
||||
params: anytype,
|
||||
comptime fmt: []const u8,
|
||||
opt: std.fmt.FormatOptions,
|
||||
writer: anytype,
|
||||
) !void {
|
||||
if (comptime std.meta.trait.is(.Pointer)(@TypeOf(params))) {
|
||||
return formatQueryParams(params.*, fmt, opt, writer);
|
||||
}
|
||||
|
||||
return formatRecursive("", params, writer);
|
||||
}
|
||||
|
||||
fn formatRecursive(comptime prefix: []const u8, params: anytype, writer: anytype) !void {
|
||||
inline for (std.meta.fields(@TypeOf(params))) |field| {
|
||||
const val = @field(params, field.name);
|
||||
const is_optional = comptime std.meta.trait.is(.Optional)(field.field_type);
|
||||
const present = if (comptime is_optional) val != null else true;
|
||||
if (present) {
|
||||
const unwrapped = if (is_optional) val.? else val;
|
||||
// TODO: percent-encode this
|
||||
_ = try switch (@TypeOf(unwrapped)) {
|
||||
[]const u8 => blk: {
|
||||
break :blk std.fmt.format(writer, "{s}{s}={s}&", .{ prefix, field.name, unwrapped });
|
||||
},
|
||||
|
||||
else => |U| blk: {
|
||||
if (comptime std.meta.trait.isContainer(U) and std.meta.trait.hasFn("format")(U)) {
|
||||
break :blk std.fmt.format(writer, "{s}{s}={}&", .{ prefix, field.name, unwrapped });
|
||||
}
|
||||
|
||||
break :blk switch (@typeInfo(U)) {
|
||||
.Enum => std.fmt.format(writer, "{s}{s}={s}&", .{ prefix, field.name, @tagName(unwrapped) }),
|
||||
.Struct => formatRecursive(field.name ++ ".", unwrapped, writer),
|
||||
else => std.fmt.format(writer, "{s}{s}={}&", .{ prefix, field.name, unwrapped }),
|
||||
};
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn queryArgsToControllerQuery(args: QueryArgs) query.Query {
|
||||
var result = query.Query{
|
||||
.max_items = args.max_items,
|
||||
.owner_id = args.owner_id,
|
||||
.like = args.like,
|
||||
.created_before = args.created_before,
|
||||
.created_after = args.created_after,
|
||||
.order_by = args.order_by,
|
||||
.direction = args.direction,
|
||||
.prev = .{},
|
||||
.page_direction = args.page_direction,
|
||||
};
|
||||
|
||||
if (args.prev) |prev| {
|
||||
result.prev = .{
|
||||
.id = prev.id,
|
||||
.name = if (prev.order_val == .name) prev.order_val.name else null,
|
||||
.host = if (prev.order_val == .host) prev.order_val.host else null,
|
||||
.created_at = if (prev.order_val == .created_at) prev.order_val.created_at else null,
|
||||
};
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
21
src/main/controllers/timelines.zig
Normal file
21
src/main/controllers/timelines.zig
Normal file
|
@ -0,0 +1,21 @@
|
|||
pub const global = struct {
|
||||
pub const method = .GET;
|
||||
pub const path = "/timelines/global";
|
||||
|
||||
pub fn handler(_: anytype, res: anytype, srv: anytype) !void {
|
||||
const results = try srv.globalTimeline();
|
||||
|
||||
try res.json(.ok, results);
|
||||
}
|
||||
};
|
||||
|
||||
pub const local = struct {
|
||||
pub const method = .GET;
|
||||
pub const path = "/timelines/local";
|
||||
|
||||
pub fn handler(_: anytype, res: anytype, srv: anytype) !void {
|
||||
const results = try srv.localTimeline();
|
||||
|
||||
try res.json(.ok, results);
|
||||
}
|
||||
};
|
|
@ -499,7 +499,7 @@ fn parseInternal(
|
|||
if (!fields_seen[i]) {
|
||||
if (field.default_value) |default_ptr| {
|
||||
if (!field.is_comptime) {
|
||||
const default = @ptrCast(*const field.field_type, default_ptr).*;
|
||||
const default = @ptrCast(*align(1) const field.field_type, default_ptr).*;
|
||||
@field(r, field.name) = default;
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -87,7 +87,7 @@ fn thread_main(src: *api.ApiSource, srv: *std.net.StreamServer) void {
|
|||
}
|
||||
}
|
||||
|
||||
fn handle(ctx: anytype, req: http.Request, res: *http.Response) void {
|
||||
fn handle(ctx: anytype, req: *http.Request, res: *http.Response) void {
|
||||
c.routeRequest(ctx.src, req, res, ctx.allocator);
|
||||
}
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ pub fn up(db: anytype) !void {
|
|||
std.log.info("Running migration {s}", .{migration.name});
|
||||
try execScript(tx, migration.up, gpa.allocator());
|
||||
try tx.insert("migration", .{
|
||||
.name = migration.name,
|
||||
.name = @as([]const u8, migration.name),
|
||||
.applied_at = DateTime.now(),
|
||||
}, gpa.allocator());
|
||||
}
|
||||
|
|
|
@ -71,37 +71,132 @@ const QueryIter = @import("util").QueryIter;
|
|||
/// TODO: values are currently case-sensitive, and are not url-decoded properly.
|
||||
/// This should be fixed.
|
||||
pub fn parseQuery(comptime T: type, query: []const u8) !T {
|
||||
//if (!std.meta.trait.isContainer(T)) @compileError("T must be a struct");
|
||||
if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct");
|
||||
var iter = QueryIter.from(query);
|
||||
var result = T{};
|
||||
|
||||
var fields = Intermediary(T){};
|
||||
while (iter.next()) |pair| {
|
||||
try parseQueryPair(T, &result, pair.key, pair.value);
|
||||
// TODO: Hash map
|
||||
inline for (std.meta.fields(Intermediary(T))) |field| {
|
||||
if (std.ascii.eqlIgnoreCase(field.name[2..], pair.key)) {
|
||||
@field(fields, field.name) = if (pair.value) |v| .{ .value = v } else .{ .no_value = {} };
|
||||
break;
|
||||
}
|
||||
} else std.log.debug("unknown param {s}", .{pair.key});
|
||||
}
|
||||
|
||||
return result;
|
||||
return (try parse(T, "", "", fields)).?;
|
||||
}
|
||||
|
||||
fn parseQueryPair(comptime T: type, result: *T, key: []const u8, value: ?[]const u8) !void {
|
||||
const key_part = std.mem.sliceTo(key, '.');
|
||||
const field_idx = std.meta.stringToEnum(std.meta.FieldEnum(T), key_part) orelse return error.UnknownField;
|
||||
fn parseScalar(comptime T: type, comptime name: []const u8, fields: anytype) !?T {
|
||||
const param = @field(fields, name);
|
||||
return switch (param) {
|
||||
.not_specified => null,
|
||||
.no_value => try parseQueryValue(T, null),
|
||||
.value => |v| try parseQueryValue(T, v),
|
||||
};
|
||||
}
|
||||
|
||||
inline for (std.meta.fields(T)) |info, idx| {
|
||||
if (@enumToInt(field_idx) == idx) {
|
||||
if (comptime isScalar(info.field_type)) {
|
||||
if (key_part.len == key.len) {
|
||||
@field(result, info.name) = try parseQueryValue(info.field_type, value);
|
||||
return;
|
||||
} else {
|
||||
return error.UnknownField;
|
||||
fn parse(comptime T: type, comptime prefix: []const u8, comptime name: []const u8, fields: anytype) !?T {
|
||||
if (comptime isScalar(T)) return parseScalar(T, prefix ++ "." ++ name, fields);
|
||||
switch (@typeInfo(T)) {
|
||||
.Union => |info| {
|
||||
var result: ?T = null;
|
||||
inline for (info.fields) |field| {
|
||||
const F = field.field_type;
|
||||
|
||||
const maybe_value = try parse(F, prefix, field.name, fields);
|
||||
if (maybe_value) |value| {
|
||||
if (result != null) return error.DuplicateUnionField;
|
||||
|
||||
result = @unionInit(T, field.name, value);
|
||||
}
|
||||
}
|
||||
std.log.debug("{any}", .{result});
|
||||
return result;
|
||||
},
|
||||
|
||||
.Struct => |info| {
|
||||
var result: T = undefined;
|
||||
var fields_specified: usize = 0;
|
||||
|
||||
inline for (info.fields) |field| {
|
||||
const F = field.field_type;
|
||||
|
||||
var maybe_value: ?F = null;
|
||||
if (try parse(F, prefix ++ "." ++ name, field.name, fields)) |v| {
|
||||
maybe_value = v;
|
||||
} else if (field.default_value) |default| {
|
||||
maybe_value = @ptrCast(*const F, @alignCast(@alignOf(F), default)).*;
|
||||
}
|
||||
|
||||
if (maybe_value) |v| {
|
||||
fields_specified += 1;
|
||||
@field(result, field.name) = v;
|
||||
}
|
||||
}
|
||||
|
||||
if (fields_specified == 0) {
|
||||
return null;
|
||||
} else if (fields_specified != info.fields.len) {
|
||||
return error.PartiallySpecifiedStruct;
|
||||
} else {
|
||||
const remaining = std.mem.trimLeft(u8, key[key_part.len..], ".");
|
||||
return try parseQueryPair(info.field_type, &@field(result, info.name), remaining, value);
|
||||
return result;
|
||||
}
|
||||
},
|
||||
|
||||
// Only applies to non-scalar optionals
|
||||
.Optional => |info| return try parse(info.child, prefix, name, fields),
|
||||
|
||||
else => @compileError("tmp"),
|
||||
}
|
||||
}
|
||||
|
||||
fn recursiveFieldPaths(comptime T: type, comptime prefix: []const u8) []const []const u8 {
|
||||
comptime {
|
||||
if (std.meta.trait.is(.Optional)(T)) return recursiveFieldPaths(std.meta.Child(T), prefix);
|
||||
|
||||
var fields: []const []const u8 = &.{};
|
||||
|
||||
for (std.meta.fields(T)) |f| {
|
||||
const full_name = prefix ++ f.name;
|
||||
|
||||
if (isScalar(f.field_type)) {
|
||||
fields = fields ++ @as([]const []const u8, &.{full_name});
|
||||
} else {
|
||||
const field_prefix = if (@typeInfo(f.field_type) == .Union) prefix else full_name ++ ".";
|
||||
fields = fields ++ recursiveFieldPaths(f.field_type, field_prefix);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return error.UnknownField;
|
||||
return fields;
|
||||
}
|
||||
}
|
||||
|
||||
const QueryParam = union(enum) {
|
||||
not_specified: void,
|
||||
no_value: void,
|
||||
value: []const u8,
|
||||
};
|
||||
|
||||
fn Intermediary(comptime T: type) type {
|
||||
const field_names = recursiveFieldPaths(T, "..");
|
||||
|
||||
var fields: [field_names.len]std.builtin.Type.StructField = undefined;
|
||||
for (field_names) |name, i| fields[i] = .{
|
||||
.name = name,
|
||||
.field_type = QueryParam,
|
||||
.default_value = &QueryParam{ .not_specified = {} },
|
||||
.is_comptime = false,
|
||||
.alignment = @alignOf(QueryParam),
|
||||
};
|
||||
|
||||
return @Type(.{ .Struct = .{
|
||||
.layout = .Auto,
|
||||
.fields = &fields,
|
||||
.decls = &.{},
|
||||
.is_tuple = false,
|
||||
} });
|
||||
}
|
||||
|
||||
fn parseQueryValue(comptime T: type, value: ?[]const u8) !T {
|
||||
|
@ -157,6 +252,50 @@ fn isScalar(comptime T: type) bool {
|
|||
return false;
|
||||
}
|
||||
|
||||
pub fn formatQuery(params: anytype, writer: anytype) !void {
|
||||
try format("", "", params, writer);
|
||||
}
|
||||
|
||||
fn formatScalar(comptime name: []const u8, val: anytype, writer: anytype) !void {
|
||||
const T = @TypeOf(val);
|
||||
if (comptime std.meta.trait.isZigString(T)) return std.fmt.format(writer, "{s}={s}&", .{ name, val });
|
||||
_ = try switch (@typeInfo(T)) {
|
||||
.Enum => std.fmt.format(writer, "{s}={s}&", .{ name, @tagName(val) }),
|
||||
.Optional => if (val) |v| formatScalar(name, v, writer),
|
||||
else => std.fmt.format(writer, "{s}={}&", .{ name, val }),
|
||||
};
|
||||
}
|
||||
|
||||
fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytype, writer: anytype) !void {
|
||||
const T = @TypeOf(params);
|
||||
const eff_prefix = if (prefix.len == 0) "" else prefix ++ ".";
|
||||
if (comptime isScalar(T)) return formatScalar(eff_prefix ++ name, params, writer);
|
||||
|
||||
switch (@typeInfo(T)) {
|
||||
.Struct => {
|
||||
inline for (std.meta.fields(T)) |field| {
|
||||
const val = @field(params, field.name);
|
||||
try format(eff_prefix ++ name, field.name, val, writer);
|
||||
}
|
||||
},
|
||||
.Union => {
|
||||
//inline for (std.meta.tags(T)) |tag| {
|
||||
inline for (std.meta.fields(T)) |field| {
|
||||
const tag = @field(std.meta.Tag(T), field.name);
|
||||
const tag_name = field.name;
|
||||
if (@as(std.meta.Tag(T), params) == tag) {
|
||||
const val = @field(params, tag_name);
|
||||
try format(prefix, tag_name, val, writer);
|
||||
}
|
||||
}
|
||||
},
|
||||
.Optional => {
|
||||
if (params) |p| try format(prefix, name, p, writer);
|
||||
},
|
||||
else => @compileError("Unsupported query type"),
|
||||
}
|
||||
}
|
||||
|
||||
test {
|
||||
const TestQuery = struct {
|
||||
int: usize = 3,
|
||||
|
|
|
@ -68,7 +68,7 @@ pub const QueryOptions = struct {
|
|||
// do not require allocators for prep. If an allocator is needed but not
|
||||
// provided, `error.AllocatorRequired` will be returned.
|
||||
// Only used with the postgres backend.
|
||||
prep_allocator: ?Allocator = null,
|
||||
allocator: ?Allocator = null,
|
||||
};
|
||||
|
||||
// Turns a value into its appropriate textual value (or null)
|
||||
|
|
|
@ -180,14 +180,25 @@ pub const Db = struct {
|
|||
const format_text = 0;
|
||||
const format_binary = 1;
|
||||
pub fn exec(self: Db, sql: [:0]const u8, args: anytype, opt: common.QueryOptions) common.ExecError!Results {
|
||||
const alloc = opt.prep_allocator;
|
||||
const alloc = opt.allocator;
|
||||
const result = blk: {
|
||||
if (@TypeOf(args) != void and args.len > 0) {
|
||||
var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired);
|
||||
defer arena.deinit();
|
||||
const params = try arena.allocator().alloc(?[*:0]const u8, args.len);
|
||||
inline for (args) |arg, i| {
|
||||
params[i] = if (try common.prepareParamText(&arena, arg)) |slice|
|
||||
// TODO: The following is a fix for the stage1 compiler. remove this
|
||||
//inline for (args) |arg, i| {
|
||||
inline for (std.meta.fields(@TypeOf(args))) |field, i| {
|
||||
const arg = @field(args, field.name);
|
||||
|
||||
// The stage1 compiler has issues with runtime branches that in any
|
||||
// way involve compile time values
|
||||
const maybe_slice = if (@import("builtin").zig_backend == .stage1)
|
||||
common.prepareParamText(&arena, arg) catch unreachable
|
||||
else
|
||||
try common.prepareParamText(&arena, arg);
|
||||
|
||||
params[i] = if (maybe_slice) |slice|
|
||||
slice.ptr
|
||||
else
|
||||
null;
|
||||
|
|
|
@ -118,7 +118,10 @@ pub const Db = struct {
|
|||
};
|
||||
|
||||
if (@TypeOf(args) != void) {
|
||||
inline for (args) |arg, i| {
|
||||
// TODO: Fix for stage1 compiler
|
||||
//inline for (args) |arg, i| {
|
||||
inline for (std.meta.fields(@TypeOf(args))) |field, i| {
|
||||
const arg = @field(args, field.name);
|
||||
// SQLite treats $NNN args as having the name NNN, not index NNN.
|
||||
// As such, if you reference $2 and not $1 in your query (such as
|
||||
// when dynamically constructing queries), it could assign $2 the
|
||||
|
|
|
@ -25,6 +25,54 @@ pub const Engine = enum {
|
|||
sqlite,
|
||||
};
|
||||
|
||||
/// Helper for building queries at runtime. All constituent parts of the
|
||||
/// query should be defined at comptime, however the choice of whether
|
||||
/// or not to include them can occur at runtime.
|
||||
pub const QueryBuilder = struct {
|
||||
array: std.ArrayList(u8),
|
||||
where_clauses_appended: usize = 0,
|
||||
|
||||
pub fn init(alloc: std.mem.Allocator) QueryBuilder {
|
||||
return QueryBuilder{ .array = std.ArrayList(u8).init(alloc) };
|
||||
}
|
||||
|
||||
pub fn deinit(self: *const QueryBuilder) void {
|
||||
self.array.deinit();
|
||||
}
|
||||
|
||||
/// Add a chunk of sql to the query without processing
|
||||
pub fn appendSlice(self: *QueryBuilder, comptime sql: []const u8) !void {
|
||||
try self.array.appendSlice(sql);
|
||||
}
|
||||
|
||||
/// Add a where clause to the query. Clauses are assumed to be components
|
||||
/// in an overall expression in Conjunctive Normal Form (AND of OR's).
|
||||
/// https://en.wikipedia.org/wiki/Conjunctive_normal_form
|
||||
/// All calls to andWhere must be contiguous, that is, they cannot be
|
||||
/// interspersed with calls to appendSlice
|
||||
pub fn andWhere(self: *QueryBuilder, comptime clause: []const u8) !void {
|
||||
if (self.where_clauses_appended == 0) {
|
||||
try self.array.appendSlice("WHERE ");
|
||||
} else {
|
||||
try self.array.appendSlice(" AND ");
|
||||
}
|
||||
|
||||
try self.array.appendSlice(clause);
|
||||
self.where_clauses_appended += 1;
|
||||
}
|
||||
|
||||
pub fn str(self: *const QueryBuilder) []const u8 {
|
||||
return self.array.items;
|
||||
}
|
||||
|
||||
pub fn terminate(self: *QueryBuilder) ![:0]const u8 {
|
||||
std.debug.assert(self.array.items.len != 0);
|
||||
if (self.array.items[self.array.items.len - 1] != 0) try self.array.append(0);
|
||||
|
||||
return std.meta.assumeSentinel(self.array.items, 0);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: make this suck less
|
||||
pub const Config = union(Engine) {
|
||||
postgres: struct {
|
||||
|
@ -410,7 +458,7 @@ fn Tx(comptime tx_level: u8) type {
|
|||
args: anytype,
|
||||
alloc: ?Allocator,
|
||||
) QueryError!Results(RowType) {
|
||||
return self.queryWithOptions(RowType, sql, args, .{ .prep_allocator = alloc });
|
||||
return self.queryWithOptions(RowType, sql, args, .{ .allocator = alloc });
|
||||
}
|
||||
|
||||
/// Runs a query to completion and returns a row of results, unless the query
|
||||
|
@ -439,6 +487,45 @@ fn Tx(comptime tx_level: u8) type {
|
|||
return row;
|
||||
}
|
||||
|
||||
// Runs a query to completion and returns the results as a slice
|
||||
pub fn queryRowsWithOptions(
|
||||
self: Self,
|
||||
comptime RowType: type,
|
||||
q: [:0]const u8,
|
||||
args: anytype,
|
||||
max_items: ?usize,
|
||||
options: QueryOptions,
|
||||
) QueryRowError![]RowType {
|
||||
var results = try self.queryWithOptions(RowType, q, args, options);
|
||||
defer results.finish();
|
||||
|
||||
const alloc = options.allocator orelse return error.AllocatorRequired;
|
||||
|
||||
var result_array = std.ArrayList(RowType).init(alloc);
|
||||
errdefer result_array.deinit();
|
||||
if (max_items) |max| try result_array.ensureTotalCapacity(max);
|
||||
|
||||
errdefer for (result_array.items) |r| util.deepFree(alloc, r);
|
||||
|
||||
var too_many: bool = false;
|
||||
while (try results.row(alloc)) |row| {
|
||||
errdefer util.deepFree(alloc, row);
|
||||
if (max_items) |max| {
|
||||
if (result_array.items.len >= max) {
|
||||
util.deepFree(alloc, row);
|
||||
too_many = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
try result_array.append(row);
|
||||
}
|
||||
|
||||
if (too_many) return error.TooManyRows;
|
||||
|
||||
return result_array.toOwnedSlice();
|
||||
}
|
||||
|
||||
// Inserts a single value into a table
|
||||
pub fn insert(
|
||||
self: Self,
|
||||
|
@ -455,7 +542,7 @@ fn Tx(comptime tx_level: u8) type {
|
|||
inline for (fields) |field, i| {
|
||||
// This causes a compiler crash. Why?
|
||||
//const F = field.field_type;
|
||||
const F = @TypeOf(@field(std.mem.zeroes(ValueType), field.name));
|
||||
const F = @TypeOf(@field(value, field.name));
|
||||
// causes issues if F is @TypeOf(null), use dummy type
|
||||
types[i] = if (F == @TypeOf(null)) ?i64 else F;
|
||||
table_spec = comptime (table_spec ++ field.name ++ ",");
|
||||
|
@ -499,7 +586,7 @@ fn Tx(comptime tx_level: u8) type {
|
|||
alloc: ?std.mem.Allocator,
|
||||
comptime check_tx: bool,
|
||||
) !void {
|
||||
var results = try self.runSql(sql, args, .{ .prep_allocator = alloc }, check_tx);
|
||||
var results = try self.runSql(sql, args, .{ .allocator = alloc }, check_tx);
|
||||
defer results.finish();
|
||||
|
||||
while (try results.row()) |_| {}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue