diff --git a/build.zig b/build.zig index 108f884..9cbd45c 100644 --- a/build.zig +++ b/build.zig @@ -53,16 +53,16 @@ pub fn build(b: *std.build.Builder) void { exe.linkSystemLibrary("pq"); exe.linkLibC(); - const util_tests = b.addTest("src/util/lib.zig"); - const http_tests = b.addTest("src/http/lib.zig"); - const sql_tests = b.addTest("src/sql/lib.zig"); + //const util_tests = b.addTest("src/util/lib.zig"); + const http_tests = b.addTest("src/http/test.zig"); + //const sql_tests = b.addTest("src/sql/lib.zig"); http_tests.addPackage(util_pkg); - sql_tests.addPackage(util_pkg); + //sql_tests.addPackage(util_pkg); const unit_tests = b.step("unit-tests", "Run tests"); - unit_tests.dependOn(&util_tests.step); + //unit_tests.dependOn(&util_tests.step); unit_tests.dependOn(&http_tests.step); - unit_tests.dependOn(&sql_tests.step); + //unit_tests.dependOn(&sql_tests.step); const api_integration = b.addTest("./tests/api_integration/lib.zig"); api_integration.addPackage(sql_pkg); diff --git a/src/api/lib.zig b/src/api/lib.zig index a48170a..fb8f67e 100644 --- a/src/api/lib.zig +++ b/src/api/lib.zig @@ -276,7 +276,10 @@ fn ApiConn(comptime DbConn: type) type { username, password, self.community.id, - .{ .invite_id = if (maybe_invite) |inv| inv.id else null, .email = opt.email }, + .{ + .invite_id = if (maybe_invite) |inv| @as(?Uuid, inv.id) else null, + .email = opt.email, + }, self.arena.allocator(), ); @@ -348,5 +351,19 @@ fn ApiConn(comptime DbConn: type) type { if (!self.isAdmin()) return error.PermissionDenied; return try services.communities.query(self.db, args, self.arena.allocator()); } + + pub fn globalTimeline(self: *Self) ![]services.notes.Note { + const result = try services.notes.query(self.db, .{}, self.arena.allocator()); + return result.items; + } + + pub fn localTimeline(self: *Self) ![]services.notes.Note { + const result = try services.notes.query( + self.db, + .{ .community_id = self.community.id }, + self.arena.allocator(), + ); + return result.items; + } }; } diff --git a/src/api/services/actors.zig b/src/api/services/actors.zig index ee7c38d..2a7f964 100644 --- a/src/api/services/actors.zig +++ b/src/api/services/actors.zig @@ -94,7 +94,8 @@ pub const Actor = struct { created_at: DateTime, }; -pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !Actor { +pub const GetError = error{ NotFound, DatabaseFailure }; +pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Actor { return db.queryRow( Actor, \\SELECT diff --git a/src/api/services/common.zig b/src/api/services/common.zig new file mode 100644 index 0000000..8a407d5 --- /dev/null +++ b/src/api/services/common.zig @@ -0,0 +1,16 @@ +const std = @import("std"); +const util = @import("util"); + +pub const Direction = enum { + ascending, + descending, + + pub const jsonStringify = util.jsonSerializeEnumAsString; +}; + +pub const PageDirection = enum { + forward, + backward, + + pub const jsonStringify = util.jsonSerializeEnumAsString; +}; diff --git a/src/api/services/communities.zig b/src/api/services/communities.zig index 84f07b6..780b7d5 100644 --- a/src/api/services/communities.zig +++ b/src/api/services/communities.zig @@ -2,6 +2,7 @@ const std = @import("std"); const builtin = @import("builtin"); const util = @import("util"); const sql = @import("sql"); +const common = @import("./common.zig"); const Uuid = util.Uuid; const DateTime = util.DateTime; @@ -82,11 +83,12 @@ pub fn create(db: anytype, origin: []const u8, options: CreateOptions, alloc: st else => return error.DatabaseFailure, } + const name = options.name orelse host; db.insert("community", .{ .id = id, .owner_id = null, .host = host, - .name = options.name orelse host, + .name = name, .scheme = scheme, .kind = options.kind, .created_at = DateTime.now(), @@ -153,20 +155,8 @@ pub const QueryArgs = struct { pub const jsonStringify = util.jsonSerializeEnumAsString; }; - pub const Direction = enum { - ascending, - descending, - - pub const jsonStringify = util.jsonSerializeEnumAsString; - }; - - pub const PageDirection = enum { - forward, - backward, - - pub const jsonStringify = util.jsonSerializeEnumAsString; - }; - + pub const Direction = common.Direction; + pub const PageDirection = common.PageDirection; pub const Prev = std.meta.Child(std.meta.fieldInfo(QueryArgs, .prev).field_type); pub const OrderVal = std.meta.fieldInfo(Prev, .order_val).field_type; @@ -211,30 +201,6 @@ pub const QueryResult = struct { next_page: QueryArgs, }; -const QueryBuilder = struct { - array: std.ArrayList(u8), - where_clauses_appended: usize = 0, - - pub fn init(alloc: std.mem.Allocator) QueryBuilder { - return QueryBuilder{ .array = std.ArrayList(u8).init(alloc) }; - } - - pub fn deinit(self: *const QueryBuilder) void { - self.array.deinit(); - } - - pub fn andWhere(self: *QueryBuilder, clause: []const u8) !void { - if (self.where_clauses_appended == 0) { - try self.array.appendSlice("WHERE "); - } else { - try self.array.appendSlice(" AND "); - } - - try self.array.appendSlice(clause); - self.where_clauses_appended += 1; - } -}; - const max_max_items = 100; pub const QueryError = error{ @@ -246,7 +212,7 @@ pub const QueryError = error{ // arguments. // `args.max_items` is only a request, and fewer entries may be returned. pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResult { - var builder = QueryBuilder.init(alloc); + var builder = sql.QueryBuilder.init(alloc); defer builder.deinit(); try builder.array.appendSlice( @@ -266,21 +232,21 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul if (args.prev) |prev| { if (prev.order_val != args.order_by) return error.PageArgMismatch; - try builder.andWhere(switch (args.order_by) { - .name => "(name, id)", - .host => "(host, id)", - .created_at => "(created_at, id)", - }); - _ = try builder.array.appendSlice(switch (args.direction) { + switch (args.order_by) { + .name => try builder.andWhere("(name, id)"), + .host => try builder.andWhere("(host, id)"), + .created_at => try builder.andWhere("(created_at, id)"), + } + switch (args.direction) { .ascending => switch (args.page_direction) { - .forward => " > ", - .backward => " < ", + .forward => try builder.appendSlice(" > "), + .backward => try builder.appendSlice(" < "), }, .descending => switch (args.page_direction) { - .forward => " < ", - .backward => " > ", + .forward => try builder.appendSlice(" < "), + .backward => try builder.appendSlice(" > "), }, - }); + } _ = try builder.array.appendSlice("($5, $6)"); } @@ -297,57 +263,52 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul _ = try builder.array.appendSlice("\nLIMIT $7"); - const query_args = .{ - args.owner_id, - args.like, - args.created_before, - args.created_after, - if (args.prev) |prev| prev.order_val else null, - if (args.prev) |prev| prev.id else null, - max_items, + const query_args = blk: { + const ord_val = + if (args.prev) |prev| @as(?QueryArgs.OrderVal, prev.order_val) else null; + const id = + if (args.prev) |prev| @as(?Uuid, prev.id) else null; + break :blk .{ + args.owner_id, + args.like, + args.created_before, + args.created_after, + ord_val, + id, + max_items, + }; }; try builder.array.append(0); - var results = try db.queryWithOptions( + var results = try db.queryRowsWithOptions( Community, std.meta.assumeSentinel(builder.array.items, 0), query_args, - .{ .prep_allocator = alloc, .ignore_unused_arguments = true }, + max_items, + .{ .allocator = alloc, .ignore_unused_arguments = true }, ); - defer results.finish(); - - const result_buf = try alloc.alloc(Community, args.max_items); - errdefer alloc.free(result_buf); - - var count: usize = 0; - errdefer for (result_buf[0..count]) |c| util.deepFree(alloc, c); - - for (result_buf) |*c| { - c.* = (try results.row(alloc)) orelse break; - - count += 1; - } + errdefer util.deepFree(alloc, results); var next_page = args; var prev_page = args; prev_page.page_direction = .backward; next_page.page_direction = .forward; - if (count != 0) { + if (results.len != 0) { prev_page.prev = .{ - .id = result_buf[0].id, - .order_val = getOrderVal(result_buf[0], args.order_by), + .id = results[0].id, + .order_val = getOrderVal(results[0], args.order_by), }; next_page.prev = .{ - .id = result_buf[count - 1].id, - .order_val = getOrderVal(result_buf[count - 1], args.order_by), + .id = results[results.len - 1].id, + .order_val = getOrderVal(results[results.len - 1], args.order_by), }; } // TODO: This will give incorrect links on an empty page return QueryResult{ - .items = result_buf[0..count], + .items = results, .next_page = next_page, .prev_page = prev_page, diff --git a/src/api/services/invites.zig b/src/api/services/invites.zig index 09a156c..2a58354 100644 --- a/src/api/services/invites.zig +++ b/src/api/services/invites.zig @@ -71,7 +71,7 @@ pub fn create(db: anytype, created_by: Uuid, community_id: ?Uuid, options: Invit .max_uses = options.max_uses, .created_at = created_at, .expires_at = if (options.lifespan) |lifespan| - created_at.add(lifespan) + @as(?DateTime, created_at.add(lifespan)) else null, diff --git a/src/api/services/notes.zig b/src/api/services/notes.zig index edb8c32..6fda4f6 100644 --- a/src/api/services/notes.zig +++ b/src/api/services/notes.zig @@ -1,6 +1,7 @@ const std = @import("std"); const util = @import("util"); const sql = @import("sql"); +const common = @import("./common.zig"); const Uuid = util.Uuid; const DateTime = util.DateTime; @@ -42,7 +43,7 @@ const selectStarFromNote = std.fmt.comptimePrint( \\SELECT {s} \\FROM note \\ -, .{util.comptimeJoin(",", std.meta.fieldNames(Note))}); +, .{util.comptimeJoinWithPrefix(",", "note.", std.meta.fieldNames(Note))}); pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Note { return db.queryRow( Note, @@ -57,3 +58,108 @@ pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Note { else => error.DatabaseFailure, }; } + +const max_max_items = 100; + +pub const QueryArgs = struct { + pub const PageDirection = common.PageDirection; + pub const Prev = std.meta.Child(std.meta.field(@This(), .prev).field_type); + + max_items: usize = 20, + + created_before: ?DateTime = null, + created_after: ?DateTime = null, + community_id: ?Uuid = null, + + prev: ?struct { + id: Uuid, + created_at: DateTime, + } = null, + + page_direction: PageDirection = .forward, +}; + +pub const QueryResult = struct { + items: []Note, + + prev_page: QueryArgs, + next_page: QueryArgs, +}; + +pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResult { + var builder = sql.QueryBuilder.init(alloc); + defer builder.deinit(); + + try builder.appendSlice(selectStarFromNote ++ + \\ JOIN actor ON actor.id = note.author_id + \\ + ); + + if (args.created_before != null) try builder.andWhere("note.created_at < $1"); + if (args.created_after != null) try builder.andWhere("note.created_at > $2"); + if (args.prev != null) { + try builder.andWhere("(note.created_at, note.id)"); + + switch (args.page_direction) { + .forward => try builder.appendSlice(" < "), + .backward => try builder.appendSlice(" > "), + } + try builder.appendSlice("($3, $4)"); + } + if (args.community_id != null) try builder.andWhere("actor.community_id = $5"); + + try builder.appendSlice( + \\ + \\ORDER BY note.created_at DESC + \\LIMIT $6 + \\ + ); + + const max_items = if (args.max_items > max_max_items) max_max_items else args.max_items; + + const query_args = blk: { + const prev_created_at = if (args.prev) |prev| @as(?DateTime, prev.created_at) else null; + const prev_id = if (args.prev) |prev| @as(?Uuid, prev.id) else null; + + break :blk .{ + args.created_before, + args.created_after, + prev_created_at, + prev_id, + args.community_id, + max_items, + }; + }; + + const results = try db.queryRowsWithOptions( + Note, + try builder.terminate(), + query_args, + max_items, + .{ .allocator = alloc, .ignore_unused_arguments = true }, + ); + errdefer util.deepFree(results); + + var next_page = args; + var prev_page = args; + prev_page.page_direction = .backward; + next_page.page_direction = .forward; + if (results.len != 0) { + prev_page.prev = .{ + .id = results[0].id, + .created_at = results[0].created_at, + }; + + next_page.prev = .{ + .id = results[results.len - 1].id, + .created_at = results[results.len - 1].created_at, + }; + } + // TODO: this will give incorrect links on an empty page + + return QueryResult{ + .items = results, + .next_page = next_page, + .prev_page = prev_page, + }; +} diff --git a/src/http/headers.zig b/src/http/headers.zig new file mode 100644 index 0000000..4f45029 --- /dev/null +++ b/src/http/headers.zig @@ -0,0 +1,124 @@ +const std = @import("std"); + +pub const Fields = struct { + const HashContext = struct { + const hash_seed = 1; + pub fn eql(_: @This(), lhs: []const u8, rhs: []const u8, _: usize) bool { + return std.ascii.eqlIgnoreCase(lhs, rhs); + } + pub fn hash(_: @This(), s: []const u8) u32 { + var h = std.hash.Wyhash.init(hash_seed); + for (s) |ch| { + const c = [1]u8{std.ascii.toLower(ch)}; + h.update(&c); + } + return @truncate(u32, h.final()); + } + }; + + const HashMap = std.ArrayHashMapUnmanaged( + []const u8, + []const u8, + HashContext, + true, + ); + + unmanaged: HashMap, + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator) Fields { + return Fields{ + .unmanaged = .{}, + .allocator = allocator, + }; + } + + pub fn deinit(self: *Fields) void { + var hash_iter = self.unmanaged.iterator(); + while (hash_iter.next()) |entry| { + self.allocator.free(entry.key_ptr.*); + self.allocator.free(entry.value_ptr.*); + } + + self.unmanaged.deinit(self.allocator); + } + + pub fn iterator(self: Fields) HashMap.Iterator { + return self.unmanaged.iterator(); + } + + pub fn get(self: Fields, key: []const u8) ?[]const u8 { + return self.unmanaged.get(key); + } + + pub const ListIterator = struct { + remaining: []const u8, + + fn extractElement(self: *ListIterator) ?[]const u8 { + if (self.remaining.len == 0) return null; + + var start: usize = 0; + var is_quoted = false; + const end = for (self.remaining) |ch, i| { + if (start == i and std.ascii.isWhitespace(ch)) { + start += 1; + } else if (ch == '"') { + is_quoted = !is_quoted; + } + if (ch == ',' and !is_quoted) { + break i; + } + } else self.remaining.len; + + const str = self.remaining[start..end]; + if (end == self.remaining.len) { + self.remaining = ""; + } else { + self.remaining = self.remaining[end + 1 ..]; + } + + return std.mem.trim(u8, str, " \t"); + } + + pub fn next(self: *ListIterator) ?[]const u8 { + while (self.extractElement()) |elem| { + if (elem.len != 0) return elem; + } + + return null; + } + }; + + pub fn getList(self: Fields, key: []const u8) ?ListIterator { + return if (self.unmanaged.get(key)) |hdr| ListIterator{ .remaining = hdr } else null; + } + + pub fn put(self: *Fields, key: []const u8, val: []const u8) !void { + const key_clone = try self.allocator.alloc(u8, key.len); + std.mem.copy(u8, key_clone, key); + errdefer self.allocator.free(key_clone); + + const val_clone = try self.allocator.alloc(u8, val.len); + std.mem.copy(u8, val_clone, val); + errdefer self.allocator.free(val_clone); + + if (try self.unmanaged.fetchPut(self.allocator, key_clone, val_clone)) |entry| { + self.allocator.free(entry.key); + self.allocator.free(entry.value); + } + } + + pub fn append(self: *Fields, key: []const u8, val: []const u8) !void { + if (self.unmanaged.getEntry(key)) |entry| { + const new_val = try std.mem.join(self.allocator, ", ", &.{ entry.value_ptr.*, val }); + self.allocator.free(entry.value_ptr.*); + entry.value_ptr.* = new_val; + } else { + try self.put(key, val); + } + } + + pub fn count(self: Fields) usize { + return self.unmanaged.count(); + } +}; diff --git a/src/http/lib.zig b/src/http/lib.zig index eff1e45..ce97ad1 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -10,22 +10,15 @@ pub const socket = @import("./socket.zig"); pub const Method = std.http.Method; pub const Status = std.http.Status; -pub const Request = request.Request; +pub const Request = request.Request(std.net.Stream.Reader); pub const serveConn = server.serveConn; pub const Response = server.Response; pub const Handler = server.Handler; -pub const Headers = std.HashMap([]const u8, []const u8, struct { - pub fn eql(_: @This(), a: []const u8, b: []const u8) bool { - return ciutf8.eql(a, b); - } +pub const Fields = @import("./headers.zig").Fields; - pub fn hash(_: @This(), str: []const u8) u64 { - return ciutf8.hash(str); - } -}, std.hash_map.default_max_load_percentage); - -test { - _ = server; - _ = request; -} +pub const Protocol = enum { + http_1_0, + http_1_1, + http_1_x, +}; diff --git a/src/http/request.zig b/src/http/request.zig index e6fd79c..b713d16 100644 --- a/src/http/request.zig +++ b/src/http/request.zig @@ -3,29 +3,23 @@ const http = @import("./lib.zig"); const parser = @import("./request/parser.zig"); -pub const Request = struct { - pub const Protocol = enum { - http_1_0, - http_1_1, +pub fn Request(comptime Reader: type) type { + return struct { + protocol: http.Protocol, + + method: http.Method, + uri: []const u8, + headers: http.Fields, + + body: ?parser.TransferStream(Reader), + + pub fn parseFree(self: *@This(), allocator: std.mem.Allocator) void { + allocator.free(self.uri); + self.headers.deinit(); + } }; - - protocol: Protocol, - source_address: ?std.net.Address, - - method: http.Method, - uri: []const u8, - headers: http.Headers, - body: ?[]const u8 = null, - - pub fn parse(alloc: std.mem.Allocator, reader: anytype, addr: std.net.Address) !Request { - return parser.parse(alloc, reader, addr); - } - - pub fn parseFree(self: Request, alloc: std.mem.Allocator) void { - parser.parseFree(alloc, self); - } -}; - -test { - _ = parser; +} + +pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request(@TypeOf(reader)) { + return parser.parse(alloc, reader); } diff --git a/src/http/request/parser.zig b/src/http/request/parser.zig index 8a7c19f..57fae33 100644 --- a/src/http/request/parser.zig +++ b/src/http/request/parser.zig @@ -1,15 +1,13 @@ const std = @import("std"); -const util = @import("util"); const http = @import("../lib.zig"); const Method = http.Method; -const Headers = http.Headers; +const Fields = http.Fields; const Request = @import("../request.zig").Request; const request_buf_size = 1 << 16; const max_path_len = 1 << 10; -const max_body_len = 1 << 12; fn ParseError(comptime Reader: type) type { return error{ @@ -22,7 +20,7 @@ const Encoding = enum { chunked, }; -pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address) !Request { +pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request(@TypeOf(reader)) { const method = try parseMethod(reader); const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) { error.StreamTooLong => return error.RequestUriTooLong, @@ -33,28 +31,20 @@ pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address const proto = try parseProto(reader); // discard \r\n - _ = try reader.readByte(); - _ = try reader.readByte(); + switch (try reader.readByte()) { + '\r' => if ((try reader.readByte()) != '\n') return error.BadRequest, + '\n' => {}, + else => return error.BadRequest, + } var headers = try parseHeaders(alloc, reader); - errdefer freeHeaders(alloc, &headers); + errdefer headers.deinit(); - const body = if (method.requestHasBody()) - try readBody(alloc, headers, reader) - else - null; - errdefer if (body) |b| alloc.free(b); + const body = try prepareBody(headers, reader); + if (body != null and !method.requestHasBody()) return error.BadRequest; - const eff_addr = if (headers.get("X-Real-IP")) |ip| - std.net.Address.parseIp(ip, address.getPort()) catch { - return error.BadRequest; - } - else - address; - - return Request{ + return Request(@TypeOf(reader)){ .protocol = proto, - .source_address = eff_addr, .method = method, .uri = uri, @@ -79,7 +69,7 @@ fn parseMethod(reader: anytype) !Method { return error.MethodNotImplemented; } -fn parseProto(reader: anytype) !Request.Protocol { +fn parseProto(reader: anytype) !http.Protocol { var buf: [8]u8 = undefined; const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { error.StreamTooLong => return error.UnknownProtocol, @@ -99,85 +89,145 @@ fn parseProto(reader: anytype) !Request.Protocol { return switch (buf[2]) { '0' => .http_1_0, '1' => .http_1_1, - else => error.HttpVersionNotSupported, + else => .http_1_x, }; } -fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers { - var map = Headers.init(allocator); - errdefer map.deinit(); - errdefer { - var iter = map.iterator(); - while (iter.next()) |it| { - allocator.free(it.key_ptr.*); - allocator.free(it.value_ptr.*); - } - } - // todo: - //errdefer { - //var iter = map.iterator(); - //while (iter.next()) |it| { - //allocator.free(it.key_ptr); - //allocator.free(it.value_ptr); - //} - //} - - var buf: [1024]u8 = undefined; +fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields { + var headers = Fields.init(allocator); + var buf: [4096]u8 = undefined; while (true) { - const line = try reader.readUntilDelimiter(&buf, '\n'); - if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break; + const full_line = reader.readUntilDelimiter(&buf, '\n') catch |err| switch (err) { + error.StreamTooLong => return error.HeaderLineTooLong, + else => return err, + }; + const line = std.mem.trimRight(u8, full_line, "\r"); + if (line.len == 0) break; - // TODO: handle multi-line headers - const name = extractHeaderName(line) orelse continue; - const value_end = if (line[line.len - 1] == '\r') line.len - 1 else line.len; - const value = line[name.len + 1 + 1 .. value_end]; + const name = std.mem.sliceTo(line, ':'); + if (!isTokenValid(name)) return error.BadRequest; + if (name.len == line.len) return error.BadRequest; - if (name.len == 0 or value.len == 0) return error.BadRequest; + const value = std.mem.trim(u8, line[name.len + 1 ..], " \t"); - const name_alloc = try allocator.alloc(u8, name.len); - errdefer allocator.free(name_alloc); - const value_alloc = try allocator.alloc(u8, value.len); - errdefer allocator.free(value_alloc); - - @memcpy(name_alloc.ptr, name.ptr, name.len); - @memcpy(value_alloc.ptr, value.ptr, value.len); - - try map.put(name_alloc, value_alloc); + try headers.append(name, value); } - return map; + return headers; } -fn extractHeaderName(line: []const u8) ?[]const u8 { - var index: usize = 0; +fn isTokenValid(token: []const u8) bool { + if (token.len == 0) return false; + for (token) |ch| { + switch (ch) { + '"', '(', ')', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '{', '}' => return false, - // TODO: handle whitespace - while (index < line.len) : (index += 1) { - if (line[index] == ':') { - if (index == 0) return null; - return line[0..index]; + '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~' => {}, + else => if (!std.ascii.isAlphanumeric(ch)) return false, } } - return null; + return true; } -fn readBody(alloc: std.mem.Allocator, headers: Headers, reader: anytype) !?[]const u8 { - const xfer_encoding = try parseEncoding(headers.get("Transfer-Encoding")); - if (xfer_encoding != .identity) return error.UnsupportedMediaType; +fn prepareBody(headers: Fields, reader: anytype) !?TransferStream(@TypeOf(reader)) { + const hdr = headers.get("Transfer-Encoding"); + // TODO: + // if (hder != null and protocol == .http_1_0) return error.BadRequest; + const xfer_encoding = try parseEncoding(hdr); const content_encoding = try parseEncoding(headers.get("Content-Encoding")); if (content_encoding != .identity) return error.UnsupportedMediaType; - const len_str = headers.get("Content-Length") orelse return null; - const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest; - if (len > max_body_len) return error.RequestEntityTooLarge; - const body = try alloc.alloc(u8, len); - errdefer alloc.free(body); + switch (xfer_encoding) { + .identity => { + const len_str = headers.get("Content-Length") orelse return null; + const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest; - reader.readNoEof(body) catch return error.BadRequest; + return TransferStream(@TypeOf(reader)){ .underlying = .{ .identity = std.io.limitedReader(reader, len) } }; + }, + .chunked => { + if (headers.get("Content-Length") != null) return error.BadRequest; + return TransferStream(@TypeOf(reader)){ + .underlying = .{ + .chunked = try ChunkedStream(@TypeOf(reader)).init(reader), + }, + }; + }, + } +} - return body; +fn ChunkedStream(comptime R: type) type { + return struct { + const Self = @This(); + + remaining: ?usize = 0, + underlying: R, + + const Error = R.Error || error{ Unexpected, InvalidChunkHeader, StreamTooLong, EndOfStream }; + fn init(reader: R) !Self { + var self: Self = .{ .underlying = reader }; + return self; + } + + fn read(self: *Self, buf: []u8) !usize { + var count: usize = 0; + while (true) { + if (count == buf.len) return count; + if (self.remaining == null) return count; + if (self.remaining.? == 0) self.remaining = try self.readChunkHeader(); + + const max_read = std.math.min(buf.len, self.remaining.?); + const amt = try self.underlying.read(buf[count .. count + max_read]); + if (amt != max_read) return error.EndOfStream; + count += amt; + self.remaining.? -= amt; + if (self.remaining.? == 0) { + var crlf: [2]u8 = undefined; + _ = try self.underlying.readUntilDelimiter(&crlf, '\n'); + self.remaining = try self.readChunkHeader(); + } + + if (count == buf.len) return count; + } + } + + fn readChunkHeader(self: *Self) !?usize { + // TODO: Pick a reasonable limit for this + var buf = std.mem.zeroes([10]u8); + const line = self.underlying.readUntilDelimiter(&buf, '\n') catch |err| { + return if (err == error.StreamTooLong) error.InvalidChunkHeader else err; + }; + if (line.len < 2 or line[line.len - 1] != '\r') return error.InvalidChunkHeader; + + const size = std.fmt.parseInt(usize, line[0 .. line.len - 1], 16) catch return error.InvalidChunkHeader; + + return if (size != 0) size else null; + } + }; +} + +pub fn TransferStream(comptime R: type) type { + return struct { + const Error = R.Error || ChunkedStream(R).Error; + const Reader = std.io.Reader(*@This(), Error, read); + + underlying: union(enum) { + identity: std.io.LimitedReader(R), + chunked: ChunkedStream(R), + }, + + pub fn read(self: *@This(), buf: []u8) Error!usize { + return switch (self.underlying) { + .identity => |*r| try r.read(buf), + .chunked => |*r| try r.read(buf), + }; + } + + pub fn reader(self: *@This()) Reader { + return .{ .context = self }; + } + }; } // TODO: assumes that there's only one encoding, not layered encodings @@ -187,257 +237,3 @@ fn parseEncoding(encoding: ?[]const u8) !Encoding { if (std.mem.eql(u8, encoding.?, "chunked")) return .chunked; return error.UnsupportedMediaType; } - -pub fn parseFree(allocator: std.mem.Allocator, request: *Request) void { - allocator.free(request.uri); - freeHeaders(allocator, &request.headers); - if (request.body) |body| allocator.free(body); -} - -fn freeHeaders(allocator: std.mem.Allocator, headers: *http.Headers) void { - var iter = headers.iterator(); - while (iter.next()) |it| { - allocator.free(it.key_ptr.*); - allocator.free(it.value_ptr.*); - } - headers.deinit(); -} - -const _test = struct { - const expectEqual = std.testing.expectEqual; - const expectEqualStrings = std.testing.expectEqualStrings; - - fn toCrlf(comptime str: []const u8) []const u8 { - comptime { - var buf: [str.len * 2]u8 = undefined; - - @setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2 - - var buf_len: usize = 0; - for (str) |ch| { - if (ch == '\n') { - buf[buf_len] = '\r'; - buf_len += 1; - } - - buf[buf_len] = ch; - buf_len += 1; - } - - return buf[0..buf_len]; - } - } - - fn makeHeaders(alloc: std.mem.Allocator, headers: anytype) !Headers { - var result = Headers.init(alloc); - inline for (headers) |tup| { - try result.put(tup[0], tup[1]); - } - return result; - } - - fn areEqualHeaders(lhs: Headers, rhs: Headers) bool { - if (lhs.count() != rhs.count()) return false; - var iter = lhs.iterator(); - while (iter.next()) |it| { - const rhs_val = rhs.get(it.key_ptr.*) orelse return false; - if (!std.mem.eql(u8, it.value_ptr.*, rhs_val)) return false; - } - return true; - } - - fn printHeaders(headers: Headers) void { - var iter = headers.iterator(); - while (iter.next()) |it| { - std.debug.print("{s}: {s}\n", .{ it.key_ptr.*, it.value_ptr.* }); - } - } - - fn expectEqualHeaders(expected: Headers, actual: Headers) !void { - if (!areEqualHeaders(expected, actual)) { - std.debug.print("\nexpected: \n", .{}); - printHeaders(expected); - std.debug.print("\n\nfound: \n", .{}); - printHeaders(actual); - std.debug.print("\n\n", .{}); - return error.TestExpectedEqual; - } - } - - fn parseTestCase(alloc: std.mem.Allocator, comptime request: []const u8, expected: http.Request) !void { - var stream = std.io.fixedBufferStream(toCrlf(request)); - - const result = try parse(alloc, stream.reader()); - - try expectEqual(expected.method, result.method); - try expectEqualStrings(expected.path, result.path); - try expectEqualHeaders(expected.headers, result.headers); - if ((expected.body == null) != (result.body == null)) { - const null_str: []const u8 = "(null)"; - const exp = expected.body orelse null_str; - const act = result.body orelse null_str; - std.debug.print("\nexpected:\n{s}\n\nfound:\n{s}\n\n", .{ exp, act }); - return error.TestExpectedEqual; - } - if (expected.body != null) { - try expectEqualStrings(expected.body.?, result.body.?); - } - } -}; - -// TOOD: failure test cases -test "parse" { - const testCase = _test.parseTestCase; - var buf = [_]u8{0} ** (1 << 16); - var fba = std.heap.FixedBufferAllocator.init(&buf); - const alloc = fba.allocator(); - try testCase(alloc, ( - \\GET / HTTP/1.1 - \\ - \\ - ), .{ - .method = .GET, - .headers = try _test.makeHeaders(alloc, .{}), - .path = "/", - }); - - fba.reset(); - try testCase(alloc, ( - \\POST / HTTP/1.1 - \\ - \\ - ), .{ - .method = .POST, - .headers = try _test.makeHeaders(alloc, .{}), - .path = "/", - }); - - fba.reset(); - try testCase(alloc, ( - \\HEAD / HTTP/1.1 - \\Authorization: bearer - \\ - \\ - ), .{ - .method = .HEAD, - .headers = try _test.makeHeaders(alloc, .{ - .{ "Authorization", "bearer " }, - }), - .path = "/", - }); - - fba.reset(); - try testCase(alloc, ( - \\POST /nonsense HTTP/1.1 - \\Authorization: bearer - \\Content-Length: 5 - \\ - \\12345 - ), .{ - .method = .POST, - .headers = try _test.makeHeaders(alloc, .{ - .{ "Authorization", "bearer " }, - .{ "Content-Length", "5" }, - }), - .path = "/nonsense", - .body = "12345", - }); - - fba.reset(); - try std.testing.expectError( - error.MethodNotImplemented, - testCase(alloc, ( - \\FOO /nonsense HTTP/1.1 - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.MethodNotImplemented, - testCase(alloc, ( - \\FOOBARBAZ /nonsense HTTP/1.1 - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.RequestUriTooLong, - testCase(alloc, ( - \\GET / - ++ ("a" ** 2048)), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.UnknownProtocol, - testCase(alloc, ( - \\GET /nonsense SPECIALHTTP/1.1 - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.UnknownProtocol, - testCase(alloc, ( - \\GET /nonsense JSON/1.1 - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.HttpVersionNotSupported, - testCase(alloc, ( - \\GET /nonsense HTTP/1.9 - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.HttpVersionNotSupported, - testCase(alloc, ( - \\GET /nonsense HTTP/8.1 - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.BadRequest, - testCase(alloc, ( - \\GET /nonsense HTTP/blah blah blah - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.BadRequest, - testCase(alloc, ( - \\GET /nonsense HTTP/1/1 - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.BadRequest, - testCase(alloc, ( - \\GET /nonsense HTTP/1/1 - \\ - \\ - ), undefined), - ); -} diff --git a/src/http/request/test_parser.zig b/src/http/request/test_parser.zig new file mode 100644 index 0000000..55a66d6 --- /dev/null +++ b/src/http/request/test_parser.zig @@ -0,0 +1,282 @@ +const std = @import("std"); +const parser = @import("./parser.zig"); +const http = @import("../lib.zig"); +const t = std.testing; + +const test_case = struct { + fn parse(text: []const u8, expected: struct { + protocol: http.Protocol = .http_1_1, + method: http.Method = .GET, + headers: []const std.meta.Tuple(&.{ []const u8, []const u8 }) = &.{}, + uri: []const u8 = "", + }) !void { + var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(text) }; + var actual = try parser.parse(t.allocator, stream.reader()); + defer actual.parseFree(t.allocator); + + try t.expectEqual(expected.protocol, actual.protocol); + try t.expectEqual(expected.method, actual.method); + try t.expectEqualStrings(expected.uri, actual.uri); + + try t.expectEqual(expected.headers.len, actual.headers.count()); + for (expected.headers) |hdr| { + if (actual.headers.get(hdr[0])) |val| { + try t.expectEqualStrings(hdr[1], val); + } else { + std.debug.print("Error: Header {s} expected to be present, was not.\n", .{hdr[0]}); + try t.expect(false); + } + } + } +}; + +fn toCrlf(comptime str: []const u8) []const u8 { + comptime { + var buf: [str.len * 2]u8 = undefined; + + @setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2 + + var buf_len: usize = 0; + for (str) |ch| { + if (ch == '\n') { + buf[buf_len] = '\r'; + buf_len += 1; + } + + buf[buf_len] = ch; + buf_len += 1; + } + + return buf[0..buf_len]; + } +} + +test "HTTP/1.x parse - No body" { + try test_case.parse( + toCrlf( + \\GET / HTTP/1.1 + \\ + \\ + ), + .{ + .protocol = .http_1_1, + .method = .GET, + .uri = "/", + }, + ); + try test_case.parse( + toCrlf( + \\POST / HTTP/1.1 + \\ + \\ + ), + .{ + .protocol = .http_1_1, + .method = .POST, + .uri = "/", + }, + ); + try test_case.parse( + toCrlf( + \\GET /url/abcd HTTP/1.1 + \\ + \\ + ), + .{ + .protocol = .http_1_1, + .method = .GET, + .uri = "/url/abcd", + }, + ); + try test_case.parse( + toCrlf( + \\GET / HTTP/1.0 + \\ + \\ + ), + .{ + .protocol = .http_1_0, + .method = .GET, + .uri = "/", + }, + ); + try test_case.parse( + toCrlf( + \\GET /url/abcd HTTP/1.1 + \\Content-Type: application/json + \\ + \\ + ), + .{ + .protocol = .http_1_1, + .method = .GET, + .uri = "/url/abcd", + .headers = &.{.{ "Content-Type", "application/json" }}, + }, + ); + try test_case.parse( + toCrlf( + \\GET /url/abcd HTTP/1.1 + \\Content-Type: application/json + \\Authorization: bearer + \\ + \\ + ), + .{ + .protocol = .http_1_1, + .method = .GET, + .uri = "/url/abcd", + .headers = &.{ + .{ "Content-Type", "application/json" }, + .{ "Authorization", "bearer " }, + }, + }, + ); + + // Test without CRLF + try test_case.parse( + \\GET /url/abcd HTTP/1.1 + \\Content-Type: application/json + \\Authorization: bearer + \\ + \\ + , + .{ + .protocol = .http_1_1, + .method = .GET, + .uri = "/url/abcd", + .headers = &.{ + .{ "Content-Type", "application/json" }, + .{ "Authorization", "bearer " }, + }, + }, + ); + try test_case.parse( + \\POST / HTTP/1.1 + \\ + \\ + , + .{ + .protocol = .http_1_1, + .method = .POST, + .uri = "/", + }, + ); + try test_case.parse( + toCrlf( + \\GET / HTTP/1.2 + \\ + \\ + ), + .{ + .protocol = .http_1_x, + .method = .GET, + .uri = "/", + }, + ); +} + +test "HTTP/1.x parse - unsupported protocol" { + try t.expectError(error.UnknownProtocol, test_case.parse( + \\GET / JSON/1.1 + \\ + \\ + , + .{}, + )); + try t.expectError(error.UnknownProtocol, test_case.parse( + \\GET / SOMETHINGELSE/3.5 + \\ + \\ + , + .{}, + )); + try t.expectError(error.UnknownProtocol, test_case.parse( + \\GET / /1.1 + \\ + \\ + , + .{}, + )); + try t.expectError(error.HttpVersionNotSupported, test_case.parse( + \\GET / HTTP/2.1 + \\ + \\ + , + .{}, + )); +} + +test "HTTP/1.x parse - Unknown method" { + try t.expectError(error.MethodNotImplemented, test_case.parse( + \\ABCD / HTTP/1.1 + \\ + \\ + , + .{}, + )); + try t.expectError(error.MethodNotImplemented, test_case.parse( + \\PATCHPATCHPATCH / HTTP/1.1 + \\ + \\ + , + .{}, + )); +} + +test "HTTP/1.x parse - Too long" { + try t.expectError(error.RequestUriTooLong, test_case.parse( + std.fmt.comptimePrint("GET {s} HTTP/1.1\n\n", .{"a" ** 8192}), + .{}, + )); + try t.expectError(error.HeaderLineTooLong, test_case.parse( + std.fmt.comptimePrint("GET / HTTP/1.1\r\n{s}: abcd", .{"a" ** 8192}), + .{}, + )); + try t.expectError(error.HeaderLineTooLong, test_case.parse( + std.fmt.comptimePrint("GET / HTTP/1.1\r\nabcd: {s}", .{"a" ** 8192}), + .{}, + )); +} + +test "HTTP/1.x parse - bad requests" { + try t.expectError(error.BadRequest, test_case.parse( + \\GET / HTTP/1.1 blah blah + \\ + \\ + , + .{}, + )); + try t.expectError(error.BadRequest, test_case.parse( + \\GET / HTTP/1.1 + \\abcd : lksjdfkl + \\ + , + .{}, + )); + try t.expectError(error.BadRequest, test_case.parse( + \\GET / HTTP/1.1 + \\ lksjfklsjdfklj + \\ + , + .{}, + )); +} + +test "HTTP/1.x parse - Headers" { + try test_case.parse( + toCrlf( + \\GET /url/abcd HTTP/1.1 + \\Content-Type: application/json + \\Content-Type: application/xml + \\ + \\ + ), + .{ + .protocol = .http_1_1, + .method = .GET, + .uri = "/url/abcd", + .headers = &.{.{ "Content-Type", "application/json, application/xml" }}, + }, + ); +} diff --git a/src/http/server.zig b/src/http/server.zig index 269fc56..d81d77a 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -3,13 +3,14 @@ const util = @import("util"); const http = @import("./lib.zig"); const response = @import("./server/response.zig"); +const request = @import("./request.zig"); pub const Response = struct { alloc: std.mem.Allocator, stream: std.net.Stream, should_close: bool = false, pub const Stream = response.ResponseStream(std.net.Stream.Writer); - pub fn open(self: *Response, status: http.Status, headers: *const http.Headers) !Stream { + pub fn open(self: *Response, status: http.Status, headers: *const http.Fields) !Stream { if (headers.get("Connection")) |hdr| { if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true; } @@ -17,7 +18,7 @@ pub const Response = struct { return response.open(self.alloc, self.stream.writer(), headers, status); } - pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Headers) !std.net.Stream { + pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Fields) !std.net.Stream { try response.writeRequestHeader(self.stream.writer(), headers, status); return self.stream; } @@ -26,10 +27,6 @@ pub const Response = struct { const Request = http.Request; const request_buf_size = 1 << 16; -pub fn Handler(comptime Ctx: type) type { - return fn (Ctx, Request, *Response) void; -} - pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: anytype, alloc: std.mem.Allocator) !void { // TODO: Timeouts while (true) { @@ -37,7 +34,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a var arena = std.heap.ArenaAllocator.init(alloc); defer arena.deinit(); - const req = Request.parse(arena.allocator(), conn.stream.reader(), conn.address) catch |err| { + var req = request.parse(arena.allocator(), conn.stream.reader()) catch |err| { return handleError(conn.stream.writer(), err) catch {}; }; std.log.debug("done parsing", .{}); @@ -47,7 +44,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a .stream = conn.stream, }; - handler(ctx, req, &res); + handler(ctx, &req, &res); std.log.debug("done handling", .{}); if (req.headers.get("Connection")) |hdr| { diff --git a/src/http/server/response.zig b/src/http/server/response.zig index 615f0c1..296382d 100644 --- a/src/http/server/response.zig +++ b/src/http/server/response.zig @@ -2,20 +2,20 @@ const std = @import("std"); const http = @import("../lib.zig"); const Status = http.Status; -const Headers = http.Headers; +const Fields = http.Fields; const chunk_size = 16 * 1024; pub fn open( alloc: std.mem.Allocator, writer: anytype, - headers: *const Headers, + headers: *const Fields, status: Status, ) !ResponseStream(@TypeOf(writer)) { const buf = try alloc.alloc(u8, chunk_size); errdefer alloc.free(buf); try writeStatusLine(writer, status); - try writeHeaders(writer, headers); + try writeFields(writer, headers); return ResponseStream(@TypeOf(writer)){ .allocator = alloc, @@ -25,9 +25,9 @@ pub fn open( }; } -pub fn writeRequestHeader(writer: anytype, headers: *const Headers, status: Status) !void { +pub fn writeRequestHeader(writer: anytype, headers: *const Fields, status: Status) !void { try writeStatusLine(writer, status); - try writeHeaders(writer, headers); + try writeFields(writer, headers); try writer.writeAll("\r\n"); } @@ -36,7 +36,7 @@ fn writeStatusLine(writer: anytype, status: Status) !void { try writer.print("HTTP/1.1 {} {s}\r\n", .{ @enumToInt(status), status_text }); } -fn writeHeaders(writer: anytype, headers: *const Headers) !void { +fn writeFields(writer: anytype, headers: *const Fields) !void { var iter = headers.iterator(); while (iter.next()) |header| { for (header.value_ptr.*) |ch| { @@ -65,7 +65,7 @@ pub fn ResponseStream(comptime BaseWriter: type) type { allocator: std.mem.Allocator, base_writer: BaseWriter, - headers: *const Headers, + headers: *const Fields, buffer: []u8, buffer_pos: usize = 0, chunked: bool = false, @@ -95,7 +95,6 @@ pub fn ResponseStream(comptime BaseWriter: type) type { return; } - std.debug.print("{}\n", .{cursor}); self.writeToBuffer(bytes[cursor .. cursor + remaining_in_chunk]); cursor += remaining_in_chunk; try self.flushChunk(); @@ -177,7 +176,7 @@ const _tests = struct { test "ResponseStream no headers empty body" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Headers.init(std.testing.allocator); + var headers = Fields.init(std.testing.allocator); defer headers.deinit(); { @@ -205,7 +204,7 @@ const _tests = struct { test "ResponseStream empty body" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Headers.init(std.testing.allocator); + var headers = Fields.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); @@ -236,7 +235,7 @@ const _tests = struct { test "ResponseStream not 200 OK" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Headers.init(std.testing.allocator); + var headers = Fields.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); @@ -266,7 +265,7 @@ const _tests = struct { test "ResponseStream small body" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Headers.init(std.testing.allocator); + var headers = Fields.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); @@ -300,7 +299,7 @@ const _tests = struct { test "ResponseStream large body" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Headers.init(std.testing.allocator); + var headers = Fields.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); @@ -341,7 +340,7 @@ const _tests = struct { test "ResponseStream large body ending on chunk boundary" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Headers.init(std.testing.allocator); + var headers = Fields.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); diff --git a/src/http/socket.zig b/src/http/socket.zig index 24c10d4..eab1a1d 100644 --- a/src/http/socket.zig +++ b/src/http/socket.zig @@ -23,21 +23,21 @@ const Opcode = enum(u4) { } }; -pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Response) !Socket { +pub fn handshake(alloc: std.mem.Allocator, req: *http.Request, res: *http.Response) !Socket { const upgrade = req.headers.get("Upgrade") orelse return error.BadHandshake; const connection = req.headers.get("Connection") orelse return error.BadHandshake; if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake; if (std.ascii.indexOfIgnoreCase(connection, "Upgrade") == null) return error.BadHandshake; const key_hdr = req.headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake; - if (try std.base64.standard.Decoder.calcSizeForSlice(key_hdr) != 16) return error.BadHandshake; + if ((try std.base64.standard.Decoder.calcSizeForSlice(key_hdr)) != 16) return error.BadHandshake; var key: [16]u8 = undefined; std.base64.standard.Decoder.decode(&key, key_hdr) catch return error.BadHandshake; const version = req.headers.get("Sec-WebSocket-version") orelse return error.BadHandshake; if (!std.mem.eql(u8, "13", version)) return error.BadHandshake; - var headers = http.Headers.init(alloc); + var headers = http.Fields.init(alloc); defer headers.deinit(); try headers.put("Upgrade", "websocket"); @@ -51,7 +51,7 @@ pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Respons var hash_encoded: [std.base64.standard.Encoder.calcSize(Sha1.digest_length)]u8 = undefined; _ = std.base64.standard.Encoder.encode(&hash_encoded, &hash); try headers.put("Sec-WebSocket-Accept", &hash_encoded); - const stream = try res.upgrade(.switching_protcols, &headers); + const stream = try res.upgrade(.switching_protocols, &headers); return Socket{ .stream = stream }; } @@ -164,15 +164,15 @@ fn writeFrame(writer: anytype, header: FrameInfo, buf: []const u8) !void { const initial_len: u7 = if (header.len < 126) @intCast(u7, header.len) else if (std.math.cast(u16, header.len)) |_| - 126 + @as(u7, 126) else - 127; + @as(u7, 127); var hdr_buf = [2]u8{ 0, 0 }; - hdr_buf[0] |= if (header.is_final) 0b1000_0000 else 0; + hdr_buf[0] |= if (header.is_final) @as(u8, 0b1000_0000) else 0; hdr_buf[0] |= @as(u8, header.rsv) << 4; hdr_buf[0] |= @enumToInt(header.opcode); - hdr_buf[1] |= if (header.masking_key) |_| 0b1000_0000 else 0; + hdr_buf[1] |= if (header.masking_key) |_| @as(u8, 0b1000_0000) else 0; hdr_buf[1] |= initial_len; try writer.writeAll(&hdr_buf); if (initial_len == 126) diff --git a/src/http/test.zig b/src/http/test.zig new file mode 100644 index 0000000..1441ec2 --- /dev/null +++ b/src/http/test.zig @@ -0,0 +1,3 @@ +test { + _ = @import("./request/test_parser.zig"); +} diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 31fbe90..902711e 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -13,10 +13,11 @@ pub const invites = @import("./controllers/invites.zig"); pub const users = @import("./controllers/users.zig"); pub const notes = @import("./controllers/notes.zig"); pub const streaming = @import("./controllers/streaming.zig"); +pub const timelines = @import("./controllers/timelines.zig"); -pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void { +pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response, alloc: std.mem.Allocator) void { // TODO: hashmaps? - var response = Response{ .headers = http.Headers.init(alloc), .res = res }; + var response = Response{ .headers = http.Fields.init(alloc), .res = res }; defer response.headers.deinit(); const found = routeRequestInternal(api_source, req, &response, alloc); @@ -24,7 +25,7 @@ pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, if (!found) response.status(.not_found) catch {}; } -fn routeRequestInternal(api_source: anytype, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool { +fn routeRequestInternal(api_source: anytype, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool { inline for (routes) |route| { if (Context(route).matchAndHandle(api_source, req, res, alloc)) return true; } @@ -42,6 +43,8 @@ const routes = .{ notes.create, notes.get, streaming.streaming, + timelines.global, + timelines.local, }; pub fn Context(comptime Route: type) type { @@ -58,18 +61,21 @@ pub fn Context(comptime Route: type) type { // leave it as a simple string instead of void pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void; - base_request: http.Request, + base_request: *http.Request, allocator: std.mem.Allocator, method: http.Method, uri: []const u8, - headers: http.Headers, + headers: http.Fields, args: Args, body: Body, query: Query, + // TODO + body_buf: ?[]const u8 = null, + fn parseArgs(path: []const u8) ?Args { var args: Args = undefined; var path_iter = util.PathIter.from(path); @@ -94,7 +100,7 @@ pub fn Context(comptime Route: type) type { @compileError("Unsupported Type " ++ @typeName(T)); } - pub fn matchAndHandle(api_source: *api.ApiSource, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool { + pub fn matchAndHandle(api_source: *api.ApiSource, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool { if (req.method != Route.method) return false; var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?'); var args: Args = parseArgs(path) orelse return false; @@ -112,6 +118,8 @@ pub fn Context(comptime Route: type) type { .query = undefined, }; + std.log.debug("Matched route {s}", .{path}); + self.prepareAndHandle(api_source, req, res); return true; @@ -129,7 +137,7 @@ pub fn Context(comptime Route: type) type { }; } - fn prepareAndHandle(self: *Self, api_source: anytype, req: http.Request, response: *Response) void { + fn prepareAndHandle(self: *Self, api_source: anytype, req: *http.Request, response: *Response) void { self.parseBody(req) catch |err| return errorHandler(response, .bad_request, err); defer self.freeBody(); @@ -141,16 +149,20 @@ pub fn Context(comptime Route: type) type { self.handle(response, &api_conn); } - fn parseBody(self: *Self, req: http.Request) !void { + fn parseBody(self: *Self, req: *http.Request) !void { if (Body != void) { - const body = req.body orelse return error.NoBody; + var stream = req.body orelse return error.NoBody; + const body = try stream.reader().readAllAlloc(self.allocator, 1 << 16); + errdefer self.allocator.free(body); self.body = try json_utils.parse(Body, body, self.allocator); + self.body_buf = body; } } fn freeBody(self: *Self) void { if (Body != void) { json_utils.parseFree(self.body, self.allocator); + self.allocator.free(self.body_buf.?); } } @@ -191,7 +203,7 @@ pub fn Context(comptime Route: type) type { pub const Response = struct { const Self = @This(); - headers: http.Headers, + headers: http.Fields, res: *http.Response, opened: bool = false, diff --git a/src/main/controllers/auth.zig b/src/main/controllers/auth.zig index 95be918..1b6e652 100644 --- a/src/main/controllers/auth.zig +++ b/src/main/controllers/auth.zig @@ -1,4 +1,5 @@ const api = @import("api"); +const std = @import("std"); pub const login = struct { pub const method = .POST; @@ -12,6 +13,8 @@ pub const login = struct { pub fn handler(req: anytype, res: anytype, srv: anytype) !void { const token = try srv.login(req.body.username, req.body.password); + std.log.debug("{any}", .{res.headers}); + try res.json(.ok, token); } }; diff --git a/src/main/controllers/communities.zig b/src/main/controllers/communities.zig index f6b4c69..3aeec1f 100644 --- a/src/main/controllers/communities.zig +++ b/src/main/controllers/communities.zig @@ -1,6 +1,7 @@ const std = @import("std"); const api = @import("api"); const util = @import("util"); +const query_utils = @import("../query.zig"); const QueryArgs = api.CommunityQueryArgs; const Uuid = util.Uuid; @@ -25,89 +26,18 @@ pub const query = struct { pub const method = .GET; pub const path = "/communities"; - // NOTE: This has to match QueryArgs - // TODO: Support union fields in query strings natively, so we don't - // have to keep these in sync - pub const Query = struct { - const OrderBy = QueryArgs.OrderBy; - const Direction = QueryArgs.Direction; - const PageDirection = QueryArgs.PageDirection; - - // Max items to fetch - max_items: usize = 20, - - // Selection filters - owner_id: ?Uuid = null, - like: ?[]const u8 = null, - created_before: ?DateTime = null, - created_after: ?DateTime = null, - - // Ordering parameter - order_by: OrderBy = .created_at, - direction: Direction = .ascending, - - // the `prev` struct has a slightly different format to QueryArgs - prev: struct { - id: ?Uuid = null, - - // Only one of these can be present, and must match order_by above - name: ?[]const u8 = null, - host: ?[]const u8 = null, - created_at: ?DateTime = null, - } = .{}, - - // What direction to scan the page window - page_direction: PageDirection = .forward, - - pub const format = formatQueryParams; - }; + pub const Query = QueryArgs; pub fn handler(req: anytype, res: anytype, srv: anytype) !void { - const q = req.query; - const query_matches = if (q.prev.id) |_| switch (q.order_by) { - .name => q.prev.name != null and q.prev.host == null and q.prev.created_at == null, - .host => q.prev.name == null and q.prev.host != null and q.prev.created_at == null, - .created_at => q.prev.name == null and q.prev.host == null and q.prev.created_at != null, - } else (q.prev.name == null and q.prev.host == null and q.prev.created_at == null); - - if (!query_matches) return res.err(.bad_request, "prev.* parameters do not match", {}); - - const prev_arg: ?QueryArgs.Prev = if (q.prev.id) |id| .{ - .id = id, - .order_val = switch (q.order_by) { - .name => .{ .name = q.prev.name.? }, - .host => .{ .host = q.prev.host.? }, - .created_at => .{ .created_at = q.prev.created_at.? }, - }, - } else null; - - const query_args = QueryArgs{ - .max_items = q.max_items, - .owner_id = q.owner_id, - .like = q.like, - .created_before = q.created_before, - .created_after = q.created_after, - - .order_by = q.order_by, - .direction = q.direction, - - .prev = prev_arg, - - .page_direction = q.page_direction, - }; - - const results = try srv.queryCommunities(query_args); + const results = try srv.queryCommunities(req.query); var link = std.ArrayList(u8).init(req.allocator); const link_writer = link.writer(); defer link.deinit(); - const next_page = queryArgsToControllerQuery(results.next_page); - const prev_page = queryArgsToControllerQuery(results.prev_page); - - try writeLink(link_writer, srv.community, path, next_page, "next"); + try writeLink(link_writer, srv.community, path, results.next_page, "next"); try link_writer.writeByte(','); - try writeLink(link_writer, srv.community, path, prev_page, "prev"); + try writeLink(link_writer, srv.community, path, results.prev_page, "prev"); try res.headers.put("Link", link.items); @@ -129,7 +59,7 @@ fn writeLink( .{ @tagName(community.scheme), community.host, path }, ); - try std.fmt.format(writer, "{}", .{params}); + try query_utils.formatQuery(params, writer); try std.fmt.format( writer, @@ -137,70 +67,3 @@ fn writeLink( .{rel}, ); } - -fn formatQueryParams( - params: anytype, - comptime fmt: []const u8, - opt: std.fmt.FormatOptions, - writer: anytype, -) !void { - if (comptime std.meta.trait.is(.Pointer)(@TypeOf(params))) { - return formatQueryParams(params.*, fmt, opt, writer); - } - - return formatRecursive("", params, writer); -} - -fn formatRecursive(comptime prefix: []const u8, params: anytype, writer: anytype) !void { - inline for (std.meta.fields(@TypeOf(params))) |field| { - const val = @field(params, field.name); - const is_optional = comptime std.meta.trait.is(.Optional)(field.field_type); - const present = if (comptime is_optional) val != null else true; - if (present) { - const unwrapped = if (is_optional) val.? else val; - // TODO: percent-encode this - _ = try switch (@TypeOf(unwrapped)) { - []const u8 => blk: { - break :blk std.fmt.format(writer, "{s}{s}={s}&", .{ prefix, field.name, unwrapped }); - }, - - else => |U| blk: { - if (comptime std.meta.trait.isContainer(U) and std.meta.trait.hasFn("format")(U)) { - break :blk std.fmt.format(writer, "{s}{s}={}&", .{ prefix, field.name, unwrapped }); - } - - break :blk switch (@typeInfo(U)) { - .Enum => std.fmt.format(writer, "{s}{s}={s}&", .{ prefix, field.name, @tagName(unwrapped) }), - .Struct => formatRecursive(field.name ++ ".", unwrapped, writer), - else => std.fmt.format(writer, "{s}{s}={}&", .{ prefix, field.name, unwrapped }), - }; - }, - }; - } - } -} - -fn queryArgsToControllerQuery(args: QueryArgs) query.Query { - var result = query.Query{ - .max_items = args.max_items, - .owner_id = args.owner_id, - .like = args.like, - .created_before = args.created_before, - .created_after = args.created_after, - .order_by = args.order_by, - .direction = args.direction, - .prev = .{}, - .page_direction = args.page_direction, - }; - - if (args.prev) |prev| { - result.prev = .{ - .id = prev.id, - .name = if (prev.order_val == .name) prev.order_val.name else null, - .host = if (prev.order_val == .host) prev.order_val.host else null, - .created_at = if (prev.order_val == .created_at) prev.order_val.created_at else null, - }; - } - - return result; -} diff --git a/src/main/controllers/timelines.zig b/src/main/controllers/timelines.zig new file mode 100644 index 0000000..092d7c8 --- /dev/null +++ b/src/main/controllers/timelines.zig @@ -0,0 +1,21 @@ +pub const global = struct { + pub const method = .GET; + pub const path = "/timelines/global"; + + pub fn handler(_: anytype, res: anytype, srv: anytype) !void { + const results = try srv.globalTimeline(); + + try res.json(.ok, results); + } +}; + +pub const local = struct { + pub const method = .GET; + pub const path = "/timelines/local"; + + pub fn handler(_: anytype, res: anytype, srv: anytype) !void { + const results = try srv.localTimeline(); + + try res.json(.ok, results); + } +}; diff --git a/src/main/json.zig b/src/main/json.zig index b6bdff0..21474cc 100644 --- a/src/main/json.zig +++ b/src/main/json.zig @@ -499,7 +499,7 @@ fn parseInternal( if (!fields_seen[i]) { if (field.default_value) |default_ptr| { if (!field.is_comptime) { - const default = @ptrCast(*const field.field_type, default_ptr).*; + const default = @ptrCast(*align(1) const field.field_type, default_ptr).*; @field(r, field.name) = default; } } else { diff --git a/src/main/main.zig b/src/main/main.zig index 5870bc2..9ff01fe 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -87,7 +87,7 @@ fn thread_main(src: *api.ApiSource, srv: *std.net.StreamServer) void { } } -fn handle(ctx: anytype, req: http.Request, res: *http.Response) void { +fn handle(ctx: anytype, req: *http.Request, res: *http.Response) void { c.routeRequest(ctx.src, req, res, ctx.allocator); } diff --git a/src/main/migrations.zig b/src/main/migrations.zig index 459e75a..b0abc02 100644 --- a/src/main/migrations.zig +++ b/src/main/migrations.zig @@ -53,7 +53,7 @@ pub fn up(db: anytype) !void { std.log.info("Running migration {s}", .{migration.name}); try execScript(tx, migration.up, gpa.allocator()); try tx.insert("migration", .{ - .name = migration.name, + .name = @as([]const u8, migration.name), .applied_at = DateTime.now(), }, gpa.allocator()); } diff --git a/src/main/query.zig b/src/main/query.zig index e3bc725..d79ada8 100644 --- a/src/main/query.zig +++ b/src/main/query.zig @@ -71,37 +71,132 @@ const QueryIter = @import("util").QueryIter; /// TODO: values are currently case-sensitive, and are not url-decoded properly. /// This should be fixed. pub fn parseQuery(comptime T: type, query: []const u8) !T { - //if (!std.meta.trait.isContainer(T)) @compileError("T must be a struct"); + if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct"); var iter = QueryIter.from(query); - var result = T{}; + + var fields = Intermediary(T){}; while (iter.next()) |pair| { - try parseQueryPair(T, &result, pair.key, pair.value); + // TODO: Hash map + inline for (std.meta.fields(Intermediary(T))) |field| { + if (std.ascii.eqlIgnoreCase(field.name[2..], pair.key)) { + @field(fields, field.name) = if (pair.value) |v| .{ .value = v } else .{ .no_value = {} }; + break; + } + } else std.log.debug("unknown param {s}", .{pair.key}); } - return result; + return (try parse(T, "", "", fields)).?; } -fn parseQueryPair(comptime T: type, result: *T, key: []const u8, value: ?[]const u8) !void { - const key_part = std.mem.sliceTo(key, '.'); - const field_idx = std.meta.stringToEnum(std.meta.FieldEnum(T), key_part) orelse return error.UnknownField; +fn parseScalar(comptime T: type, comptime name: []const u8, fields: anytype) !?T { + const param = @field(fields, name); + return switch (param) { + .not_specified => null, + .no_value => try parseQueryValue(T, null), + .value => |v| try parseQueryValue(T, v), + }; +} - inline for (std.meta.fields(T)) |info, idx| { - if (@enumToInt(field_idx) == idx) { - if (comptime isScalar(info.field_type)) { - if (key_part.len == key.len) { - @field(result, info.name) = try parseQueryValue(info.field_type, value); - return; - } else { - return error.UnknownField; +fn parse(comptime T: type, comptime prefix: []const u8, comptime name: []const u8, fields: anytype) !?T { + if (comptime isScalar(T)) return parseScalar(T, prefix ++ "." ++ name, fields); + switch (@typeInfo(T)) { + .Union => |info| { + var result: ?T = null; + inline for (info.fields) |field| { + const F = field.field_type; + + const maybe_value = try parse(F, prefix, field.name, fields); + if (maybe_value) |value| { + if (result != null) return error.DuplicateUnionField; + + result = @unionInit(T, field.name, value); } + } + std.log.debug("{any}", .{result}); + return result; + }, + + .Struct => |info| { + var result: T = undefined; + var fields_specified: usize = 0; + + inline for (info.fields) |field| { + const F = field.field_type; + + var maybe_value: ?F = null; + if (try parse(F, prefix ++ "." ++ name, field.name, fields)) |v| { + maybe_value = v; + } else if (field.default_value) |default| { + maybe_value = @ptrCast(*const F, @alignCast(@alignOf(F), default)).*; + } + + if (maybe_value) |v| { + fields_specified += 1; + @field(result, field.name) = v; + } + } + + if (fields_specified == 0) { + return null; + } else if (fields_specified != info.fields.len) { + return error.PartiallySpecifiedStruct; } else { - const remaining = std.mem.trimLeft(u8, key[key_part.len..], "."); - return try parseQueryPair(info.field_type, &@field(result, info.name), remaining, value); + return result; + } + }, + + // Only applies to non-scalar optionals + .Optional => |info| return try parse(info.child, prefix, name, fields), + + else => @compileError("tmp"), + } +} + +fn recursiveFieldPaths(comptime T: type, comptime prefix: []const u8) []const []const u8 { + comptime { + if (std.meta.trait.is(.Optional)(T)) return recursiveFieldPaths(std.meta.Child(T), prefix); + + var fields: []const []const u8 = &.{}; + + for (std.meta.fields(T)) |f| { + const full_name = prefix ++ f.name; + + if (isScalar(f.field_type)) { + fields = fields ++ @as([]const []const u8, &.{full_name}); + } else { + const field_prefix = if (@typeInfo(f.field_type) == .Union) prefix else full_name ++ "."; + fields = fields ++ recursiveFieldPaths(f.field_type, field_prefix); } } - } - return error.UnknownField; + return fields; + } +} + +const QueryParam = union(enum) { + not_specified: void, + no_value: void, + value: []const u8, +}; + +fn Intermediary(comptime T: type) type { + const field_names = recursiveFieldPaths(T, ".."); + + var fields: [field_names.len]std.builtin.Type.StructField = undefined; + for (field_names) |name, i| fields[i] = .{ + .name = name, + .field_type = QueryParam, + .default_value = &QueryParam{ .not_specified = {} }, + .is_comptime = false, + .alignment = @alignOf(QueryParam), + }; + + return @Type(.{ .Struct = .{ + .layout = .Auto, + .fields = &fields, + .decls = &.{}, + .is_tuple = false, + } }); } fn parseQueryValue(comptime T: type, value: ?[]const u8) !T { @@ -157,6 +252,50 @@ fn isScalar(comptime T: type) bool { return false; } +pub fn formatQuery(params: anytype, writer: anytype) !void { + try format("", "", params, writer); +} + +fn formatScalar(comptime name: []const u8, val: anytype, writer: anytype) !void { + const T = @TypeOf(val); + if (comptime std.meta.trait.isZigString(T)) return std.fmt.format(writer, "{s}={s}&", .{ name, val }); + _ = try switch (@typeInfo(T)) { + .Enum => std.fmt.format(writer, "{s}={s}&", .{ name, @tagName(val) }), + .Optional => if (val) |v| formatScalar(name, v, writer), + else => std.fmt.format(writer, "{s}={}&", .{ name, val }), + }; +} + +fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytype, writer: anytype) !void { + const T = @TypeOf(params); + const eff_prefix = if (prefix.len == 0) "" else prefix ++ "."; + if (comptime isScalar(T)) return formatScalar(eff_prefix ++ name, params, writer); + + switch (@typeInfo(T)) { + .Struct => { + inline for (std.meta.fields(T)) |field| { + const val = @field(params, field.name); + try format(eff_prefix ++ name, field.name, val, writer); + } + }, + .Union => { + //inline for (std.meta.tags(T)) |tag| { + inline for (std.meta.fields(T)) |field| { + const tag = @field(std.meta.Tag(T), field.name); + const tag_name = field.name; + if (@as(std.meta.Tag(T), params) == tag) { + const val = @field(params, tag_name); + try format(prefix, tag_name, val, writer); + } + } + }, + .Optional => { + if (params) |p| try format(prefix, name, p, writer); + }, + else => @compileError("Unsupported query type"), + } +} + test { const TestQuery = struct { int: usize = 3, diff --git a/src/sql/engines/common.zig b/src/sql/engines/common.zig index 67fef2f..b50b7d0 100644 --- a/src/sql/engines/common.zig +++ b/src/sql/engines/common.zig @@ -68,7 +68,7 @@ pub const QueryOptions = struct { // do not require allocators for prep. If an allocator is needed but not // provided, `error.AllocatorRequired` will be returned. // Only used with the postgres backend. - prep_allocator: ?Allocator = null, + allocator: ?Allocator = null, }; // Turns a value into its appropriate textual value (or null) diff --git a/src/sql/engines/postgres.zig b/src/sql/engines/postgres.zig index 94ef674..bb24415 100644 --- a/src/sql/engines/postgres.zig +++ b/src/sql/engines/postgres.zig @@ -180,14 +180,25 @@ pub const Db = struct { const format_text = 0; const format_binary = 1; pub fn exec(self: Db, sql: [:0]const u8, args: anytype, opt: common.QueryOptions) common.ExecError!Results { - const alloc = opt.prep_allocator; + const alloc = opt.allocator; const result = blk: { if (@TypeOf(args) != void and args.len > 0) { var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired); defer arena.deinit(); const params = try arena.allocator().alloc(?[*:0]const u8, args.len); - inline for (args) |arg, i| { - params[i] = if (try common.prepareParamText(&arena, arg)) |slice| + // TODO: The following is a fix for the stage1 compiler. remove this + //inline for (args) |arg, i| { + inline for (std.meta.fields(@TypeOf(args))) |field, i| { + const arg = @field(args, field.name); + + // The stage1 compiler has issues with runtime branches that in any + // way involve compile time values + const maybe_slice = if (@import("builtin").zig_backend == .stage1) + common.prepareParamText(&arena, arg) catch unreachable + else + try common.prepareParamText(&arena, arg); + + params[i] = if (maybe_slice) |slice| slice.ptr else null; diff --git a/src/sql/engines/sqlite.zig b/src/sql/engines/sqlite.zig index d1a73c2..e2bb697 100644 --- a/src/sql/engines/sqlite.zig +++ b/src/sql/engines/sqlite.zig @@ -118,7 +118,10 @@ pub const Db = struct { }; if (@TypeOf(args) != void) { - inline for (args) |arg, i| { + // TODO: Fix for stage1 compiler + //inline for (args) |arg, i| { + inline for (std.meta.fields(@TypeOf(args))) |field, i| { + const arg = @field(args, field.name); // SQLite treats $NNN args as having the name NNN, not index NNN. // As such, if you reference $2 and not $1 in your query (such as // when dynamically constructing queries), it could assign $2 the diff --git a/src/sql/lib.zig b/src/sql/lib.zig index 3124863..9b1f599 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -25,6 +25,54 @@ pub const Engine = enum { sqlite, }; +/// Helper for building queries at runtime. All constituent parts of the +/// query should be defined at comptime, however the choice of whether +/// or not to include them can occur at runtime. +pub const QueryBuilder = struct { + array: std.ArrayList(u8), + where_clauses_appended: usize = 0, + + pub fn init(alloc: std.mem.Allocator) QueryBuilder { + return QueryBuilder{ .array = std.ArrayList(u8).init(alloc) }; + } + + pub fn deinit(self: *const QueryBuilder) void { + self.array.deinit(); + } + + /// Add a chunk of sql to the query without processing + pub fn appendSlice(self: *QueryBuilder, comptime sql: []const u8) !void { + try self.array.appendSlice(sql); + } + + /// Add a where clause to the query. Clauses are assumed to be components + /// in an overall expression in Conjunctive Normal Form (AND of OR's). + /// https://en.wikipedia.org/wiki/Conjunctive_normal_form + /// All calls to andWhere must be contiguous, that is, they cannot be + /// interspersed with calls to appendSlice + pub fn andWhere(self: *QueryBuilder, comptime clause: []const u8) !void { + if (self.where_clauses_appended == 0) { + try self.array.appendSlice("WHERE "); + } else { + try self.array.appendSlice(" AND "); + } + + try self.array.appendSlice(clause); + self.where_clauses_appended += 1; + } + + pub fn str(self: *const QueryBuilder) []const u8 { + return self.array.items; + } + + pub fn terminate(self: *QueryBuilder) ![:0]const u8 { + std.debug.assert(self.array.items.len != 0); + if (self.array.items[self.array.items.len - 1] != 0) try self.array.append(0); + + return std.meta.assumeSentinel(self.array.items, 0); + } +}; + // TODO: make this suck less pub const Config = union(Engine) { postgres: struct { @@ -410,7 +458,7 @@ fn Tx(comptime tx_level: u8) type { args: anytype, alloc: ?Allocator, ) QueryError!Results(RowType) { - return self.queryWithOptions(RowType, sql, args, .{ .prep_allocator = alloc }); + return self.queryWithOptions(RowType, sql, args, .{ .allocator = alloc }); } /// Runs a query to completion and returns a row of results, unless the query @@ -439,6 +487,45 @@ fn Tx(comptime tx_level: u8) type { return row; } + // Runs a query to completion and returns the results as a slice + pub fn queryRowsWithOptions( + self: Self, + comptime RowType: type, + q: [:0]const u8, + args: anytype, + max_items: ?usize, + options: QueryOptions, + ) QueryRowError![]RowType { + var results = try self.queryWithOptions(RowType, q, args, options); + defer results.finish(); + + const alloc = options.allocator orelse return error.AllocatorRequired; + + var result_array = std.ArrayList(RowType).init(alloc); + errdefer result_array.deinit(); + if (max_items) |max| try result_array.ensureTotalCapacity(max); + + errdefer for (result_array.items) |r| util.deepFree(alloc, r); + + var too_many: bool = false; + while (try results.row(alloc)) |row| { + errdefer util.deepFree(alloc, row); + if (max_items) |max| { + if (result_array.items.len >= max) { + util.deepFree(alloc, row); + too_many = true; + continue; + } + } + + try result_array.append(row); + } + + if (too_many) return error.TooManyRows; + + return result_array.toOwnedSlice(); + } + // Inserts a single value into a table pub fn insert( self: Self, @@ -455,7 +542,7 @@ fn Tx(comptime tx_level: u8) type { inline for (fields) |field, i| { // This causes a compiler crash. Why? //const F = field.field_type; - const F = @TypeOf(@field(std.mem.zeroes(ValueType), field.name)); + const F = @TypeOf(@field(value, field.name)); // causes issues if F is @TypeOf(null), use dummy type types[i] = if (F == @TypeOf(null)) ?i64 else F; table_spec = comptime (table_spec ++ field.name ++ ","); @@ -499,7 +586,7 @@ fn Tx(comptime tx_level: u8) type { alloc: ?std.mem.Allocator, comptime check_tx: bool, ) !void { - var results = try self.runSql(sql, args, .{ .prep_allocator = alloc }, check_tx); + var results = try self.runSql(sql, args, .{ .allocator = alloc }, check_tx); defer results.finish(); while (try results.row()) |_| {}