Compare commits

...

17 commits

28 changed files with 1101 additions and 673 deletions

View file

@ -53,16 +53,16 @@ pub fn build(b: *std.build.Builder) void {
exe.linkSystemLibrary("pq");
exe.linkLibC();
const util_tests = b.addTest("src/util/lib.zig");
const http_tests = b.addTest("src/http/lib.zig");
const sql_tests = b.addTest("src/sql/lib.zig");
//const util_tests = b.addTest("src/util/lib.zig");
const http_tests = b.addTest("src/http/test.zig");
//const sql_tests = b.addTest("src/sql/lib.zig");
http_tests.addPackage(util_pkg);
sql_tests.addPackage(util_pkg);
//sql_tests.addPackage(util_pkg);
const unit_tests = b.step("unit-tests", "Run tests");
unit_tests.dependOn(&util_tests.step);
//unit_tests.dependOn(&util_tests.step);
unit_tests.dependOn(&http_tests.step);
unit_tests.dependOn(&sql_tests.step);
//unit_tests.dependOn(&sql_tests.step);
const api_integration = b.addTest("./tests/api_integration/lib.zig");
api_integration.addPackage(sql_pkg);

View file

@ -276,7 +276,10 @@ fn ApiConn(comptime DbConn: type) type {
username,
password,
self.community.id,
.{ .invite_id = if (maybe_invite) |inv| inv.id else null, .email = opt.email },
.{
.invite_id = if (maybe_invite) |inv| @as(?Uuid, inv.id) else null,
.email = opt.email,
},
self.arena.allocator(),
);
@ -348,5 +351,19 @@ fn ApiConn(comptime DbConn: type) type {
if (!self.isAdmin()) return error.PermissionDenied;
return try services.communities.query(self.db, args, self.arena.allocator());
}
pub fn globalTimeline(self: *Self) ![]services.notes.Note {
const result = try services.notes.query(self.db, .{}, self.arena.allocator());
return result.items;
}
pub fn localTimeline(self: *Self) ![]services.notes.Note {
const result = try services.notes.query(
self.db,
.{ .community_id = self.community.id },
self.arena.allocator(),
);
return result.items;
}
};
}

View file

@ -94,7 +94,8 @@ pub const Actor = struct {
created_at: DateTime,
};
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !Actor {
pub const GetError = error{ NotFound, DatabaseFailure };
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Actor {
return db.queryRow(
Actor,
\\SELECT

View file

@ -0,0 +1,16 @@
const std = @import("std");
const util = @import("util");
pub const Direction = enum {
ascending,
descending,
pub const jsonStringify = util.jsonSerializeEnumAsString;
};
pub const PageDirection = enum {
forward,
backward,
pub const jsonStringify = util.jsonSerializeEnumAsString;
};

View file

@ -2,6 +2,7 @@ const std = @import("std");
const builtin = @import("builtin");
const util = @import("util");
const sql = @import("sql");
const common = @import("./common.zig");
const Uuid = util.Uuid;
const DateTime = util.DateTime;
@ -82,11 +83,12 @@ pub fn create(db: anytype, origin: []const u8, options: CreateOptions, alloc: st
else => return error.DatabaseFailure,
}
const name = options.name orelse host;
db.insert("community", .{
.id = id,
.owner_id = null,
.host = host,
.name = options.name orelse host,
.name = name,
.scheme = scheme,
.kind = options.kind,
.created_at = DateTime.now(),
@ -153,20 +155,8 @@ pub const QueryArgs = struct {
pub const jsonStringify = util.jsonSerializeEnumAsString;
};
pub const Direction = enum {
ascending,
descending,
pub const jsonStringify = util.jsonSerializeEnumAsString;
};
pub const PageDirection = enum {
forward,
backward,
pub const jsonStringify = util.jsonSerializeEnumAsString;
};
pub const Direction = common.Direction;
pub const PageDirection = common.PageDirection;
pub const Prev = std.meta.Child(std.meta.fieldInfo(QueryArgs, .prev).field_type);
pub const OrderVal = std.meta.fieldInfo(Prev, .order_val).field_type;
@ -211,30 +201,6 @@ pub const QueryResult = struct {
next_page: QueryArgs,
};
const QueryBuilder = struct {
array: std.ArrayList(u8),
where_clauses_appended: usize = 0,
pub fn init(alloc: std.mem.Allocator) QueryBuilder {
return QueryBuilder{ .array = std.ArrayList(u8).init(alloc) };
}
pub fn deinit(self: *const QueryBuilder) void {
self.array.deinit();
}
pub fn andWhere(self: *QueryBuilder, clause: []const u8) !void {
if (self.where_clauses_appended == 0) {
try self.array.appendSlice("WHERE ");
} else {
try self.array.appendSlice(" AND ");
}
try self.array.appendSlice(clause);
self.where_clauses_appended += 1;
}
};
const max_max_items = 100;
pub const QueryError = error{
@ -246,7 +212,7 @@ pub const QueryError = error{
// arguments.
// `args.max_items` is only a request, and fewer entries may be returned.
pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResult {
var builder = QueryBuilder.init(alloc);
var builder = sql.QueryBuilder.init(alloc);
defer builder.deinit();
try builder.array.appendSlice(
@ -266,21 +232,21 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul
if (args.prev) |prev| {
if (prev.order_val != args.order_by) return error.PageArgMismatch;
try builder.andWhere(switch (args.order_by) {
.name => "(name, id)",
.host => "(host, id)",
.created_at => "(created_at, id)",
});
_ = try builder.array.appendSlice(switch (args.direction) {
switch (args.order_by) {
.name => try builder.andWhere("(name, id)"),
.host => try builder.andWhere("(host, id)"),
.created_at => try builder.andWhere("(created_at, id)"),
}
switch (args.direction) {
.ascending => switch (args.page_direction) {
.forward => " > ",
.backward => " < ",
.forward => try builder.appendSlice(" > "),
.backward => try builder.appendSlice(" < "),
},
.descending => switch (args.page_direction) {
.forward => " < ",
.backward => " > ",
.forward => try builder.appendSlice(" < "),
.backward => try builder.appendSlice(" > "),
},
});
}
_ = try builder.array.appendSlice("($5, $6)");
}
@ -297,57 +263,52 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul
_ = try builder.array.appendSlice("\nLIMIT $7");
const query_args = .{
args.owner_id,
args.like,
args.created_before,
args.created_after,
if (args.prev) |prev| prev.order_val else null,
if (args.prev) |prev| prev.id else null,
max_items,
const query_args = blk: {
const ord_val =
if (args.prev) |prev| @as(?QueryArgs.OrderVal, prev.order_val) else null;
const id =
if (args.prev) |prev| @as(?Uuid, prev.id) else null;
break :blk .{
args.owner_id,
args.like,
args.created_before,
args.created_after,
ord_val,
id,
max_items,
};
};
try builder.array.append(0);
var results = try db.queryWithOptions(
var results = try db.queryRowsWithOptions(
Community,
std.meta.assumeSentinel(builder.array.items, 0),
query_args,
.{ .prep_allocator = alloc, .ignore_unused_arguments = true },
max_items,
.{ .allocator = alloc, .ignore_unused_arguments = true },
);
defer results.finish();
const result_buf = try alloc.alloc(Community, args.max_items);
errdefer alloc.free(result_buf);
var count: usize = 0;
errdefer for (result_buf[0..count]) |c| util.deepFree(alloc, c);
for (result_buf) |*c| {
c.* = (try results.row(alloc)) orelse break;
count += 1;
}
errdefer util.deepFree(alloc, results);
var next_page = args;
var prev_page = args;
prev_page.page_direction = .backward;
next_page.page_direction = .forward;
if (count != 0) {
if (results.len != 0) {
prev_page.prev = .{
.id = result_buf[0].id,
.order_val = getOrderVal(result_buf[0], args.order_by),
.id = results[0].id,
.order_val = getOrderVal(results[0], args.order_by),
};
next_page.prev = .{
.id = result_buf[count - 1].id,
.order_val = getOrderVal(result_buf[count - 1], args.order_by),
.id = results[results.len - 1].id,
.order_val = getOrderVal(results[results.len - 1], args.order_by),
};
}
// TODO: This will give incorrect links on an empty page
return QueryResult{
.items = result_buf[0..count],
.items = results,
.next_page = next_page,
.prev_page = prev_page,

View file

@ -71,7 +71,7 @@ pub fn create(db: anytype, created_by: Uuid, community_id: ?Uuid, options: Invit
.max_uses = options.max_uses,
.created_at = created_at,
.expires_at = if (options.lifespan) |lifespan|
created_at.add(lifespan)
@as(?DateTime, created_at.add(lifespan))
else
null,

View file

@ -1,6 +1,7 @@
const std = @import("std");
const util = @import("util");
const sql = @import("sql");
const common = @import("./common.zig");
const Uuid = util.Uuid;
const DateTime = util.DateTime;
@ -42,7 +43,7 @@ const selectStarFromNote = std.fmt.comptimePrint(
\\SELECT {s}
\\FROM note
\\
, .{util.comptimeJoin(",", std.meta.fieldNames(Note))});
, .{util.comptimeJoinWithPrefix(",", "note.", std.meta.fieldNames(Note))});
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Note {
return db.queryRow(
Note,
@ -57,3 +58,108 @@ pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Note {
else => error.DatabaseFailure,
};
}
const max_max_items = 100;
pub const QueryArgs = struct {
pub const PageDirection = common.PageDirection;
pub const Prev = std.meta.Child(std.meta.field(@This(), .prev).field_type);
max_items: usize = 20,
created_before: ?DateTime = null,
created_after: ?DateTime = null,
community_id: ?Uuid = null,
prev: ?struct {
id: Uuid,
created_at: DateTime,
} = null,
page_direction: PageDirection = .forward,
};
pub const QueryResult = struct {
items: []Note,
prev_page: QueryArgs,
next_page: QueryArgs,
};
pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResult {
var builder = sql.QueryBuilder.init(alloc);
defer builder.deinit();
try builder.appendSlice(selectStarFromNote ++
\\ JOIN actor ON actor.id = note.author_id
\\
);
if (args.created_before != null) try builder.andWhere("note.created_at < $1");
if (args.created_after != null) try builder.andWhere("note.created_at > $2");
if (args.prev != null) {
try builder.andWhere("(note.created_at, note.id)");
switch (args.page_direction) {
.forward => try builder.appendSlice(" < "),
.backward => try builder.appendSlice(" > "),
}
try builder.appendSlice("($3, $4)");
}
if (args.community_id != null) try builder.andWhere("actor.community_id = $5");
try builder.appendSlice(
\\
\\ORDER BY note.created_at DESC
\\LIMIT $6
\\
);
const max_items = if (args.max_items > max_max_items) max_max_items else args.max_items;
const query_args = blk: {
const prev_created_at = if (args.prev) |prev| @as(?DateTime, prev.created_at) else null;
const prev_id = if (args.prev) |prev| @as(?Uuid, prev.id) else null;
break :blk .{
args.created_before,
args.created_after,
prev_created_at,
prev_id,
args.community_id,
max_items,
};
};
const results = try db.queryRowsWithOptions(
Note,
try builder.terminate(),
query_args,
max_items,
.{ .allocator = alloc, .ignore_unused_arguments = true },
);
errdefer util.deepFree(results);
var next_page = args;
var prev_page = args;
prev_page.page_direction = .backward;
next_page.page_direction = .forward;
if (results.len != 0) {
prev_page.prev = .{
.id = results[0].id,
.created_at = results[0].created_at,
};
next_page.prev = .{
.id = results[results.len - 1].id,
.created_at = results[results.len - 1].created_at,
};
}
// TODO: this will give incorrect links on an empty page
return QueryResult{
.items = results,
.next_page = next_page,
.prev_page = prev_page,
};
}

124
src/http/headers.zig Normal file
View file

@ -0,0 +1,124 @@
const std = @import("std");
pub const Fields = struct {
const HashContext = struct {
const hash_seed = 1;
pub fn eql(_: @This(), lhs: []const u8, rhs: []const u8, _: usize) bool {
return std.ascii.eqlIgnoreCase(lhs, rhs);
}
pub fn hash(_: @This(), s: []const u8) u32 {
var h = std.hash.Wyhash.init(hash_seed);
for (s) |ch| {
const c = [1]u8{std.ascii.toLower(ch)};
h.update(&c);
}
return @truncate(u32, h.final());
}
};
const HashMap = std.ArrayHashMapUnmanaged(
[]const u8,
[]const u8,
HashContext,
true,
);
unmanaged: HashMap,
allocator: std.mem.Allocator,
pub fn init(allocator: std.mem.Allocator) Fields {
return Fields{
.unmanaged = .{},
.allocator = allocator,
};
}
pub fn deinit(self: *Fields) void {
var hash_iter = self.unmanaged.iterator();
while (hash_iter.next()) |entry| {
self.allocator.free(entry.key_ptr.*);
self.allocator.free(entry.value_ptr.*);
}
self.unmanaged.deinit(self.allocator);
}
pub fn iterator(self: Fields) HashMap.Iterator {
return self.unmanaged.iterator();
}
pub fn get(self: Fields, key: []const u8) ?[]const u8 {
return self.unmanaged.get(key);
}
pub const ListIterator = struct {
remaining: []const u8,
fn extractElement(self: *ListIterator) ?[]const u8 {
if (self.remaining.len == 0) return null;
var start: usize = 0;
var is_quoted = false;
const end = for (self.remaining) |ch, i| {
if (start == i and std.ascii.isWhitespace(ch)) {
start += 1;
} else if (ch == '"') {
is_quoted = !is_quoted;
}
if (ch == ',' and !is_quoted) {
break i;
}
} else self.remaining.len;
const str = self.remaining[start..end];
if (end == self.remaining.len) {
self.remaining = "";
} else {
self.remaining = self.remaining[end + 1 ..];
}
return std.mem.trim(u8, str, " \t");
}
pub fn next(self: *ListIterator) ?[]const u8 {
while (self.extractElement()) |elem| {
if (elem.len != 0) return elem;
}
return null;
}
};
pub fn getList(self: Fields, key: []const u8) ?ListIterator {
return if (self.unmanaged.get(key)) |hdr| ListIterator{ .remaining = hdr } else null;
}
pub fn put(self: *Fields, key: []const u8, val: []const u8) !void {
const key_clone = try self.allocator.alloc(u8, key.len);
std.mem.copy(u8, key_clone, key);
errdefer self.allocator.free(key_clone);
const val_clone = try self.allocator.alloc(u8, val.len);
std.mem.copy(u8, val_clone, val);
errdefer self.allocator.free(val_clone);
if (try self.unmanaged.fetchPut(self.allocator, key_clone, val_clone)) |entry| {
self.allocator.free(entry.key);
self.allocator.free(entry.value);
}
}
pub fn append(self: *Fields, key: []const u8, val: []const u8) !void {
if (self.unmanaged.getEntry(key)) |entry| {
const new_val = try std.mem.join(self.allocator, ", ", &.{ entry.value_ptr.*, val });
self.allocator.free(entry.value_ptr.*);
entry.value_ptr.* = new_val;
} else {
try self.put(key, val);
}
}
pub fn count(self: Fields) usize {
return self.unmanaged.count();
}
};

View file

@ -10,22 +10,15 @@ pub const socket = @import("./socket.zig");
pub const Method = std.http.Method;
pub const Status = std.http.Status;
pub const Request = request.Request;
pub const Request = request.Request(std.net.Stream.Reader);
pub const serveConn = server.serveConn;
pub const Response = server.Response;
pub const Handler = server.Handler;
pub const Headers = std.HashMap([]const u8, []const u8, struct {
pub fn eql(_: @This(), a: []const u8, b: []const u8) bool {
return ciutf8.eql(a, b);
}
pub const Fields = @import("./headers.zig").Fields;
pub fn hash(_: @This(), str: []const u8) u64 {
return ciutf8.hash(str);
}
}, std.hash_map.default_max_load_percentage);
test {
_ = server;
_ = request;
}
pub const Protocol = enum {
http_1_0,
http_1_1,
http_1_x,
};

View file

@ -3,29 +3,23 @@ const http = @import("./lib.zig");
const parser = @import("./request/parser.zig");
pub const Request = struct {
pub const Protocol = enum {
http_1_0,
http_1_1,
pub fn Request(comptime Reader: type) type {
return struct {
protocol: http.Protocol,
method: http.Method,
uri: []const u8,
headers: http.Fields,
body: ?parser.TransferStream(Reader),
pub fn parseFree(self: *@This(), allocator: std.mem.Allocator) void {
allocator.free(self.uri);
self.headers.deinit();
}
};
protocol: Protocol,
source_address: ?std.net.Address,
method: http.Method,
uri: []const u8,
headers: http.Headers,
body: ?[]const u8 = null,
pub fn parse(alloc: std.mem.Allocator, reader: anytype, addr: std.net.Address) !Request {
return parser.parse(alloc, reader, addr);
}
pub fn parseFree(self: Request, alloc: std.mem.Allocator) void {
parser.parseFree(alloc, self);
}
};
test {
_ = parser;
}
pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request(@TypeOf(reader)) {
return parser.parse(alloc, reader);
}

View file

@ -1,15 +1,13 @@
const std = @import("std");
const util = @import("util");
const http = @import("../lib.zig");
const Method = http.Method;
const Headers = http.Headers;
const Fields = http.Fields;
const Request = @import("../request.zig").Request;
const request_buf_size = 1 << 16;
const max_path_len = 1 << 10;
const max_body_len = 1 << 12;
fn ParseError(comptime Reader: type) type {
return error{
@ -22,7 +20,7 @@ const Encoding = enum {
chunked,
};
pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address) !Request {
pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request(@TypeOf(reader)) {
const method = try parseMethod(reader);
const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) {
error.StreamTooLong => return error.RequestUriTooLong,
@ -33,28 +31,20 @@ pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address
const proto = try parseProto(reader);
// discard \r\n
_ = try reader.readByte();
_ = try reader.readByte();
switch (try reader.readByte()) {
'\r' => if ((try reader.readByte()) != '\n') return error.BadRequest,
'\n' => {},
else => return error.BadRequest,
}
var headers = try parseHeaders(alloc, reader);
errdefer freeHeaders(alloc, &headers);
errdefer headers.deinit();
const body = if (method.requestHasBody())
try readBody(alloc, headers, reader)
else
null;
errdefer if (body) |b| alloc.free(b);
const body = try prepareBody(headers, reader);
if (body != null and !method.requestHasBody()) return error.BadRequest;
const eff_addr = if (headers.get("X-Real-IP")) |ip|
std.net.Address.parseIp(ip, address.getPort()) catch {
return error.BadRequest;
}
else
address;
return Request{
return Request(@TypeOf(reader)){
.protocol = proto,
.source_address = eff_addr,
.method = method,
.uri = uri,
@ -79,7 +69,7 @@ fn parseMethod(reader: anytype) !Method {
return error.MethodNotImplemented;
}
fn parseProto(reader: anytype) !Request.Protocol {
fn parseProto(reader: anytype) !http.Protocol {
var buf: [8]u8 = undefined;
const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) {
error.StreamTooLong => return error.UnknownProtocol,
@ -99,85 +89,145 @@ fn parseProto(reader: anytype) !Request.Protocol {
return switch (buf[2]) {
'0' => .http_1_0,
'1' => .http_1_1,
else => error.HttpVersionNotSupported,
else => .http_1_x,
};
}
fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers {
var map = Headers.init(allocator);
errdefer map.deinit();
errdefer {
var iter = map.iterator();
while (iter.next()) |it| {
allocator.free(it.key_ptr.*);
allocator.free(it.value_ptr.*);
}
}
// todo:
//errdefer {
//var iter = map.iterator();
//while (iter.next()) |it| {
//allocator.free(it.key_ptr);
//allocator.free(it.value_ptr);
//}
//}
var buf: [1024]u8 = undefined;
fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields {
var headers = Fields.init(allocator);
var buf: [4096]u8 = undefined;
while (true) {
const line = try reader.readUntilDelimiter(&buf, '\n');
if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break;
const full_line = reader.readUntilDelimiter(&buf, '\n') catch |err| switch (err) {
error.StreamTooLong => return error.HeaderLineTooLong,
else => return err,
};
const line = std.mem.trimRight(u8, full_line, "\r");
if (line.len == 0) break;
// TODO: handle multi-line headers
const name = extractHeaderName(line) orelse continue;
const value_end = if (line[line.len - 1] == '\r') line.len - 1 else line.len;
const value = line[name.len + 1 + 1 .. value_end];
const name = std.mem.sliceTo(line, ':');
if (!isTokenValid(name)) return error.BadRequest;
if (name.len == line.len) return error.BadRequest;
if (name.len == 0 or value.len == 0) return error.BadRequest;
const value = std.mem.trim(u8, line[name.len + 1 ..], " \t");
const name_alloc = try allocator.alloc(u8, name.len);
errdefer allocator.free(name_alloc);
const value_alloc = try allocator.alloc(u8, value.len);
errdefer allocator.free(value_alloc);
@memcpy(name_alloc.ptr, name.ptr, name.len);
@memcpy(value_alloc.ptr, value.ptr, value.len);
try map.put(name_alloc, value_alloc);
try headers.append(name, value);
}
return map;
return headers;
}
fn extractHeaderName(line: []const u8) ?[]const u8 {
var index: usize = 0;
fn isTokenValid(token: []const u8) bool {
if (token.len == 0) return false;
for (token) |ch| {
switch (ch) {
'"', '(', ')', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '{', '}' => return false,
// TODO: handle whitespace
while (index < line.len) : (index += 1) {
if (line[index] == ':') {
if (index == 0) return null;
return line[0..index];
'!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~' => {},
else => if (!std.ascii.isAlphanumeric(ch)) return false,
}
}
return null;
return true;
}
fn readBody(alloc: std.mem.Allocator, headers: Headers, reader: anytype) !?[]const u8 {
const xfer_encoding = try parseEncoding(headers.get("Transfer-Encoding"));
if (xfer_encoding != .identity) return error.UnsupportedMediaType;
fn prepareBody(headers: Fields, reader: anytype) !?TransferStream(@TypeOf(reader)) {
const hdr = headers.get("Transfer-Encoding");
// TODO:
// if (hder != null and protocol == .http_1_0) return error.BadRequest;
const xfer_encoding = try parseEncoding(hdr);
const content_encoding = try parseEncoding(headers.get("Content-Encoding"));
if (content_encoding != .identity) return error.UnsupportedMediaType;
const len_str = headers.get("Content-Length") orelse return null;
const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest;
if (len > max_body_len) return error.RequestEntityTooLarge;
const body = try alloc.alloc(u8, len);
errdefer alloc.free(body);
switch (xfer_encoding) {
.identity => {
const len_str = headers.get("Content-Length") orelse return null;
const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest;
reader.readNoEof(body) catch return error.BadRequest;
return TransferStream(@TypeOf(reader)){ .underlying = .{ .identity = std.io.limitedReader(reader, len) } };
},
.chunked => {
if (headers.get("Content-Length") != null) return error.BadRequest;
return TransferStream(@TypeOf(reader)){
.underlying = .{
.chunked = try ChunkedStream(@TypeOf(reader)).init(reader),
},
};
},
}
}
return body;
fn ChunkedStream(comptime R: type) type {
return struct {
const Self = @This();
remaining: ?usize = 0,
underlying: R,
const Error = R.Error || error{ Unexpected, InvalidChunkHeader, StreamTooLong, EndOfStream };
fn init(reader: R) !Self {
var self: Self = .{ .underlying = reader };
return self;
}
fn read(self: *Self, buf: []u8) !usize {
var count: usize = 0;
while (true) {
if (count == buf.len) return count;
if (self.remaining == null) return count;
if (self.remaining.? == 0) self.remaining = try self.readChunkHeader();
const max_read = std.math.min(buf.len, self.remaining.?);
const amt = try self.underlying.read(buf[count .. count + max_read]);
if (amt != max_read) return error.EndOfStream;
count += amt;
self.remaining.? -= amt;
if (self.remaining.? == 0) {
var crlf: [2]u8 = undefined;
_ = try self.underlying.readUntilDelimiter(&crlf, '\n');
self.remaining = try self.readChunkHeader();
}
if (count == buf.len) return count;
}
}
fn readChunkHeader(self: *Self) !?usize {
// TODO: Pick a reasonable limit for this
var buf = std.mem.zeroes([10]u8);
const line = self.underlying.readUntilDelimiter(&buf, '\n') catch |err| {
return if (err == error.StreamTooLong) error.InvalidChunkHeader else err;
};
if (line.len < 2 or line[line.len - 1] != '\r') return error.InvalidChunkHeader;
const size = std.fmt.parseInt(usize, line[0 .. line.len - 1], 16) catch return error.InvalidChunkHeader;
return if (size != 0) size else null;
}
};
}
pub fn TransferStream(comptime R: type) type {
return struct {
const Error = R.Error || ChunkedStream(R).Error;
const Reader = std.io.Reader(*@This(), Error, read);
underlying: union(enum) {
identity: std.io.LimitedReader(R),
chunked: ChunkedStream(R),
},
pub fn read(self: *@This(), buf: []u8) Error!usize {
return switch (self.underlying) {
.identity => |*r| try r.read(buf),
.chunked => |*r| try r.read(buf),
};
}
pub fn reader(self: *@This()) Reader {
return .{ .context = self };
}
};
}
// TODO: assumes that there's only one encoding, not layered encodings
@ -187,257 +237,3 @@ fn parseEncoding(encoding: ?[]const u8) !Encoding {
if (std.mem.eql(u8, encoding.?, "chunked")) return .chunked;
return error.UnsupportedMediaType;
}
pub fn parseFree(allocator: std.mem.Allocator, request: *Request) void {
allocator.free(request.uri);
freeHeaders(allocator, &request.headers);
if (request.body) |body| allocator.free(body);
}
fn freeHeaders(allocator: std.mem.Allocator, headers: *http.Headers) void {
var iter = headers.iterator();
while (iter.next()) |it| {
allocator.free(it.key_ptr.*);
allocator.free(it.value_ptr.*);
}
headers.deinit();
}
const _test = struct {
const expectEqual = std.testing.expectEqual;
const expectEqualStrings = std.testing.expectEqualStrings;
fn toCrlf(comptime str: []const u8) []const u8 {
comptime {
var buf: [str.len * 2]u8 = undefined;
@setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2
var buf_len: usize = 0;
for (str) |ch| {
if (ch == '\n') {
buf[buf_len] = '\r';
buf_len += 1;
}
buf[buf_len] = ch;
buf_len += 1;
}
return buf[0..buf_len];
}
}
fn makeHeaders(alloc: std.mem.Allocator, headers: anytype) !Headers {
var result = Headers.init(alloc);
inline for (headers) |tup| {
try result.put(tup[0], tup[1]);
}
return result;
}
fn areEqualHeaders(lhs: Headers, rhs: Headers) bool {
if (lhs.count() != rhs.count()) return false;
var iter = lhs.iterator();
while (iter.next()) |it| {
const rhs_val = rhs.get(it.key_ptr.*) orelse return false;
if (!std.mem.eql(u8, it.value_ptr.*, rhs_val)) return false;
}
return true;
}
fn printHeaders(headers: Headers) void {
var iter = headers.iterator();
while (iter.next()) |it| {
std.debug.print("{s}: {s}\n", .{ it.key_ptr.*, it.value_ptr.* });
}
}
fn expectEqualHeaders(expected: Headers, actual: Headers) !void {
if (!areEqualHeaders(expected, actual)) {
std.debug.print("\nexpected: \n", .{});
printHeaders(expected);
std.debug.print("\n\nfound: \n", .{});
printHeaders(actual);
std.debug.print("\n\n", .{});
return error.TestExpectedEqual;
}
}
fn parseTestCase(alloc: std.mem.Allocator, comptime request: []const u8, expected: http.Request) !void {
var stream = std.io.fixedBufferStream(toCrlf(request));
const result = try parse(alloc, stream.reader());
try expectEqual(expected.method, result.method);
try expectEqualStrings(expected.path, result.path);
try expectEqualHeaders(expected.headers, result.headers);
if ((expected.body == null) != (result.body == null)) {
const null_str: []const u8 = "(null)";
const exp = expected.body orelse null_str;
const act = result.body orelse null_str;
std.debug.print("\nexpected:\n{s}\n\nfound:\n{s}\n\n", .{ exp, act });
return error.TestExpectedEqual;
}
if (expected.body != null) {
try expectEqualStrings(expected.body.?, result.body.?);
}
}
};
// TOOD: failure test cases
test "parse" {
const testCase = _test.parseTestCase;
var buf = [_]u8{0} ** (1 << 16);
var fba = std.heap.FixedBufferAllocator.init(&buf);
const alloc = fba.allocator();
try testCase(alloc, (
\\GET / HTTP/1.1
\\
\\
), .{
.method = .GET,
.headers = try _test.makeHeaders(alloc, .{}),
.path = "/",
});
fba.reset();
try testCase(alloc, (
\\POST / HTTP/1.1
\\
\\
), .{
.method = .POST,
.headers = try _test.makeHeaders(alloc, .{}),
.path = "/",
});
fba.reset();
try testCase(alloc, (
\\HEAD / HTTP/1.1
\\Authorization: bearer <token>
\\
\\
), .{
.method = .HEAD,
.headers = try _test.makeHeaders(alloc, .{
.{ "Authorization", "bearer <token>" },
}),
.path = "/",
});
fba.reset();
try testCase(alloc, (
\\POST /nonsense HTTP/1.1
\\Authorization: bearer <token>
\\Content-Length: 5
\\
\\12345
), .{
.method = .POST,
.headers = try _test.makeHeaders(alloc, .{
.{ "Authorization", "bearer <token>" },
.{ "Content-Length", "5" },
}),
.path = "/nonsense",
.body = "12345",
});
fba.reset();
try std.testing.expectError(
error.MethodNotImplemented,
testCase(alloc, (
\\FOO /nonsense HTTP/1.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.MethodNotImplemented,
testCase(alloc, (
\\FOOBARBAZ /nonsense HTTP/1.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.RequestUriTooLong,
testCase(alloc, (
\\GET /
++ ("a" ** 2048)), undefined),
);
fba.reset();
try std.testing.expectError(
error.UnknownProtocol,
testCase(alloc, (
\\GET /nonsense SPECIALHTTP/1.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.UnknownProtocol,
testCase(alloc, (
\\GET /nonsense JSON/1.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.HttpVersionNotSupported,
testCase(alloc, (
\\GET /nonsense HTTP/1.9
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.HttpVersionNotSupported,
testCase(alloc, (
\\GET /nonsense HTTP/8.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.BadRequest,
testCase(alloc, (
\\GET /nonsense HTTP/blah blah blah
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.BadRequest,
testCase(alloc, (
\\GET /nonsense HTTP/1/1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.BadRequest,
testCase(alloc, (
\\GET /nonsense HTTP/1/1
\\
\\
), undefined),
);
}

View file

@ -0,0 +1,282 @@
const std = @import("std");
const parser = @import("./parser.zig");
const http = @import("../lib.zig");
const t = std.testing;
const test_case = struct {
fn parse(text: []const u8, expected: struct {
protocol: http.Protocol = .http_1_1,
method: http.Method = .GET,
headers: []const std.meta.Tuple(&.{ []const u8, []const u8 }) = &.{},
uri: []const u8 = "",
}) !void {
var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(text) };
var actual = try parser.parse(t.allocator, stream.reader());
defer actual.parseFree(t.allocator);
try t.expectEqual(expected.protocol, actual.protocol);
try t.expectEqual(expected.method, actual.method);
try t.expectEqualStrings(expected.uri, actual.uri);
try t.expectEqual(expected.headers.len, actual.headers.count());
for (expected.headers) |hdr| {
if (actual.headers.get(hdr[0])) |val| {
try t.expectEqualStrings(hdr[1], val);
} else {
std.debug.print("Error: Header {s} expected to be present, was not.\n", .{hdr[0]});
try t.expect(false);
}
}
}
};
fn toCrlf(comptime str: []const u8) []const u8 {
comptime {
var buf: [str.len * 2]u8 = undefined;
@setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2
var buf_len: usize = 0;
for (str) |ch| {
if (ch == '\n') {
buf[buf_len] = '\r';
buf_len += 1;
}
buf[buf_len] = ch;
buf_len += 1;
}
return buf[0..buf_len];
}
}
test "HTTP/1.x parse - No body" {
try test_case.parse(
toCrlf(
\\GET / HTTP/1.1
\\
\\
),
.{
.protocol = .http_1_1,
.method = .GET,
.uri = "/",
},
);
try test_case.parse(
toCrlf(
\\POST / HTTP/1.1
\\
\\
),
.{
.protocol = .http_1_1,
.method = .POST,
.uri = "/",
},
);
try test_case.parse(
toCrlf(
\\GET /url/abcd HTTP/1.1
\\
\\
),
.{
.protocol = .http_1_1,
.method = .GET,
.uri = "/url/abcd",
},
);
try test_case.parse(
toCrlf(
\\GET / HTTP/1.0
\\
\\
),
.{
.protocol = .http_1_0,
.method = .GET,
.uri = "/",
},
);
try test_case.parse(
toCrlf(
\\GET /url/abcd HTTP/1.1
\\Content-Type: application/json
\\
\\
),
.{
.protocol = .http_1_1,
.method = .GET,
.uri = "/url/abcd",
.headers = &.{.{ "Content-Type", "application/json" }},
},
);
try test_case.parse(
toCrlf(
\\GET /url/abcd HTTP/1.1
\\Content-Type: application/json
\\Authorization: bearer <token>
\\
\\
),
.{
.protocol = .http_1_1,
.method = .GET,
.uri = "/url/abcd",
.headers = &.{
.{ "Content-Type", "application/json" },
.{ "Authorization", "bearer <token>" },
},
},
);
// Test without CRLF
try test_case.parse(
\\GET /url/abcd HTTP/1.1
\\Content-Type: application/json
\\Authorization: bearer <token>
\\
\\
,
.{
.protocol = .http_1_1,
.method = .GET,
.uri = "/url/abcd",
.headers = &.{
.{ "Content-Type", "application/json" },
.{ "Authorization", "bearer <token>" },
},
},
);
try test_case.parse(
\\POST / HTTP/1.1
\\
\\
,
.{
.protocol = .http_1_1,
.method = .POST,
.uri = "/",
},
);
try test_case.parse(
toCrlf(
\\GET / HTTP/1.2
\\
\\
),
.{
.protocol = .http_1_x,
.method = .GET,
.uri = "/",
},
);
}
test "HTTP/1.x parse - unsupported protocol" {
try t.expectError(error.UnknownProtocol, test_case.parse(
\\GET / JSON/1.1
\\
\\
,
.{},
));
try t.expectError(error.UnknownProtocol, test_case.parse(
\\GET / SOMETHINGELSE/3.5
\\
\\
,
.{},
));
try t.expectError(error.UnknownProtocol, test_case.parse(
\\GET / /1.1
\\
\\
,
.{},
));
try t.expectError(error.HttpVersionNotSupported, test_case.parse(
\\GET / HTTP/2.1
\\
\\
,
.{},
));
}
test "HTTP/1.x parse - Unknown method" {
try t.expectError(error.MethodNotImplemented, test_case.parse(
\\ABCD / HTTP/1.1
\\
\\
,
.{},
));
try t.expectError(error.MethodNotImplemented, test_case.parse(
\\PATCHPATCHPATCH / HTTP/1.1
\\
\\
,
.{},
));
}
test "HTTP/1.x parse - Too long" {
try t.expectError(error.RequestUriTooLong, test_case.parse(
std.fmt.comptimePrint("GET {s} HTTP/1.1\n\n", .{"a" ** 8192}),
.{},
));
try t.expectError(error.HeaderLineTooLong, test_case.parse(
std.fmt.comptimePrint("GET / HTTP/1.1\r\n{s}: abcd", .{"a" ** 8192}),
.{},
));
try t.expectError(error.HeaderLineTooLong, test_case.parse(
std.fmt.comptimePrint("GET / HTTP/1.1\r\nabcd: {s}", .{"a" ** 8192}),
.{},
));
}
test "HTTP/1.x parse - bad requests" {
try t.expectError(error.BadRequest, test_case.parse(
\\GET / HTTP/1.1 blah blah
\\
\\
,
.{},
));
try t.expectError(error.BadRequest, test_case.parse(
\\GET / HTTP/1.1
\\abcd : lksjdfkl
\\
,
.{},
));
try t.expectError(error.BadRequest, test_case.parse(
\\GET / HTTP/1.1
\\ lksjfklsjdfklj
\\
,
.{},
));
}
test "HTTP/1.x parse - Headers" {
try test_case.parse(
toCrlf(
\\GET /url/abcd HTTP/1.1
\\Content-Type: application/json
\\Content-Type: application/xml
\\
\\
),
.{
.protocol = .http_1_1,
.method = .GET,
.uri = "/url/abcd",
.headers = &.{.{ "Content-Type", "application/json, application/xml" }},
},
);
}

View file

@ -3,13 +3,14 @@ const util = @import("util");
const http = @import("./lib.zig");
const response = @import("./server/response.zig");
const request = @import("./request.zig");
pub const Response = struct {
alloc: std.mem.Allocator,
stream: std.net.Stream,
should_close: bool = false,
pub const Stream = response.ResponseStream(std.net.Stream.Writer);
pub fn open(self: *Response, status: http.Status, headers: *const http.Headers) !Stream {
pub fn open(self: *Response, status: http.Status, headers: *const http.Fields) !Stream {
if (headers.get("Connection")) |hdr| {
if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true;
}
@ -17,7 +18,7 @@ pub const Response = struct {
return response.open(self.alloc, self.stream.writer(), headers, status);
}
pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Headers) !std.net.Stream {
pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Fields) !std.net.Stream {
try response.writeRequestHeader(self.stream.writer(), headers, status);
return self.stream;
}
@ -26,10 +27,6 @@ pub const Response = struct {
const Request = http.Request;
const request_buf_size = 1 << 16;
pub fn Handler(comptime Ctx: type) type {
return fn (Ctx, Request, *Response) void;
}
pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: anytype, alloc: std.mem.Allocator) !void {
// TODO: Timeouts
while (true) {
@ -37,7 +34,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a
var arena = std.heap.ArenaAllocator.init(alloc);
defer arena.deinit();
const req = Request.parse(arena.allocator(), conn.stream.reader(), conn.address) catch |err| {
var req = request.parse(arena.allocator(), conn.stream.reader()) catch |err| {
return handleError(conn.stream.writer(), err) catch {};
};
std.log.debug("done parsing", .{});
@ -47,7 +44,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a
.stream = conn.stream,
};
handler(ctx, req, &res);
handler(ctx, &req, &res);
std.log.debug("done handling", .{});
if (req.headers.get("Connection")) |hdr| {

View file

@ -2,20 +2,20 @@ const std = @import("std");
const http = @import("../lib.zig");
const Status = http.Status;
const Headers = http.Headers;
const Fields = http.Fields;
const chunk_size = 16 * 1024;
pub fn open(
alloc: std.mem.Allocator,
writer: anytype,
headers: *const Headers,
headers: *const Fields,
status: Status,
) !ResponseStream(@TypeOf(writer)) {
const buf = try alloc.alloc(u8, chunk_size);
errdefer alloc.free(buf);
try writeStatusLine(writer, status);
try writeHeaders(writer, headers);
try writeFields(writer, headers);
return ResponseStream(@TypeOf(writer)){
.allocator = alloc,
@ -25,9 +25,9 @@ pub fn open(
};
}
pub fn writeRequestHeader(writer: anytype, headers: *const Headers, status: Status) !void {
pub fn writeRequestHeader(writer: anytype, headers: *const Fields, status: Status) !void {
try writeStatusLine(writer, status);
try writeHeaders(writer, headers);
try writeFields(writer, headers);
try writer.writeAll("\r\n");
}
@ -36,7 +36,7 @@ fn writeStatusLine(writer: anytype, status: Status) !void {
try writer.print("HTTP/1.1 {} {s}\r\n", .{ @enumToInt(status), status_text });
}
fn writeHeaders(writer: anytype, headers: *const Headers) !void {
fn writeFields(writer: anytype, headers: *const Fields) !void {
var iter = headers.iterator();
while (iter.next()) |header| {
for (header.value_ptr.*) |ch| {
@ -65,7 +65,7 @@ pub fn ResponseStream(comptime BaseWriter: type) type {
allocator: std.mem.Allocator,
base_writer: BaseWriter,
headers: *const Headers,
headers: *const Fields,
buffer: []u8,
buffer_pos: usize = 0,
chunked: bool = false,
@ -95,7 +95,6 @@ pub fn ResponseStream(comptime BaseWriter: type) type {
return;
}
std.debug.print("{}\n", .{cursor});
self.writeToBuffer(bytes[cursor .. cursor + remaining_in_chunk]);
cursor += remaining_in_chunk;
try self.flushChunk();
@ -177,7 +176,7 @@ const _tests = struct {
test "ResponseStream no headers empty body" {
var buffer: [test_buffer_size]u8 = undefined;
var test_stream = std.io.fixedBufferStream(&buffer);
var headers = Headers.init(std.testing.allocator);
var headers = Fields.init(std.testing.allocator);
defer headers.deinit();
{
@ -205,7 +204,7 @@ const _tests = struct {
test "ResponseStream empty body" {
var buffer: [test_buffer_size]u8 = undefined;
var test_stream = std.io.fixedBufferStream(&buffer);
var headers = Headers.init(std.testing.allocator);
var headers = Fields.init(std.testing.allocator);
defer headers.deinit();
try headers.put("Content-Type", "text/plain");
@ -236,7 +235,7 @@ const _tests = struct {
test "ResponseStream not 200 OK" {
var buffer: [test_buffer_size]u8 = undefined;
var test_stream = std.io.fixedBufferStream(&buffer);
var headers = Headers.init(std.testing.allocator);
var headers = Fields.init(std.testing.allocator);
defer headers.deinit();
try headers.put("Content-Type", "text/plain");
@ -266,7 +265,7 @@ const _tests = struct {
test "ResponseStream small body" {
var buffer: [test_buffer_size]u8 = undefined;
var test_stream = std.io.fixedBufferStream(&buffer);
var headers = Headers.init(std.testing.allocator);
var headers = Fields.init(std.testing.allocator);
defer headers.deinit();
try headers.put("Content-Type", "text/plain");
@ -300,7 +299,7 @@ const _tests = struct {
test "ResponseStream large body" {
var buffer: [test_buffer_size]u8 = undefined;
var test_stream = std.io.fixedBufferStream(&buffer);
var headers = Headers.init(std.testing.allocator);
var headers = Fields.init(std.testing.allocator);
defer headers.deinit();
try headers.put("Content-Type", "text/plain");
@ -341,7 +340,7 @@ const _tests = struct {
test "ResponseStream large body ending on chunk boundary" {
var buffer: [test_buffer_size]u8 = undefined;
var test_stream = std.io.fixedBufferStream(&buffer);
var headers = Headers.init(std.testing.allocator);
var headers = Fields.init(std.testing.allocator);
defer headers.deinit();
try headers.put("Content-Type", "text/plain");

View file

@ -23,21 +23,21 @@ const Opcode = enum(u4) {
}
};
pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Response) !Socket {
pub fn handshake(alloc: std.mem.Allocator, req: *http.Request, res: *http.Response) !Socket {
const upgrade = req.headers.get("Upgrade") orelse return error.BadHandshake;
const connection = req.headers.get("Connection") orelse return error.BadHandshake;
if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake;
if (std.ascii.indexOfIgnoreCase(connection, "Upgrade") == null) return error.BadHandshake;
const key_hdr = req.headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake;
if (try std.base64.standard.Decoder.calcSizeForSlice(key_hdr) != 16) return error.BadHandshake;
if ((try std.base64.standard.Decoder.calcSizeForSlice(key_hdr)) != 16) return error.BadHandshake;
var key: [16]u8 = undefined;
std.base64.standard.Decoder.decode(&key, key_hdr) catch return error.BadHandshake;
const version = req.headers.get("Sec-WebSocket-version") orelse return error.BadHandshake;
if (!std.mem.eql(u8, "13", version)) return error.BadHandshake;
var headers = http.Headers.init(alloc);
var headers = http.Fields.init(alloc);
defer headers.deinit();
try headers.put("Upgrade", "websocket");
@ -51,7 +51,7 @@ pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Respons
var hash_encoded: [std.base64.standard.Encoder.calcSize(Sha1.digest_length)]u8 = undefined;
_ = std.base64.standard.Encoder.encode(&hash_encoded, &hash);
try headers.put("Sec-WebSocket-Accept", &hash_encoded);
const stream = try res.upgrade(.switching_protcols, &headers);
const stream = try res.upgrade(.switching_protocols, &headers);
return Socket{ .stream = stream };
}
@ -164,15 +164,15 @@ fn writeFrame(writer: anytype, header: FrameInfo, buf: []const u8) !void {
const initial_len: u7 = if (header.len < 126)
@intCast(u7, header.len)
else if (std.math.cast(u16, header.len)) |_|
126
@as(u7, 126)
else
127;
@as(u7, 127);
var hdr_buf = [2]u8{ 0, 0 };
hdr_buf[0] |= if (header.is_final) 0b1000_0000 else 0;
hdr_buf[0] |= if (header.is_final) @as(u8, 0b1000_0000) else 0;
hdr_buf[0] |= @as(u8, header.rsv) << 4;
hdr_buf[0] |= @enumToInt(header.opcode);
hdr_buf[1] |= if (header.masking_key) |_| 0b1000_0000 else 0;
hdr_buf[1] |= if (header.masking_key) |_| @as(u8, 0b1000_0000) else 0;
hdr_buf[1] |= initial_len;
try writer.writeAll(&hdr_buf);
if (initial_len == 126)

3
src/http/test.zig Normal file
View file

@ -0,0 +1,3 @@
test {
_ = @import("./request/test_parser.zig");
}

View file

@ -13,10 +13,11 @@ pub const invites = @import("./controllers/invites.zig");
pub const users = @import("./controllers/users.zig");
pub const notes = @import("./controllers/notes.zig");
pub const streaming = @import("./controllers/streaming.zig");
pub const timelines = @import("./controllers/timelines.zig");
pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void {
pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response, alloc: std.mem.Allocator) void {
// TODO: hashmaps?
var response = Response{ .headers = http.Headers.init(alloc), .res = res };
var response = Response{ .headers = http.Fields.init(alloc), .res = res };
defer response.headers.deinit();
const found = routeRequestInternal(api_source, req, &response, alloc);
@ -24,7 +25,7 @@ pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response,
if (!found) response.status(.not_found) catch {};
}
fn routeRequestInternal(api_source: anytype, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool {
fn routeRequestInternal(api_source: anytype, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool {
inline for (routes) |route| {
if (Context(route).matchAndHandle(api_source, req, res, alloc)) return true;
}
@ -42,6 +43,8 @@ const routes = .{
notes.create,
notes.get,
streaming.streaming,
timelines.global,
timelines.local,
};
pub fn Context(comptime Route: type) type {
@ -58,18 +61,21 @@ pub fn Context(comptime Route: type) type {
// leave it as a simple string instead of void
pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void;
base_request: http.Request,
base_request: *http.Request,
allocator: std.mem.Allocator,
method: http.Method,
uri: []const u8,
headers: http.Headers,
headers: http.Fields,
args: Args,
body: Body,
query: Query,
// TODO
body_buf: ?[]const u8 = null,
fn parseArgs(path: []const u8) ?Args {
var args: Args = undefined;
var path_iter = util.PathIter.from(path);
@ -94,7 +100,7 @@ pub fn Context(comptime Route: type) type {
@compileError("Unsupported Type " ++ @typeName(T));
}
pub fn matchAndHandle(api_source: *api.ApiSource, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool {
pub fn matchAndHandle(api_source: *api.ApiSource, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool {
if (req.method != Route.method) return false;
var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?');
var args: Args = parseArgs(path) orelse return false;
@ -112,6 +118,8 @@ pub fn Context(comptime Route: type) type {
.query = undefined,
};
std.log.debug("Matched route {s}", .{path});
self.prepareAndHandle(api_source, req, res);
return true;
@ -129,7 +137,7 @@ pub fn Context(comptime Route: type) type {
};
}
fn prepareAndHandle(self: *Self, api_source: anytype, req: http.Request, response: *Response) void {
fn prepareAndHandle(self: *Self, api_source: anytype, req: *http.Request, response: *Response) void {
self.parseBody(req) catch |err| return errorHandler(response, .bad_request, err);
defer self.freeBody();
@ -141,16 +149,20 @@ pub fn Context(comptime Route: type) type {
self.handle(response, &api_conn);
}
fn parseBody(self: *Self, req: http.Request) !void {
fn parseBody(self: *Self, req: *http.Request) !void {
if (Body != void) {
const body = req.body orelse return error.NoBody;
var stream = req.body orelse return error.NoBody;
const body = try stream.reader().readAllAlloc(self.allocator, 1 << 16);
errdefer self.allocator.free(body);
self.body = try json_utils.parse(Body, body, self.allocator);
self.body_buf = body;
}
}
fn freeBody(self: *Self) void {
if (Body != void) {
json_utils.parseFree(self.body, self.allocator);
self.allocator.free(self.body_buf.?);
}
}
@ -191,7 +203,7 @@ pub fn Context(comptime Route: type) type {
pub const Response = struct {
const Self = @This();
headers: http.Headers,
headers: http.Fields,
res: *http.Response,
opened: bool = false,

View file

@ -1,4 +1,5 @@
const api = @import("api");
const std = @import("std");
pub const login = struct {
pub const method = .POST;
@ -12,6 +13,8 @@ pub const login = struct {
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
const token = try srv.login(req.body.username, req.body.password);
std.log.debug("{any}", .{res.headers});
try res.json(.ok, token);
}
};

View file

@ -1,6 +1,7 @@
const std = @import("std");
const api = @import("api");
const util = @import("util");
const query_utils = @import("../query.zig");
const QueryArgs = api.CommunityQueryArgs;
const Uuid = util.Uuid;
@ -25,89 +26,18 @@ pub const query = struct {
pub const method = .GET;
pub const path = "/communities";
// NOTE: This has to match QueryArgs
// TODO: Support union fields in query strings natively, so we don't
// have to keep these in sync
pub const Query = struct {
const OrderBy = QueryArgs.OrderBy;
const Direction = QueryArgs.Direction;
const PageDirection = QueryArgs.PageDirection;
// Max items to fetch
max_items: usize = 20,
// Selection filters
owner_id: ?Uuid = null,
like: ?[]const u8 = null,
created_before: ?DateTime = null,
created_after: ?DateTime = null,
// Ordering parameter
order_by: OrderBy = .created_at,
direction: Direction = .ascending,
// the `prev` struct has a slightly different format to QueryArgs
prev: struct {
id: ?Uuid = null,
// Only one of these can be present, and must match order_by above
name: ?[]const u8 = null,
host: ?[]const u8 = null,
created_at: ?DateTime = null,
} = .{},
// What direction to scan the page window
page_direction: PageDirection = .forward,
pub const format = formatQueryParams;
};
pub const Query = QueryArgs;
pub fn handler(req: anytype, res: anytype, srv: anytype) !void {
const q = req.query;
const query_matches = if (q.prev.id) |_| switch (q.order_by) {
.name => q.prev.name != null and q.prev.host == null and q.prev.created_at == null,
.host => q.prev.name == null and q.prev.host != null and q.prev.created_at == null,
.created_at => q.prev.name == null and q.prev.host == null and q.prev.created_at != null,
} else (q.prev.name == null and q.prev.host == null and q.prev.created_at == null);
if (!query_matches) return res.err(.bad_request, "prev.* parameters do not match", {});
const prev_arg: ?QueryArgs.Prev = if (q.prev.id) |id| .{
.id = id,
.order_val = switch (q.order_by) {
.name => .{ .name = q.prev.name.? },
.host => .{ .host = q.prev.host.? },
.created_at => .{ .created_at = q.prev.created_at.? },
},
} else null;
const query_args = QueryArgs{
.max_items = q.max_items,
.owner_id = q.owner_id,
.like = q.like,
.created_before = q.created_before,
.created_after = q.created_after,
.order_by = q.order_by,
.direction = q.direction,
.prev = prev_arg,
.page_direction = q.page_direction,
};
const results = try srv.queryCommunities(query_args);
const results = try srv.queryCommunities(req.query);
var link = std.ArrayList(u8).init(req.allocator);
const link_writer = link.writer();
defer link.deinit();
const next_page = queryArgsToControllerQuery(results.next_page);
const prev_page = queryArgsToControllerQuery(results.prev_page);
try writeLink(link_writer, srv.community, path, next_page, "next");
try writeLink(link_writer, srv.community, path, results.next_page, "next");
try link_writer.writeByte(',');
try writeLink(link_writer, srv.community, path, prev_page, "prev");
try writeLink(link_writer, srv.community, path, results.prev_page, "prev");
try res.headers.put("Link", link.items);
@ -129,7 +59,7 @@ fn writeLink(
.{ @tagName(community.scheme), community.host, path },
);
try std.fmt.format(writer, "{}", .{params});
try query_utils.formatQuery(params, writer);
try std.fmt.format(
writer,
@ -137,70 +67,3 @@ fn writeLink(
.{rel},
);
}
fn formatQueryParams(
params: anytype,
comptime fmt: []const u8,
opt: std.fmt.FormatOptions,
writer: anytype,
) !void {
if (comptime std.meta.trait.is(.Pointer)(@TypeOf(params))) {
return formatQueryParams(params.*, fmt, opt, writer);
}
return formatRecursive("", params, writer);
}
fn formatRecursive(comptime prefix: []const u8, params: anytype, writer: anytype) !void {
inline for (std.meta.fields(@TypeOf(params))) |field| {
const val = @field(params, field.name);
const is_optional = comptime std.meta.trait.is(.Optional)(field.field_type);
const present = if (comptime is_optional) val != null else true;
if (present) {
const unwrapped = if (is_optional) val.? else val;
// TODO: percent-encode this
_ = try switch (@TypeOf(unwrapped)) {
[]const u8 => blk: {
break :blk std.fmt.format(writer, "{s}{s}={s}&", .{ prefix, field.name, unwrapped });
},
else => |U| blk: {
if (comptime std.meta.trait.isContainer(U) and std.meta.trait.hasFn("format")(U)) {
break :blk std.fmt.format(writer, "{s}{s}={}&", .{ prefix, field.name, unwrapped });
}
break :blk switch (@typeInfo(U)) {
.Enum => std.fmt.format(writer, "{s}{s}={s}&", .{ prefix, field.name, @tagName(unwrapped) }),
.Struct => formatRecursive(field.name ++ ".", unwrapped, writer),
else => std.fmt.format(writer, "{s}{s}={}&", .{ prefix, field.name, unwrapped }),
};
},
};
}
}
}
fn queryArgsToControllerQuery(args: QueryArgs) query.Query {
var result = query.Query{
.max_items = args.max_items,
.owner_id = args.owner_id,
.like = args.like,
.created_before = args.created_before,
.created_after = args.created_after,
.order_by = args.order_by,
.direction = args.direction,
.prev = .{},
.page_direction = args.page_direction,
};
if (args.prev) |prev| {
result.prev = .{
.id = prev.id,
.name = if (prev.order_val == .name) prev.order_val.name else null,
.host = if (prev.order_val == .host) prev.order_val.host else null,
.created_at = if (prev.order_val == .created_at) prev.order_val.created_at else null,
};
}
return result;
}

View file

@ -0,0 +1,21 @@
pub const global = struct {
pub const method = .GET;
pub const path = "/timelines/global";
pub fn handler(_: anytype, res: anytype, srv: anytype) !void {
const results = try srv.globalTimeline();
try res.json(.ok, results);
}
};
pub const local = struct {
pub const method = .GET;
pub const path = "/timelines/local";
pub fn handler(_: anytype, res: anytype, srv: anytype) !void {
const results = try srv.localTimeline();
try res.json(.ok, results);
}
};

View file

@ -499,7 +499,7 @@ fn parseInternal(
if (!fields_seen[i]) {
if (field.default_value) |default_ptr| {
if (!field.is_comptime) {
const default = @ptrCast(*const field.field_type, default_ptr).*;
const default = @ptrCast(*align(1) const field.field_type, default_ptr).*;
@field(r, field.name) = default;
}
} else {

View file

@ -87,7 +87,7 @@ fn thread_main(src: *api.ApiSource, srv: *std.net.StreamServer) void {
}
}
fn handle(ctx: anytype, req: http.Request, res: *http.Response) void {
fn handle(ctx: anytype, req: *http.Request, res: *http.Response) void {
c.routeRequest(ctx.src, req, res, ctx.allocator);
}

View file

@ -53,7 +53,7 @@ pub fn up(db: anytype) !void {
std.log.info("Running migration {s}", .{migration.name});
try execScript(tx, migration.up, gpa.allocator());
try tx.insert("migration", .{
.name = migration.name,
.name = @as([]const u8, migration.name),
.applied_at = DateTime.now(),
}, gpa.allocator());
}

View file

@ -71,37 +71,132 @@ const QueryIter = @import("util").QueryIter;
/// TODO: values are currently case-sensitive, and are not url-decoded properly.
/// This should be fixed.
pub fn parseQuery(comptime T: type, query: []const u8) !T {
//if (!std.meta.trait.isContainer(T)) @compileError("T must be a struct");
if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct");
var iter = QueryIter.from(query);
var result = T{};
var fields = Intermediary(T){};
while (iter.next()) |pair| {
try parseQueryPair(T, &result, pair.key, pair.value);
// TODO: Hash map
inline for (std.meta.fields(Intermediary(T))) |field| {
if (std.ascii.eqlIgnoreCase(field.name[2..], pair.key)) {
@field(fields, field.name) = if (pair.value) |v| .{ .value = v } else .{ .no_value = {} };
break;
}
} else std.log.debug("unknown param {s}", .{pair.key});
}
return result;
return (try parse(T, "", "", fields)).?;
}
fn parseQueryPair(comptime T: type, result: *T, key: []const u8, value: ?[]const u8) !void {
const key_part = std.mem.sliceTo(key, '.');
const field_idx = std.meta.stringToEnum(std.meta.FieldEnum(T), key_part) orelse return error.UnknownField;
fn parseScalar(comptime T: type, comptime name: []const u8, fields: anytype) !?T {
const param = @field(fields, name);
return switch (param) {
.not_specified => null,
.no_value => try parseQueryValue(T, null),
.value => |v| try parseQueryValue(T, v),
};
}
inline for (std.meta.fields(T)) |info, idx| {
if (@enumToInt(field_idx) == idx) {
if (comptime isScalar(info.field_type)) {
if (key_part.len == key.len) {
@field(result, info.name) = try parseQueryValue(info.field_type, value);
return;
} else {
return error.UnknownField;
fn parse(comptime T: type, comptime prefix: []const u8, comptime name: []const u8, fields: anytype) !?T {
if (comptime isScalar(T)) return parseScalar(T, prefix ++ "." ++ name, fields);
switch (@typeInfo(T)) {
.Union => |info| {
var result: ?T = null;
inline for (info.fields) |field| {
const F = field.field_type;
const maybe_value = try parse(F, prefix, field.name, fields);
if (maybe_value) |value| {
if (result != null) return error.DuplicateUnionField;
result = @unionInit(T, field.name, value);
}
}
std.log.debug("{any}", .{result});
return result;
},
.Struct => |info| {
var result: T = undefined;
var fields_specified: usize = 0;
inline for (info.fields) |field| {
const F = field.field_type;
var maybe_value: ?F = null;
if (try parse(F, prefix ++ "." ++ name, field.name, fields)) |v| {
maybe_value = v;
} else if (field.default_value) |default| {
maybe_value = @ptrCast(*const F, @alignCast(@alignOf(F), default)).*;
}
if (maybe_value) |v| {
fields_specified += 1;
@field(result, field.name) = v;
}
}
if (fields_specified == 0) {
return null;
} else if (fields_specified != info.fields.len) {
return error.PartiallySpecifiedStruct;
} else {
const remaining = std.mem.trimLeft(u8, key[key_part.len..], ".");
return try parseQueryPair(info.field_type, &@field(result, info.name), remaining, value);
return result;
}
},
// Only applies to non-scalar optionals
.Optional => |info| return try parse(info.child, prefix, name, fields),
else => @compileError("tmp"),
}
}
fn recursiveFieldPaths(comptime T: type, comptime prefix: []const u8) []const []const u8 {
comptime {
if (std.meta.trait.is(.Optional)(T)) return recursiveFieldPaths(std.meta.Child(T), prefix);
var fields: []const []const u8 = &.{};
for (std.meta.fields(T)) |f| {
const full_name = prefix ++ f.name;
if (isScalar(f.field_type)) {
fields = fields ++ @as([]const []const u8, &.{full_name});
} else {
const field_prefix = if (@typeInfo(f.field_type) == .Union) prefix else full_name ++ ".";
fields = fields ++ recursiveFieldPaths(f.field_type, field_prefix);
}
}
}
return error.UnknownField;
return fields;
}
}
const QueryParam = union(enum) {
not_specified: void,
no_value: void,
value: []const u8,
};
fn Intermediary(comptime T: type) type {
const field_names = recursiveFieldPaths(T, "..");
var fields: [field_names.len]std.builtin.Type.StructField = undefined;
for (field_names) |name, i| fields[i] = .{
.name = name,
.field_type = QueryParam,
.default_value = &QueryParam{ .not_specified = {} },
.is_comptime = false,
.alignment = @alignOf(QueryParam),
};
return @Type(.{ .Struct = .{
.layout = .Auto,
.fields = &fields,
.decls = &.{},
.is_tuple = false,
} });
}
fn parseQueryValue(comptime T: type, value: ?[]const u8) !T {
@ -157,6 +252,50 @@ fn isScalar(comptime T: type) bool {
return false;
}
pub fn formatQuery(params: anytype, writer: anytype) !void {
try format("", "", params, writer);
}
fn formatScalar(comptime name: []const u8, val: anytype, writer: anytype) !void {
const T = @TypeOf(val);
if (comptime std.meta.trait.isZigString(T)) return std.fmt.format(writer, "{s}={s}&", .{ name, val });
_ = try switch (@typeInfo(T)) {
.Enum => std.fmt.format(writer, "{s}={s}&", .{ name, @tagName(val) }),
.Optional => if (val) |v| formatScalar(name, v, writer),
else => std.fmt.format(writer, "{s}={}&", .{ name, val }),
};
}
fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytype, writer: anytype) !void {
const T = @TypeOf(params);
const eff_prefix = if (prefix.len == 0) "" else prefix ++ ".";
if (comptime isScalar(T)) return formatScalar(eff_prefix ++ name, params, writer);
switch (@typeInfo(T)) {
.Struct => {
inline for (std.meta.fields(T)) |field| {
const val = @field(params, field.name);
try format(eff_prefix ++ name, field.name, val, writer);
}
},
.Union => {
//inline for (std.meta.tags(T)) |tag| {
inline for (std.meta.fields(T)) |field| {
const tag = @field(std.meta.Tag(T), field.name);
const tag_name = field.name;
if (@as(std.meta.Tag(T), params) == tag) {
const val = @field(params, tag_name);
try format(prefix, tag_name, val, writer);
}
}
},
.Optional => {
if (params) |p| try format(prefix, name, p, writer);
},
else => @compileError("Unsupported query type"),
}
}
test {
const TestQuery = struct {
int: usize = 3,

View file

@ -68,7 +68,7 @@ pub const QueryOptions = struct {
// do not require allocators for prep. If an allocator is needed but not
// provided, `error.AllocatorRequired` will be returned.
// Only used with the postgres backend.
prep_allocator: ?Allocator = null,
allocator: ?Allocator = null,
};
// Turns a value into its appropriate textual value (or null)

View file

@ -180,14 +180,25 @@ pub const Db = struct {
const format_text = 0;
const format_binary = 1;
pub fn exec(self: Db, sql: [:0]const u8, args: anytype, opt: common.QueryOptions) common.ExecError!Results {
const alloc = opt.prep_allocator;
const alloc = opt.allocator;
const result = blk: {
if (@TypeOf(args) != void and args.len > 0) {
var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired);
defer arena.deinit();
const params = try arena.allocator().alloc(?[*:0]const u8, args.len);
inline for (args) |arg, i| {
params[i] = if (try common.prepareParamText(&arena, arg)) |slice|
// TODO: The following is a fix for the stage1 compiler. remove this
//inline for (args) |arg, i| {
inline for (std.meta.fields(@TypeOf(args))) |field, i| {
const arg = @field(args, field.name);
// The stage1 compiler has issues with runtime branches that in any
// way involve compile time values
const maybe_slice = if (@import("builtin").zig_backend == .stage1)
common.prepareParamText(&arena, arg) catch unreachable
else
try common.prepareParamText(&arena, arg);
params[i] = if (maybe_slice) |slice|
slice.ptr
else
null;

View file

@ -118,7 +118,10 @@ pub const Db = struct {
};
if (@TypeOf(args) != void) {
inline for (args) |arg, i| {
// TODO: Fix for stage1 compiler
//inline for (args) |arg, i| {
inline for (std.meta.fields(@TypeOf(args))) |field, i| {
const arg = @field(args, field.name);
// SQLite treats $NNN args as having the name NNN, not index NNN.
// As such, if you reference $2 and not $1 in your query (such as
// when dynamically constructing queries), it could assign $2 the

View file

@ -25,6 +25,54 @@ pub const Engine = enum {
sqlite,
};
/// Helper for building queries at runtime. All constituent parts of the
/// query should be defined at comptime, however the choice of whether
/// or not to include them can occur at runtime.
pub const QueryBuilder = struct {
array: std.ArrayList(u8),
where_clauses_appended: usize = 0,
pub fn init(alloc: std.mem.Allocator) QueryBuilder {
return QueryBuilder{ .array = std.ArrayList(u8).init(alloc) };
}
pub fn deinit(self: *const QueryBuilder) void {
self.array.deinit();
}
/// Add a chunk of sql to the query without processing
pub fn appendSlice(self: *QueryBuilder, comptime sql: []const u8) !void {
try self.array.appendSlice(sql);
}
/// Add a where clause to the query. Clauses are assumed to be components
/// in an overall expression in Conjunctive Normal Form (AND of OR's).
/// https://en.wikipedia.org/wiki/Conjunctive_normal_form
/// All calls to andWhere must be contiguous, that is, they cannot be
/// interspersed with calls to appendSlice
pub fn andWhere(self: *QueryBuilder, comptime clause: []const u8) !void {
if (self.where_clauses_appended == 0) {
try self.array.appendSlice("WHERE ");
} else {
try self.array.appendSlice(" AND ");
}
try self.array.appendSlice(clause);
self.where_clauses_appended += 1;
}
pub fn str(self: *const QueryBuilder) []const u8 {
return self.array.items;
}
pub fn terminate(self: *QueryBuilder) ![:0]const u8 {
std.debug.assert(self.array.items.len != 0);
if (self.array.items[self.array.items.len - 1] != 0) try self.array.append(0);
return std.meta.assumeSentinel(self.array.items, 0);
}
};
// TODO: make this suck less
pub const Config = union(Engine) {
postgres: struct {
@ -410,7 +458,7 @@ fn Tx(comptime tx_level: u8) type {
args: anytype,
alloc: ?Allocator,
) QueryError!Results(RowType) {
return self.queryWithOptions(RowType, sql, args, .{ .prep_allocator = alloc });
return self.queryWithOptions(RowType, sql, args, .{ .allocator = alloc });
}
/// Runs a query to completion and returns a row of results, unless the query
@ -439,6 +487,45 @@ fn Tx(comptime tx_level: u8) type {
return row;
}
// Runs a query to completion and returns the results as a slice
pub fn queryRowsWithOptions(
self: Self,
comptime RowType: type,
q: [:0]const u8,
args: anytype,
max_items: ?usize,
options: QueryOptions,
) QueryRowError![]RowType {
var results = try self.queryWithOptions(RowType, q, args, options);
defer results.finish();
const alloc = options.allocator orelse return error.AllocatorRequired;
var result_array = std.ArrayList(RowType).init(alloc);
errdefer result_array.deinit();
if (max_items) |max| try result_array.ensureTotalCapacity(max);
errdefer for (result_array.items) |r| util.deepFree(alloc, r);
var too_many: bool = false;
while (try results.row(alloc)) |row| {
errdefer util.deepFree(alloc, row);
if (max_items) |max| {
if (result_array.items.len >= max) {
util.deepFree(alloc, row);
too_many = true;
continue;
}
}
try result_array.append(row);
}
if (too_many) return error.TooManyRows;
return result_array.toOwnedSlice();
}
// Inserts a single value into a table
pub fn insert(
self: Self,
@ -455,7 +542,7 @@ fn Tx(comptime tx_level: u8) type {
inline for (fields) |field, i| {
// This causes a compiler crash. Why?
//const F = field.field_type;
const F = @TypeOf(@field(std.mem.zeroes(ValueType), field.name));
const F = @TypeOf(@field(value, field.name));
// causes issues if F is @TypeOf(null), use dummy type
types[i] = if (F == @TypeOf(null)) ?i64 else F;
table_spec = comptime (table_spec ++ field.name ++ ",");
@ -499,7 +586,7 @@ fn Tx(comptime tx_level: u8) type {
alloc: ?std.mem.Allocator,
comptime check_tx: bool,
) !void {
var results = try self.runSql(sql, args, .{ .prep_allocator = alloc }, check_tx);
var results = try self.runSql(sql, args, .{ .allocator = alloc }, check_tx);
defer results.finish();
while (try results.row()) |_| {}