diff --git a/build.zig b/build.zig index 9cbd45c..108f884 100644 --- a/build.zig +++ b/build.zig @@ -53,16 +53,16 @@ pub fn build(b: *std.build.Builder) void { exe.linkSystemLibrary("pq"); exe.linkLibC(); - //const util_tests = b.addTest("src/util/lib.zig"); - const http_tests = b.addTest("src/http/test.zig"); - //const sql_tests = b.addTest("src/sql/lib.zig"); + const util_tests = b.addTest("src/util/lib.zig"); + const http_tests = b.addTest("src/http/lib.zig"); + const sql_tests = b.addTest("src/sql/lib.zig"); http_tests.addPackage(util_pkg); - //sql_tests.addPackage(util_pkg); + sql_tests.addPackage(util_pkg); const unit_tests = b.step("unit-tests", "Run tests"); - //unit_tests.dependOn(&util_tests.step); + unit_tests.dependOn(&util_tests.step); unit_tests.dependOn(&http_tests.step); - //unit_tests.dependOn(&sql_tests.step); + unit_tests.dependOn(&sql_tests.step); const api_integration = b.addTest("./tests/api_integration/lib.zig"); api_integration.addPackage(sql_pkg); diff --git a/src/api/lib.zig b/src/api/lib.zig index fb8f67e..a48170a 100644 --- a/src/api/lib.zig +++ b/src/api/lib.zig @@ -276,10 +276,7 @@ fn ApiConn(comptime DbConn: type) type { username, password, self.community.id, - .{ - .invite_id = if (maybe_invite) |inv| @as(?Uuid, inv.id) else null, - .email = opt.email, - }, + .{ .invite_id = if (maybe_invite) |inv| inv.id else null, .email = opt.email }, self.arena.allocator(), ); @@ -351,19 +348,5 @@ fn ApiConn(comptime DbConn: type) type { if (!self.isAdmin()) return error.PermissionDenied; return try services.communities.query(self.db, args, self.arena.allocator()); } - - pub fn globalTimeline(self: *Self) ![]services.notes.Note { - const result = try services.notes.query(self.db, .{}, self.arena.allocator()); - return result.items; - } - - pub fn localTimeline(self: *Self) ![]services.notes.Note { - const result = try services.notes.query( - self.db, - .{ .community_id = self.community.id }, - self.arena.allocator(), - ); - return result.items; - } }; } diff --git a/src/api/services/actors.zig b/src/api/services/actors.zig index 2a7f964..ee7c38d 100644 --- a/src/api/services/actors.zig +++ b/src/api/services/actors.zig @@ -94,8 +94,7 @@ pub const Actor = struct { created_at: DateTime, }; -pub const GetError = error{ NotFound, DatabaseFailure }; -pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Actor { +pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !Actor { return db.queryRow( Actor, \\SELECT diff --git a/src/api/services/common.zig b/src/api/services/common.zig deleted file mode 100644 index 8a407d5..0000000 --- a/src/api/services/common.zig +++ /dev/null @@ -1,16 +0,0 @@ -const std = @import("std"); -const util = @import("util"); - -pub const Direction = enum { - ascending, - descending, - - pub const jsonStringify = util.jsonSerializeEnumAsString; -}; - -pub const PageDirection = enum { - forward, - backward, - - pub const jsonStringify = util.jsonSerializeEnumAsString; -}; diff --git a/src/api/services/communities.zig b/src/api/services/communities.zig index 780b7d5..84f07b6 100644 --- a/src/api/services/communities.zig +++ b/src/api/services/communities.zig @@ -2,7 +2,6 @@ const std = @import("std"); const builtin = @import("builtin"); const util = @import("util"); const sql = @import("sql"); -const common = @import("./common.zig"); const Uuid = util.Uuid; const DateTime = util.DateTime; @@ -83,12 +82,11 @@ pub fn create(db: anytype, origin: []const u8, options: CreateOptions, alloc: st else => return error.DatabaseFailure, } - const name = options.name orelse host; db.insert("community", .{ .id = id, .owner_id = null, .host = host, - .name = name, + .name = options.name orelse host, .scheme = scheme, .kind = options.kind, .created_at = DateTime.now(), @@ -155,8 +153,20 @@ pub const QueryArgs = struct { pub const jsonStringify = util.jsonSerializeEnumAsString; }; - pub const Direction = common.Direction; - pub const PageDirection = common.PageDirection; + pub const Direction = enum { + ascending, + descending, + + pub const jsonStringify = util.jsonSerializeEnumAsString; + }; + + pub const PageDirection = enum { + forward, + backward, + + pub const jsonStringify = util.jsonSerializeEnumAsString; + }; + pub const Prev = std.meta.Child(std.meta.fieldInfo(QueryArgs, .prev).field_type); pub const OrderVal = std.meta.fieldInfo(Prev, .order_val).field_type; @@ -201,6 +211,30 @@ pub const QueryResult = struct { next_page: QueryArgs, }; +const QueryBuilder = struct { + array: std.ArrayList(u8), + where_clauses_appended: usize = 0, + + pub fn init(alloc: std.mem.Allocator) QueryBuilder { + return QueryBuilder{ .array = std.ArrayList(u8).init(alloc) }; + } + + pub fn deinit(self: *const QueryBuilder) void { + self.array.deinit(); + } + + pub fn andWhere(self: *QueryBuilder, clause: []const u8) !void { + if (self.where_clauses_appended == 0) { + try self.array.appendSlice("WHERE "); + } else { + try self.array.appendSlice(" AND "); + } + + try self.array.appendSlice(clause); + self.where_clauses_appended += 1; + } +}; + const max_max_items = 100; pub const QueryError = error{ @@ -212,7 +246,7 @@ pub const QueryError = error{ // arguments. // `args.max_items` is only a request, and fewer entries may be returned. pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResult { - var builder = sql.QueryBuilder.init(alloc); + var builder = QueryBuilder.init(alloc); defer builder.deinit(); try builder.array.appendSlice( @@ -232,21 +266,21 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul if (args.prev) |prev| { if (prev.order_val != args.order_by) return error.PageArgMismatch; - switch (args.order_by) { - .name => try builder.andWhere("(name, id)"), - .host => try builder.andWhere("(host, id)"), - .created_at => try builder.andWhere("(created_at, id)"), - } - switch (args.direction) { + try builder.andWhere(switch (args.order_by) { + .name => "(name, id)", + .host => "(host, id)", + .created_at => "(created_at, id)", + }); + _ = try builder.array.appendSlice(switch (args.direction) { .ascending => switch (args.page_direction) { - .forward => try builder.appendSlice(" > "), - .backward => try builder.appendSlice(" < "), + .forward => " > ", + .backward => " < ", }, .descending => switch (args.page_direction) { - .forward => try builder.appendSlice(" < "), - .backward => try builder.appendSlice(" > "), + .forward => " < ", + .backward => " > ", }, - } + }); _ = try builder.array.appendSlice("($5, $6)"); } @@ -263,52 +297,57 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResul _ = try builder.array.appendSlice("\nLIMIT $7"); - const query_args = blk: { - const ord_val = - if (args.prev) |prev| @as(?QueryArgs.OrderVal, prev.order_val) else null; - const id = - if (args.prev) |prev| @as(?Uuid, prev.id) else null; - break :blk .{ - args.owner_id, - args.like, - args.created_before, - args.created_after, - ord_val, - id, - max_items, - }; + const query_args = .{ + args.owner_id, + args.like, + args.created_before, + args.created_after, + if (args.prev) |prev| prev.order_val else null, + if (args.prev) |prev| prev.id else null, + max_items, }; try builder.array.append(0); - var results = try db.queryRowsWithOptions( + var results = try db.queryWithOptions( Community, std.meta.assumeSentinel(builder.array.items, 0), query_args, - max_items, - .{ .allocator = alloc, .ignore_unused_arguments = true }, + .{ .prep_allocator = alloc, .ignore_unused_arguments = true }, ); - errdefer util.deepFree(alloc, results); + defer results.finish(); + + const result_buf = try alloc.alloc(Community, args.max_items); + errdefer alloc.free(result_buf); + + var count: usize = 0; + errdefer for (result_buf[0..count]) |c| util.deepFree(alloc, c); + + for (result_buf) |*c| { + c.* = (try results.row(alloc)) orelse break; + + count += 1; + } var next_page = args; var prev_page = args; prev_page.page_direction = .backward; next_page.page_direction = .forward; - if (results.len != 0) { + if (count != 0) { prev_page.prev = .{ - .id = results[0].id, - .order_val = getOrderVal(results[0], args.order_by), + .id = result_buf[0].id, + .order_val = getOrderVal(result_buf[0], args.order_by), }; next_page.prev = .{ - .id = results[results.len - 1].id, - .order_val = getOrderVal(results[results.len - 1], args.order_by), + .id = result_buf[count - 1].id, + .order_val = getOrderVal(result_buf[count - 1], args.order_by), }; } // TODO: This will give incorrect links on an empty page return QueryResult{ - .items = results, + .items = result_buf[0..count], .next_page = next_page, .prev_page = prev_page, diff --git a/src/api/services/invites.zig b/src/api/services/invites.zig index 2a58354..09a156c 100644 --- a/src/api/services/invites.zig +++ b/src/api/services/invites.zig @@ -71,7 +71,7 @@ pub fn create(db: anytype, created_by: Uuid, community_id: ?Uuid, options: Invit .max_uses = options.max_uses, .created_at = created_at, .expires_at = if (options.lifespan) |lifespan| - @as(?DateTime, created_at.add(lifespan)) + created_at.add(lifespan) else null, diff --git a/src/api/services/notes.zig b/src/api/services/notes.zig index 6fda4f6..edb8c32 100644 --- a/src/api/services/notes.zig +++ b/src/api/services/notes.zig @@ -1,7 +1,6 @@ const std = @import("std"); const util = @import("util"); const sql = @import("sql"); -const common = @import("./common.zig"); const Uuid = util.Uuid; const DateTime = util.DateTime; @@ -43,7 +42,7 @@ const selectStarFromNote = std.fmt.comptimePrint( \\SELECT {s} \\FROM note \\ -, .{util.comptimeJoinWithPrefix(",", "note.", std.meta.fieldNames(Note))}); +, .{util.comptimeJoin(",", std.meta.fieldNames(Note))}); pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Note { return db.queryRow( Note, @@ -58,108 +57,3 @@ pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Note { else => error.DatabaseFailure, }; } - -const max_max_items = 100; - -pub const QueryArgs = struct { - pub const PageDirection = common.PageDirection; - pub const Prev = std.meta.Child(std.meta.field(@This(), .prev).field_type); - - max_items: usize = 20, - - created_before: ?DateTime = null, - created_after: ?DateTime = null, - community_id: ?Uuid = null, - - prev: ?struct { - id: Uuid, - created_at: DateTime, - } = null, - - page_direction: PageDirection = .forward, -}; - -pub const QueryResult = struct { - items: []Note, - - prev_page: QueryArgs, - next_page: QueryArgs, -}; - -pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) !QueryResult { - var builder = sql.QueryBuilder.init(alloc); - defer builder.deinit(); - - try builder.appendSlice(selectStarFromNote ++ - \\ JOIN actor ON actor.id = note.author_id - \\ - ); - - if (args.created_before != null) try builder.andWhere("note.created_at < $1"); - if (args.created_after != null) try builder.andWhere("note.created_at > $2"); - if (args.prev != null) { - try builder.andWhere("(note.created_at, note.id)"); - - switch (args.page_direction) { - .forward => try builder.appendSlice(" < "), - .backward => try builder.appendSlice(" > "), - } - try builder.appendSlice("($3, $4)"); - } - if (args.community_id != null) try builder.andWhere("actor.community_id = $5"); - - try builder.appendSlice( - \\ - \\ORDER BY note.created_at DESC - \\LIMIT $6 - \\ - ); - - const max_items = if (args.max_items > max_max_items) max_max_items else args.max_items; - - const query_args = blk: { - const prev_created_at = if (args.prev) |prev| @as(?DateTime, prev.created_at) else null; - const prev_id = if (args.prev) |prev| @as(?Uuid, prev.id) else null; - - break :blk .{ - args.created_before, - args.created_after, - prev_created_at, - prev_id, - args.community_id, - max_items, - }; - }; - - const results = try db.queryRowsWithOptions( - Note, - try builder.terminate(), - query_args, - max_items, - .{ .allocator = alloc, .ignore_unused_arguments = true }, - ); - errdefer util.deepFree(results); - - var next_page = args; - var prev_page = args; - prev_page.page_direction = .backward; - next_page.page_direction = .forward; - if (results.len != 0) { - prev_page.prev = .{ - .id = results[0].id, - .created_at = results[0].created_at, - }; - - next_page.prev = .{ - .id = results[results.len - 1].id, - .created_at = results[results.len - 1].created_at, - }; - } - // TODO: this will give incorrect links on an empty page - - return QueryResult{ - .items = results, - .next_page = next_page, - .prev_page = prev_page, - }; -} diff --git a/src/http/headers.zig b/src/http/headers.zig deleted file mode 100644 index 4f45029..0000000 --- a/src/http/headers.zig +++ /dev/null @@ -1,124 +0,0 @@ -const std = @import("std"); - -pub const Fields = struct { - const HashContext = struct { - const hash_seed = 1; - pub fn eql(_: @This(), lhs: []const u8, rhs: []const u8, _: usize) bool { - return std.ascii.eqlIgnoreCase(lhs, rhs); - } - pub fn hash(_: @This(), s: []const u8) u32 { - var h = std.hash.Wyhash.init(hash_seed); - for (s) |ch| { - const c = [1]u8{std.ascii.toLower(ch)}; - h.update(&c); - } - return @truncate(u32, h.final()); - } - }; - - const HashMap = std.ArrayHashMapUnmanaged( - []const u8, - []const u8, - HashContext, - true, - ); - - unmanaged: HashMap, - allocator: std.mem.Allocator, - - pub fn init(allocator: std.mem.Allocator) Fields { - return Fields{ - .unmanaged = .{}, - .allocator = allocator, - }; - } - - pub fn deinit(self: *Fields) void { - var hash_iter = self.unmanaged.iterator(); - while (hash_iter.next()) |entry| { - self.allocator.free(entry.key_ptr.*); - self.allocator.free(entry.value_ptr.*); - } - - self.unmanaged.deinit(self.allocator); - } - - pub fn iterator(self: Fields) HashMap.Iterator { - return self.unmanaged.iterator(); - } - - pub fn get(self: Fields, key: []const u8) ?[]const u8 { - return self.unmanaged.get(key); - } - - pub const ListIterator = struct { - remaining: []const u8, - - fn extractElement(self: *ListIterator) ?[]const u8 { - if (self.remaining.len == 0) return null; - - var start: usize = 0; - var is_quoted = false; - const end = for (self.remaining) |ch, i| { - if (start == i and std.ascii.isWhitespace(ch)) { - start += 1; - } else if (ch == '"') { - is_quoted = !is_quoted; - } - if (ch == ',' and !is_quoted) { - break i; - } - } else self.remaining.len; - - const str = self.remaining[start..end]; - if (end == self.remaining.len) { - self.remaining = ""; - } else { - self.remaining = self.remaining[end + 1 ..]; - } - - return std.mem.trim(u8, str, " \t"); - } - - pub fn next(self: *ListIterator) ?[]const u8 { - while (self.extractElement()) |elem| { - if (elem.len != 0) return elem; - } - - return null; - } - }; - - pub fn getList(self: Fields, key: []const u8) ?ListIterator { - return if (self.unmanaged.get(key)) |hdr| ListIterator{ .remaining = hdr } else null; - } - - pub fn put(self: *Fields, key: []const u8, val: []const u8) !void { - const key_clone = try self.allocator.alloc(u8, key.len); - std.mem.copy(u8, key_clone, key); - errdefer self.allocator.free(key_clone); - - const val_clone = try self.allocator.alloc(u8, val.len); - std.mem.copy(u8, val_clone, val); - errdefer self.allocator.free(val_clone); - - if (try self.unmanaged.fetchPut(self.allocator, key_clone, val_clone)) |entry| { - self.allocator.free(entry.key); - self.allocator.free(entry.value); - } - } - - pub fn append(self: *Fields, key: []const u8, val: []const u8) !void { - if (self.unmanaged.getEntry(key)) |entry| { - const new_val = try std.mem.join(self.allocator, ", ", &.{ entry.value_ptr.*, val }); - self.allocator.free(entry.value_ptr.*); - entry.value_ptr.* = new_val; - } else { - try self.put(key, val); - } - } - - pub fn count(self: Fields) usize { - return self.unmanaged.count(); - } -}; diff --git a/src/http/lib.zig b/src/http/lib.zig index ce97ad1..eff1e45 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -10,15 +10,22 @@ pub const socket = @import("./socket.zig"); pub const Method = std.http.Method; pub const Status = std.http.Status; -pub const Request = request.Request(std.net.Stream.Reader); +pub const Request = request.Request; pub const serveConn = server.serveConn; pub const Response = server.Response; pub const Handler = server.Handler; -pub const Fields = @import("./headers.zig").Fields; +pub const Headers = std.HashMap([]const u8, []const u8, struct { + pub fn eql(_: @This(), a: []const u8, b: []const u8) bool { + return ciutf8.eql(a, b); + } -pub const Protocol = enum { - http_1_0, - http_1_1, - http_1_x, -}; + pub fn hash(_: @This(), str: []const u8) u64 { + return ciutf8.hash(str); + } +}, std.hash_map.default_max_load_percentage); + +test { + _ = server; + _ = request; +} diff --git a/src/http/request.zig b/src/http/request.zig index b713d16..e6fd79c 100644 --- a/src/http/request.zig +++ b/src/http/request.zig @@ -3,23 +3,29 @@ const http = @import("./lib.zig"); const parser = @import("./request/parser.zig"); -pub fn Request(comptime Reader: type) type { - return struct { - protocol: http.Protocol, - - method: http.Method, - uri: []const u8, - headers: http.Fields, - - body: ?parser.TransferStream(Reader), - - pub fn parseFree(self: *@This(), allocator: std.mem.Allocator) void { - allocator.free(self.uri); - self.headers.deinit(); - } +pub const Request = struct { + pub const Protocol = enum { + http_1_0, + http_1_1, }; -} -pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request(@TypeOf(reader)) { - return parser.parse(alloc, reader); + protocol: Protocol, + source_address: ?std.net.Address, + + method: http.Method, + uri: []const u8, + headers: http.Headers, + body: ?[]const u8 = null, + + pub fn parse(alloc: std.mem.Allocator, reader: anytype, addr: std.net.Address) !Request { + return parser.parse(alloc, reader, addr); + } + + pub fn parseFree(self: Request, alloc: std.mem.Allocator) void { + parser.parseFree(alloc, self); + } +}; + +test { + _ = parser; } diff --git a/src/http/request/parser.zig b/src/http/request/parser.zig index 57fae33..8a7c19f 100644 --- a/src/http/request/parser.zig +++ b/src/http/request/parser.zig @@ -1,13 +1,15 @@ const std = @import("std"); +const util = @import("util"); const http = @import("../lib.zig"); const Method = http.Method; -const Fields = http.Fields; +const Headers = http.Headers; const Request = @import("../request.zig").Request; const request_buf_size = 1 << 16; const max_path_len = 1 << 10; +const max_body_len = 1 << 12; fn ParseError(comptime Reader: type) type { return error{ @@ -20,7 +22,7 @@ const Encoding = enum { chunked, }; -pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request(@TypeOf(reader)) { +pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address) !Request { const method = try parseMethod(reader); const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) { error.StreamTooLong => return error.RequestUriTooLong, @@ -31,20 +33,28 @@ pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request(@TypeOf(reader) const proto = try parseProto(reader); // discard \r\n - switch (try reader.readByte()) { - '\r' => if ((try reader.readByte()) != '\n') return error.BadRequest, - '\n' => {}, - else => return error.BadRequest, - } + _ = try reader.readByte(); + _ = try reader.readByte(); var headers = try parseHeaders(alloc, reader); - errdefer headers.deinit(); + errdefer freeHeaders(alloc, &headers); - const body = try prepareBody(headers, reader); - if (body != null and !method.requestHasBody()) return error.BadRequest; + const body = if (method.requestHasBody()) + try readBody(alloc, headers, reader) + else + null; + errdefer if (body) |b| alloc.free(b); - return Request(@TypeOf(reader)){ + const eff_addr = if (headers.get("X-Real-IP")) |ip| + std.net.Address.parseIp(ip, address.getPort()) catch { + return error.BadRequest; + } + else + address; + + return Request{ .protocol = proto, + .source_address = eff_addr, .method = method, .uri = uri, @@ -69,7 +79,7 @@ fn parseMethod(reader: anytype) !Method { return error.MethodNotImplemented; } -fn parseProto(reader: anytype) !http.Protocol { +fn parseProto(reader: anytype) !Request.Protocol { var buf: [8]u8 = undefined; const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { error.StreamTooLong => return error.UnknownProtocol, @@ -89,145 +99,85 @@ fn parseProto(reader: anytype) !http.Protocol { return switch (buf[2]) { '0' => .http_1_0, '1' => .http_1_1, - else => .http_1_x, + else => error.HttpVersionNotSupported, }; } -fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields { - var headers = Fields.init(allocator); +fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers { + var map = Headers.init(allocator); + errdefer map.deinit(); + errdefer { + var iter = map.iterator(); + while (iter.next()) |it| { + allocator.free(it.key_ptr.*); + allocator.free(it.value_ptr.*); + } + } + // todo: + //errdefer { + //var iter = map.iterator(); + //while (iter.next()) |it| { + //allocator.free(it.key_ptr); + //allocator.free(it.value_ptr); + //} + //} + + var buf: [1024]u8 = undefined; - var buf: [4096]u8 = undefined; while (true) { - const full_line = reader.readUntilDelimiter(&buf, '\n') catch |err| switch (err) { - error.StreamTooLong => return error.HeaderLineTooLong, - else => return err, - }; - const line = std.mem.trimRight(u8, full_line, "\r"); - if (line.len == 0) break; + const line = try reader.readUntilDelimiter(&buf, '\n'); + if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break; - const name = std.mem.sliceTo(line, ':'); - if (!isTokenValid(name)) return error.BadRequest; - if (name.len == line.len) return error.BadRequest; + // TODO: handle multi-line headers + const name = extractHeaderName(line) orelse continue; + const value_end = if (line[line.len - 1] == '\r') line.len - 1 else line.len; + const value = line[name.len + 1 + 1 .. value_end]; - const value = std.mem.trim(u8, line[name.len + 1 ..], " \t"); + if (name.len == 0 or value.len == 0) return error.BadRequest; - try headers.append(name, value); + const name_alloc = try allocator.alloc(u8, name.len); + errdefer allocator.free(name_alloc); + const value_alloc = try allocator.alloc(u8, value.len); + errdefer allocator.free(value_alloc); + + @memcpy(name_alloc.ptr, name.ptr, name.len); + @memcpy(value_alloc.ptr, value.ptr, value.len); + + try map.put(name_alloc, value_alloc); } - return headers; + return map; } -fn isTokenValid(token: []const u8) bool { - if (token.len == 0) return false; - for (token) |ch| { - switch (ch) { - '"', '(', ')', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '{', '}' => return false, +fn extractHeaderName(line: []const u8) ?[]const u8 { + var index: usize = 0; - '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~' => {}, - else => if (!std.ascii.isAlphanumeric(ch)) return false, + // TODO: handle whitespace + while (index < line.len) : (index += 1) { + if (line[index] == ':') { + if (index == 0) return null; + return line[0..index]; } } - return true; + return null; } -fn prepareBody(headers: Fields, reader: anytype) !?TransferStream(@TypeOf(reader)) { - const hdr = headers.get("Transfer-Encoding"); - // TODO: - // if (hder != null and protocol == .http_1_0) return error.BadRequest; - const xfer_encoding = try parseEncoding(hdr); +fn readBody(alloc: std.mem.Allocator, headers: Headers, reader: anytype) !?[]const u8 { + const xfer_encoding = try parseEncoding(headers.get("Transfer-Encoding")); + if (xfer_encoding != .identity) return error.UnsupportedMediaType; const content_encoding = try parseEncoding(headers.get("Content-Encoding")); if (content_encoding != .identity) return error.UnsupportedMediaType; - switch (xfer_encoding) { - .identity => { - const len_str = headers.get("Content-Length") orelse return null; - const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest; + const len_str = headers.get("Content-Length") orelse return null; + const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest; + if (len > max_body_len) return error.RequestEntityTooLarge; + const body = try alloc.alloc(u8, len); + errdefer alloc.free(body); - return TransferStream(@TypeOf(reader)){ .underlying = .{ .identity = std.io.limitedReader(reader, len) } }; - }, - .chunked => { - if (headers.get("Content-Length") != null) return error.BadRequest; - return TransferStream(@TypeOf(reader)){ - .underlying = .{ - .chunked = try ChunkedStream(@TypeOf(reader)).init(reader), - }, - }; - }, - } -} + reader.readNoEof(body) catch return error.BadRequest; -fn ChunkedStream(comptime R: type) type { - return struct { - const Self = @This(); - - remaining: ?usize = 0, - underlying: R, - - const Error = R.Error || error{ Unexpected, InvalidChunkHeader, StreamTooLong, EndOfStream }; - fn init(reader: R) !Self { - var self: Self = .{ .underlying = reader }; - return self; - } - - fn read(self: *Self, buf: []u8) !usize { - var count: usize = 0; - while (true) { - if (count == buf.len) return count; - if (self.remaining == null) return count; - if (self.remaining.? == 0) self.remaining = try self.readChunkHeader(); - - const max_read = std.math.min(buf.len, self.remaining.?); - const amt = try self.underlying.read(buf[count .. count + max_read]); - if (amt != max_read) return error.EndOfStream; - count += amt; - self.remaining.? -= amt; - if (self.remaining.? == 0) { - var crlf: [2]u8 = undefined; - _ = try self.underlying.readUntilDelimiter(&crlf, '\n'); - self.remaining = try self.readChunkHeader(); - } - - if (count == buf.len) return count; - } - } - - fn readChunkHeader(self: *Self) !?usize { - // TODO: Pick a reasonable limit for this - var buf = std.mem.zeroes([10]u8); - const line = self.underlying.readUntilDelimiter(&buf, '\n') catch |err| { - return if (err == error.StreamTooLong) error.InvalidChunkHeader else err; - }; - if (line.len < 2 or line[line.len - 1] != '\r') return error.InvalidChunkHeader; - - const size = std.fmt.parseInt(usize, line[0 .. line.len - 1], 16) catch return error.InvalidChunkHeader; - - return if (size != 0) size else null; - } - }; -} - -pub fn TransferStream(comptime R: type) type { - return struct { - const Error = R.Error || ChunkedStream(R).Error; - const Reader = std.io.Reader(*@This(), Error, read); - - underlying: union(enum) { - identity: std.io.LimitedReader(R), - chunked: ChunkedStream(R), - }, - - pub fn read(self: *@This(), buf: []u8) Error!usize { - return switch (self.underlying) { - .identity => |*r| try r.read(buf), - .chunked => |*r| try r.read(buf), - }; - } - - pub fn reader(self: *@This()) Reader { - return .{ .context = self }; - } - }; + return body; } // TODO: assumes that there's only one encoding, not layered encodings @@ -237,3 +187,257 @@ fn parseEncoding(encoding: ?[]const u8) !Encoding { if (std.mem.eql(u8, encoding.?, "chunked")) return .chunked; return error.UnsupportedMediaType; } + +pub fn parseFree(allocator: std.mem.Allocator, request: *Request) void { + allocator.free(request.uri); + freeHeaders(allocator, &request.headers); + if (request.body) |body| allocator.free(body); +} + +fn freeHeaders(allocator: std.mem.Allocator, headers: *http.Headers) void { + var iter = headers.iterator(); + while (iter.next()) |it| { + allocator.free(it.key_ptr.*); + allocator.free(it.value_ptr.*); + } + headers.deinit(); +} + +const _test = struct { + const expectEqual = std.testing.expectEqual; + const expectEqualStrings = std.testing.expectEqualStrings; + + fn toCrlf(comptime str: []const u8) []const u8 { + comptime { + var buf: [str.len * 2]u8 = undefined; + + @setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2 + + var buf_len: usize = 0; + for (str) |ch| { + if (ch == '\n') { + buf[buf_len] = '\r'; + buf_len += 1; + } + + buf[buf_len] = ch; + buf_len += 1; + } + + return buf[0..buf_len]; + } + } + + fn makeHeaders(alloc: std.mem.Allocator, headers: anytype) !Headers { + var result = Headers.init(alloc); + inline for (headers) |tup| { + try result.put(tup[0], tup[1]); + } + return result; + } + + fn areEqualHeaders(lhs: Headers, rhs: Headers) bool { + if (lhs.count() != rhs.count()) return false; + var iter = lhs.iterator(); + while (iter.next()) |it| { + const rhs_val = rhs.get(it.key_ptr.*) orelse return false; + if (!std.mem.eql(u8, it.value_ptr.*, rhs_val)) return false; + } + return true; + } + + fn printHeaders(headers: Headers) void { + var iter = headers.iterator(); + while (iter.next()) |it| { + std.debug.print("{s}: {s}\n", .{ it.key_ptr.*, it.value_ptr.* }); + } + } + + fn expectEqualHeaders(expected: Headers, actual: Headers) !void { + if (!areEqualHeaders(expected, actual)) { + std.debug.print("\nexpected: \n", .{}); + printHeaders(expected); + std.debug.print("\n\nfound: \n", .{}); + printHeaders(actual); + std.debug.print("\n\n", .{}); + return error.TestExpectedEqual; + } + } + + fn parseTestCase(alloc: std.mem.Allocator, comptime request: []const u8, expected: http.Request) !void { + var stream = std.io.fixedBufferStream(toCrlf(request)); + + const result = try parse(alloc, stream.reader()); + + try expectEqual(expected.method, result.method); + try expectEqualStrings(expected.path, result.path); + try expectEqualHeaders(expected.headers, result.headers); + if ((expected.body == null) != (result.body == null)) { + const null_str: []const u8 = "(null)"; + const exp = expected.body orelse null_str; + const act = result.body orelse null_str; + std.debug.print("\nexpected:\n{s}\n\nfound:\n{s}\n\n", .{ exp, act }); + return error.TestExpectedEqual; + } + if (expected.body != null) { + try expectEqualStrings(expected.body.?, result.body.?); + } + } +}; + +// TOOD: failure test cases +test "parse" { + const testCase = _test.parseTestCase; + var buf = [_]u8{0} ** (1 << 16); + var fba = std.heap.FixedBufferAllocator.init(&buf); + const alloc = fba.allocator(); + try testCase(alloc, ( + \\GET / HTTP/1.1 + \\ + \\ + ), .{ + .method = .GET, + .headers = try _test.makeHeaders(alloc, .{}), + .path = "/", + }); + + fba.reset(); + try testCase(alloc, ( + \\POST / HTTP/1.1 + \\ + \\ + ), .{ + .method = .POST, + .headers = try _test.makeHeaders(alloc, .{}), + .path = "/", + }); + + fba.reset(); + try testCase(alloc, ( + \\HEAD / HTTP/1.1 + \\Authorization: bearer + \\ + \\ + ), .{ + .method = .HEAD, + .headers = try _test.makeHeaders(alloc, .{ + .{ "Authorization", "bearer " }, + }), + .path = "/", + }); + + fba.reset(); + try testCase(alloc, ( + \\POST /nonsense HTTP/1.1 + \\Authorization: bearer + \\Content-Length: 5 + \\ + \\12345 + ), .{ + .method = .POST, + .headers = try _test.makeHeaders(alloc, .{ + .{ "Authorization", "bearer " }, + .{ "Content-Length", "5" }, + }), + .path = "/nonsense", + .body = "12345", + }); + + fba.reset(); + try std.testing.expectError( + error.MethodNotImplemented, + testCase(alloc, ( + \\FOO /nonsense HTTP/1.1 + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.MethodNotImplemented, + testCase(alloc, ( + \\FOOBARBAZ /nonsense HTTP/1.1 + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.RequestUriTooLong, + testCase(alloc, ( + \\GET / + ++ ("a" ** 2048)), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.UnknownProtocol, + testCase(alloc, ( + \\GET /nonsense SPECIALHTTP/1.1 + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.UnknownProtocol, + testCase(alloc, ( + \\GET /nonsense JSON/1.1 + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.HttpVersionNotSupported, + testCase(alloc, ( + \\GET /nonsense HTTP/1.9 + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.HttpVersionNotSupported, + testCase(alloc, ( + \\GET /nonsense HTTP/8.1 + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.BadRequest, + testCase(alloc, ( + \\GET /nonsense HTTP/blah blah blah + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.BadRequest, + testCase(alloc, ( + \\GET /nonsense HTTP/1/1 + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.BadRequest, + testCase(alloc, ( + \\GET /nonsense HTTP/1/1 + \\ + \\ + ), undefined), + ); +} diff --git a/src/http/request/test_parser.zig b/src/http/request/test_parser.zig deleted file mode 100644 index 55a66d6..0000000 --- a/src/http/request/test_parser.zig +++ /dev/null @@ -1,282 +0,0 @@ -const std = @import("std"); -const parser = @import("./parser.zig"); -const http = @import("../lib.zig"); -const t = std.testing; - -const test_case = struct { - fn parse(text: []const u8, expected: struct { - protocol: http.Protocol = .http_1_1, - method: http.Method = .GET, - headers: []const std.meta.Tuple(&.{ []const u8, []const u8 }) = &.{}, - uri: []const u8 = "", - }) !void { - var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(text) }; - var actual = try parser.parse(t.allocator, stream.reader()); - defer actual.parseFree(t.allocator); - - try t.expectEqual(expected.protocol, actual.protocol); - try t.expectEqual(expected.method, actual.method); - try t.expectEqualStrings(expected.uri, actual.uri); - - try t.expectEqual(expected.headers.len, actual.headers.count()); - for (expected.headers) |hdr| { - if (actual.headers.get(hdr[0])) |val| { - try t.expectEqualStrings(hdr[1], val); - } else { - std.debug.print("Error: Header {s} expected to be present, was not.\n", .{hdr[0]}); - try t.expect(false); - } - } - } -}; - -fn toCrlf(comptime str: []const u8) []const u8 { - comptime { - var buf: [str.len * 2]u8 = undefined; - - @setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2 - - var buf_len: usize = 0; - for (str) |ch| { - if (ch == '\n') { - buf[buf_len] = '\r'; - buf_len += 1; - } - - buf[buf_len] = ch; - buf_len += 1; - } - - return buf[0..buf_len]; - } -} - -test "HTTP/1.x parse - No body" { - try test_case.parse( - toCrlf( - \\GET / HTTP/1.1 - \\ - \\ - ), - .{ - .protocol = .http_1_1, - .method = .GET, - .uri = "/", - }, - ); - try test_case.parse( - toCrlf( - \\POST / HTTP/1.1 - \\ - \\ - ), - .{ - .protocol = .http_1_1, - .method = .POST, - .uri = "/", - }, - ); - try test_case.parse( - toCrlf( - \\GET /url/abcd HTTP/1.1 - \\ - \\ - ), - .{ - .protocol = .http_1_1, - .method = .GET, - .uri = "/url/abcd", - }, - ); - try test_case.parse( - toCrlf( - \\GET / HTTP/1.0 - \\ - \\ - ), - .{ - .protocol = .http_1_0, - .method = .GET, - .uri = "/", - }, - ); - try test_case.parse( - toCrlf( - \\GET /url/abcd HTTP/1.1 - \\Content-Type: application/json - \\ - \\ - ), - .{ - .protocol = .http_1_1, - .method = .GET, - .uri = "/url/abcd", - .headers = &.{.{ "Content-Type", "application/json" }}, - }, - ); - try test_case.parse( - toCrlf( - \\GET /url/abcd HTTP/1.1 - \\Content-Type: application/json - \\Authorization: bearer - \\ - \\ - ), - .{ - .protocol = .http_1_1, - .method = .GET, - .uri = "/url/abcd", - .headers = &.{ - .{ "Content-Type", "application/json" }, - .{ "Authorization", "bearer " }, - }, - }, - ); - - // Test without CRLF - try test_case.parse( - \\GET /url/abcd HTTP/1.1 - \\Content-Type: application/json - \\Authorization: bearer - \\ - \\ - , - .{ - .protocol = .http_1_1, - .method = .GET, - .uri = "/url/abcd", - .headers = &.{ - .{ "Content-Type", "application/json" }, - .{ "Authorization", "bearer " }, - }, - }, - ); - try test_case.parse( - \\POST / HTTP/1.1 - \\ - \\ - , - .{ - .protocol = .http_1_1, - .method = .POST, - .uri = "/", - }, - ); - try test_case.parse( - toCrlf( - \\GET / HTTP/1.2 - \\ - \\ - ), - .{ - .protocol = .http_1_x, - .method = .GET, - .uri = "/", - }, - ); -} - -test "HTTP/1.x parse - unsupported protocol" { - try t.expectError(error.UnknownProtocol, test_case.parse( - \\GET / JSON/1.1 - \\ - \\ - , - .{}, - )); - try t.expectError(error.UnknownProtocol, test_case.parse( - \\GET / SOMETHINGELSE/3.5 - \\ - \\ - , - .{}, - )); - try t.expectError(error.UnknownProtocol, test_case.parse( - \\GET / /1.1 - \\ - \\ - , - .{}, - )); - try t.expectError(error.HttpVersionNotSupported, test_case.parse( - \\GET / HTTP/2.1 - \\ - \\ - , - .{}, - )); -} - -test "HTTP/1.x parse - Unknown method" { - try t.expectError(error.MethodNotImplemented, test_case.parse( - \\ABCD / HTTP/1.1 - \\ - \\ - , - .{}, - )); - try t.expectError(error.MethodNotImplemented, test_case.parse( - \\PATCHPATCHPATCH / HTTP/1.1 - \\ - \\ - , - .{}, - )); -} - -test "HTTP/1.x parse - Too long" { - try t.expectError(error.RequestUriTooLong, test_case.parse( - std.fmt.comptimePrint("GET {s} HTTP/1.1\n\n", .{"a" ** 8192}), - .{}, - )); - try t.expectError(error.HeaderLineTooLong, test_case.parse( - std.fmt.comptimePrint("GET / HTTP/1.1\r\n{s}: abcd", .{"a" ** 8192}), - .{}, - )); - try t.expectError(error.HeaderLineTooLong, test_case.parse( - std.fmt.comptimePrint("GET / HTTP/1.1\r\nabcd: {s}", .{"a" ** 8192}), - .{}, - )); -} - -test "HTTP/1.x parse - bad requests" { - try t.expectError(error.BadRequest, test_case.parse( - \\GET / HTTP/1.1 blah blah - \\ - \\ - , - .{}, - )); - try t.expectError(error.BadRequest, test_case.parse( - \\GET / HTTP/1.1 - \\abcd : lksjdfkl - \\ - , - .{}, - )); - try t.expectError(error.BadRequest, test_case.parse( - \\GET / HTTP/1.1 - \\ lksjfklsjdfklj - \\ - , - .{}, - )); -} - -test "HTTP/1.x parse - Headers" { - try test_case.parse( - toCrlf( - \\GET /url/abcd HTTP/1.1 - \\Content-Type: application/json - \\Content-Type: application/xml - \\ - \\ - ), - .{ - .protocol = .http_1_1, - .method = .GET, - .uri = "/url/abcd", - .headers = &.{.{ "Content-Type", "application/json, application/xml" }}, - }, - ); -} diff --git a/src/http/server.zig b/src/http/server.zig index d81d77a..269fc56 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -3,14 +3,13 @@ const util = @import("util"); const http = @import("./lib.zig"); const response = @import("./server/response.zig"); -const request = @import("./request.zig"); pub const Response = struct { alloc: std.mem.Allocator, stream: std.net.Stream, should_close: bool = false, pub const Stream = response.ResponseStream(std.net.Stream.Writer); - pub fn open(self: *Response, status: http.Status, headers: *const http.Fields) !Stream { + pub fn open(self: *Response, status: http.Status, headers: *const http.Headers) !Stream { if (headers.get("Connection")) |hdr| { if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true; } @@ -18,7 +17,7 @@ pub const Response = struct { return response.open(self.alloc, self.stream.writer(), headers, status); } - pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Fields) !std.net.Stream { + pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Headers) !std.net.Stream { try response.writeRequestHeader(self.stream.writer(), headers, status); return self.stream; } @@ -27,6 +26,10 @@ pub const Response = struct { const Request = http.Request; const request_buf_size = 1 << 16; +pub fn Handler(comptime Ctx: type) type { + return fn (Ctx, Request, *Response) void; +} + pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: anytype, alloc: std.mem.Allocator) !void { // TODO: Timeouts while (true) { @@ -34,7 +37,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a var arena = std.heap.ArenaAllocator.init(alloc); defer arena.deinit(); - var req = request.parse(arena.allocator(), conn.stream.reader()) catch |err| { + const req = Request.parse(arena.allocator(), conn.stream.reader(), conn.address) catch |err| { return handleError(conn.stream.writer(), err) catch {}; }; std.log.debug("done parsing", .{}); @@ -44,7 +47,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a .stream = conn.stream, }; - handler(ctx, &req, &res); + handler(ctx, req, &res); std.log.debug("done handling", .{}); if (req.headers.get("Connection")) |hdr| { diff --git a/src/http/server/response.zig b/src/http/server/response.zig index 296382d..615f0c1 100644 --- a/src/http/server/response.zig +++ b/src/http/server/response.zig @@ -2,20 +2,20 @@ const std = @import("std"); const http = @import("../lib.zig"); const Status = http.Status; -const Fields = http.Fields; +const Headers = http.Headers; const chunk_size = 16 * 1024; pub fn open( alloc: std.mem.Allocator, writer: anytype, - headers: *const Fields, + headers: *const Headers, status: Status, ) !ResponseStream(@TypeOf(writer)) { const buf = try alloc.alloc(u8, chunk_size); errdefer alloc.free(buf); try writeStatusLine(writer, status); - try writeFields(writer, headers); + try writeHeaders(writer, headers); return ResponseStream(@TypeOf(writer)){ .allocator = alloc, @@ -25,9 +25,9 @@ pub fn open( }; } -pub fn writeRequestHeader(writer: anytype, headers: *const Fields, status: Status) !void { +pub fn writeRequestHeader(writer: anytype, headers: *const Headers, status: Status) !void { try writeStatusLine(writer, status); - try writeFields(writer, headers); + try writeHeaders(writer, headers); try writer.writeAll("\r\n"); } @@ -36,7 +36,7 @@ fn writeStatusLine(writer: anytype, status: Status) !void { try writer.print("HTTP/1.1 {} {s}\r\n", .{ @enumToInt(status), status_text }); } -fn writeFields(writer: anytype, headers: *const Fields) !void { +fn writeHeaders(writer: anytype, headers: *const Headers) !void { var iter = headers.iterator(); while (iter.next()) |header| { for (header.value_ptr.*) |ch| { @@ -65,7 +65,7 @@ pub fn ResponseStream(comptime BaseWriter: type) type { allocator: std.mem.Allocator, base_writer: BaseWriter, - headers: *const Fields, + headers: *const Headers, buffer: []u8, buffer_pos: usize = 0, chunked: bool = false, @@ -95,6 +95,7 @@ pub fn ResponseStream(comptime BaseWriter: type) type { return; } + std.debug.print("{}\n", .{cursor}); self.writeToBuffer(bytes[cursor .. cursor + remaining_in_chunk]); cursor += remaining_in_chunk; try self.flushChunk(); @@ -176,7 +177,7 @@ const _tests = struct { test "ResponseStream no headers empty body" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Fields.init(std.testing.allocator); + var headers = Headers.init(std.testing.allocator); defer headers.deinit(); { @@ -204,7 +205,7 @@ const _tests = struct { test "ResponseStream empty body" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Fields.init(std.testing.allocator); + var headers = Headers.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); @@ -235,7 +236,7 @@ const _tests = struct { test "ResponseStream not 200 OK" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Fields.init(std.testing.allocator); + var headers = Headers.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); @@ -265,7 +266,7 @@ const _tests = struct { test "ResponseStream small body" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Fields.init(std.testing.allocator); + var headers = Headers.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); @@ -299,7 +300,7 @@ const _tests = struct { test "ResponseStream large body" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Fields.init(std.testing.allocator); + var headers = Headers.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); @@ -340,7 +341,7 @@ const _tests = struct { test "ResponseStream large body ending on chunk boundary" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Fields.init(std.testing.allocator); + var headers = Headers.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); diff --git a/src/http/socket.zig b/src/http/socket.zig index eab1a1d..24c10d4 100644 --- a/src/http/socket.zig +++ b/src/http/socket.zig @@ -23,21 +23,21 @@ const Opcode = enum(u4) { } }; -pub fn handshake(alloc: std.mem.Allocator, req: *http.Request, res: *http.Response) !Socket { +pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Response) !Socket { const upgrade = req.headers.get("Upgrade") orelse return error.BadHandshake; const connection = req.headers.get("Connection") orelse return error.BadHandshake; if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake; if (std.ascii.indexOfIgnoreCase(connection, "Upgrade") == null) return error.BadHandshake; const key_hdr = req.headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake; - if ((try std.base64.standard.Decoder.calcSizeForSlice(key_hdr)) != 16) return error.BadHandshake; + if (try std.base64.standard.Decoder.calcSizeForSlice(key_hdr) != 16) return error.BadHandshake; var key: [16]u8 = undefined; std.base64.standard.Decoder.decode(&key, key_hdr) catch return error.BadHandshake; const version = req.headers.get("Sec-WebSocket-version") orelse return error.BadHandshake; if (!std.mem.eql(u8, "13", version)) return error.BadHandshake; - var headers = http.Fields.init(alloc); + var headers = http.Headers.init(alloc); defer headers.deinit(); try headers.put("Upgrade", "websocket"); @@ -51,7 +51,7 @@ pub fn handshake(alloc: std.mem.Allocator, req: *http.Request, res: *http.Respon var hash_encoded: [std.base64.standard.Encoder.calcSize(Sha1.digest_length)]u8 = undefined; _ = std.base64.standard.Encoder.encode(&hash_encoded, &hash); try headers.put("Sec-WebSocket-Accept", &hash_encoded); - const stream = try res.upgrade(.switching_protocols, &headers); + const stream = try res.upgrade(.switching_protcols, &headers); return Socket{ .stream = stream }; } @@ -164,15 +164,15 @@ fn writeFrame(writer: anytype, header: FrameInfo, buf: []const u8) !void { const initial_len: u7 = if (header.len < 126) @intCast(u7, header.len) else if (std.math.cast(u16, header.len)) |_| - @as(u7, 126) + 126 else - @as(u7, 127); + 127; var hdr_buf = [2]u8{ 0, 0 }; - hdr_buf[0] |= if (header.is_final) @as(u8, 0b1000_0000) else 0; + hdr_buf[0] |= if (header.is_final) 0b1000_0000 else 0; hdr_buf[0] |= @as(u8, header.rsv) << 4; hdr_buf[0] |= @enumToInt(header.opcode); - hdr_buf[1] |= if (header.masking_key) |_| @as(u8, 0b1000_0000) else 0; + hdr_buf[1] |= if (header.masking_key) |_| 0b1000_0000 else 0; hdr_buf[1] |= initial_len; try writer.writeAll(&hdr_buf); if (initial_len == 126) diff --git a/src/http/test.zig b/src/http/test.zig deleted file mode 100644 index 1441ec2..0000000 --- a/src/http/test.zig +++ /dev/null @@ -1,3 +0,0 @@ -test { - _ = @import("./request/test_parser.zig"); -} diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 902711e..31fbe90 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -13,11 +13,10 @@ pub const invites = @import("./controllers/invites.zig"); pub const users = @import("./controllers/users.zig"); pub const notes = @import("./controllers/notes.zig"); pub const streaming = @import("./controllers/streaming.zig"); -pub const timelines = @import("./controllers/timelines.zig"); -pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response, alloc: std.mem.Allocator) void { +pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void { // TODO: hashmaps? - var response = Response{ .headers = http.Fields.init(alloc), .res = res }; + var response = Response{ .headers = http.Headers.init(alloc), .res = res }; defer response.headers.deinit(); const found = routeRequestInternal(api_source, req, &response, alloc); @@ -25,7 +24,7 @@ pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response if (!found) response.status(.not_found) catch {}; } -fn routeRequestInternal(api_source: anytype, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool { +fn routeRequestInternal(api_source: anytype, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool { inline for (routes) |route| { if (Context(route).matchAndHandle(api_source, req, res, alloc)) return true; } @@ -43,8 +42,6 @@ const routes = .{ notes.create, notes.get, streaming.streaming, - timelines.global, - timelines.local, }; pub fn Context(comptime Route: type) type { @@ -61,21 +58,18 @@ pub fn Context(comptime Route: type) type { // leave it as a simple string instead of void pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void; - base_request: *http.Request, + base_request: http.Request, allocator: std.mem.Allocator, method: http.Method, uri: []const u8, - headers: http.Fields, + headers: http.Headers, args: Args, body: Body, query: Query, - // TODO - body_buf: ?[]const u8 = null, - fn parseArgs(path: []const u8) ?Args { var args: Args = undefined; var path_iter = util.PathIter.from(path); @@ -100,7 +94,7 @@ pub fn Context(comptime Route: type) type { @compileError("Unsupported Type " ++ @typeName(T)); } - pub fn matchAndHandle(api_source: *api.ApiSource, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool { + pub fn matchAndHandle(api_source: *api.ApiSource, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool { if (req.method != Route.method) return false; var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?'); var args: Args = parseArgs(path) orelse return false; @@ -118,8 +112,6 @@ pub fn Context(comptime Route: type) type { .query = undefined, }; - std.log.debug("Matched route {s}", .{path}); - self.prepareAndHandle(api_source, req, res); return true; @@ -137,7 +129,7 @@ pub fn Context(comptime Route: type) type { }; } - fn prepareAndHandle(self: *Self, api_source: anytype, req: *http.Request, response: *Response) void { + fn prepareAndHandle(self: *Self, api_source: anytype, req: http.Request, response: *Response) void { self.parseBody(req) catch |err| return errorHandler(response, .bad_request, err); defer self.freeBody(); @@ -149,20 +141,16 @@ pub fn Context(comptime Route: type) type { self.handle(response, &api_conn); } - fn parseBody(self: *Self, req: *http.Request) !void { + fn parseBody(self: *Self, req: http.Request) !void { if (Body != void) { - var stream = req.body orelse return error.NoBody; - const body = try stream.reader().readAllAlloc(self.allocator, 1 << 16); - errdefer self.allocator.free(body); + const body = req.body orelse return error.NoBody; self.body = try json_utils.parse(Body, body, self.allocator); - self.body_buf = body; } } fn freeBody(self: *Self) void { if (Body != void) { json_utils.parseFree(self.body, self.allocator); - self.allocator.free(self.body_buf.?); } } @@ -203,7 +191,7 @@ pub fn Context(comptime Route: type) type { pub const Response = struct { const Self = @This(); - headers: http.Fields, + headers: http.Headers, res: *http.Response, opened: bool = false, diff --git a/src/main/controllers/auth.zig b/src/main/controllers/auth.zig index 1b6e652..95be918 100644 --- a/src/main/controllers/auth.zig +++ b/src/main/controllers/auth.zig @@ -1,5 +1,4 @@ const api = @import("api"); -const std = @import("std"); pub const login = struct { pub const method = .POST; @@ -13,8 +12,6 @@ pub const login = struct { pub fn handler(req: anytype, res: anytype, srv: anytype) !void { const token = try srv.login(req.body.username, req.body.password); - std.log.debug("{any}", .{res.headers}); - try res.json(.ok, token); } }; diff --git a/src/main/controllers/communities.zig b/src/main/controllers/communities.zig index 3aeec1f..f6b4c69 100644 --- a/src/main/controllers/communities.zig +++ b/src/main/controllers/communities.zig @@ -1,7 +1,6 @@ const std = @import("std"); const api = @import("api"); const util = @import("util"); -const query_utils = @import("../query.zig"); const QueryArgs = api.CommunityQueryArgs; const Uuid = util.Uuid; @@ -26,18 +25,89 @@ pub const query = struct { pub const method = .GET; pub const path = "/communities"; - pub const Query = QueryArgs; + // NOTE: This has to match QueryArgs + // TODO: Support union fields in query strings natively, so we don't + // have to keep these in sync + pub const Query = struct { + const OrderBy = QueryArgs.OrderBy; + const Direction = QueryArgs.Direction; + const PageDirection = QueryArgs.PageDirection; + + // Max items to fetch + max_items: usize = 20, + + // Selection filters + owner_id: ?Uuid = null, + like: ?[]const u8 = null, + created_before: ?DateTime = null, + created_after: ?DateTime = null, + + // Ordering parameter + order_by: OrderBy = .created_at, + direction: Direction = .ascending, + + // the `prev` struct has a slightly different format to QueryArgs + prev: struct { + id: ?Uuid = null, + + // Only one of these can be present, and must match order_by above + name: ?[]const u8 = null, + host: ?[]const u8 = null, + created_at: ?DateTime = null, + } = .{}, + + // What direction to scan the page window + page_direction: PageDirection = .forward, + + pub const format = formatQueryParams; + }; pub fn handler(req: anytype, res: anytype, srv: anytype) !void { - const results = try srv.queryCommunities(req.query); + const q = req.query; + const query_matches = if (q.prev.id) |_| switch (q.order_by) { + .name => q.prev.name != null and q.prev.host == null and q.prev.created_at == null, + .host => q.prev.name == null and q.prev.host != null and q.prev.created_at == null, + .created_at => q.prev.name == null and q.prev.host == null and q.prev.created_at != null, + } else (q.prev.name == null and q.prev.host == null and q.prev.created_at == null); + + if (!query_matches) return res.err(.bad_request, "prev.* parameters do not match", {}); + + const prev_arg: ?QueryArgs.Prev = if (q.prev.id) |id| .{ + .id = id, + .order_val = switch (q.order_by) { + .name => .{ .name = q.prev.name.? }, + .host => .{ .host = q.prev.host.? }, + .created_at => .{ .created_at = q.prev.created_at.? }, + }, + } else null; + + const query_args = QueryArgs{ + .max_items = q.max_items, + .owner_id = q.owner_id, + .like = q.like, + .created_before = q.created_before, + .created_after = q.created_after, + + .order_by = q.order_by, + .direction = q.direction, + + .prev = prev_arg, + + .page_direction = q.page_direction, + }; + + const results = try srv.queryCommunities(query_args); var link = std.ArrayList(u8).init(req.allocator); const link_writer = link.writer(); defer link.deinit(); - try writeLink(link_writer, srv.community, path, results.next_page, "next"); + const next_page = queryArgsToControllerQuery(results.next_page); + const prev_page = queryArgsToControllerQuery(results.prev_page); + + try writeLink(link_writer, srv.community, path, next_page, "next"); try link_writer.writeByte(','); - try writeLink(link_writer, srv.community, path, results.prev_page, "prev"); + try writeLink(link_writer, srv.community, path, prev_page, "prev"); try res.headers.put("Link", link.items); @@ -59,7 +129,7 @@ fn writeLink( .{ @tagName(community.scheme), community.host, path }, ); - try query_utils.formatQuery(params, writer); + try std.fmt.format(writer, "{}", .{params}); try std.fmt.format( writer, @@ -67,3 +137,70 @@ fn writeLink( .{rel}, ); } + +fn formatQueryParams( + params: anytype, + comptime fmt: []const u8, + opt: std.fmt.FormatOptions, + writer: anytype, +) !void { + if (comptime std.meta.trait.is(.Pointer)(@TypeOf(params))) { + return formatQueryParams(params.*, fmt, opt, writer); + } + + return formatRecursive("", params, writer); +} + +fn formatRecursive(comptime prefix: []const u8, params: anytype, writer: anytype) !void { + inline for (std.meta.fields(@TypeOf(params))) |field| { + const val = @field(params, field.name); + const is_optional = comptime std.meta.trait.is(.Optional)(field.field_type); + const present = if (comptime is_optional) val != null else true; + if (present) { + const unwrapped = if (is_optional) val.? else val; + // TODO: percent-encode this + _ = try switch (@TypeOf(unwrapped)) { + []const u8 => blk: { + break :blk std.fmt.format(writer, "{s}{s}={s}&", .{ prefix, field.name, unwrapped }); + }, + + else => |U| blk: { + if (comptime std.meta.trait.isContainer(U) and std.meta.trait.hasFn("format")(U)) { + break :blk std.fmt.format(writer, "{s}{s}={}&", .{ prefix, field.name, unwrapped }); + } + + break :blk switch (@typeInfo(U)) { + .Enum => std.fmt.format(writer, "{s}{s}={s}&", .{ prefix, field.name, @tagName(unwrapped) }), + .Struct => formatRecursive(field.name ++ ".", unwrapped, writer), + else => std.fmt.format(writer, "{s}{s}={}&", .{ prefix, field.name, unwrapped }), + }; + }, + }; + } + } +} + +fn queryArgsToControllerQuery(args: QueryArgs) query.Query { + var result = query.Query{ + .max_items = args.max_items, + .owner_id = args.owner_id, + .like = args.like, + .created_before = args.created_before, + .created_after = args.created_after, + .order_by = args.order_by, + .direction = args.direction, + .prev = .{}, + .page_direction = args.page_direction, + }; + + if (args.prev) |prev| { + result.prev = .{ + .id = prev.id, + .name = if (prev.order_val == .name) prev.order_val.name else null, + .host = if (prev.order_val == .host) prev.order_val.host else null, + .created_at = if (prev.order_val == .created_at) prev.order_val.created_at else null, + }; + } + + return result; +} diff --git a/src/main/controllers/timelines.zig b/src/main/controllers/timelines.zig deleted file mode 100644 index 092d7c8..0000000 --- a/src/main/controllers/timelines.zig +++ /dev/null @@ -1,21 +0,0 @@ -pub const global = struct { - pub const method = .GET; - pub const path = "/timelines/global"; - - pub fn handler(_: anytype, res: anytype, srv: anytype) !void { - const results = try srv.globalTimeline(); - - try res.json(.ok, results); - } -}; - -pub const local = struct { - pub const method = .GET; - pub const path = "/timelines/local"; - - pub fn handler(_: anytype, res: anytype, srv: anytype) !void { - const results = try srv.localTimeline(); - - try res.json(.ok, results); - } -}; diff --git a/src/main/json.zig b/src/main/json.zig index 21474cc..b6bdff0 100644 --- a/src/main/json.zig +++ b/src/main/json.zig @@ -499,7 +499,7 @@ fn parseInternal( if (!fields_seen[i]) { if (field.default_value) |default_ptr| { if (!field.is_comptime) { - const default = @ptrCast(*align(1) const field.field_type, default_ptr).*; + const default = @ptrCast(*const field.field_type, default_ptr).*; @field(r, field.name) = default; } } else { diff --git a/src/main/main.zig b/src/main/main.zig index 9ff01fe..5870bc2 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -87,7 +87,7 @@ fn thread_main(src: *api.ApiSource, srv: *std.net.StreamServer) void { } } -fn handle(ctx: anytype, req: *http.Request, res: *http.Response) void { +fn handle(ctx: anytype, req: http.Request, res: *http.Response) void { c.routeRequest(ctx.src, req, res, ctx.allocator); } diff --git a/src/main/migrations.zig b/src/main/migrations.zig index b0abc02..459e75a 100644 --- a/src/main/migrations.zig +++ b/src/main/migrations.zig @@ -53,7 +53,7 @@ pub fn up(db: anytype) !void { std.log.info("Running migration {s}", .{migration.name}); try execScript(tx, migration.up, gpa.allocator()); try tx.insert("migration", .{ - .name = @as([]const u8, migration.name), + .name = migration.name, .applied_at = DateTime.now(), }, gpa.allocator()); } diff --git a/src/main/query.zig b/src/main/query.zig index d79ada8..e3bc725 100644 --- a/src/main/query.zig +++ b/src/main/query.zig @@ -71,132 +71,37 @@ const QueryIter = @import("util").QueryIter; /// TODO: values are currently case-sensitive, and are not url-decoded properly. /// This should be fixed. pub fn parseQuery(comptime T: type, query: []const u8) !T { - if (comptime !std.meta.trait.isContainer(T)) @compileError("T must be a struct"); + //if (!std.meta.trait.isContainer(T)) @compileError("T must be a struct"); var iter = QueryIter.from(query); - - var fields = Intermediary(T){}; + var result = T{}; while (iter.next()) |pair| { - // TODO: Hash map - inline for (std.meta.fields(Intermediary(T))) |field| { - if (std.ascii.eqlIgnoreCase(field.name[2..], pair.key)) { - @field(fields, field.name) = if (pair.value) |v| .{ .value = v } else .{ .no_value = {} }; - break; - } - } else std.log.debug("unknown param {s}", .{pair.key}); + try parseQueryPair(T, &result, pair.key, pair.value); } - return (try parse(T, "", "", fields)).?; + return result; } -fn parseScalar(comptime T: type, comptime name: []const u8, fields: anytype) !?T { - const param = @field(fields, name); - return switch (param) { - .not_specified => null, - .no_value => try parseQueryValue(T, null), - .value => |v| try parseQueryValue(T, v), - }; -} +fn parseQueryPair(comptime T: type, result: *T, key: []const u8, value: ?[]const u8) !void { + const key_part = std.mem.sliceTo(key, '.'); + const field_idx = std.meta.stringToEnum(std.meta.FieldEnum(T), key_part) orelse return error.UnknownField; -fn parse(comptime T: type, comptime prefix: []const u8, comptime name: []const u8, fields: anytype) !?T { - if (comptime isScalar(T)) return parseScalar(T, prefix ++ "." ++ name, fields); - switch (@typeInfo(T)) { - .Union => |info| { - var result: ?T = null; - inline for (info.fields) |field| { - const F = field.field_type; - - const maybe_value = try parse(F, prefix, field.name, fields); - if (maybe_value) |value| { - if (result != null) return error.DuplicateUnionField; - - result = @unionInit(T, field.name, value); + inline for (std.meta.fields(T)) |info, idx| { + if (@enumToInt(field_idx) == idx) { + if (comptime isScalar(info.field_type)) { + if (key_part.len == key.len) { + @field(result, info.name) = try parseQueryValue(info.field_type, value); + return; + } else { + return error.UnknownField; } - } - std.log.debug("{any}", .{result}); - return result; - }, - - .Struct => |info| { - var result: T = undefined; - var fields_specified: usize = 0; - - inline for (info.fields) |field| { - const F = field.field_type; - - var maybe_value: ?F = null; - if (try parse(F, prefix ++ "." ++ name, field.name, fields)) |v| { - maybe_value = v; - } else if (field.default_value) |default| { - maybe_value = @ptrCast(*const F, @alignCast(@alignOf(F), default)).*; - } - - if (maybe_value) |v| { - fields_specified += 1; - @field(result, field.name) = v; - } - } - - if (fields_specified == 0) { - return null; - } else if (fields_specified != info.fields.len) { - return error.PartiallySpecifiedStruct; } else { - return result; - } - }, - - // Only applies to non-scalar optionals - .Optional => |info| return try parse(info.child, prefix, name, fields), - - else => @compileError("tmp"), - } -} - -fn recursiveFieldPaths(comptime T: type, comptime prefix: []const u8) []const []const u8 { - comptime { - if (std.meta.trait.is(.Optional)(T)) return recursiveFieldPaths(std.meta.Child(T), prefix); - - var fields: []const []const u8 = &.{}; - - for (std.meta.fields(T)) |f| { - const full_name = prefix ++ f.name; - - if (isScalar(f.field_type)) { - fields = fields ++ @as([]const []const u8, &.{full_name}); - } else { - const field_prefix = if (@typeInfo(f.field_type) == .Union) prefix else full_name ++ "."; - fields = fields ++ recursiveFieldPaths(f.field_type, field_prefix); + const remaining = std.mem.trimLeft(u8, key[key_part.len..], "."); + return try parseQueryPair(info.field_type, &@field(result, info.name), remaining, value); } } - - return fields; } -} -const QueryParam = union(enum) { - not_specified: void, - no_value: void, - value: []const u8, -}; - -fn Intermediary(comptime T: type) type { - const field_names = recursiveFieldPaths(T, ".."); - - var fields: [field_names.len]std.builtin.Type.StructField = undefined; - for (field_names) |name, i| fields[i] = .{ - .name = name, - .field_type = QueryParam, - .default_value = &QueryParam{ .not_specified = {} }, - .is_comptime = false, - .alignment = @alignOf(QueryParam), - }; - - return @Type(.{ .Struct = .{ - .layout = .Auto, - .fields = &fields, - .decls = &.{}, - .is_tuple = false, - } }); + return error.UnknownField; } fn parseQueryValue(comptime T: type, value: ?[]const u8) !T { @@ -252,50 +157,6 @@ fn isScalar(comptime T: type) bool { return false; } -pub fn formatQuery(params: anytype, writer: anytype) !void { - try format("", "", params, writer); -} - -fn formatScalar(comptime name: []const u8, val: anytype, writer: anytype) !void { - const T = @TypeOf(val); - if (comptime std.meta.trait.isZigString(T)) return std.fmt.format(writer, "{s}={s}&", .{ name, val }); - _ = try switch (@typeInfo(T)) { - .Enum => std.fmt.format(writer, "{s}={s}&", .{ name, @tagName(val) }), - .Optional => if (val) |v| formatScalar(name, v, writer), - else => std.fmt.format(writer, "{s}={}&", .{ name, val }), - }; -} - -fn format(comptime prefix: []const u8, comptime name: []const u8, params: anytype, writer: anytype) !void { - const T = @TypeOf(params); - const eff_prefix = if (prefix.len == 0) "" else prefix ++ "."; - if (comptime isScalar(T)) return formatScalar(eff_prefix ++ name, params, writer); - - switch (@typeInfo(T)) { - .Struct => { - inline for (std.meta.fields(T)) |field| { - const val = @field(params, field.name); - try format(eff_prefix ++ name, field.name, val, writer); - } - }, - .Union => { - //inline for (std.meta.tags(T)) |tag| { - inline for (std.meta.fields(T)) |field| { - const tag = @field(std.meta.Tag(T), field.name); - const tag_name = field.name; - if (@as(std.meta.Tag(T), params) == tag) { - const val = @field(params, tag_name); - try format(prefix, tag_name, val, writer); - } - } - }, - .Optional => { - if (params) |p| try format(prefix, name, p, writer); - }, - else => @compileError("Unsupported query type"), - } -} - test { const TestQuery = struct { int: usize = 3, diff --git a/src/sql/engines/common.zig b/src/sql/engines/common.zig index b50b7d0..67fef2f 100644 --- a/src/sql/engines/common.zig +++ b/src/sql/engines/common.zig @@ -68,7 +68,7 @@ pub const QueryOptions = struct { // do not require allocators for prep. If an allocator is needed but not // provided, `error.AllocatorRequired` will be returned. // Only used with the postgres backend. - allocator: ?Allocator = null, + prep_allocator: ?Allocator = null, }; // Turns a value into its appropriate textual value (or null) diff --git a/src/sql/engines/postgres.zig b/src/sql/engines/postgres.zig index bb24415..94ef674 100644 --- a/src/sql/engines/postgres.zig +++ b/src/sql/engines/postgres.zig @@ -180,25 +180,14 @@ pub const Db = struct { const format_text = 0; const format_binary = 1; pub fn exec(self: Db, sql: [:0]const u8, args: anytype, opt: common.QueryOptions) common.ExecError!Results { - const alloc = opt.allocator; + const alloc = opt.prep_allocator; const result = blk: { if (@TypeOf(args) != void and args.len > 0) { var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired); defer arena.deinit(); const params = try arena.allocator().alloc(?[*:0]const u8, args.len); - // TODO: The following is a fix for the stage1 compiler. remove this - //inline for (args) |arg, i| { - inline for (std.meta.fields(@TypeOf(args))) |field, i| { - const arg = @field(args, field.name); - - // The stage1 compiler has issues with runtime branches that in any - // way involve compile time values - const maybe_slice = if (@import("builtin").zig_backend == .stage1) - common.prepareParamText(&arena, arg) catch unreachable - else - try common.prepareParamText(&arena, arg); - - params[i] = if (maybe_slice) |slice| + inline for (args) |arg, i| { + params[i] = if (try common.prepareParamText(&arena, arg)) |slice| slice.ptr else null; diff --git a/src/sql/engines/sqlite.zig b/src/sql/engines/sqlite.zig index e2bb697..d1a73c2 100644 --- a/src/sql/engines/sqlite.zig +++ b/src/sql/engines/sqlite.zig @@ -118,10 +118,7 @@ pub const Db = struct { }; if (@TypeOf(args) != void) { - // TODO: Fix for stage1 compiler - //inline for (args) |arg, i| { - inline for (std.meta.fields(@TypeOf(args))) |field, i| { - const arg = @field(args, field.name); + inline for (args) |arg, i| { // SQLite treats $NNN args as having the name NNN, not index NNN. // As such, if you reference $2 and not $1 in your query (such as // when dynamically constructing queries), it could assign $2 the diff --git a/src/sql/lib.zig b/src/sql/lib.zig index 9b1f599..3124863 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -25,54 +25,6 @@ pub const Engine = enum { sqlite, }; -/// Helper for building queries at runtime. All constituent parts of the -/// query should be defined at comptime, however the choice of whether -/// or not to include them can occur at runtime. -pub const QueryBuilder = struct { - array: std.ArrayList(u8), - where_clauses_appended: usize = 0, - - pub fn init(alloc: std.mem.Allocator) QueryBuilder { - return QueryBuilder{ .array = std.ArrayList(u8).init(alloc) }; - } - - pub fn deinit(self: *const QueryBuilder) void { - self.array.deinit(); - } - - /// Add a chunk of sql to the query without processing - pub fn appendSlice(self: *QueryBuilder, comptime sql: []const u8) !void { - try self.array.appendSlice(sql); - } - - /// Add a where clause to the query. Clauses are assumed to be components - /// in an overall expression in Conjunctive Normal Form (AND of OR's). - /// https://en.wikipedia.org/wiki/Conjunctive_normal_form - /// All calls to andWhere must be contiguous, that is, they cannot be - /// interspersed with calls to appendSlice - pub fn andWhere(self: *QueryBuilder, comptime clause: []const u8) !void { - if (self.where_clauses_appended == 0) { - try self.array.appendSlice("WHERE "); - } else { - try self.array.appendSlice(" AND "); - } - - try self.array.appendSlice(clause); - self.where_clauses_appended += 1; - } - - pub fn str(self: *const QueryBuilder) []const u8 { - return self.array.items; - } - - pub fn terminate(self: *QueryBuilder) ![:0]const u8 { - std.debug.assert(self.array.items.len != 0); - if (self.array.items[self.array.items.len - 1] != 0) try self.array.append(0); - - return std.meta.assumeSentinel(self.array.items, 0); - } -}; - // TODO: make this suck less pub const Config = union(Engine) { postgres: struct { @@ -458,7 +410,7 @@ fn Tx(comptime tx_level: u8) type { args: anytype, alloc: ?Allocator, ) QueryError!Results(RowType) { - return self.queryWithOptions(RowType, sql, args, .{ .allocator = alloc }); + return self.queryWithOptions(RowType, sql, args, .{ .prep_allocator = alloc }); } /// Runs a query to completion and returns a row of results, unless the query @@ -487,45 +439,6 @@ fn Tx(comptime tx_level: u8) type { return row; } - // Runs a query to completion and returns the results as a slice - pub fn queryRowsWithOptions( - self: Self, - comptime RowType: type, - q: [:0]const u8, - args: anytype, - max_items: ?usize, - options: QueryOptions, - ) QueryRowError![]RowType { - var results = try self.queryWithOptions(RowType, q, args, options); - defer results.finish(); - - const alloc = options.allocator orelse return error.AllocatorRequired; - - var result_array = std.ArrayList(RowType).init(alloc); - errdefer result_array.deinit(); - if (max_items) |max| try result_array.ensureTotalCapacity(max); - - errdefer for (result_array.items) |r| util.deepFree(alloc, r); - - var too_many: bool = false; - while (try results.row(alloc)) |row| { - errdefer util.deepFree(alloc, row); - if (max_items) |max| { - if (result_array.items.len >= max) { - util.deepFree(alloc, row); - too_many = true; - continue; - } - } - - try result_array.append(row); - } - - if (too_many) return error.TooManyRows; - - return result_array.toOwnedSlice(); - } - // Inserts a single value into a table pub fn insert( self: Self, @@ -542,7 +455,7 @@ fn Tx(comptime tx_level: u8) type { inline for (fields) |field, i| { // This causes a compiler crash. Why? //const F = field.field_type; - const F = @TypeOf(@field(value, field.name)); + const F = @TypeOf(@field(std.mem.zeroes(ValueType), field.name)); // causes issues if F is @TypeOf(null), use dummy type types[i] = if (F == @TypeOf(null)) ?i64 else F; table_spec = comptime (table_spec ++ field.name ++ ","); @@ -586,7 +499,7 @@ fn Tx(comptime tx_level: u8) type { alloc: ?std.mem.Allocator, comptime check_tx: bool, ) !void { - var results = try self.runSql(sql, args, .{ .allocator = alloc }, check_tx); + var results = try self.runSql(sql, args, .{ .prep_allocator = alloc }, check_tx); defer results.finish(); while (try results.row()) |_| {}