Compare commits
17 commits
e90d9daf77
...
4e81441a0d
Author | SHA1 | Date | |
---|---|---|---|
4e81441a0d | |||
bfd73b7a1f | |||
8694516180 | |||
5630c6160f | |||
139fc92854 | |||
83da3c914b | |||
7d13b2546b | |||
181e57a631 | |||
cfdc8c5761 | |||
cbf98c1cf3 | |||
d2151ae326 | |||
2d464f0820 | |||
438c72b7e9 | |||
b46e82746c | |||
4cb574bc91 | |||
5d94f6874d | |||
1ba1b18c39 |
28 changed files with 1101 additions and 673 deletions
12
build.zig
12
build.zig
|
@ -53,16 +53,16 @@ pub fn build(b: *std.build.Builder) void {
|
||||||
exe.linkSystemLibrary("pq");
|
exe.linkSystemLibrary("pq");
|
||||||
exe.linkLibC();
|
exe.linkLibC();
|
||||||
|
|
||||||
const util_tests = b.addTest("src/util/lib.zig");
|
//const util_tests = b.addTest("src/util/lib.zig");
|
||||||
const http_tests = b.addTest("src/http/lib.zig");
|
const http_tests = b.addTest("src/http/test.zig");
|
||||||
const sql_tests = b.addTest("src/sql/lib.zig");
|
//const sql_tests = b.addTest("src/sql/lib.zig");
|
||||||
http_tests.addPackage(util_pkg);
|
http_tests.addPackage(util_pkg);
|
||||||
sql_tests.addPackage(util_pkg);
|
//sql_tests.addPackage(util_pkg);
|
||||||
|
|
||||||
const unit_tests = b.step("unit-tests", "Run tests");
|
const unit_tests = b.step("unit-tests", "Run tests");
|
||||||
unit_tests.dependOn(&util_tests.step);
|
//unit_tests.dependOn(&util_tests.step);
|
||||||
unit_tests.dependOn(&http_tests.step);
|
unit_tests.dependOn(&http_tests.step);
|
||||||
unit_tests.dependOn(&sql_tests.step);
|
//unit_tests.dependOn(&sql_tests.step);
|
||||||
|
|
||||||
const api_integration = b.addTest("./tests/api_integration/lib.zig");
|
const api_integration = b.addTest("./tests/api_integration/lib.zig");
|
||||||
api_integration.addPackage(sql_pkg);
|
api_integration.addPackage(sql_pkg);
|
||||||
|
|
|
@ -276,7 +276,10 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
username,
|
username,
|
||||||
password,
|
password,
|
||||||
self.community.id,
|
self.community.id,
|
||||||
.{ .invite_id = if (maybe_invite) |inv| inv.id else null, .email = opt.email },
|
.{
|
||||||
|
.invite_id = if (maybe_invite) |inv| @as(?Uuid, inv.id) else null,
|
||||||
|
.email = opt.email,
|
||||||
|
},
|
||||||
self.arena.allocator(),
|
self.arena.allocator(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -348,5 +351,19 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
if (!self.isAdmin()) return error.PermissionDenied;
|
if (!self.isAdmin()) return error.PermissionDenied;
|
||||||
return try services.communities.query(self.db, args, self.arena.allocator());
|
return try services.communities.query(self.db, args, self.arena.allocator());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn globalTimeline(self: *Self) ![]services.notes.Note {
|
||||||
|
const result = try services.notes.query(self.db, .{}, self.arena.allocator());
|
||||||
|
return result.items;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn localTimeline(self: *Self) ![]services.notes.Note {
|
||||||
|
const result = try services.notes.query(
|
||||||
|
self.db,
|
||||||
|
.{ .community_id = self.community.id },
|
||||||
|
self.arena.allocator(),
|
||||||
|
);
|
||||||
|
return result.items;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -94,7 +94,8 @@ pub const Actor = struct {
|
||||||
created_at: DateTime,
|
created_at: DateTime,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !Actor {
|
pub const GetError = error{ NotFound, DatabaseFailure };
|
||||||
|
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Actor {
|
||||||
return db.queryRow(
|
return db.queryRow(
|
||||||
Actor,
|
Actor,
|
||||||
\\SELECT
|
\\SELECT
|
||||||
|
|
16
src/api/services/common.zig
Normal file
16
src/api/services/common.zig
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
const std = @import("std");
|
||||||
|
const util = @import("util");
|
||||||
|
|
||||||
|
pub const Direction = enum {
|
||||||
|
ascending,
|
||||||
|
descending,
|
||||||
|
|
||||||
|
pub const jsonStringify = util.jsonSerializeEnumAsString;
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const PageDirection = enum {
|
||||||
|
forward,
|
||||||
|
backward,
|
||||||
|
|
||||||
|
pub const jsonStringify = util.jsonSerializeEnumAsString;
|
||||||
|
};
|
|
@ -2,6 +2,7 @@ const std = @import("std");
|
||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const util = @import("util");
|
const util = @import("util");
|
||||||
const sql = @import("sql");
|
const sql = @import("sql");
|
||||||
|
const common = @import("./common.zig");
|
||||||
|
|
||||||
const Uuid = util.Uuid;
|
const Uuid = util.Uuid;
|
||||||
const DateTime = util.DateTime;
|
const DateTime = util.DateTime;
|
||||||
|
@ -82,11 +83,12 @@ pub fn create(db: anytype, origin: []const u8, options: CreateOptions, alloc: st
|
||||||
else => return error.DatabaseFailure,
|
else => return error.DatabaseFailure,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const name = options.name orelse host;
|
||||||
db.insert("community", .{
|
db.insert("community", .{
|
||||||
.id = id,
|
.id = id,
|
||||||
.owner_id = null,
|
.owner_id = null,
|
||||||
.host = host,
|
.host = host,
|
||||||
.name = options.name orelse host,
|
.name = name,
|
||||||
.scheme = scheme,
|
.scheme = scheme,
|
||||||
.kind = options.kind,
|
.kind = options.kind,
|
||||||
.created_at = DateTime.now(),
|
.created_at = DateTime.now(),
|
||||||
|
@ -153,20 +155,8 @@ pub const QueryArgs = struct {
|
||||||
pub const jsonStringify = util.jsonSerializeEnumAsString;
|
pub const jsonStringify = util.jsonSerializeEnumAsString;
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Direction = enum {
|
pub const Direction = common.Direction;
|
||||||
ascending,
|
pub const PageDirection = common.PageDirection;
|
||||||
descending,
|
|
||||||
|
|
||||||
pub const jsonStringify = util.jsonSerializeEnumAsString;
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const PageDirection = enum {
|
|
||||||
forward,
|
|
||||||
backward,
|
|
||||||
|
|
||||||
pub const jsonStringify = util.jsonSerializeEnumAsString;
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const Prev = std.meta.Child(std.meta.fieldInfo(QueryArgs, .prev).field_type);
|
pub const Prev = std.meta.Child(std.meta.fieldInfo(QueryArgs, .prev).field_type);
|
||||||
pub const OrderVal = std.meta.fieldInfo(Prev, .order_val).field_type;
|
pub const OrderVal = std.meta.fieldInfo(Prev, .order_val).field_type;
|
||||||
|
|
||||||
|
@ -211,30 +201,6 @@ pub const QueryResult = struct {
|
||||||
next_page: QueryArgs,
|
next_page: QueryArgs,
|
||||||
};
|
};
|
||||||
|
|
||||||
const QueryBuilder = struct {
|
|
||||||
array: std.ArrayList(u8),
|
|
||||||
where_clauses_appended: usize = 0,
|
|
||||||
|
|
||||||
pub fn init(alloc: std.mem.Allocator) QueryBuilder {
|
|
||||||
return QueryBuilder{ .array = std.ArrayList(u8).init(alloc) };
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deinit(self: *const QueryBuilder) void {
|
|
||||||
self.array.deinit();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn andWhere(self: *QueryBuilder, clause: []const u8) !void {
|
|
||||||
if (self.where_clauses_appended == 0) {
|
|
||||||
try self.array.appendSlice("WHERE ");
|
|
||||||
} else {
|
|
||||||
try self.array.appendSlice(" AND ");
|
|
||||||
}
|
|
||||||
|
|
||||||
try self.array.appendSlice(clause);
|
|
||||||
self.where_clauses_appended += 1;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const max_max_items = 100;
|
const max_max_items = 100;
|
||||||
|
|
||||||
pub const QueryError = error{
|
pub const QueryError = error{
|
||||||
|
@ -246,7 +212,7 @@ pub const QueryError = error{
|
||||||
// arguments.
|
// arguments.
|
||||||
// `args.max_items` is only a request, and fewer entries may be returned.
|
// `args.max_items` is only a request, and fewer entries may be returned.
|
||||||
pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResult {
|
pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResult {
|
||||||
var builder = QueryBuilder.init(alloc);
|
var builder = sql.QueryBuilder.init(alloc);
|
||||||
defer builder.deinit();
|
defer builder.deinit();
|
||||||
|
|
||||||
try builder.array.appendSlice(
|
try builder.array.appendSlice(
|
||||||
|
@ -266,21 +232,21 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul
|
||||||
if (args.prev) |prev| {
|
if (args.prev) |prev| {
|
||||||
if (prev.order_val != args.order_by) return error.PageArgMismatch;
|
if (prev.order_val != args.order_by) return error.PageArgMismatch;
|
||||||
|
|
||||||
try builder.andWhere(switch (args.order_by) {
|
switch (args.order_by) {
|
||||||
.name => "(name, id)",
|
.name => try builder.andWhere("(name, id)"),
|
||||||
.host => "(host, id)",
|
.host => try builder.andWhere("(host, id)"),
|
||||||
.created_at => "(created_at, id)",
|
.created_at => try builder.andWhere("(created_at, id)"),
|
||||||
});
|
}
|
||||||
_ = try builder.array.appendSlice(switch (args.direction) {
|
switch (args.direction) {
|
||||||
.ascending => switch (args.page_direction) {
|
.ascending => switch (args.page_direction) {
|
||||||
.forward => " > ",
|
.forward => try builder.appendSlice(" > "),
|
||||||
.backward => " < ",
|
.backward => try builder.appendSlice(" < "),
|
||||||
},
|
},
|
||||||
.descending => switch (args.page_direction) {
|
.descending => switch (args.page_direction) {
|
||||||
.forward => " < ",
|
.forward => try builder.appendSlice(" < "),
|
||||||
.backward => " > ",
|
.backward => try builder.appendSlice(" > "),
|
||||||
},
|
},
|
||||||
});
|
}
|
||||||
|
|
||||||
_ = try builder.array.appendSlice("($5, $6)");
|
_ = try builder.array.appendSlice("($5, $6)");
|
||||||
}
|
}
|
||||||
|
@ -297,57 +263,52 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul
|
||||||
|
|
||||||
_ = try builder.array.appendSlice("\nLIMIT $7");
|
_ = try builder.array.appendSlice("\nLIMIT $7");
|
||||||
|
|
||||||
const query_args = .{
|
const query_args = blk: {
|
||||||
|
const ord_val =
|
||||||
|
if (args.prev) |prev| @as(?QueryArgs.OrderVal, prev.order_val) else null;
|
||||||
|
const id =
|
||||||
|
if (args.prev) |prev| @as(?Uuid, prev.id) else null;
|
||||||
|
break :blk .{
|
||||||
args.owner_id,
|
args.owner_id,
|
||||||
args.like,
|
args.like,
|
||||||
args.created_before,
|
args.created_before,
|
||||||
args.created_after,
|
args.created_after,
|
||||||
if (args.prev) |prev| prev.order_val else null,
|
ord_val,
|
||||||
if (args.prev) |prev| prev.id else null,
|
id,
|
||||||
max_items,
|
max_items,
|
||||||
};
|
};
|
||||||
|
};
|
||||||
|
|
||||||
try builder.array.append(0);
|
try builder.array.append(0);
|
||||||
|
|
||||||
var results = try db.queryWithOptions(
|
var results = try db.queryRowsWithOptions(
|
||||||
Community,
|
Community,
|
||||||
std.meta.assumeSentinel(builder.array.items, 0),
|
std.meta.assumeSentinel(builder.array.items, 0),
|
||||||
query_args,
|
query_args,
|
||||||
.{ .prep_allocator = alloc, .ignore_unused_arguments = true },
|
max_items,
|
||||||
|
.{ .allocator = alloc, .ignore_unused_arguments = true },
|
||||||
);
|
);
|
||||||
defer results.finish();
|
errdefer util.deepFree(alloc, results);
|
||||||
|
|
||||||
const result_buf = try alloc.alloc(Community, args.max_items);
|
|
||||||
errdefer alloc.free(result_buf);
|
|
||||||
|
|
||||||
var count: usize = 0;
|
|
||||||
errdefer for (result_buf[0..count]) |c| util.deepFree(alloc, c);
|
|
||||||
|
|
||||||
for (result_buf) |*c| {
|
|
||||||
c.* = (try results.row(alloc)) orelse break;
|
|
||||||
|
|
||||||
count += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
var next_page = args;
|
var next_page = args;
|
||||||
var prev_page = args;
|
var prev_page = args;
|
||||||
prev_page.page_direction = .backward;
|
prev_page.page_direction = .backward;
|
||||||
next_page.page_direction = .forward;
|
next_page.page_direction = .forward;
|
||||||
if (count != 0) {
|
if (results.len != 0) {
|
||||||
prev_page.prev = .{
|
prev_page.prev = .{
|
||||||
.id = result_buf[0].id,
|
.id = results[0].id,
|
||||||
.order_val = getOrderVal(result_buf[0], args.order_by),
|
.order_val = getOrderVal(results[0], args.order_by),
|
||||||
};
|
};
|
||||||
|
|
||||||
next_page.prev = .{
|
next_page.prev = .{
|
||||||
.id = result_buf[count - 1].id,
|
.id = results[results.len - 1].id,
|
||||||
.order_val = getOrderVal(result_buf[count - 1], args.order_by),
|
.order_val = getOrderVal(results[results.len - 1], args.order_by),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
// TODO: This will give incorrect links on an empty page
|
// TODO: This will give incorrect links on an empty page
|
||||||
|
|
||||||
return QueryResult{
|
return QueryResult{
|
||||||
.items = result_buf[0..count],
|
.items = results,
|
||||||
|
|
||||||
.next_page = next_page,
|
.next_page = next_page,
|
||||||
.prev_page = prev_page,
|
.prev_page = prev_page,
|
||||||
|
|
|
@ -71,7 +71,7 @@ pub fn create(db: anytype, created_by: Uuid, community_id: ?Uuid, options: Invit
|
||||||
.max_uses = options.max_uses,
|
.max_uses = options.max_uses,
|
||||||
.created_at = created_at,
|
.created_at = created_at,
|
||||||
.expires_at = if (options.lifespan) |lifespan|
|
.expires_at = if (options.lifespan) |lifespan|
|
||||||
created_at.add(lifespan)
|
@as(?DateTime, created_at.add(lifespan))
|
||||||
else
|
else
|
||||||
null,
|
null,
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const util = @import("util");
|
const util = @import("util");
|
||||||
const sql = @import("sql");
|
const sql = @import("sql");
|
||||||
|
const common = @import("./common.zig");
|
||||||
|
|
||||||
const Uuid = util.Uuid;
|
const Uuid = util.Uuid;
|
||||||
const DateTime = util.DateTime;
|
const DateTime = util.DateTime;
|
||||||
|
@ -42,7 +43,7 @@ const selectStarFromNote = std.fmt.comptimePrint(
|
||||||
\\SELECT {s}
|
\\SELECT {s}
|
||||||
\\FROM note
|
\\FROM note
|
||||||
\\
|
\\
|
||||||
, .{util.comptimeJoin(",", std.meta.fieldNames(Note))});
|
, .{util.comptimeJoinWithPrefix(",", "note.", std.meta.fieldNames(Note))});
|
||||||
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Note {
|
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Note {
|
||||||
return db.queryRow(
|
return db.queryRow(
|
||||||
Note,
|
Note,
|
||||||
|
@ -57,3 +58,108 @@ pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Note {
|
||||||
else => error.DatabaseFailure,
|
else => error.DatabaseFailure,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const max_max_items = 100;
|
||||||
|
|
||||||
|
pub const QueryArgs = struct {
|
||||||
|
pub const PageDirection = common.PageDirection;
|
||||||
|
pub const Prev = std.meta.Child(std.meta.field(@This(), .prev).field_type);
|
||||||
|
|
||||||
|
max_items: usize = 20,
|
||||||
|
|
||||||
|
created_before: ?DateTime = null,
|
||||||
|
created_after: ?DateTime = null,
|
||||||
|
community_id: ?Uuid = null,
|
||||||
|
|
||||||
|
prev: ?struct {
|
||||||
|
id: Uuid,
|
||||||
|
created_at: DateTime,
|
||||||
|
} = null,
|
||||||
|
|
||||||
|
page_direction: PageDirection = .forward,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const QueryResult = struct {
|
||||||
|
items: []Note,
|
||||||
|
|
||||||
|
prev_page: QueryArgs,
|
||||||
|
next_page: QueryArgs,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResult {
|
||||||
|
var builder = sql.QueryBuilder.init(alloc);
|
||||||
|
defer builder.deinit();
|
||||||
|
|
||||||
|
try builder.appendSlice(selectStarFromNote ++
|
||||||
|
\\ JOIN actor ON actor.id = note.author_id
|
||||||
|
\\
|
||||||
|
);
|
||||||
|
|
||||||
|
if (args.created_before != null) try builder.andWhere("note.created_at < $1");
|
||||||
|
if (args.created_after != null) try builder.andWhere("note.created_at > $2");
|
||||||
|
if (args.prev != null) {
|
||||||
|
try builder.andWhere("(note.created_at, note.id)");
|
||||||
|
|
||||||
|
switch (args.page_direction) {
|
||||||
|
.forward => try builder.appendSlice(" < "),
|
||||||
|
.backward => try builder.appendSlice(" > "),
|
||||||
|
}
|
||||||
|
try builder.appendSlice("($3, $4)");
|
||||||
|
}
|
||||||
|
if (args.community_id != null) try builder.andWhere("actor.community_id = $5");
|
||||||
|
|
||||||
|
try builder.appendSlice(
|
||||||
|
\\
|
||||||
|
\\ORDER BY note.created_at DESC
|
||||||
|
\\LIMIT $6
|
||||||
|
\\
|
||||||
|
);
|
||||||
|
|
||||||
|
const max_items = if (args.max_items > max_max_items) max_max_items else args.max_items;
|
||||||
|
|
||||||
|
const query_args = blk: {
|
||||||
|
const prev_created_at = if (args.prev) |prev| @as(?DateTime, prev.created_at) else null;
|
||||||
|
const prev_id = if (args.prev) |prev| @as(?Uuid, prev.id) else null;
|
||||||
|
|
||||||
|
break :blk .{
|
||||||
|
args.created_before,
|
||||||
|
args.created_after,
|
||||||
|
prev_created_at,
|
||||||
|
prev_id,
|
||||||
|
args.community_id,
|
||||||
|
max_items,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const results = try db.queryRowsWithOptions(
|
||||||
|
Note,
|
||||||
|
try builder.terminate(),
|
||||||
|
query_args,
|
||||||
|
max_items,
|
||||||
|
.{ .allocator = alloc, .ignore_unused_arguments = true },
|
||||||
|
);
|
||||||
|
errdefer util.deepFree(results);
|
||||||
|
|
||||||
|
var next_page = args;
|
||||||
|
var prev_page = args;
|
||||||
|
prev_page.page_direction = .backward;
|
||||||
|
next_page.page_direction = .forward;
|
||||||
|
if (results.len != 0) {
|
||||||
|
prev_page.prev = .{
|
||||||
|
.id = results[0].id,
|
||||||
|
.created_at = results[0].created_at,
|
||||||
|
};
|
||||||
|
|
||||||
|
next_page.prev = .{
|
||||||
|
.id = results[results.len - 1].id,
|
||||||
|
.created_at = results[results.len - 1].created_at,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
// TODO: this will give incorrect links on an empty page
|
||||||
|
|
||||||
|
return QueryResult{
|
||||||
|
.items = results,
|
||||||
|
.next_page = next_page,
|
||||||
|
.prev_page = prev_page,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
124
src/http/headers.zig
Normal file
124
src/http/headers.zig
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
const std = @import("std");
|
||||||
|
|
||||||
|
pub const Fields = struct {
|
||||||
|
const HashContext = struct {
|
||||||
|
const hash_seed = 1;
|
||||||
|
pub fn eql(_: @This(), lhs: []const u8, rhs: []const u8, _: usize) bool {
|
||||||
|
return std.ascii.eqlIgnoreCase(lhs, rhs);
|
||||||
|
}
|
||||||
|
pub fn hash(_: @This(), s: []const u8) u32 {
|
||||||
|
var h = std.hash.Wyhash.init(hash_seed);
|
||||||
|
for (s) |ch| {
|
||||||
|
const c = [1]u8{std.ascii.toLower(ch)};
|
||||||
|
h.update(&c);
|
||||||
|
}
|
||||||
|
return @truncate(u32, h.final());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const HashMap = std.ArrayHashMapUnmanaged(
|
||||||
|
[]const u8,
|
||||||
|
[]const u8,
|
||||||
|
HashContext,
|
||||||
|
true,
|
||||||
|
);
|
||||||
|
|
||||||
|
unmanaged: HashMap,
|
||||||
|
allocator: std.mem.Allocator,
|
||||||
|
|
||||||
|
pub fn init(allocator: std.mem.Allocator) Fields {
|
||||||
|
return Fields{
|
||||||
|
.unmanaged = .{},
|
||||||
|
.allocator = allocator,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *Fields) void {
|
||||||
|
var hash_iter = self.unmanaged.iterator();
|
||||||
|
while (hash_iter.next()) |entry| {
|
||||||
|
self.allocator.free(entry.key_ptr.*);
|
||||||
|
self.allocator.free(entry.value_ptr.*);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.unmanaged.deinit(self.allocator);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn iterator(self: Fields) HashMap.Iterator {
|
||||||
|
return self.unmanaged.iterator();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(self: Fields, key: []const u8) ?[]const u8 {
|
||||||
|
return self.unmanaged.get(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const ListIterator = struct {
|
||||||
|
remaining: []const u8,
|
||||||
|
|
||||||
|
fn extractElement(self: *ListIterator) ?[]const u8 {
|
||||||
|
if (self.remaining.len == 0) return null;
|
||||||
|
|
||||||
|
var start: usize = 0;
|
||||||
|
var is_quoted = false;
|
||||||
|
const end = for (self.remaining) |ch, i| {
|
||||||
|
if (start == i and std.ascii.isWhitespace(ch)) {
|
||||||
|
start += 1;
|
||||||
|
} else if (ch == '"') {
|
||||||
|
is_quoted = !is_quoted;
|
||||||
|
}
|
||||||
|
if (ch == ',' and !is_quoted) {
|
||||||
|
break i;
|
||||||
|
}
|
||||||
|
} else self.remaining.len;
|
||||||
|
|
||||||
|
const str = self.remaining[start..end];
|
||||||
|
if (end == self.remaining.len) {
|
||||||
|
self.remaining = "";
|
||||||
|
} else {
|
||||||
|
self.remaining = self.remaining[end + 1 ..];
|
||||||
|
}
|
||||||
|
|
||||||
|
return std.mem.trim(u8, str, " \t");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn next(self: *ListIterator) ?[]const u8 {
|
||||||
|
while (self.extractElement()) |elem| {
|
||||||
|
if (elem.len != 0) return elem;
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn getList(self: Fields, key: []const u8) ?ListIterator {
|
||||||
|
return if (self.unmanaged.get(key)) |hdr| ListIterator{ .remaining = hdr } else null;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn put(self: *Fields, key: []const u8, val: []const u8) !void {
|
||||||
|
const key_clone = try self.allocator.alloc(u8, key.len);
|
||||||
|
std.mem.copy(u8, key_clone, key);
|
||||||
|
errdefer self.allocator.free(key_clone);
|
||||||
|
|
||||||
|
const val_clone = try self.allocator.alloc(u8, val.len);
|
||||||
|
std.mem.copy(u8, val_clone, val);
|
||||||
|
errdefer self.allocator.free(val_clone);
|
||||||
|
|
||||||
|
if (try self.unmanaged.fetchPut(self.allocator, key_clone, val_clone)) |entry| {
|
||||||
|
self.allocator.free(entry.key);
|
||||||
|
self.allocator.free(entry.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn append(self: *Fields, key: []const u8, val: []const u8) !void {
|
||||||
|
if (self.unmanaged.getEntry(key)) |entry| {
|
||||||
|
const new_val = try std.mem.join(self.allocator, ", ", &.{ entry.value_ptr.*, val });
|
||||||
|
self.allocator.free(entry.value_ptr.*);
|
||||||
|
entry.value_ptr.* = new_val;
|
||||||
|
} else {
|
||||||
|
try self.put(key, val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn count(self: Fields) usize {
|
||||||
|
return self.unmanaged.count();
|
||||||
|
}
|
||||||
|
};
|
|
@ -10,22 +10,15 @@ pub const socket = @import("./socket.zig");
|
||||||
pub const Method = std.http.Method;
|
pub const Method = std.http.Method;
|
||||||
pub const Status = std.http.Status;
|
pub const Status = std.http.Status;
|
||||||
|
|
||||||
pub const Request = request.Request;
|
pub const Request = request.Request(std.net.Stream.Reader);
|
||||||
pub const serveConn = server.serveConn;
|
pub const serveConn = server.serveConn;
|
||||||
pub const Response = server.Response;
|
pub const Response = server.Response;
|
||||||
pub const Handler = server.Handler;
|
pub const Handler = server.Handler;
|
||||||
|
|
||||||
pub const Headers = std.HashMap([]const u8, []const u8, struct {
|
pub const Fields = @import("./headers.zig").Fields;
|
||||||
pub fn eql(_: @This(), a: []const u8, b: []const u8) bool {
|
|
||||||
return ciutf8.eql(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn hash(_: @This(), str: []const u8) u64 {
|
pub const Protocol = enum {
|
||||||
return ciutf8.hash(str);
|
http_1_0,
|
||||||
}
|
http_1_1,
|
||||||
}, std.hash_map.default_max_load_percentage);
|
http_1_x,
|
||||||
|
};
|
||||||
test {
|
|
||||||
_ = server;
|
|
||||||
_ = request;
|
|
||||||
}
|
|
||||||
|
|
|
@ -3,29 +3,23 @@ const http = @import("./lib.zig");
|
||||||
|
|
||||||
const parser = @import("./request/parser.zig");
|
const parser = @import("./request/parser.zig");
|
||||||
|
|
||||||
pub const Request = struct {
|
pub fn Request(comptime Reader: type) type {
|
||||||
pub const Protocol = enum {
|
return struct {
|
||||||
http_1_0,
|
protocol: http.Protocol,
|
||||||
http_1_1,
|
|
||||||
};
|
|
||||||
|
|
||||||
protocol: Protocol,
|
|
||||||
source_address: ?std.net.Address,
|
|
||||||
|
|
||||||
method: http.Method,
|
method: http.Method,
|
||||||
uri: []const u8,
|
uri: []const u8,
|
||||||
headers: http.Headers,
|
headers: http.Fields,
|
||||||
body: ?[]const u8 = null,
|
|
||||||
|
|
||||||
pub fn parse(alloc: std.mem.Allocator, reader: anytype, addr: std.net.Address) !Request {
|
body: ?parser.TransferStream(Reader),
|
||||||
return parser.parse(alloc, reader, addr);
|
|
||||||
|
pub fn parseFree(self: *@This(), allocator: std.mem.Allocator) void {
|
||||||
|
allocator.free(self.uri);
|
||||||
|
self.headers.deinit();
|
||||||
}
|
}
|
||||||
|
};
|
||||||
pub fn parseFree(self: Request, alloc: std.mem.Allocator) void {
|
}
|
||||||
parser.parseFree(alloc, self);
|
|
||||||
}
|
pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request(@TypeOf(reader)) {
|
||||||
};
|
return parser.parse(alloc, reader);
|
||||||
|
|
||||||
test {
|
|
||||||
_ = parser;
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,13 @@
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const util = @import("util");
|
|
||||||
const http = @import("../lib.zig");
|
const http = @import("../lib.zig");
|
||||||
|
|
||||||
const Method = http.Method;
|
const Method = http.Method;
|
||||||
const Headers = http.Headers;
|
const Fields = http.Fields;
|
||||||
|
|
||||||
const Request = @import("../request.zig").Request;
|
const Request = @import("../request.zig").Request;
|
||||||
|
|
||||||
const request_buf_size = 1 << 16;
|
const request_buf_size = 1 << 16;
|
||||||
const max_path_len = 1 << 10;
|
const max_path_len = 1 << 10;
|
||||||
const max_body_len = 1 << 12;
|
|
||||||
|
|
||||||
fn ParseError(comptime Reader: type) type {
|
fn ParseError(comptime Reader: type) type {
|
||||||
return error{
|
return error{
|
||||||
|
@ -22,7 +20,7 @@ const Encoding = enum {
|
||||||
chunked,
|
chunked,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address) !Request {
|
pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request(@TypeOf(reader)) {
|
||||||
const method = try parseMethod(reader);
|
const method = try parseMethod(reader);
|
||||||
const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) {
|
const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) {
|
||||||
error.StreamTooLong => return error.RequestUriTooLong,
|
error.StreamTooLong => return error.RequestUriTooLong,
|
||||||
|
@ -33,28 +31,20 @@ pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address
|
||||||
const proto = try parseProto(reader);
|
const proto = try parseProto(reader);
|
||||||
|
|
||||||
// discard \r\n
|
// discard \r\n
|
||||||
_ = try reader.readByte();
|
switch (try reader.readByte()) {
|
||||||
_ = try reader.readByte();
|
'\r' => if ((try reader.readByte()) != '\n') return error.BadRequest,
|
||||||
|
'\n' => {},
|
||||||
|
else => return error.BadRequest,
|
||||||
|
}
|
||||||
|
|
||||||
var headers = try parseHeaders(alloc, reader);
|
var headers = try parseHeaders(alloc, reader);
|
||||||
errdefer freeHeaders(alloc, &headers);
|
errdefer headers.deinit();
|
||||||
|
|
||||||
const body = if (method.requestHasBody())
|
const body = try prepareBody(headers, reader);
|
||||||
try readBody(alloc, headers, reader)
|
if (body != null and !method.requestHasBody()) return error.BadRequest;
|
||||||
else
|
|
||||||
null;
|
|
||||||
errdefer if (body) |b| alloc.free(b);
|
|
||||||
|
|
||||||
const eff_addr = if (headers.get("X-Real-IP")) |ip|
|
return Request(@TypeOf(reader)){
|
||||||
std.net.Address.parseIp(ip, address.getPort()) catch {
|
|
||||||
return error.BadRequest;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
address;
|
|
||||||
|
|
||||||
return Request{
|
|
||||||
.protocol = proto,
|
.protocol = proto,
|
||||||
.source_address = eff_addr,
|
|
||||||
|
|
||||||
.method = method,
|
.method = method,
|
||||||
.uri = uri,
|
.uri = uri,
|
||||||
|
@ -79,7 +69,7 @@ fn parseMethod(reader: anytype) !Method {
|
||||||
return error.MethodNotImplemented;
|
return error.MethodNotImplemented;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parseProto(reader: anytype) !Request.Protocol {
|
fn parseProto(reader: anytype) !http.Protocol {
|
||||||
var buf: [8]u8 = undefined;
|
var buf: [8]u8 = undefined;
|
||||||
const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) {
|
const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) {
|
||||||
error.StreamTooLong => return error.UnknownProtocol,
|
error.StreamTooLong => return error.UnknownProtocol,
|
||||||
|
@ -99,85 +89,145 @@ fn parseProto(reader: anytype) !Request.Protocol {
|
||||||
return switch (buf[2]) {
|
return switch (buf[2]) {
|
||||||
'0' => .http_1_0,
|
'0' => .http_1_0,
|
||||||
'1' => .http_1_1,
|
'1' => .http_1_1,
|
||||||
else => error.HttpVersionNotSupported,
|
else => .http_1_x,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers {
|
fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields {
|
||||||
var map = Headers.init(allocator);
|
var headers = Fields.init(allocator);
|
||||||
errdefer map.deinit();
|
|
||||||
errdefer {
|
|
||||||
var iter = map.iterator();
|
|
||||||
while (iter.next()) |it| {
|
|
||||||
allocator.free(it.key_ptr.*);
|
|
||||||
allocator.free(it.value_ptr.*);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// todo:
|
|
||||||
//errdefer {
|
|
||||||
//var iter = map.iterator();
|
|
||||||
//while (iter.next()) |it| {
|
|
||||||
//allocator.free(it.key_ptr);
|
|
||||||
//allocator.free(it.value_ptr);
|
|
||||||
//}
|
|
||||||
//}
|
|
||||||
|
|
||||||
var buf: [1024]u8 = undefined;
|
|
||||||
|
|
||||||
|
var buf: [4096]u8 = undefined;
|
||||||
while (true) {
|
while (true) {
|
||||||
const line = try reader.readUntilDelimiter(&buf, '\n');
|
const full_line = reader.readUntilDelimiter(&buf, '\n') catch |err| switch (err) {
|
||||||
if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break;
|
error.StreamTooLong => return error.HeaderLineTooLong,
|
||||||
|
else => return err,
|
||||||
|
};
|
||||||
|
const line = std.mem.trimRight(u8, full_line, "\r");
|
||||||
|
if (line.len == 0) break;
|
||||||
|
|
||||||
// TODO: handle multi-line headers
|
const name = std.mem.sliceTo(line, ':');
|
||||||
const name = extractHeaderName(line) orelse continue;
|
if (!isTokenValid(name)) return error.BadRequest;
|
||||||
const value_end = if (line[line.len - 1] == '\r') line.len - 1 else line.len;
|
if (name.len == line.len) return error.BadRequest;
|
||||||
const value = line[name.len + 1 + 1 .. value_end];
|
|
||||||
|
|
||||||
if (name.len == 0 or value.len == 0) return error.BadRequest;
|
const value = std.mem.trim(u8, line[name.len + 1 ..], " \t");
|
||||||
|
|
||||||
const name_alloc = try allocator.alloc(u8, name.len);
|
try headers.append(name, value);
|
||||||
errdefer allocator.free(name_alloc);
|
|
||||||
const value_alloc = try allocator.alloc(u8, value.len);
|
|
||||||
errdefer allocator.free(value_alloc);
|
|
||||||
|
|
||||||
@memcpy(name_alloc.ptr, name.ptr, name.len);
|
|
||||||
@memcpy(value_alloc.ptr, value.ptr, value.len);
|
|
||||||
|
|
||||||
try map.put(name_alloc, value_alloc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return map;
|
return headers;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extractHeaderName(line: []const u8) ?[]const u8 {
|
fn isTokenValid(token: []const u8) bool {
|
||||||
var index: usize = 0;
|
if (token.len == 0) return false;
|
||||||
|
for (token) |ch| {
|
||||||
|
switch (ch) {
|
||||||
|
'"', '(', ')', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '{', '}' => return false,
|
||||||
|
|
||||||
// TODO: handle whitespace
|
'!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~' => {},
|
||||||
while (index < line.len) : (index += 1) {
|
else => if (!std.ascii.isAlphanumeric(ch)) return false,
|
||||||
if (line[index] == ':') {
|
|
||||||
if (index == 0) return null;
|
|
||||||
return line[0..index];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn readBody(alloc: std.mem.Allocator, headers: Headers, reader: anytype) !?[]const u8 {
|
fn prepareBody(headers: Fields, reader: anytype) !?TransferStream(@TypeOf(reader)) {
|
||||||
const xfer_encoding = try parseEncoding(headers.get("Transfer-Encoding"));
|
const hdr = headers.get("Transfer-Encoding");
|
||||||
if (xfer_encoding != .identity) return error.UnsupportedMediaType;
|
// TODO:
|
||||||
|
// if (hder != null and protocol == .http_1_0) return error.BadRequest;
|
||||||
|
const xfer_encoding = try parseEncoding(hdr);
|
||||||
const content_encoding = try parseEncoding(headers.get("Content-Encoding"));
|
const content_encoding = try parseEncoding(headers.get("Content-Encoding"));
|
||||||
if (content_encoding != .identity) return error.UnsupportedMediaType;
|
if (content_encoding != .identity) return error.UnsupportedMediaType;
|
||||||
|
|
||||||
|
switch (xfer_encoding) {
|
||||||
|
.identity => {
|
||||||
const len_str = headers.get("Content-Length") orelse return null;
|
const len_str = headers.get("Content-Length") orelse return null;
|
||||||
const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest;
|
const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest;
|
||||||
if (len > max_body_len) return error.RequestEntityTooLarge;
|
|
||||||
const body = try alloc.alloc(u8, len);
|
|
||||||
errdefer alloc.free(body);
|
|
||||||
|
|
||||||
reader.readNoEof(body) catch return error.BadRequest;
|
return TransferStream(@TypeOf(reader)){ .underlying = .{ .identity = std.io.limitedReader(reader, len) } };
|
||||||
|
},
|
||||||
|
.chunked => {
|
||||||
|
if (headers.get("Content-Length") != null) return error.BadRequest;
|
||||||
|
return TransferStream(@TypeOf(reader)){
|
||||||
|
.underlying = .{
|
||||||
|
.chunked = try ChunkedStream(@TypeOf(reader)).init(reader),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return body;
|
fn ChunkedStream(comptime R: type) type {
|
||||||
|
return struct {
|
||||||
|
const Self = @This();
|
||||||
|
|
||||||
|
remaining: ?usize = 0,
|
||||||
|
underlying: R,
|
||||||
|
|
||||||
|
const Error = R.Error || error{ Unexpected, InvalidChunkHeader, StreamTooLong, EndOfStream };
|
||||||
|
fn init(reader: R) !Self {
|
||||||
|
var self: Self = .{ .underlying = reader };
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read(self: *Self, buf: []u8) !usize {
|
||||||
|
var count: usize = 0;
|
||||||
|
while (true) {
|
||||||
|
if (count == buf.len) return count;
|
||||||
|
if (self.remaining == null) return count;
|
||||||
|
if (self.remaining.? == 0) self.remaining = try self.readChunkHeader();
|
||||||
|
|
||||||
|
const max_read = std.math.min(buf.len, self.remaining.?);
|
||||||
|
const amt = try self.underlying.read(buf[count .. count + max_read]);
|
||||||
|
if (amt != max_read) return error.EndOfStream;
|
||||||
|
count += amt;
|
||||||
|
self.remaining.? -= amt;
|
||||||
|
if (self.remaining.? == 0) {
|
||||||
|
var crlf: [2]u8 = undefined;
|
||||||
|
_ = try self.underlying.readUntilDelimiter(&crlf, '\n');
|
||||||
|
self.remaining = try self.readChunkHeader();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (count == buf.len) return count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn readChunkHeader(self: *Self) !?usize {
|
||||||
|
// TODO: Pick a reasonable limit for this
|
||||||
|
var buf = std.mem.zeroes([10]u8);
|
||||||
|
const line = self.underlying.readUntilDelimiter(&buf, '\n') catch |err| {
|
||||||
|
return if (err == error.StreamTooLong) error.InvalidChunkHeader else err;
|
||||||
|
};
|
||||||
|
if (line.len < 2 or line[line.len - 1] != '\r') return error.InvalidChunkHeader;
|
||||||
|
|
||||||
|
const size = std.fmt.parseInt(usize, line[0 .. line.len - 1], 16) catch return error.InvalidChunkHeader;
|
||||||
|
|
||||||
|
return if (size != 0) size else null;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn TransferStream(comptime R: type) type {
|
||||||
|
return struct {
|
||||||
|
const Error = R.Error || ChunkedStream(R).Error;
|
||||||
|
const Reader = std.io.Reader(*@This(), Error, read);
|
||||||
|
|
||||||
|
underlying: union(enum) {
|
||||||
|
identity: std.io.LimitedReader(R),
|
||||||
|
chunked: ChunkedStream(R),
|
||||||
|
},
|
||||||
|
|
||||||
|
pub fn read(self: *@This(), buf: []u8) Error!usize {
|
||||||
|
return switch (self.underlying) {
|
||||||
|
.identity => |*r| try r.read(buf),
|
||||||
|
.chunked => |*r| try r.read(buf),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reader(self: *@This()) Reader {
|
||||||
|
return .{ .context = self };
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: assumes that there's only one encoding, not layered encodings
|
// TODO: assumes that there's only one encoding, not layered encodings
|
||||||
|
@ -187,257 +237,3 @@ fn parseEncoding(encoding: ?[]const u8) !Encoding {
|
||||||
if (std.mem.eql(u8, encoding.?, "chunked")) return .chunked;
|
if (std.mem.eql(u8, encoding.?, "chunked")) return .chunked;
|
||||||
return error.UnsupportedMediaType;
|
return error.UnsupportedMediaType;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parseFree(allocator: std.mem.Allocator, request: *Request) void {
|
|
||||||
allocator.free(request.uri);
|
|
||||||
freeHeaders(allocator, &request.headers);
|
|
||||||
if (request.body) |body| allocator.free(body);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn freeHeaders(allocator: std.mem.Allocator, headers: *http.Headers) void {
|
|
||||||
var iter = headers.iterator();
|
|
||||||
while (iter.next()) |it| {
|
|
||||||
allocator.free(it.key_ptr.*);
|
|
||||||
allocator.free(it.value_ptr.*);
|
|
||||||
}
|
|
||||||
headers.deinit();
|
|
||||||
}
|
|
||||||
|
|
||||||
const _test = struct {
|
|
||||||
const expectEqual = std.testing.expectEqual;
|
|
||||||
const expectEqualStrings = std.testing.expectEqualStrings;
|
|
||||||
|
|
||||||
fn toCrlf(comptime str: []const u8) []const u8 {
|
|
||||||
comptime {
|
|
||||||
var buf: [str.len * 2]u8 = undefined;
|
|
||||||
|
|
||||||
@setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2
|
|
||||||
|
|
||||||
var buf_len: usize = 0;
|
|
||||||
for (str) |ch| {
|
|
||||||
if (ch == '\n') {
|
|
||||||
buf[buf_len] = '\r';
|
|
||||||
buf_len += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
buf[buf_len] = ch;
|
|
||||||
buf_len += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf[0..buf_len];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn makeHeaders(alloc: std.mem.Allocator, headers: anytype) !Headers {
|
|
||||||
var result = Headers.init(alloc);
|
|
||||||
inline for (headers) |tup| {
|
|
||||||
try result.put(tup[0], tup[1]);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn areEqualHeaders(lhs: Headers, rhs: Headers) bool {
|
|
||||||
if (lhs.count() != rhs.count()) return false;
|
|
||||||
var iter = lhs.iterator();
|
|
||||||
while (iter.next()) |it| {
|
|
||||||
const rhs_val = rhs.get(it.key_ptr.*) orelse return false;
|
|
||||||
if (!std.mem.eql(u8, it.value_ptr.*, rhs_val)) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn printHeaders(headers: Headers) void {
|
|
||||||
var iter = headers.iterator();
|
|
||||||
while (iter.next()) |it| {
|
|
||||||
std.debug.print("{s}: {s}\n", .{ it.key_ptr.*, it.value_ptr.* });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn expectEqualHeaders(expected: Headers, actual: Headers) !void {
|
|
||||||
if (!areEqualHeaders(expected, actual)) {
|
|
||||||
std.debug.print("\nexpected: \n", .{});
|
|
||||||
printHeaders(expected);
|
|
||||||
std.debug.print("\n\nfound: \n", .{});
|
|
||||||
printHeaders(actual);
|
|
||||||
std.debug.print("\n\n", .{});
|
|
||||||
return error.TestExpectedEqual;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parseTestCase(alloc: std.mem.Allocator, comptime request: []const u8, expected: http.Request) !void {
|
|
||||||
var stream = std.io.fixedBufferStream(toCrlf(request));
|
|
||||||
|
|
||||||
const result = try parse(alloc, stream.reader());
|
|
||||||
|
|
||||||
try expectEqual(expected.method, result.method);
|
|
||||||
try expectEqualStrings(expected.path, result.path);
|
|
||||||
try expectEqualHeaders(expected.headers, result.headers);
|
|
||||||
if ((expected.body == null) != (result.body == null)) {
|
|
||||||
const null_str: []const u8 = "(null)";
|
|
||||||
const exp = expected.body orelse null_str;
|
|
||||||
const act = result.body orelse null_str;
|
|
||||||
std.debug.print("\nexpected:\n{s}\n\nfound:\n{s}\n\n", .{ exp, act });
|
|
||||||
return error.TestExpectedEqual;
|
|
||||||
}
|
|
||||||
if (expected.body != null) {
|
|
||||||
try expectEqualStrings(expected.body.?, result.body.?);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// TOOD: failure test cases
|
|
||||||
test "parse" {
|
|
||||||
const testCase = _test.parseTestCase;
|
|
||||||
var buf = [_]u8{0} ** (1 << 16);
|
|
||||||
var fba = std.heap.FixedBufferAllocator.init(&buf);
|
|
||||||
const alloc = fba.allocator();
|
|
||||||
try testCase(alloc, (
|
|
||||||
\\GET / HTTP/1.1
|
|
||||||
\\
|
|
||||||
\\
|
|
||||||
), .{
|
|
||||||
.method = .GET,
|
|
||||||
.headers = try _test.makeHeaders(alloc, .{}),
|
|
||||||
.path = "/",
|
|
||||||
});
|
|
||||||
|
|
||||||
fba.reset();
|
|
||||||
try testCase(alloc, (
|
|
||||||
\\POST / HTTP/1.1
|
|
||||||
\\
|
|
||||||
\\
|
|
||||||
), .{
|
|
||||||
.method = .POST,
|
|
||||||
.headers = try _test.makeHeaders(alloc, .{}),
|
|
||||||
.path = "/",
|
|
||||||
});
|
|
||||||
|
|
||||||
fba.reset();
|
|
||||||
try testCase(alloc, (
|
|
||||||
\\HEAD / HTTP/1.1
|
|
||||||
\\Authorization: bearer <token>
|
|
||||||
\\
|
|
||||||
\\
|
|
||||||
), .{
|
|
||||||
.method = .HEAD,
|
|
||||||
.headers = try _test.makeHeaders(alloc, .{
|
|
||||||
.{ "Authorization", "bearer <token>" },
|
|
||||||
}),
|
|
||||||
.path = "/",
|
|
||||||
});
|
|
||||||
|
|
||||||
fba.reset();
|
|
||||||
try testCase(alloc, (
|
|
||||||
\\POST /nonsense HTTP/1.1
|
|
||||||
\\Authorization: bearer <token>
|
|
||||||
\\Content-Length: 5
|
|
||||||
\\
|
|
||||||
\\12345
|
|
||||||
), .{
|
|
||||||
.method = .POST,
|
|
||||||
.headers = try _test.makeHeaders(alloc, .{
|
|
||||||
.{ "Authorization", "bearer <token>" },
|
|
||||||
.{ "Content-Length", "5" },
|
|
||||||
}),
|
|
||||||
.path = "/nonsense",
|
|
||||||
.body = "12345",
|
|
||||||
});
|
|
||||||
|
|
||||||
fba.reset();
|
|
||||||
try std.testing.expectError(
|
|
||||||
error.MethodNotImplemented,
|
|
||||||
testCase(alloc, (
|
|
||||||
\\FOO /nonsense HTTP/1.1
|
|
||||||
\\
|
|
||||||
\\
|
|
||||||
), undefined),
|
|
||||||
);
|
|
||||||
|
|
||||||
fba.reset();
|
|
||||||
try std.testing.expectError(
|
|
||||||
error.MethodNotImplemented,
|
|
||||||
testCase(alloc, (
|
|
||||||
\\FOOBARBAZ /nonsense HTTP/1.1
|
|
||||||
\\
|
|
||||||
\\
|
|
||||||
), undefined),
|
|
||||||
);
|
|
||||||
|
|
||||||
fba.reset();
|
|
||||||
try std.testing.expectError(
|
|
||||||
error.RequestUriTooLong,
|
|
||||||
testCase(alloc, (
|
|
||||||
\\GET /
|
|
||||||
++ ("a" ** 2048)), undefined),
|
|
||||||
);
|
|
||||||
|
|
||||||
fba.reset();
|
|
||||||
try std.testing.expectError(
|
|
||||||
error.UnknownProtocol,
|
|
||||||
testCase(alloc, (
|
|
||||||
\\GET /nonsense SPECIALHTTP/1.1
|
|
||||||
\\
|
|
||||||
\\
|
|
||||||
), undefined),
|
|
||||||
);
|
|
||||||
|
|
||||||
fba.reset();
|
|
||||||
try std.testing.expectError(
|
|
||||||
error.UnknownProtocol,
|
|
||||||
testCase(alloc, (
|
|
||||||
\\GET /nonsense JSON/1.1
|
|
||||||
\\
|
|
||||||
\\
|
|
||||||
), undefined),
|
|
||||||
);
|
|
||||||
|
|
||||||
fba.reset();
|
|
||||||
try std.testing.expectError(
|
|
||||||
error.HttpVersionNotSupported,
|
|
||||||
testCase(alloc, (
|
|
||||||
\\GET /nonsense HTTP/1.9
|
|
||||||
\\
|
|
||||||
\\
|
|
||||||
), undefined),
|
|
||||||
);
|
|
||||||
|
|
||||||
fba.reset();
|
|
||||||
try std.testing.expectError(
|
|
||||||
error.HttpVersionNotSupported,
|
|
||||||
testCase(alloc, (
|
|
||||||
\\GET /nonsense HTTP/8.1
|
|
||||||
\\
|
|
||||||
\\
|
|
||||||
), undefined),
|
|
||||||
);
|
|
||||||
|
|
||||||
fba.reset();
|
|
||||||
try std.testing.expectError(
|
|
||||||
error.BadRequest,
|
|
||||||
testCase(alloc, (
|
|
||||||
\\GET /nonsense HTTP/blah blah blah
|
|
||||||
\\
|
|
||||||
\\
|
|
||||||
), undefined),
|
|
||||||
);
|
|
||||||
|
|
||||||
fba.reset();
|
|
||||||
try std.testing.expectError(
|
|
||||||
error.BadRequest,
|
|
||||||
testCase(alloc, (
|
|
||||||
\\GET /nonsense HTTP/1/1
|
|
||||||
\\
|
|
||||||
\\
|
|
||||||
), undefined),
|
|
||||||
);
|
|
||||||
|
|
||||||
fba.reset();
|
|
||||||
try std.testing.expectError(
|
|
||||||
error.BadRequest,
|
|
||||||
testCase(alloc, (
|
|
||||||
\\GET /nonsense HTTP/1/1
|
|
||||||
\\
|
|
||||||
\\
|
|
||||||
), undefined),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
282
src/http/request/test_parser.zig
Normal file
282
src/http/request/test_parser.zig
Normal file
|
@ -0,0 +1,282 @@
|
||||||
|
const std = @import("std");
|
||||||
|
const parser = @import("./parser.zig");
|
||||||
|
const http = @import("../lib.zig");
|
||||||
|
const t = std.testing;
|
||||||
|
|
||||||
|
const test_case = struct {
|
||||||
|
fn parse(text: []const u8, expected: struct {
|
||||||
|
protocol: http.Protocol = .http_1_1,
|
||||||
|
method: http.Method = .GET,
|
||||||
|
headers: []const std.meta.Tuple(&.{ []const u8, []const u8 }) = &.{},
|
||||||
|
uri: []const u8 = "",
|
||||||
|
}) !void {
|
||||||
|
var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(text) };
|
||||||
|
var actual = try parser.parse(t.allocator, stream.reader());
|
||||||
|
defer actual.parseFree(t.allocator);
|
||||||
|
|
||||||
|
try t.expectEqual(expected.protocol, actual.protocol);
|
||||||
|
try t.expectEqual(expected.method, actual.method);
|
||||||
|
try t.expectEqualStrings(expected.uri, actual.uri);
|
||||||
|
|
||||||
|
try t.expectEqual(expected.headers.len, actual.headers.count());
|
||||||
|
for (expected.headers) |hdr| {
|
||||||
|
if (actual.headers.get(hdr[0])) |val| {
|
||||||
|
try t.expectEqualStrings(hdr[1], val);
|
||||||
|
} else {
|
||||||
|
std.debug.print("Error: Header {s} expected to be present, was not.\n", .{hdr[0]});
|
||||||
|
try t.expect(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fn toCrlf(comptime str: []const u8) []const u8 {
|
||||||
|
comptime {
|
||||||
|
var buf: [str.len * 2]u8 = undefined;
|
||||||
|
|
||||||
|
@setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2
|
||||||
|
|
||||||
|
var buf_len: usize = 0;
|
||||||
|
for (str) |ch| {
|
||||||
|
if (ch == '\n') {
|
||||||
|
buf[buf_len] = '\r';
|
||||||
|
buf_len += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
buf[buf_len] = ch;
|
||||||
|
buf_len += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf[0..buf_len];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test "HTTP/1.x parse - No body" {
|
||||||
|
try test_case.parse(
|
||||||
|
toCrlf(
|
||||||
|
\\GET / HTTP/1.1
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
),
|
||||||
|
.{
|
||||||
|
.protocol = .http_1_1,
|
||||||
|
.method = .GET,
|
||||||
|
.uri = "/",
|
||||||
|
},
|
||||||
|
);
|
||||||
|
try test_case.parse(
|
||||||
|
toCrlf(
|
||||||
|
\\POST / HTTP/1.1
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
),
|
||||||
|
.{
|
||||||
|
.protocol = .http_1_1,
|
||||||
|
.method = .POST,
|
||||||
|
.uri = "/",
|
||||||
|
},
|
||||||
|
);
|
||||||
|
try test_case.parse(
|
||||||
|
toCrlf(
|
||||||
|
\\GET /url/abcd HTTP/1.1
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
),
|
||||||
|
.{
|
||||||
|
.protocol = .http_1_1,
|
||||||
|
.method = .GET,
|
||||||
|
.uri = "/url/abcd",
|
||||||
|
},
|
||||||
|
);
|
||||||
|
try test_case.parse(
|
||||||
|
toCrlf(
|
||||||
|
\\GET / HTTP/1.0
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
),
|
||||||
|
.{
|
||||||
|
.protocol = .http_1_0,
|
||||||
|
.method = .GET,
|
||||||
|
.uri = "/",
|
||||||
|
},
|
||||||
|
);
|
||||||
|
try test_case.parse(
|
||||||
|
toCrlf(
|
||||||
|
\\GET /url/abcd HTTP/1.1
|
||||||
|
\\Content-Type: application/json
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
),
|
||||||
|
.{
|
||||||
|
.protocol = .http_1_1,
|
||||||
|
.method = .GET,
|
||||||
|
.uri = "/url/abcd",
|
||||||
|
.headers = &.{.{ "Content-Type", "application/json" }},
|
||||||
|
},
|
||||||
|
);
|
||||||
|
try test_case.parse(
|
||||||
|
toCrlf(
|
||||||
|
\\GET /url/abcd HTTP/1.1
|
||||||
|
\\Content-Type: application/json
|
||||||
|
\\Authorization: bearer <token>
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
),
|
||||||
|
.{
|
||||||
|
.protocol = .http_1_1,
|
||||||
|
.method = .GET,
|
||||||
|
.uri = "/url/abcd",
|
||||||
|
.headers = &.{
|
||||||
|
.{ "Content-Type", "application/json" },
|
||||||
|
.{ "Authorization", "bearer <token>" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// Test without CRLF
|
||||||
|
try test_case.parse(
|
||||||
|
\\GET /url/abcd HTTP/1.1
|
||||||
|
\\Content-Type: application/json
|
||||||
|
\\Authorization: bearer <token>
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
,
|
||||||
|
.{
|
||||||
|
.protocol = .http_1_1,
|
||||||
|
.method = .GET,
|
||||||
|
.uri = "/url/abcd",
|
||||||
|
.headers = &.{
|
||||||
|
.{ "Content-Type", "application/json" },
|
||||||
|
.{ "Authorization", "bearer <token>" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
);
|
||||||
|
try test_case.parse(
|
||||||
|
\\POST / HTTP/1.1
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
,
|
||||||
|
.{
|
||||||
|
.protocol = .http_1_1,
|
||||||
|
.method = .POST,
|
||||||
|
.uri = "/",
|
||||||
|
},
|
||||||
|
);
|
||||||
|
try test_case.parse(
|
||||||
|
toCrlf(
|
||||||
|
\\GET / HTTP/1.2
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
),
|
||||||
|
.{
|
||||||
|
.protocol = .http_1_x,
|
||||||
|
.method = .GET,
|
||||||
|
.uri = "/",
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
test "HTTP/1.x parse - unsupported protocol" {
|
||||||
|
try t.expectError(error.UnknownProtocol, test_case.parse(
|
||||||
|
\\GET / JSON/1.1
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
,
|
||||||
|
.{},
|
||||||
|
));
|
||||||
|
try t.expectError(error.UnknownProtocol, test_case.parse(
|
||||||
|
\\GET / SOMETHINGELSE/3.5
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
,
|
||||||
|
.{},
|
||||||
|
));
|
||||||
|
try t.expectError(error.UnknownProtocol, test_case.parse(
|
||||||
|
\\GET / /1.1
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
,
|
||||||
|
.{},
|
||||||
|
));
|
||||||
|
try t.expectError(error.HttpVersionNotSupported, test_case.parse(
|
||||||
|
\\GET / HTTP/2.1
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
,
|
||||||
|
.{},
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
test "HTTP/1.x parse - Unknown method" {
|
||||||
|
try t.expectError(error.MethodNotImplemented, test_case.parse(
|
||||||
|
\\ABCD / HTTP/1.1
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
,
|
||||||
|
.{},
|
||||||
|
));
|
||||||
|
try t.expectError(error.MethodNotImplemented, test_case.parse(
|
||||||
|
\\PATCHPATCHPATCH / HTTP/1.1
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
,
|
||||||
|
.{},
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
test "HTTP/1.x parse - Too long" {
|
||||||
|
try t.expectError(error.RequestUriTooLong, test_case.parse(
|
||||||
|
std.fmt.comptimePrint("GET {s} HTTP/1.1\n\n", .{"a" ** 8192}),
|
||||||
|
.{},
|
||||||
|
));
|
||||||
|
try t.expectError(error.HeaderLineTooLong, test_case.parse(
|
||||||
|
std.fmt.comptimePrint("GET / HTTP/1.1\r\n{s}: abcd", .{"a" ** 8192}),
|
||||||
|
.{},
|
||||||
|
));
|
||||||
|
try t.expectError(error.HeaderLineTooLong, test_case.parse(
|
||||||
|
std.fmt.comptimePrint("GET / HTTP/1.1\r\nabcd: {s}", .{"a" ** 8192}),
|
||||||
|
.{},
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
test "HTTP/1.x parse - bad requests" {
|
||||||
|
try t.expectError(error.BadRequest, test_case.parse(
|
||||||
|
\\GET / HTTP/1.1 blah blah
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
,
|
||||||
|
.{},
|
||||||
|
));
|
||||||
|
try t.expectError(error.BadRequest, test_case.parse(
|
||||||
|
\\GET / HTTP/1.1
|
||||||
|
\\abcd : lksjdfkl
|
||||||
|
\\
|
||||||
|
,
|
||||||
|
.{},
|
||||||
|
));
|
||||||
|
try t.expectError(error.BadRequest, test_case.parse(
|
||||||
|
\\GET / HTTP/1.1
|
||||||
|
\\ lksjfklsjdfklj
|
||||||
|
\\
|
||||||
|
,
|
||||||
|
.{},
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
test "HTTP/1.x parse - Headers" {
|
||||||
|
try test_case.parse(
|
||||||
|
toCrlf(
|
||||||
|
\\GET /url/abcd HTTP/1.1
|
||||||
|
\\Content-Type: application/json
|
||||||
|
\\Content-Type: application/xml
|
||||||
|
\\
|
||||||
|
\\
|
||||||
|
),
|
||||||
|
.{
|
||||||
|
.protocol = .http_1_1,
|
||||||
|
.method = .GET,
|
||||||
|
.uri = "/url/abcd",
|
||||||
|
.headers = &.{.{ "Content-Type", "application/json, application/xml" }},
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
|
@ -3,13 +3,14 @@ const util = @import("util");
|
||||||
const http = @import("./lib.zig");
|
const http = @import("./lib.zig");
|
||||||
|
|
||||||
const response = @import("./server/response.zig");
|
const response = @import("./server/response.zig");
|
||||||
|
const request = @import("./request.zig");
|
||||||
|
|
||||||
pub const Response = struct {
|
pub const Response = struct {
|
||||||
alloc: std.mem.Allocator,
|
alloc: std.mem.Allocator,
|
||||||
stream: std.net.Stream,
|
stream: std.net.Stream,
|
||||||
should_close: bool = false,
|
should_close: bool = false,
|
||||||
pub const Stream = response.ResponseStream(std.net.Stream.Writer);
|
pub const Stream = response.ResponseStream(std.net.Stream.Writer);
|
||||||
pub fn open(self: *Response, status: http.Status, headers: *const http.Headers) !Stream {
|
pub fn open(self: *Response, status: http.Status, headers: *const http.Fields) !Stream {
|
||||||
if (headers.get("Connection")) |hdr| {
|
if (headers.get("Connection")) |hdr| {
|
||||||
if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true;
|
if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true;
|
||||||
}
|
}
|
||||||
|
@ -17,7 +18,7 @@ pub const Response = struct {
|
||||||
return response.open(self.alloc, self.stream.writer(), headers, status);
|
return response.open(self.alloc, self.stream.writer(), headers, status);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Headers) !std.net.Stream {
|
pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Fields) !std.net.Stream {
|
||||||
try response.writeRequestHeader(self.stream.writer(), headers, status);
|
try response.writeRequestHeader(self.stream.writer(), headers, status);
|
||||||
return self.stream;
|
return self.stream;
|
||||||
}
|
}
|
||||||
|
@ -26,10 +27,6 @@ pub const Response = struct {
|
||||||
const Request = http.Request;
|
const Request = http.Request;
|
||||||
const request_buf_size = 1 << 16;
|
const request_buf_size = 1 << 16;
|
||||||
|
|
||||||
pub fn Handler(comptime Ctx: type) type {
|
|
||||||
return fn (Ctx, Request, *Response) void;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: anytype, alloc: std.mem.Allocator) !void {
|
pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: anytype, alloc: std.mem.Allocator) !void {
|
||||||
// TODO: Timeouts
|
// TODO: Timeouts
|
||||||
while (true) {
|
while (true) {
|
||||||
|
@ -37,7 +34,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a
|
||||||
var arena = std.heap.ArenaAllocator.init(alloc);
|
var arena = std.heap.ArenaAllocator.init(alloc);
|
||||||
defer arena.deinit();
|
defer arena.deinit();
|
||||||
|
|
||||||
const req = Request.parse(arena.allocator(), conn.stream.reader(), conn.address) catch |err| {
|
var req = request.parse(arena.allocator(), conn.stream.reader()) catch |err| {
|
||||||
return handleError(conn.stream.writer(), err) catch {};
|
return handleError(conn.stream.writer(), err) catch {};
|
||||||
};
|
};
|
||||||
std.log.debug("done parsing", .{});
|
std.log.debug("done parsing", .{});
|
||||||
|
@ -47,7 +44,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a
|
||||||
.stream = conn.stream,
|
.stream = conn.stream,
|
||||||
};
|
};
|
||||||
|
|
||||||
handler(ctx, req, &res);
|
handler(ctx, &req, &res);
|
||||||
std.log.debug("done handling", .{});
|
std.log.debug("done handling", .{});
|
||||||
|
|
||||||
if (req.headers.get("Connection")) |hdr| {
|
if (req.headers.get("Connection")) |hdr| {
|
||||||
|
|
|
@ -2,20 +2,20 @@ const std = @import("std");
|
||||||
const http = @import("../lib.zig");
|
const http = @import("../lib.zig");
|
||||||
|
|
||||||
const Status = http.Status;
|
const Status = http.Status;
|
||||||
const Headers = http.Headers;
|
const Fields = http.Fields;
|
||||||
|
|
||||||
const chunk_size = 16 * 1024;
|
const chunk_size = 16 * 1024;
|
||||||
pub fn open(
|
pub fn open(
|
||||||
alloc: std.mem.Allocator,
|
alloc: std.mem.Allocator,
|
||||||
writer: anytype,
|
writer: anytype,
|
||||||
headers: *const Headers,
|
headers: *const Fields,
|
||||||
status: Status,
|
status: Status,
|
||||||
) !ResponseStream(@TypeOf(writer)) {
|
) !ResponseStream(@TypeOf(writer)) {
|
||||||
const buf = try alloc.alloc(u8, chunk_size);
|
const buf = try alloc.alloc(u8, chunk_size);
|
||||||
errdefer alloc.free(buf);
|
errdefer alloc.free(buf);
|
||||||
|
|
||||||
try writeStatusLine(writer, status);
|
try writeStatusLine(writer, status);
|
||||||
try writeHeaders(writer, headers);
|
try writeFields(writer, headers);
|
||||||
|
|
||||||
return ResponseStream(@TypeOf(writer)){
|
return ResponseStream(@TypeOf(writer)){
|
||||||
.allocator = alloc,
|
.allocator = alloc,
|
||||||
|
@ -25,9 +25,9 @@ pub fn open(
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn writeRequestHeader(writer: anytype, headers: *const Headers, status: Status) !void {
|
pub fn writeRequestHeader(writer: anytype, headers: *const Fields, status: Status) !void {
|
||||||
try writeStatusLine(writer, status);
|
try writeStatusLine(writer, status);
|
||||||
try writeHeaders(writer, headers);
|
try writeFields(writer, headers);
|
||||||
try writer.writeAll("\r\n");
|
try writer.writeAll("\r\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ fn writeStatusLine(writer: anytype, status: Status) !void {
|
||||||
try writer.print("HTTP/1.1 {} {s}\r\n", .{ @enumToInt(status), status_text });
|
try writer.print("HTTP/1.1 {} {s}\r\n", .{ @enumToInt(status), status_text });
|
||||||
}
|
}
|
||||||
|
|
||||||
fn writeHeaders(writer: anytype, headers: *const Headers) !void {
|
fn writeFields(writer: anytype, headers: *const Fields) !void {
|
||||||
var iter = headers.iterator();
|
var iter = headers.iterator();
|
||||||
while (iter.next()) |header| {
|
while (iter.next()) |header| {
|
||||||
for (header.value_ptr.*) |ch| {
|
for (header.value_ptr.*) |ch| {
|
||||||
|
@ -65,7 +65,7 @@ pub fn ResponseStream(comptime BaseWriter: type) type {
|
||||||
|
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
base_writer: BaseWriter,
|
base_writer: BaseWriter,
|
||||||
headers: *const Headers,
|
headers: *const Fields,
|
||||||
buffer: []u8,
|
buffer: []u8,
|
||||||
buffer_pos: usize = 0,
|
buffer_pos: usize = 0,
|
||||||
chunked: bool = false,
|
chunked: bool = false,
|
||||||
|
@ -95,7 +95,6 @@ pub fn ResponseStream(comptime BaseWriter: type) type {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std.debug.print("{}\n", .{cursor});
|
|
||||||
self.writeToBuffer(bytes[cursor .. cursor + remaining_in_chunk]);
|
self.writeToBuffer(bytes[cursor .. cursor + remaining_in_chunk]);
|
||||||
cursor += remaining_in_chunk;
|
cursor += remaining_in_chunk;
|
||||||
try self.flushChunk();
|
try self.flushChunk();
|
||||||
|
@ -177,7 +176,7 @@ const _tests = struct {
|
||||||
test "ResponseStream no headers empty body" {
|
test "ResponseStream no headers empty body" {
|
||||||
var buffer: [test_buffer_size]u8 = undefined;
|
var buffer: [test_buffer_size]u8 = undefined;
|
||||||
var test_stream = std.io.fixedBufferStream(&buffer);
|
var test_stream = std.io.fixedBufferStream(&buffer);
|
||||||
var headers = Headers.init(std.testing.allocator);
|
var headers = Fields.init(std.testing.allocator);
|
||||||
defer headers.deinit();
|
defer headers.deinit();
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -205,7 +204,7 @@ const _tests = struct {
|
||||||
test "ResponseStream empty body" {
|
test "ResponseStream empty body" {
|
||||||
var buffer: [test_buffer_size]u8 = undefined;
|
var buffer: [test_buffer_size]u8 = undefined;
|
||||||
var test_stream = std.io.fixedBufferStream(&buffer);
|
var test_stream = std.io.fixedBufferStream(&buffer);
|
||||||
var headers = Headers.init(std.testing.allocator);
|
var headers = Fields.init(std.testing.allocator);
|
||||||
defer headers.deinit();
|
defer headers.deinit();
|
||||||
|
|
||||||
try headers.put("Content-Type", "text/plain");
|
try headers.put("Content-Type", "text/plain");
|
||||||
|
@ -236,7 +235,7 @@ const _tests = struct {
|
||||||
test "ResponseStream not 200 OK" {
|
test "ResponseStream not 200 OK" {
|
||||||
var buffer: [test_buffer_size]u8 = undefined;
|
var buffer: [test_buffer_size]u8 = undefined;
|
||||||
var test_stream = std.io.fixedBufferStream(&buffer);
|
var test_stream = std.io.fixedBufferStream(&buffer);
|
||||||
var headers = Headers.init(std.testing.allocator);
|
var headers = Fields.init(std.testing.allocator);
|
||||||
defer headers.deinit();
|
defer headers.deinit();
|
||||||
|
|
||||||
try headers.put("Content-Type", "text/plain");
|
try headers.put("Content-Type", "text/plain");
|
||||||
|
@ -266,7 +265,7 @@ const _tests = struct {
|
||||||
test "ResponseStream small body" {
|
test "ResponseStream small body" {
|
||||||
var buffer: [test_buffer_size]u8 = undefined;
|
var buffer: [test_buffer_size]u8 = undefined;
|
||||||
var test_stream = std.io.fixedBufferStream(&buffer);
|
var test_stream = std.io.fixedBufferStream(&buffer);
|
||||||
var headers = Headers.init(std.testing.allocator);
|
var headers = Fields.init(std.testing.allocator);
|
||||||
defer headers.deinit();
|
defer headers.deinit();
|
||||||
|
|
||||||
try headers.put("Content-Type", "text/plain");
|
try headers.put("Content-Type", "text/plain");
|
||||||
|
@ -300,7 +299,7 @@ const _tests = struct {
|
||||||
test "ResponseStream large body" {
|
test "ResponseStream large body" {
|
||||||
var buffer: [test_buffer_size]u8 = undefined;
|
var buffer: [test_buffer_size]u8 = undefined;
|
||||||
var test_stream = std.io.fixedBufferStream(&buffer);
|
var test_stream = std.io.fixedBufferStream(&buffer);
|
||||||
var headers = Headers.init(std.testing.allocator);
|
var headers = Fields.init(std.testing.allocator);
|
||||||
defer headers.deinit();
|
defer headers.deinit();
|
||||||
|
|
||||||
try headers.put("Content-Type", "text/plain");
|
try headers.put("Content-Type", "text/plain");
|
||||||
|
@ -341,7 +340,7 @@ const _tests = struct {
|
||||||
test "ResponseStream large body ending on chunk boundary" {
|
test "ResponseStream large body ending on chunk boundary" {
|
||||||
var buffer: [test_buffer_size]u8 = undefined;
|
var buffer: [test_buffer_size]u8 = undefined;
|
||||||
var test_stream = std.io.fixedBufferStream(&buffer);
|
var test_stream = std.io.fixedBufferStream(&buffer);
|
||||||
var headers = Headers.init(std.testing.allocator);
|
var headers = Fields.init(std.testing.allocator);
|
||||||
defer headers.deinit();
|
defer headers.deinit();
|
||||||
|
|
||||||
try headers.put("Content-Type", "text/plain");
|
try headers.put("Content-Type", "text/plain");
|
||||||
|
|
|
@ -23,21 +23,21 @@ const Opcode = enum(u4) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Response) !Socket {
|
pub fn handshake(alloc: std.mem.Allocator, req: *http.Request, res: *http.Response) !Socket {
|
||||||
const upgrade = req.headers.get("Upgrade") orelse return error.BadHandshake;
|
const upgrade = req.headers.get("Upgrade") orelse return error.BadHandshake;
|
||||||
const connection = req.headers.get("Connection") orelse return error.BadHandshake;
|
const connection = req.headers.get("Connection") orelse return error.BadHandshake;
|
||||||
if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake;
|
if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake;
|
||||||
if (std.ascii.indexOfIgnoreCase(connection, "Upgrade") == null) return error.BadHandshake;
|
if (std.ascii.indexOfIgnoreCase(connection, "Upgrade") == null) return error.BadHandshake;
|
||||||
|
|
||||||
const key_hdr = req.headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake;
|
const key_hdr = req.headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake;
|
||||||
if (try std.base64.standard.Decoder.calcSizeForSlice(key_hdr) != 16) return error.BadHandshake;
|
if ((try std.base64.standard.Decoder.calcSizeForSlice(key_hdr)) != 16) return error.BadHandshake;
|
||||||
var key: [16]u8 = undefined;
|
var key: [16]u8 = undefined;
|
||||||
std.base64.standard.Decoder.decode(&key, key_hdr) catch return error.BadHandshake;
|
std.base64.standard.Decoder.decode(&key, key_hdr) catch return error.BadHandshake;
|
||||||
|
|
||||||
const version = req.headers.get("Sec-WebSocket-version") orelse return error.BadHandshake;
|
const version = req.headers.get("Sec-WebSocket-version") orelse return error.BadHandshake;
|
||||||
if (!std.mem.eql(u8, "13", version)) return error.BadHandshake;
|
if (!std.mem.eql(u8, "13", version)) return error.BadHandshake;
|
||||||
|
|
||||||
var headers = http.Headers.init(alloc);
|
var headers = http.Fields.init(alloc);
|
||||||
defer headers.deinit();
|
defer headers.deinit();
|
||||||
|
|
||||||
try headers.put("Upgrade", "websocket");
|
try headers.put("Upgrade", "websocket");
|
||||||
|
@ -51,7 +51,7 @@ pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Respons
|
||||||
var hash_encoded: [std.base64.standard.Encoder.calcSize(Sha1.digest_length)]u8 = undefined;
|
var hash_encoded: [std.base64.standard.Encoder.calcSize(Sha1.digest_length)]u8 = undefined;
|
||||||
_ = std.base64.standard.Encoder.encode(&hash_encoded, &hash);
|
_ = std.base64.standard.Encoder.encode(&hash_encoded, &hash);
|
||||||
try headers.put("Sec-WebSocket-Accept", &hash_encoded);
|
try headers.put("Sec-WebSocket-Accept", &hash_encoded);
|
||||||
const stream = try res.upgrade(.switching_protcols, &headers);
|
const stream = try res.upgrade(.switching_protocols, &headers);
|
||||||
|
|
||||||
return Socket{ .stream = stream };
|
return Socket{ .stream = stream };
|
||||||
}
|
}
|
||||||
|
@ -164,15 +164,15 @@ fn writeFrame(writer: anytype, header: FrameInfo, buf: []const u8) !void {
|
||||||
const initial_len: u7 = if (header.len < 126)
|
const initial_len: u7 = if (header.len < 126)
|
||||||
@intCast(u7, header.len)
|
@intCast(u7, header.len)
|
||||||
else if (std.math.cast(u16, header.len)) |_|
|
else if (std.math.cast(u16, header.len)) |_|
|
||||||
126
|
@as(u7, 126)
|
||||||
else
|
else
|
||||||
127;
|
@as(u7, 127);
|
||||||
|
|
||||||
var hdr_buf = [2]u8{ 0, 0 };
|
var hdr_buf = [2]u8{ 0, 0 };
|
||||||
hdr_buf[0] |= if (header.is_final) 0b1000_0000 else 0;
|
hdr_buf[0] |= if (header.is_final) @as(u8, 0b1000_0000) else 0;
|
||||||
hdr_buf[0] |= @as(u8, header.rsv) << 4;
|
hdr_buf[0] |= @as(u8, header.rsv) << 4;
|
||||||
hdr_buf[0] |= @enumToInt(header.opcode);
|
hdr_buf[0] |= @enumToInt(header.opcode);
|
||||||
hdr_buf[1] |= if (header.masking_key) |_| 0b1000_0000 else 0;
|
hdr_buf[1] |= if (header.masking_key) |_| @as(u8, 0b1000_0000) else 0;
|
||||||
hdr_buf[1] |= initial_len;
|
hdr_buf[1] |= initial_len;
|
||||||
try writer.writeAll(&hdr_buf);
|
try writer.writeAll(&hdr_buf);
|
||||||
if (initial_len == 126)
|
if (initial_len == 126)
|
||||||
|
|
3
src/http/test.zig
Normal file
3
src/http/test.zig
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
test {
|
||||||
|
_ = @import("./request/test_parser.zig");
|
||||||
|
}
|
|
@ -13,10 +13,11 @@ pub const invites = @import("./controllers/invites.zig");
|
||||||
pub const users = @import("./controllers/users.zig");
|
pub const users = @import("./controllers/users.zig");
|
||||||
pub const notes = @import("./controllers/notes.zig");
|
pub const notes = @import("./controllers/notes.zig");
|
||||||
pub const streaming = @import("./controllers/streaming.zig");
|
pub const streaming = @import("./controllers/streaming.zig");
|
||||||
|
pub const timelines = @import("./controllers/timelines.zig");
|
||||||
|
|
||||||
pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void {
|
pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response, alloc: std.mem.Allocator) void {
|
||||||
// TODO: hashmaps?
|
// TODO: hashmaps?
|
||||||
var response = Response{ .headers = http.Headers.init(alloc), .res = res };
|
var response = Response{ .headers = http.Fields.init(alloc), .res = res };
|
||||||
defer response.headers.deinit();
|
defer response.headers.deinit();
|
||||||
|
|
||||||
const found = routeRequestInternal(api_source, req, &response, alloc);
|
const found = routeRequestInternal(api_source, req, &response, alloc);
|
||||||
|
@ -24,7 +25,7 @@ pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response,
|
||||||
if (!found) response.status(.not_found) catch {};
|
if (!found) response.status(.not_found) catch {};
|
||||||
}
|
}
|
||||||
|
|
||||||
fn routeRequestInternal(api_source: anytype, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool {
|
fn routeRequestInternal(api_source: anytype, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool {
|
||||||
inline for (routes) |route| {
|
inline for (routes) |route| {
|
||||||
if (Context(route).matchAndHandle(api_source, req, res, alloc)) return true;
|
if (Context(route).matchAndHandle(api_source, req, res, alloc)) return true;
|
||||||
}
|
}
|
||||||
|
@ -42,6 +43,8 @@ const routes = .{
|
||||||
notes.create,
|
notes.create,
|
||||||
notes.get,
|
notes.get,
|
||||||
streaming.streaming,
|
streaming.streaming,
|
||||||
|
timelines.global,
|
||||||
|
timelines.local,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn Context(comptime Route: type) type {
|
pub fn Context(comptime Route: type) type {
|
||||||
|
@ -58,18 +61,21 @@ pub fn Context(comptime Route: type) type {
|
||||||
// leave it as a simple string instead of void
|
// leave it as a simple string instead of void
|
||||||
pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void;
|
pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void;
|
||||||
|
|
||||||
base_request: http.Request,
|
base_request: *http.Request,
|
||||||
|
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
|
|
||||||
method: http.Method,
|
method: http.Method,
|
||||||
uri: []const u8,
|
uri: []const u8,
|
||||||
headers: http.Headers,
|
headers: http.Fields,
|
||||||
|
|
||||||
args: Args,
|
args: Args,
|
||||||
body: Body,
|
body: Body,
|
||||||
query: Query,
|
query: Query,
|
||||||
|
|
||||||
|
// TODO
|
||||||
|
body_buf: ?[]const u8 = null,
|
||||||
|
|
||||||
fn parseArgs(path: []const u8) ?Args {
|
fn parseArgs(path: []const u8) ?Args {
|
||||||
var args: Args = undefined;
|
var args: Args = undefined;
|
||||||
var path_iter = util.PathIter.from(path);
|
var path_iter = util.PathIter.from(path);
|
||||||
|
@ -94,7 +100,7 @@ pub fn Context(comptime Route: type) type {
|
||||||
@compileError("Unsupported Type " ++ @typeName(T));
|
@compileError("Unsupported Type " ++ @typeName(T));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn matchAndHandle(api_source: *api.ApiSource, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool {
|
pub fn matchAndHandle(api_source: *api.ApiSource, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool {
|
||||||
if (req.method != Route.method) return false;
|
if (req.method != Route.method) return false;
|
||||||
var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?');
|
var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?');
|
||||||
var args: Args = parseArgs(path) orelse return false;
|
var args: Args = parseArgs(path) orelse return false;
|
||||||
|
@ -112,6 +118,8 @@ pub fn Context(comptime Route: type) type {
|
||||||
.query = undefined,
|
.query = undefined,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std.log.debug("Matched route {s}", .{path});
|
||||||
|
|
||||||
self.prepareAndHandle(api_source, req, res);
|
self.prepareAndHandle(api_source, req, res);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -129,7 +137,7 @@ pub fn Context(comptime Route: type) type {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prepareAndHandle(self: *Self, api_source: anytype, req: http.Request, response: *Response) void {
|
fn prepareAndHandle(self: *Self, api_source: anytype, req: *http.Request, response: *Response) void {
|
||||||
self.parseBody(req) catch |err| return errorHandler(response, .bad_request, err);
|
self.parseBody(req) catch |err| return errorHandler(response, .bad_request, err);
|
||||||
defer self.freeBody();
|
defer self.freeBody();
|
||||||
|
|
||||||
|
@ -141,16 +149,20 @@ pub fn Context(comptime Route: type) type {
|
||||||
self.handle(response, &api_conn);
|
self.handle(response, &api_conn);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parseBody(self: *Self, req: http.Request) !void {
|
fn parseBody(self: *Self, req: *http.Request) !void {
|
||||||
if (Body != void) {
|
if (Body != void) {
|
||||||
const body = req.body orelse return error.NoBody;
|
var stream = req.body orelse return error.NoBody;
|
||||||
|
const body = try stream.reader().readAllAlloc(self.allocator, 1 << 16);
|
||||||
|
errdefer self.allocator.free(body);
|
||||||
self.body = try json_utils.parse(Body, body, self.allocator);
|
self.body = try json_utils.parse(Body, body, self.allocator);
|
||||||
|
self.body_buf = body;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn freeBody(self: *Self) void {
|
fn freeBody(self: *Self) void {
|
||||||
if (Body != void) {
|
if (Body != void) {
|
||||||
json_utils.parseFree(self.body, self.allocator);
|
json_utils.parseFree(self.body, self.allocator);
|
||||||
|
self.allocator.free(self.body_buf.?);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,7 +203,7 @@ pub fn Context(comptime Route: type) type {
|
||||||
|
|
||||||
pub const Response = struct {
|
pub const Response = struct {
|
||||||
const Self = @This();
|
const Self = @This();
|
||||||
headers: http.Headers,
|
headers: http.Fields,
|
||||||
res: *http.Response,
|
res: *http.Response,
|
||||||
opened: bool = false,
|
opened: bool = false,
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
const api = @import("api");
|
const api = @import("api");
|
||||||
|
const std = @import("std");
|
||||||
|
|
||||||
pub const login = struct {
|
pub const login = struct {
|
||||||
pub const method = .POST;
|
pub const method = .POST;
|
||||||
|
@ -12,6 +13,8 @@ pub const login = struct {
|
||||||
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
|
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
|
||||||
const token = try srv.login(req.body.username, req.body.password);
|
const token = try srv.login(req.body.username, req.body.password);
|
||||||
|
|
||||||
|
std.log.debug("{any}", .{res.headers});
|
||||||
|
|
||||||
try res.json(.ok, token);
|
try res.json(.ok, token);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const api = @import("api");
|
const api = @import("api");
|
||||||
const util = @import("util");
|
const util = @import("util");
|
||||||
|
const query_utils = @import("../query.zig");
|
||||||
|
|
||||||
const QueryArgs = api.CommunityQueryArgs;
|
const QueryArgs = api.CommunityQueryArgs;
|
||||||
const Uuid = util.Uuid;
|
const Uuid = util.Uuid;
|
||||||
|
@ -25,89 +26,18 @@ pub const query = struct {
|
||||||
pub const method = .GET;
|
pub const method = .GET;
|
||||||
pub const path = "/communities";
|
pub const path = "/communities";
|
||||||
|
|
||||||
// NOTE: This has to match QueryArgs
|
pub const Query = QueryArgs;
|
||||||
// TODO: Support union fields in query strings natively, so we don't
|
|
||||||
// have to keep these in sync
|
|
||||||
pub const Query = struct {
|
|
||||||
const OrderBy = QueryArgs.OrderBy;
|
|
||||||
const Direction = QueryArgs.Direction;
|
|
||||||
const PageDirection = QueryArgs.PageDirection;
|
|
||||||
|
|
||||||
// Max items to fetch
|
|
||||||
max_items: usize = 20,
|
|
||||||
|
|
||||||
// Selection filters
|
|
||||||
owner_id: ?Uuid = null,
|
|
||||||
like: ?[]const u8 = null,
|
|
||||||
created_before: ?DateTime = null,
|
|
||||||
created_after: ?DateTime = null,
|
|
||||||
|
|
||||||
// Ordering parameter
|
|
||||||
order_by: OrderBy = .created_at,
|
|
||||||
direction: Direction = .ascending,
|
|
||||||
|
|
||||||
// the `prev` struct has a slightly different format to QueryArgs
|
|
||||||
prev: struct {
|
|
||||||
id: ?Uuid = null,
|
|
||||||
|
|
||||||
// Only one of these can be present, and must match order_by above
|
|
||||||
name: ?[]const u8 = null,
|
|
||||||
host: ?[]const u8 = null,
|
|
||||||
created_at: ?DateTime = null,
|
|
||||||
} = .{},
|
|
||||||
|
|
||||||
// What direction to scan the page window
|
|
||||||
page_direction: PageDirection = .forward,
|
|
||||||
|
|
||||||
pub const format = formatQueryParams;
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
|
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
|
||||||
const q = req.query;
|
const results = try srv.queryCommunities(req.query);
|
||||||
const query_matches = if (q.prev.id) |_| switch (q.order_by) {
|
|
||||||
.name => q.prev.name != null and q.prev.host == null and q.prev.created_at == null,
|
|
||||||
.host => q.prev.name == null and q.prev.host != null and q.prev.created_at == null,
|
|
||||||
.created_at => q.prev.name == null and q.prev.host == null and q.prev.created_at != null,
|
|
||||||
} else (q.prev.name == null and q.prev.host == null and q.prev.created_at == null);
|
|
||||||
|
|
||||||
if (!query_matches) return res.err(.bad_request, "prev.* parameters do not match", {});
|
|
||||||
|
|
||||||
const prev_arg: ?QueryArgs.Prev = if (q.prev.id) |id| .{
|
|
||||||
.id = id,
|
|
||||||
.order_val = switch (q.order_by) {
|
|
||||||
.name => .{ .name = q.prev.name.? },
|
|
||||||
.host => .{ .host = q.prev.host.? },
|
|
||||||
.created_at => .{ .created_at = q.prev.created_at.? },
|
|
||||||
},
|
|
||||||
} else null;
|
|
||||||
|
|
||||||
const query_args = QueryArgs{
|
|
||||||
.max_items = q.max_items,
|
|
||||||
.owner_id = q.owner_id,
|
|
||||||
.like = q.like,
|
|
||||||
.created_before = q.created_before,
|
|
||||||
.created_after = q.created_after,
|
|
||||||
|
|
||||||
.order_by = q.order_by,
|
|
||||||
.direction = q.direction,
|
|
||||||
|
|
||||||
.prev = prev_arg,
|
|
||||||
|
|
||||||
.page_direction = q.page_direction,
|
|
||||||
};
|
|
||||||
|
|
||||||
const results = try srv.queryCommunities(query_args);
|
|
||||||
|
|
||||||
var link = std.ArrayList(u8).init(req.allocator);
|
var link = std.ArrayList(u8).init(req.allocator);
|
||||||
const link_writer = link.writer();
|
const link_writer = link.writer();
|
||||||
defer link.deinit();
|
defer link.deinit();
|
||||||
|
|
||||||
const next_page = queryArgsToControllerQuery(results.next_page);
|
try writeLink(link_writer, srv.community, path, results.next_page, "next");
|
||||||
const prev_page = queryArgsToControllerQuery(results.prev_page);
|
|
||||||
|
|
||||||
try writeLink(link_writer, srv.community, path, next_page, "next");
|
|
||||||
try link_writer.writeByte(',');
|
try link_writer.writeByte(',');
|
||||||
try writeLink(link_writer, srv.community, path, prev_page, "prev");
|
try writeLink(link_writer, srv.community, path, results.prev_page, "prev");
|
||||||
|
|
||||||
try res.headers.put("Link", link.items);
|
try res.headers.put("Link", link.items);
|
||||||
|
|
||||||
|
@ -129,7 +59,7 @@ fn writeLink(
|
||||||
.{ @tagName(community.scheme), community.host, path },
|
.{ @tagName(community.scheme), community.host, path },
|
||||||
);
|
);
|
||||||
|
|
||||||
try std.fmt.format(writer, "{}", .{params});
|
try query_utils.formatQuery(params, writer);
|
||||||
|
|
||||||
try std.fmt.format(
|
try std.fmt.format(
|
||||||
writer,
|
writer,
|
||||||
|
@ -137,70 +67,3 @@ fn writeLink(
|
||||||
.{rel},
|
.{rel},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn formatQueryParams(
|
|
||||||
params: anytype,
|
|
||||||
comptime fmt: []const u8,
|
|
||||||
opt: std.fmt.FormatOptions,
|
|
||||||
writer: anytype,
|
|
||||||
) !void {
|
|
||||||
if (comptime std.meta.trait.is(.Pointer)(@TypeOf(params))) {
|
|
||||||
return formatQueryParams(params.*, fmt, opt, writer);
|
|
||||||
}
|
|
||||||
|
|
||||||
return formatRecursive("", params, writer);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn formatRecursive(comptime prefix: []const u8, params: anytype, writer: anytype) !void {
|
|
||||||
inline for (std.meta.fields(@TypeOf(params))) |field| {
|
|
||||||
const val = @field(params, field.name);
|
|
||||||
const is_optional = comptime std.meta.trait.is(.Optional)(field.field_type);
|
|
||||||
const present = if (comptime is_optional) val != null else true;
|
|
||||||
if (present) {
|
|
||||||
const unwrapped = if (is_optional) val.? else val;
|
|
||||||
// TODO: percent-encode this
|
|
||||||
_ = try switch (@TypeOf(unwrapped)) {
|
|
||||||
[]const u8 => blk: {
|
|
||||||
break :blk std.fmt.format(writer, "{s}{s}={s}&", .{ prefix, field.name, unwrapped });
|
|
||||||
},
|
|
||||||
|
|
||||||
else => |U| blk: {
|
|
||||||
if (comptime std.meta.trait.isContainer(U) and std.meta.trait.hasFn("format")(U)) {
|
|
||||||
break :blk std.fmt.format(writer, "{s}{s}={}&", .{ prefix, field.name, unwrapped });
|
|
||||||
}
|
|
||||||
|
|
||||||
break :blk switch (@typeInfo(U)) {
|
|
||||||
.Enum => std.fmt.format(writer, "{s}{s}={s}&", .{ prefix, field.name, @tagName(unwrapped) }),
|
|
||||||
.Struct => formatRecursive(field.name ++ ".", unwrapped, writer),
|
|
||||||
else => std.fmt.format(writer, "{s}{s}={}&", .{ prefix, field.name, unwrapped }),
|
|
||||||
};
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn queryArgsToControllerQuery(args: QueryArgs) query.Query {
|
|
||||||
var result = query.Query{
|
|
||||||
.max_items = args.max_items,
|
|
||||||
.owner_id = args.owner_id,
|
|
||||||
.like = args.like,
|
|
||||||
.created_before = args.created_before,
|
|
||||||
.created_after = args.created_after,
|
|
||||||
.order_by = args.order_by,
|
|
||||||
.direction = args.direction,
|
|
||||||
.prev = .{},
|
|
||||||
.page_direction = args.page_direction,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (args.prev) |prev| {
|
|
||||||
result.prev = .{
|
|
||||||
.id = prev.id,
|
|
||||||
.name = if (prev.order_val == .name) prev.order_val.name else null,
|
|
||||||
.host = if (prev.order_val == .host) prev.order_val.host else null,
|
|
||||||
.created_at = if (prev.order_val == .created_at) prev.order_val.created_at else null,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
21
src/main/controllers/timelines.zig
Normal file
21
src/main/controllers/timelines.zig
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
pub const global = struct {
|
||||||
|
pub const method = .GET;
|
||||||
|
pub const path = "/timelines/global";
|
||||||
|
|
||||||
|
pub fn handler(_: anytype, res: anytype, srv: anytype) !void {
|
||||||
|
const results = try srv.globalTimeline();
|
||||||
|
|
||||||
|
try res.json(.ok, results);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const local = struct {
|
||||||
|
pub const method = .GET;
|
||||||
|
pub const path = "/timelines/local";
|
||||||
|
|
||||||
|
pub fn handler(_: anytype, res: anytype, srv: anytype) !void {
|
||||||
|
const results = try srv.localTimeline();
|
||||||
|
|
||||||
|
try res.json(.ok, results);
|
||||||
|
}
|
||||||
|
};
|
|
@ -499,7 +499,7 @@ fn parseInternal(
|
||||||
if (!fields_seen[i]) {
|
if (!fields_seen[i]) {
|
||||||
if (field.default_value) |default_ptr| {
|
if (field.default_value) |default_ptr| {
|
||||||
if (!field.is_comptime) {
|
if (!field.is_comptime) {
|
||||||
const default = @ptrCast(*const field.field_type, default_ptr).*;
|
const default = @ptrCast(*align(1) const field.field_type, default_ptr).*;
|
||||||
@field(r, field.name) = default;
|
@field(r, field.name) = default;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -87,7 +87,7 @@ fn thread_main(src: *api.ApiSource, srv: *std.net.StreamServer) void {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle(ctx: anytype, req: http.Request, res: *http.Response) void {
|
fn handle(ctx: anytype, req: *http.Request, res: *http.Response) void {
|
||||||
c.routeRequest(ctx.src, req, res, ctx.allocator);
|
c.routeRequest(ctx.src, req, res, ctx.allocator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -53,7 +53,7 @@ pub fn up(db: anytype) !void {
|
||||||
std.log.info("Running migration {s}", .{migration.name});
|
std.log.info("Running migration {s}", .{migration.name});
|
||||||
try execScript(tx, migration.up, gpa.allocator());
|
try execScript(tx, migration.up, gpa.allocator());
|
||||||
try tx.insert("migration", .{
|
try tx.insert("migration", .{
|
||||||
.name = migration.name,
|
.name = @as([]const u8, migration.name),
|
||||||
.applied_at = DateTime.now(),
|
.applied_at = DateTime.now(),
|
||||||
}, gpa.allocator());
|
}, gpa.allocator());
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,37 +71,132 @@ const QueryIter = @import("util").QueryIter;
|
||||||
/// TODO: values are currently case-sensitive, and are not url-decoded properly.
|
/// TODO: values are currently case-sensitive, and are not url-decoded properly.
|
||||||
/// This should be fixed.
|
/// This should be fixed.
|
||||||
pub fn parseQuery(comptime T: type, query: []const u8) !T {
|
pub fn parseQuery(comptime T: type, query: []const u8) !T {
|
||||||
//if (!std.meta.trait.isContainer(T)) @compileError("T must be a struct");
|
if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct");
|
||||||
var iter = QueryIter.from(query);
|
var iter = QueryIter.from(query);
|
||||||
var result = T{};
|
|
||||||
|
var fields = Intermediary(T){};
|
||||||
while (iter.next()) |pair| {
|
while (iter.next()) |pair| {
|
||||||
try parseQueryPair(T, &result, pair.key, pair.value);
|
// TODO: Hash map
|
||||||
|
inline for (std.meta.fields(Intermediary(T))) |field| {
|
||||||
|
if (std.ascii.eqlIgnoreCase(field.name[2..], pair.key)) {
|
||||||
|
@field(fields, field.name) = if (pair.value) |v| .{ .value = v } else .{ .no_value = {} };
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else std.log.debug("unknown param {s}", .{pair.key});
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return (try parse(T, "", "", fields)).?;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parseQueryPair(comptime T: type, result: *T, key: []const u8, value: ?[]const u8) !void {
|
fn parseScalar(comptime T: type, comptime name: []const u8, fields: anytype) !?T {
|
||||||
const key_part = std.mem.sliceTo(key, '.');
|
const param = @field(fields, name);
|
||||||
const field_idx = std.meta.stringToEnum(std.meta.FieldEnum(T), key_part) orelse return error.UnknownField;
|
return switch (param) {
|
||||||
|
.not_specified => null,
|
||||||
|
.no_value => try parseQueryValue(T, null),
|
||||||
|
.value => |v| try parseQueryValue(T, v),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
inline for (std.meta.fields(T)) |info, idx| {
|
fn parse(comptime T: type, comptime prefix: []const u8, comptime name: []const u8, fields: anytype) !?T {
|
||||||
if (@enumToInt(field_idx) == idx) {
|
if (comptime isScalar(T)) return parseScalar(T, prefix ++ "." ++ name, fields);
|
||||||
if (comptime isScalar(info.field_type)) {
|
switch (@typeInfo(T)) {
|
||||||
if (key_part.len == key.len) {
|
.Union => |info| {
|
||||||
@field(result, info.name) = try parseQueryValue(info.field_type, value);
|
var result: ?T = null;
|
||||||
return;
|
inline for (info.fields) |field| {
|
||||||
} else {
|
const F = field.field_type;
|
||||||
return error.UnknownField;
|
|
||||||
|
const maybe_value = try parse(F, prefix, field.name, fields);
|
||||||
|
if (maybe_value) |value| {
|
||||||
|
if (result != null) return error.DuplicateUnionField;
|
||||||
|
|
||||||
|
result = @unionInit(T, field.name, value);
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
const remaining = std.mem.trimLeft(u8, key[key_part.len..], ".");
|
|
||||||
return try parseQueryPair(info.field_type, &@field(result, info.name), remaining, value);
|
|
||||||
}
|
}
|
||||||
|
std.log.debug("{any}", .{result});
|
||||||
|
return result;
|
||||||
|
},
|
||||||
|
|
||||||
|
.Struct => |info| {
|
||||||
|
var result: T = undefined;
|
||||||
|
var fields_specified: usize = 0;
|
||||||
|
|
||||||
|
inline for (info.fields) |field| {
|
||||||
|
const F = field.field_type;
|
||||||
|
|
||||||
|
var maybe_value: ?F = null;
|
||||||
|
if (try parse(F, prefix ++ "." ++ name, field.name, fields)) |v| {
|
||||||
|
maybe_value = v;
|
||||||
|
} else if (field.default_value) |default| {
|
||||||
|
maybe_value = @ptrCast(*const F, @alignCast(@alignOf(F), default)).*;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (maybe_value) |v| {
|
||||||
|
fields_specified += 1;
|
||||||
|
@field(result, field.name) = v;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return error.UnknownField;
|
if (fields_specified == 0) {
|
||||||
|
return null;
|
||||||
|
} else if (fields_specified != info.fields.len) {
|
||||||
|
return error.PartiallySpecifiedStruct;
|
||||||
|
} else {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// Only applies to non-scalar optionals
|
||||||
|
.Optional => |info| return try parse(info.child, prefix, name, fields),
|
||||||
|
|
||||||
|
else => @compileError("tmp"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn recursiveFieldPaths(comptime T: type, comptime prefix: []const u8) []const []const u8 {
|
||||||
|
comptime {
|
||||||
|
if (std.meta.trait.is(.Optional)(T)) return recursiveFieldPaths(std.meta.Child(T), prefix);
|
||||||
|
|
||||||
|
var fields: []const []const u8 = &.{};
|
||||||
|
|
||||||
|
for (std.meta.fields(T)) |f| {
|
||||||
|
const full_name = prefix ++ f.name;
|
||||||
|
|
||||||
|
if (isScalar(f.field_type)) {
|
||||||
|
fields = fields ++ @as([]const []const u8, &.{full_name});
|
||||||
|
} else {
|
||||||
|
const field_prefix = if (@typeInfo(f.field_type) == .Union) prefix else full_name ++ ".";
|
||||||
|
fields = fields ++ recursiveFieldPaths(f.field_type, field_prefix);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fields;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const QueryParam = union(enum) {
|
||||||
|
not_specified: void,
|
||||||
|
no_value: void,
|
||||||
|
value: []const u8,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn Intermediary(comptime T: type) type {
|
||||||
|
const field_names = recursiveFieldPaths(T, "..");
|
||||||
|
|
||||||
|
var fields: [field_names.len]std.builtin.Type.StructField = undefined;
|
||||||
|
for (field_names) |name, i| fields[i] = .{
|
||||||
|
.name = name,
|
||||||
|
.field_type = QueryParam,
|
||||||
|
.default_value = &QueryParam{ .not_specified = {} },
|
||||||
|
.is_comptime = false,
|
||||||
|
.alignment = @alignOf(QueryParam),
|
||||||
|
};
|
||||||
|
|
||||||
|
return @Type(.{ .Struct = .{
|
||||||
|
.layout = .Auto,
|
||||||
|
.fields = &fields,
|
||||||
|
.decls = &.{},
|
||||||
|
.is_tuple = false,
|
||||||
|
} });
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parseQueryValue(comptime T: type, value: ?[]const u8) !T {
|
fn parseQueryValue(comptime T: type, value: ?[]const u8) !T {
|
||||||
|
@ -157,6 +252,50 @@ fn isScalar(comptime T: type) bool {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn formatQuery(params: anytype, writer: anytype) !void {
|
||||||
|
try format("", "", params, writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn formatScalar(comptime name: []const u8, val: anytype, writer: anytype) !void {
|
||||||
|
const T = @TypeOf(val);
|
||||||
|
if (comptime std.meta.trait.isZigString(T)) return std.fmt.format(writer, "{s}={s}&", .{ name, val });
|
||||||
|
_ = try switch (@typeInfo(T)) {
|
||||||
|
.Enum => std.fmt.format(writer, "{s}={s}&", .{ name, @tagName(val) }),
|
||||||
|
.Optional => if (val) |v| formatScalar(name, v, writer),
|
||||||
|
else => std.fmt.format(writer, "{s}={}&", .{ name, val }),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytype, writer: anytype) !void {
|
||||||
|
const T = @TypeOf(params);
|
||||||
|
const eff_prefix = if (prefix.len == 0) "" else prefix ++ ".";
|
||||||
|
if (comptime isScalar(T)) return formatScalar(eff_prefix ++ name, params, writer);
|
||||||
|
|
||||||
|
switch (@typeInfo(T)) {
|
||||||
|
.Struct => {
|
||||||
|
inline for (std.meta.fields(T)) |field| {
|
||||||
|
const val = @field(params, field.name);
|
||||||
|
try format(eff_prefix ++ name, field.name, val, writer);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
.Union => {
|
||||||
|
//inline for (std.meta.tags(T)) |tag| {
|
||||||
|
inline for (std.meta.fields(T)) |field| {
|
||||||
|
const tag = @field(std.meta.Tag(T), field.name);
|
||||||
|
const tag_name = field.name;
|
||||||
|
if (@as(std.meta.Tag(T), params) == tag) {
|
||||||
|
const val = @field(params, tag_name);
|
||||||
|
try format(prefix, tag_name, val, writer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
.Optional => {
|
||||||
|
if (params) |p| try format(prefix, name, p, writer);
|
||||||
|
},
|
||||||
|
else => @compileError("Unsupported query type"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
test {
|
test {
|
||||||
const TestQuery = struct {
|
const TestQuery = struct {
|
||||||
int: usize = 3,
|
int: usize = 3,
|
||||||
|
|
|
@ -68,7 +68,7 @@ pub const QueryOptions = struct {
|
||||||
// do not require allocators for prep. If an allocator is needed but not
|
// do not require allocators for prep. If an allocator is needed but not
|
||||||
// provided, `error.AllocatorRequired` will be returned.
|
// provided, `error.AllocatorRequired` will be returned.
|
||||||
// Only used with the postgres backend.
|
// Only used with the postgres backend.
|
||||||
prep_allocator: ?Allocator = null,
|
allocator: ?Allocator = null,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Turns a value into its appropriate textual value (or null)
|
// Turns a value into its appropriate textual value (or null)
|
||||||
|
|
|
@ -180,14 +180,25 @@ pub const Db = struct {
|
||||||
const format_text = 0;
|
const format_text = 0;
|
||||||
const format_binary = 1;
|
const format_binary = 1;
|
||||||
pub fn exec(self: Db, sql: [:0]const u8, args: anytype, opt: common.QueryOptions) common.ExecError!Results {
|
pub fn exec(self: Db, sql: [:0]const u8, args: anytype, opt: common.QueryOptions) common.ExecError!Results {
|
||||||
const alloc = opt.prep_allocator;
|
const alloc = opt.allocator;
|
||||||
const result = blk: {
|
const result = blk: {
|
||||||
if (@TypeOf(args) != void and args.len > 0) {
|
if (@TypeOf(args) != void and args.len > 0) {
|
||||||
var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired);
|
var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired);
|
||||||
defer arena.deinit();
|
defer arena.deinit();
|
||||||
const params = try arena.allocator().alloc(?[*:0]const u8, args.len);
|
const params = try arena.allocator().alloc(?[*:0]const u8, args.len);
|
||||||
inline for (args) |arg, i| {
|
// TODO: The following is a fix for the stage1 compiler. remove this
|
||||||
params[i] = if (try common.prepareParamText(&arena, arg)) |slice|
|
//inline for (args) |arg, i| {
|
||||||
|
inline for (std.meta.fields(@TypeOf(args))) |field, i| {
|
||||||
|
const arg = @field(args, field.name);
|
||||||
|
|
||||||
|
// The stage1 compiler has issues with runtime branches that in any
|
||||||
|
// way involve compile time values
|
||||||
|
const maybe_slice = if (@import("builtin").zig_backend == .stage1)
|
||||||
|
common.prepareParamText(&arena, arg) catch unreachable
|
||||||
|
else
|
||||||
|
try common.prepareParamText(&arena, arg);
|
||||||
|
|
||||||
|
params[i] = if (maybe_slice) |slice|
|
||||||
slice.ptr
|
slice.ptr
|
||||||
else
|
else
|
||||||
null;
|
null;
|
||||||
|
|
|
@ -118,7 +118,10 @@ pub const Db = struct {
|
||||||
};
|
};
|
||||||
|
|
||||||
if (@TypeOf(args) != void) {
|
if (@TypeOf(args) != void) {
|
||||||
inline for (args) |arg, i| {
|
// TODO: Fix for stage1 compiler
|
||||||
|
//inline for (args) |arg, i| {
|
||||||
|
inline for (std.meta.fields(@TypeOf(args))) |field, i| {
|
||||||
|
const arg = @field(args, field.name);
|
||||||
// SQLite treats $NNN args as having the name NNN, not index NNN.
|
// SQLite treats $NNN args as having the name NNN, not index NNN.
|
||||||
// As such, if you reference $2 and not $1 in your query (such as
|
// As such, if you reference $2 and not $1 in your query (such as
|
||||||
// when dynamically constructing queries), it could assign $2 the
|
// when dynamically constructing queries), it could assign $2 the
|
||||||
|
|
|
@ -25,6 +25,54 @@ pub const Engine = enum {
|
||||||
sqlite,
|
sqlite,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Helper for building queries at runtime. All constituent parts of the
|
||||||
|
/// query should be defined at comptime, however the choice of whether
|
||||||
|
/// or not to include them can occur at runtime.
|
||||||
|
pub const QueryBuilder = struct {
|
||||||
|
array: std.ArrayList(u8),
|
||||||
|
where_clauses_appended: usize = 0,
|
||||||
|
|
||||||
|
pub fn init(alloc: std.mem.Allocator) QueryBuilder {
|
||||||
|
return QueryBuilder{ .array = std.ArrayList(u8).init(alloc) };
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *const QueryBuilder) void {
|
||||||
|
self.array.deinit();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a chunk of sql to the query without processing
|
||||||
|
pub fn appendSlice(self: *QueryBuilder, comptime sql: []const u8) !void {
|
||||||
|
try self.array.appendSlice(sql);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a where clause to the query. Clauses are assumed to be components
|
||||||
|
/// in an overall expression in Conjunctive Normal Form (AND of OR's).
|
||||||
|
/// https://en.wikipedia.org/wiki/Conjunctive_normal_form
|
||||||
|
/// All calls to andWhere must be contiguous, that is, they cannot be
|
||||||
|
/// interspersed with calls to appendSlice
|
||||||
|
pub fn andWhere(self: *QueryBuilder, comptime clause: []const u8) !void {
|
||||||
|
if (self.where_clauses_appended == 0) {
|
||||||
|
try self.array.appendSlice("WHERE ");
|
||||||
|
} else {
|
||||||
|
try self.array.appendSlice(" AND ");
|
||||||
|
}
|
||||||
|
|
||||||
|
try self.array.appendSlice(clause);
|
||||||
|
self.where_clauses_appended += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn str(self: *const QueryBuilder) []const u8 {
|
||||||
|
return self.array.items;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn terminate(self: *QueryBuilder) ![:0]const u8 {
|
||||||
|
std.debug.assert(self.array.items.len != 0);
|
||||||
|
if (self.array.items[self.array.items.len - 1] != 0) try self.array.append(0);
|
||||||
|
|
||||||
|
return std.meta.assumeSentinel(self.array.items, 0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// TODO: make this suck less
|
// TODO: make this suck less
|
||||||
pub const Config = union(Engine) {
|
pub const Config = union(Engine) {
|
||||||
postgres: struct {
|
postgres: struct {
|
||||||
|
@ -410,7 +458,7 @@ fn Tx(comptime tx_level: u8) type {
|
||||||
args: anytype,
|
args: anytype,
|
||||||
alloc: ?Allocator,
|
alloc: ?Allocator,
|
||||||
) QueryError!Results(RowType) {
|
) QueryError!Results(RowType) {
|
||||||
return self.queryWithOptions(RowType, sql, args, .{ .prep_allocator = alloc });
|
return self.queryWithOptions(RowType, sql, args, .{ .allocator = alloc });
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Runs a query to completion and returns a row of results, unless the query
|
/// Runs a query to completion and returns a row of results, unless the query
|
||||||
|
@ -439,6 +487,45 @@ fn Tx(comptime tx_level: u8) type {
|
||||||
return row;
|
return row;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Runs a query to completion and returns the results as a slice
|
||||||
|
pub fn queryRowsWithOptions(
|
||||||
|
self: Self,
|
||||||
|
comptime RowType: type,
|
||||||
|
q: [:0]const u8,
|
||||||
|
args: anytype,
|
||||||
|
max_items: ?usize,
|
||||||
|
options: QueryOptions,
|
||||||
|
) QueryRowError![]RowType {
|
||||||
|
var results = try self.queryWithOptions(RowType, q, args, options);
|
||||||
|
defer results.finish();
|
||||||
|
|
||||||
|
const alloc = options.allocator orelse return error.AllocatorRequired;
|
||||||
|
|
||||||
|
var result_array = std.ArrayList(RowType).init(alloc);
|
||||||
|
errdefer result_array.deinit();
|
||||||
|
if (max_items) |max| try result_array.ensureTotalCapacity(max);
|
||||||
|
|
||||||
|
errdefer for (result_array.items) |r| util.deepFree(alloc, r);
|
||||||
|
|
||||||
|
var too_many: bool = false;
|
||||||
|
while (try results.row(alloc)) |row| {
|
||||||
|
errdefer util.deepFree(alloc, row);
|
||||||
|
if (max_items) |max| {
|
||||||
|
if (result_array.items.len >= max) {
|
||||||
|
util.deepFree(alloc, row);
|
||||||
|
too_many = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try result_array.append(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (too_many) return error.TooManyRows;
|
||||||
|
|
||||||
|
return result_array.toOwnedSlice();
|
||||||
|
}
|
||||||
|
|
||||||
// Inserts a single value into a table
|
// Inserts a single value into a table
|
||||||
pub fn insert(
|
pub fn insert(
|
||||||
self: Self,
|
self: Self,
|
||||||
|
@ -455,7 +542,7 @@ fn Tx(comptime tx_level: u8) type {
|
||||||
inline for (fields) |field, i| {
|
inline for (fields) |field, i| {
|
||||||
// This causes a compiler crash. Why?
|
// This causes a compiler crash. Why?
|
||||||
//const F = field.field_type;
|
//const F = field.field_type;
|
||||||
const F = @TypeOf(@field(std.mem.zeroes(ValueType), field.name));
|
const F = @TypeOf(@field(value, field.name));
|
||||||
// causes issues if F is @TypeOf(null), use dummy type
|
// causes issues if F is @TypeOf(null), use dummy type
|
||||||
types[i] = if (F == @TypeOf(null)) ?i64 else F;
|
types[i] = if (F == @TypeOf(null)) ?i64 else F;
|
||||||
table_spec = comptime (table_spec ++ field.name ++ ",");
|
table_spec = comptime (table_spec ++ field.name ++ ",");
|
||||||
|
@ -499,7 +586,7 @@ fn Tx(comptime tx_level: u8) type {
|
||||||
alloc: ?std.mem.Allocator,
|
alloc: ?std.mem.Allocator,
|
||||||
comptime check_tx: bool,
|
comptime check_tx: bool,
|
||||||
) !void {
|
) !void {
|
||||||
var results = try self.runSql(sql, args, .{ .prep_allocator = alloc }, check_tx);
|
var results = try self.runSql(sql, args, .{ .allocator = alloc }, check_tx);
|
||||||
defer results.finish();
|
defer results.finish();
|
||||||
|
|
||||||
while (try results.row()) |_| {}
|
while (try results.row()) |_| {}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue