From 4cb574bc91c84162b8f0e0c6e0dc79031ae71bee Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Sat, 5 Nov 2022 00:26:53 -0700 Subject: [PATCH] Clean up header handling --- build.zig | 12 +- src/http/headers.zig | 82 ++++++++ src/http/lib.zig | 2 + src/http/request.zig | 11 +- src/http/request/parser.zig | 344 +++---------------------------- src/http/request/test_parser.zig | 264 ++++++++++++++++++++++++ src/http/server.zig | 6 +- src/http/server/response.zig | 26 +-- src/http/socket.zig | 2 +- src/http/test.zig | 3 + src/main/controllers.zig | 6 +- 11 files changed, 417 insertions(+), 341 deletions(-) create mode 100644 src/http/headers.zig create mode 100644 src/http/request/test_parser.zig create mode 100644 src/http/test.zig diff --git a/build.zig b/build.zig index 108f884..9cbd45c 100644 --- a/build.zig +++ b/build.zig @@ -53,16 +53,16 @@ pub fn build(b: *std.build.Builder) void { exe.linkSystemLibrary("pq"); exe.linkLibC(); - const util_tests = b.addTest("src/util/lib.zig"); - const http_tests = b.addTest("src/http/lib.zig"); - const sql_tests = b.addTest("src/sql/lib.zig"); + //const util_tests = b.addTest("src/util/lib.zig"); + const http_tests = b.addTest("src/http/test.zig"); + //const sql_tests = b.addTest("src/sql/lib.zig"); http_tests.addPackage(util_pkg); - sql_tests.addPackage(util_pkg); + //sql_tests.addPackage(util_pkg); const unit_tests = b.step("unit-tests", "Run tests"); - unit_tests.dependOn(&util_tests.step); + //unit_tests.dependOn(&util_tests.step); unit_tests.dependOn(&http_tests.step); - unit_tests.dependOn(&sql_tests.step); + //unit_tests.dependOn(&sql_tests.step); const api_integration = b.addTest("./tests/api_integration/lib.zig"); api_integration.addPackage(sql_pkg); diff --git a/src/http/headers.zig b/src/http/headers.zig new file mode 100644 index 0000000..8065ce3 --- /dev/null +++ b/src/http/headers.zig @@ -0,0 +1,82 @@ +const std = @import("std"); + +pub const Fields = struct { + const HashContext = struct { + const hash_seed = 1; + pub fn eql(_: @This(), lhs: []const u8, rhs: []const u8) bool { + return std.ascii.eqlIgnoreCase(lhs, rhs); + } + pub fn hash(_: @This(), s: []const u8) u64 { + var h = std.hash.Wyhash.init(hash_seed); + for (s) |ch| { + const c = [1]u8{std.ascii.toLower(ch)}; + h.update(&c); + } + return h.final(); + } + }; + + const HashMap = std.HashMapUnmanaged( + []const u8, + []const u8, + HashContext, + std.hash_map.default_max_load_percentage, + ); + + unmanaged: HashMap, + allocator: std.mem.Allocator, + + pub fn init(allocator: std.mem.Allocator) Fields { + return Fields{ + .unmanaged = .{}, + .allocator = allocator, + }; + } + + pub fn deinit(self: *Fields) void { + var hash_iter = self.unmanaged.iterator(); + while (hash_iter.next()) |entry| { + self.allocator.free(entry.key_ptr.*); + self.allocator.free(entry.value_ptr.*); + } + + self.unmanaged.deinit(self.allocator); + } + + pub fn iterator(self: Fields) HashMap.Iterator { + return self.unmanaged.iterator(); + } + + pub fn get(self: Fields, key: []const u8) ?[]const u8 { + return self.unmanaged.get(key); + } + + pub fn put(self: *Fields, key: []const u8, val: []const u8) !void { + const key_clone = try self.allocator.alloc(u8, key.len); + std.mem.copy(u8, key_clone, key); + errdefer self.allocator.free(key_clone); + + const val_clone = try self.allocator.alloc(u8, val.len); + std.mem.copy(u8, val_clone, val); + errdefer self.allocator.free(val_clone); + + if (try self.unmanaged.fetchPut(self.allocator, key_clone, val_clone)) |entry| { + self.allocator.free(entry.key); + self.allocator.free(entry.value); + } + } + + pub fn append(self: *Fields, key: []const u8, val: []const u8) !void { + if (self.unmanaged.getEntry(key)) |entry| { + const new_val = try std.mem.join(self.allocator, ", ", &.{ entry.value_ptr.*, val }); + self.allocator.free(entry.value_ptr.*); + entry.value_ptr.* = new_val; + } else { + try self.put(key, val); + } + } + + pub fn count(self: Fields) usize { + return self.unmanaged.count(); + } +}; diff --git a/src/http/lib.zig b/src/http/lib.zig index eff1e45..e234e27 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -15,6 +15,8 @@ pub const serveConn = server.serveConn; pub const Response = server.Response; pub const Handler = server.Handler; +pub const Fields = @import("./headers.zig").Fields; + pub const Headers = std.HashMap([]const u8, []const u8, struct { pub fn eql(_: @This(), a: []const u8, b: []const u8) bool { return ciutf8.eql(a, b); diff --git a/src/http/request.zig b/src/http/request.zig index e6fd79c..cc16b8b 100644 --- a/src/http/request.zig +++ b/src/http/request.zig @@ -7,21 +7,22 @@ pub const Request = struct { pub const Protocol = enum { http_1_0, http_1_1, + http_1_x, }; protocol: Protocol, - source_address: ?std.net.Address, method: http.Method, uri: []const u8, - headers: http.Headers, + headers: http.Fields, + body: ?[]const u8 = null, - pub fn parse(alloc: std.mem.Allocator, reader: anytype, addr: std.net.Address) !Request { - return parser.parse(alloc, reader, addr); + pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request { + return parser.parse(alloc, reader); } - pub fn parseFree(self: Request, alloc: std.mem.Allocator) void { + pub fn parseFree(self: *Request, alloc: std.mem.Allocator) void { parser.parseFree(alloc, self); } }; diff --git a/src/http/request/parser.zig b/src/http/request/parser.zig index 8a7c19f..a2100d8 100644 --- a/src/http/request/parser.zig +++ b/src/http/request/parser.zig @@ -1,9 +1,8 @@ const std = @import("std"); -const util = @import("util"); const http = @import("../lib.zig"); const Method = http.Method; -const Headers = http.Headers; +const Fields = http.Fields; const Request = @import("../request.zig").Request; @@ -22,7 +21,7 @@ const Encoding = enum { chunked, }; -pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address) !Request { +pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request { const method = try parseMethod(reader); const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) { error.StreamTooLong => return error.RequestUriTooLong, @@ -33,11 +32,14 @@ pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address const proto = try parseProto(reader); // discard \r\n - _ = try reader.readByte(); - _ = try reader.readByte(); + switch (try reader.readByte()) { + '\r' => if (try reader.readByte() != '\n') return error.BadRequest, + '\n' => {}, + else => return error.BadRequest, + } var headers = try parseHeaders(alloc, reader); - errdefer freeHeaders(alloc, &headers); + errdefer headers.deinit(); const body = if (method.requestHasBody()) try readBody(alloc, headers, reader) @@ -45,16 +47,8 @@ pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address null; errdefer if (body) |b| alloc.free(b); - const eff_addr = if (headers.get("X-Real-IP")) |ip| - std.net.Address.parseIp(ip, address.getPort()) catch { - return error.BadRequest; - } - else - address; - return Request{ .protocol = proto, - .source_address = eff_addr, .method = method, .uri = uri, @@ -99,71 +93,49 @@ fn parseProto(reader: anytype) !Request.Protocol { return switch (buf[2]) { '0' => .http_1_0, '1' => .http_1_1, - else => error.HttpVersionNotSupported, + else => .http_1_x, }; } -fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers { - var map = Headers.init(allocator); - errdefer map.deinit(); - errdefer { - var iter = map.iterator(); - while (iter.next()) |it| { - allocator.free(it.key_ptr.*); - allocator.free(it.value_ptr.*); - } - } - // todo: - //errdefer { - //var iter = map.iterator(); - //while (iter.next()) |it| { - //allocator.free(it.key_ptr); - //allocator.free(it.value_ptr); - //} - //} - - var buf: [1024]u8 = undefined; +fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields { + var headers = Fields.init(allocator); + var buf: [4096]u8 = undefined; while (true) { - const line = try reader.readUntilDelimiter(&buf, '\n'); - if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break; + const full_line = reader.readUntilDelimiter(&buf, '\n') catch |err| switch (err) { + error.StreamTooLong => return error.HeaderLineTooLong, + else => return err, + }; + const line = std.mem.trimRight(u8, full_line, "\r"); + if (line.len == 0) break; - // TODO: handle multi-line headers - const name = extractHeaderName(line) orelse continue; - const value_end = if (line[line.len - 1] == '\r') line.len - 1 else line.len; - const value = line[name.len + 1 + 1 .. value_end]; + const name = std.mem.sliceTo(line, ':'); + if (!isTokenValid(name)) return error.BadRequest; + if (name.len == line.len) return error.BadRequest; - if (name.len == 0 or value.len == 0) return error.BadRequest; + const value = std.mem.trim(u8, line[name.len + 1 ..], " \t"); - const name_alloc = try allocator.alloc(u8, name.len); - errdefer allocator.free(name_alloc); - const value_alloc = try allocator.alloc(u8, value.len); - errdefer allocator.free(value_alloc); - - @memcpy(name_alloc.ptr, name.ptr, name.len); - @memcpy(value_alloc.ptr, value.ptr, value.len); - - try map.put(name_alloc, value_alloc); + try headers.put(name, value); } - return map; + return headers; } -fn extractHeaderName(line: []const u8) ?[]const u8 { - var index: usize = 0; +fn isTokenValid(token: []const u8) bool { + if (token.len == 0) return false; + for (token) |ch| { + switch (ch) { + '"', '(', ')', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '{', '}' => return false, - // TODO: handle whitespace - while (index < line.len) : (index += 1) { - if (line[index] == ':') { - if (index == 0) return null; - return line[0..index]; + '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~' => {}, + else => if (!std.ascii.isAlphanumeric(ch)) return false, } } - return null; + return true; } -fn readBody(alloc: std.mem.Allocator, headers: Headers, reader: anytype) !?[]const u8 { +fn readBody(alloc: std.mem.Allocator, headers: Fields, reader: anytype) !?[]const u8 { const xfer_encoding = try parseEncoding(headers.get("Transfer-Encoding")); if (xfer_encoding != .identity) return error.UnsupportedMediaType; const content_encoding = try parseEncoding(headers.get("Content-Encoding")); @@ -190,254 +162,6 @@ fn parseEncoding(encoding: ?[]const u8) !Encoding { pub fn parseFree(allocator: std.mem.Allocator, request: *Request) void { allocator.free(request.uri); - freeHeaders(allocator, &request.headers); + request.headers.deinit(); if (request.body) |body| allocator.free(body); } - -fn freeHeaders(allocator: std.mem.Allocator, headers: *http.Headers) void { - var iter = headers.iterator(); - while (iter.next()) |it| { - allocator.free(it.key_ptr.*); - allocator.free(it.value_ptr.*); - } - headers.deinit(); -} - -const _test = struct { - const expectEqual = std.testing.expectEqual; - const expectEqualStrings = std.testing.expectEqualStrings; - - fn toCrlf(comptime str: []const u8) []const u8 { - comptime { - var buf: [str.len * 2]u8 = undefined; - - @setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2 - - var buf_len: usize = 0; - for (str) |ch| { - if (ch == '\n') { - buf[buf_len] = '\r'; - buf_len += 1; - } - - buf[buf_len] = ch; - buf_len += 1; - } - - return buf[0..buf_len]; - } - } - - fn makeHeaders(alloc: std.mem.Allocator, headers: anytype) !Headers { - var result = Headers.init(alloc); - inline for (headers) |tup| { - try result.put(tup[0], tup[1]); - } - return result; - } - - fn areEqualHeaders(lhs: Headers, rhs: Headers) bool { - if (lhs.count() != rhs.count()) return false; - var iter = lhs.iterator(); - while (iter.next()) |it| { - const rhs_val = rhs.get(it.key_ptr.*) orelse return false; - if (!std.mem.eql(u8, it.value_ptr.*, rhs_val)) return false; - } - return true; - } - - fn printHeaders(headers: Headers) void { - var iter = headers.iterator(); - while (iter.next()) |it| { - std.debug.print("{s}: {s}\n", .{ it.key_ptr.*, it.value_ptr.* }); - } - } - - fn expectEqualHeaders(expected: Headers, actual: Headers) !void { - if (!areEqualHeaders(expected, actual)) { - std.debug.print("\nexpected: \n", .{}); - printHeaders(expected); - std.debug.print("\n\nfound: \n", .{}); - printHeaders(actual); - std.debug.print("\n\n", .{}); - return error.TestExpectedEqual; - } - } - - fn parseTestCase(alloc: std.mem.Allocator, comptime request: []const u8, expected: http.Request) !void { - var stream = std.io.fixedBufferStream(toCrlf(request)); - - const result = try parse(alloc, stream.reader()); - - try expectEqual(expected.method, result.method); - try expectEqualStrings(expected.path, result.path); - try expectEqualHeaders(expected.headers, result.headers); - if ((expected.body == null) != (result.body == null)) { - const null_str: []const u8 = "(null)"; - const exp = expected.body orelse null_str; - const act = result.body orelse null_str; - std.debug.print("\nexpected:\n{s}\n\nfound:\n{s}\n\n", .{ exp, act }); - return error.TestExpectedEqual; - } - if (expected.body != null) { - try expectEqualStrings(expected.body.?, result.body.?); - } - } -}; - -// TOOD: failure test cases -test "parse" { - const testCase = _test.parseTestCase; - var buf = [_]u8{0} ** (1 << 16); - var fba = std.heap.FixedBufferAllocator.init(&buf); - const alloc = fba.allocator(); - try testCase(alloc, ( - \\GET / HTTP/1.1 - \\ - \\ - ), .{ - .method = .GET, - .headers = try _test.makeHeaders(alloc, .{}), - .path = "/", - }); - - fba.reset(); - try testCase(alloc, ( - \\POST / HTTP/1.1 - \\ - \\ - ), .{ - .method = .POST, - .headers = try _test.makeHeaders(alloc, .{}), - .path = "/", - }); - - fba.reset(); - try testCase(alloc, ( - \\HEAD / HTTP/1.1 - \\Authorization: bearer - \\ - \\ - ), .{ - .method = .HEAD, - .headers = try _test.makeHeaders(alloc, .{ - .{ "Authorization", "bearer " }, - }), - .path = "/", - }); - - fba.reset(); - try testCase(alloc, ( - \\POST /nonsense HTTP/1.1 - \\Authorization: bearer - \\Content-Length: 5 - \\ - \\12345 - ), .{ - .method = .POST, - .headers = try _test.makeHeaders(alloc, .{ - .{ "Authorization", "bearer " }, - .{ "Content-Length", "5" }, - }), - .path = "/nonsense", - .body = "12345", - }); - - fba.reset(); - try std.testing.expectError( - error.MethodNotImplemented, - testCase(alloc, ( - \\FOO /nonsense HTTP/1.1 - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.MethodNotImplemented, - testCase(alloc, ( - \\FOOBARBAZ /nonsense HTTP/1.1 - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.RequestUriTooLong, - testCase(alloc, ( - \\GET / - ++ ("a" ** 2048)), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.UnknownProtocol, - testCase(alloc, ( - \\GET /nonsense SPECIALHTTP/1.1 - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.UnknownProtocol, - testCase(alloc, ( - \\GET /nonsense JSON/1.1 - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.HttpVersionNotSupported, - testCase(alloc, ( - \\GET /nonsense HTTP/1.9 - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.HttpVersionNotSupported, - testCase(alloc, ( - \\GET /nonsense HTTP/8.1 - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.BadRequest, - testCase(alloc, ( - \\GET /nonsense HTTP/blah blah blah - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.BadRequest, - testCase(alloc, ( - \\GET /nonsense HTTP/1/1 - \\ - \\ - ), undefined), - ); - - fba.reset(); - try std.testing.expectError( - error.BadRequest, - testCase(alloc, ( - \\GET /nonsense HTTP/1/1 - \\ - \\ - ), undefined), - ); -} diff --git a/src/http/request/test_parser.zig b/src/http/request/test_parser.zig new file mode 100644 index 0000000..c4d34ff --- /dev/null +++ b/src/http/request/test_parser.zig @@ -0,0 +1,264 @@ +const std = @import("std"); +const parser = @import("./parser.zig"); +const http = @import("../lib.zig"); +const t = std.testing; + +const test_case = struct { + fn parse(text: []const u8, expected: struct { + protocol: http.Request.Protocol = .http_1_1, + method: http.Method = .GET, + headers: []const std.meta.Tuple(&.{ []const u8, []const u8 }) = &.{}, + uri: []const u8 = "", + }) !void { + var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(text) }; + var actual = try parser.parse(t.allocator, stream.reader()); + defer actual.parseFree(t.allocator); + + try t.expectEqual(expected.protocol, actual.protocol); + try t.expectEqual(expected.method, actual.method); + try t.expectEqualStrings(expected.uri, actual.uri); + + try t.expectEqual(expected.headers.len, actual.headers.count()); + for (expected.headers) |hdr| { + if (actual.headers.get(hdr[0])) |val| { + try t.expectEqualStrings(hdr[1], val); + } else { + std.debug.print("Error: Header {s} expected to be present, was not.\n", .{hdr[0]}); + try t.expect(false); + } + } + } +}; + +fn toCrlf(comptime str: []const u8) []const u8 { + comptime { + var buf: [str.len * 2]u8 = undefined; + + @setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2 + + var buf_len: usize = 0; + for (str) |ch| { + if (ch == '\n') { + buf[buf_len] = '\r'; + buf_len += 1; + } + + buf[buf_len] = ch; + buf_len += 1; + } + + return buf[0..buf_len]; + } +} + +test "HTTP/1.x parse - No body" { + try test_case.parse( + toCrlf( + \\GET / HTTP/1.1 + \\ + \\ + ), + .{ + .protocol = .http_1_1, + .method = .GET, + .uri = "/", + }, + ); + try test_case.parse( + toCrlf( + \\POST / HTTP/1.1 + \\ + \\ + ), + .{ + .protocol = .http_1_1, + .method = .POST, + .uri = "/", + }, + ); + try test_case.parse( + toCrlf( + \\GET /url/abcd HTTP/1.1 + \\ + \\ + ), + .{ + .protocol = .http_1_1, + .method = .GET, + .uri = "/url/abcd", + }, + ); + try test_case.parse( + toCrlf( + \\GET / HTTP/1.0 + \\ + \\ + ), + .{ + .protocol = .http_1_0, + .method = .GET, + .uri = "/", + }, + ); + try test_case.parse( + toCrlf( + \\GET /url/abcd HTTP/1.1 + \\Content-Type: application/json + \\ + \\ + ), + .{ + .protocol = .http_1_1, + .method = .GET, + .uri = "/url/abcd", + .headers = &.{.{ "Content-Type", "application/json" }}, + }, + ); + try test_case.parse( + toCrlf( + \\GET /url/abcd HTTP/1.1 + \\Content-Type: application/json + \\Authorization: bearer + \\ + \\ + ), + .{ + .protocol = .http_1_1, + .method = .GET, + .uri = "/url/abcd", + .headers = &.{ + .{ "Content-Type", "application/json" }, + .{ "Authorization", "bearer " }, + }, + }, + ); + + // Test without CRLF + try test_case.parse( + \\GET /url/abcd HTTP/1.1 + \\Content-Type: application/json + \\Authorization: bearer + \\ + \\ + , + .{ + .protocol = .http_1_1, + .method = .GET, + .uri = "/url/abcd", + .headers = &.{ + .{ "Content-Type", "application/json" }, + .{ "Authorization", "bearer " }, + }, + }, + ); + try test_case.parse( + \\POST / HTTP/1.1 + \\ + \\ + , + .{ + .protocol = .http_1_1, + .method = .POST, + .uri = "/", + }, + ); + try test_case.parse( + toCrlf( + \\GET / HTTP/1.2 + \\ + \\ + ), + .{ + .protocol = .http_1_x, + .method = .GET, + .uri = "/", + }, + ); +} + +test "HTTP/1.x parse - unsupported protocol" { + try t.expectError(error.UnknownProtocol, test_case.parse( + \\GET / JSON/1.1 + \\ + \\ + , + .{}, + )); + try t.expectError(error.UnknownProtocol, test_case.parse( + \\GET / SOMETHINGELSE/3.5 + \\ + \\ + , + .{}, + )); + try t.expectError(error.UnknownProtocol, test_case.parse( + \\GET / /1.1 + \\ + \\ + , + .{}, + )); + try t.expectError(error.HttpVersionNotSupported, test_case.parse( + \\GET / HTTP/2.1 + \\ + \\ + , + .{}, + )); +} + +test "HTTP/1.x parse - Unknown method" { + try t.expectError(error.MethodNotImplemented, test_case.parse( + \\ABCD / HTTP/1.1 + \\ + \\ + , + .{}, + )); + try t.expectError(error.MethodNotImplemented, test_case.parse( + \\PATCHPATCHPATCH / HTTP/1.1 + \\ + \\ + , + .{}, + )); +} + +test "HTTP/1.x parse - Too long" { + try t.expectError(error.RequestUriTooLong, test_case.parse( + std.fmt.comptimePrint("GET {s} HTTP/1.1\n\n", .{"a" ** 8192}), + .{}, + )); + try t.expectError(error.HeaderLineTooLong, test_case.parse( + std.fmt.comptimePrint("GET / HTTP/1.1\r\n{s}: abcd", .{"a" ** 8192}), + .{}, + )); + try t.expectError(error.HeaderLineTooLong, test_case.parse( + std.fmt.comptimePrint("GET / HTTP/1.1\r\nabcd: {s}", .{"a" ** 8192}), + .{}, + )); +} + +test "HTTP/1.x parse - bad requests" { + try t.expectError(error.BadRequest, test_case.parse( + \\GET / HTTP/1.1 blah blah + \\ + \\ + , + .{}, + )); + try t.expectError(error.BadRequest, test_case.parse( + \\GET / HTTP/1.1 + \\abcd : lksjdfkl + \\ + , + .{}, + )); + try t.expectError(error.BadRequest, test_case.parse( + \\GET / HTTP/1.1 + \\ lksjfklsjdfklj + \\ + , + .{}, + )); +} diff --git a/src/http/server.zig b/src/http/server.zig index 269fc56..91bde68 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -9,7 +9,7 @@ pub const Response = struct { stream: std.net.Stream, should_close: bool = false, pub const Stream = response.ResponseStream(std.net.Stream.Writer); - pub fn open(self: *Response, status: http.Status, headers: *const http.Headers) !Stream { + pub fn open(self: *Response, status: http.Status, headers: *const http.Fields) !Stream { if (headers.get("Connection")) |hdr| { if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true; } @@ -17,7 +17,7 @@ pub const Response = struct { return response.open(self.alloc, self.stream.writer(), headers, status); } - pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Headers) !std.net.Stream { + pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Fields) !std.net.Stream { try response.writeRequestHeader(self.stream.writer(), headers, status); return self.stream; } @@ -37,7 +37,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a var arena = std.heap.ArenaAllocator.init(alloc); defer arena.deinit(); - const req = Request.parse(arena.allocator(), conn.stream.reader(), conn.address) catch |err| { + const req = Request.parse(arena.allocator(), conn.stream.reader()) catch |err| { return handleError(conn.stream.writer(), err) catch {}; }; std.log.debug("done parsing", .{}); diff --git a/src/http/server/response.zig b/src/http/server/response.zig index 615f0c1..b50b9f1 100644 --- a/src/http/server/response.zig +++ b/src/http/server/response.zig @@ -2,20 +2,20 @@ const std = @import("std"); const http = @import("../lib.zig"); const Status = http.Status; -const Headers = http.Headers; +const Fields = http.Fields; const chunk_size = 16 * 1024; pub fn open( alloc: std.mem.Allocator, writer: anytype, - headers: *const Headers, + headers: *const Fields, status: Status, ) !ResponseStream(@TypeOf(writer)) { const buf = try alloc.alloc(u8, chunk_size); errdefer alloc.free(buf); try writeStatusLine(writer, status); - try writeHeaders(writer, headers); + try writeFields(writer, headers); return ResponseStream(@TypeOf(writer)){ .allocator = alloc, @@ -25,9 +25,9 @@ pub fn open( }; } -pub fn writeRequestHeader(writer: anytype, headers: *const Headers, status: Status) !void { +pub fn writeRequestHeader(writer: anytype, headers: *const Fields, status: Status) !void { try writeStatusLine(writer, status); - try writeHeaders(writer, headers); + try writeFields(writer, headers); try writer.writeAll("\r\n"); } @@ -36,7 +36,7 @@ fn writeStatusLine(writer: anytype, status: Status) !void { try writer.print("HTTP/1.1 {} {s}\r\n", .{ @enumToInt(status), status_text }); } -fn writeHeaders(writer: anytype, headers: *const Headers) !void { +fn writeFields(writer: anytype, headers: *const Fields) !void { var iter = headers.iterator(); while (iter.next()) |header| { for (header.value_ptr.*) |ch| { @@ -65,7 +65,7 @@ pub fn ResponseStream(comptime BaseWriter: type) type { allocator: std.mem.Allocator, base_writer: BaseWriter, - headers: *const Headers, + headers: *const Fields, buffer: []u8, buffer_pos: usize = 0, chunked: bool = false, @@ -177,7 +177,7 @@ const _tests = struct { test "ResponseStream no headers empty body" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Headers.init(std.testing.allocator); + var headers = Fields.init(std.testing.allocator); defer headers.deinit(); { @@ -205,7 +205,7 @@ const _tests = struct { test "ResponseStream empty body" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Headers.init(std.testing.allocator); + var headers = Fields.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); @@ -236,7 +236,7 @@ const _tests = struct { test "ResponseStream not 200 OK" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Headers.init(std.testing.allocator); + var headers = Fields.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); @@ -266,7 +266,7 @@ const _tests = struct { test "ResponseStream small body" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Headers.init(std.testing.allocator); + var headers = Fields.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); @@ -300,7 +300,7 @@ const _tests = struct { test "ResponseStream large body" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Headers.init(std.testing.allocator); + var headers = Fields.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); @@ -341,7 +341,7 @@ const _tests = struct { test "ResponseStream large body ending on chunk boundary" { var buffer: [test_buffer_size]u8 = undefined; var test_stream = std.io.fixedBufferStream(&buffer); - var headers = Headers.init(std.testing.allocator); + var headers = Fields.init(std.testing.allocator); defer headers.deinit(); try headers.put("Content-Type", "text/plain"); diff --git a/src/http/socket.zig b/src/http/socket.zig index 1a99ce4..faac74f 100644 --- a/src/http/socket.zig +++ b/src/http/socket.zig @@ -37,7 +37,7 @@ pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Respons const version = req.headers.get("Sec-WebSocket-version") orelse return error.BadHandshake; if (!std.mem.eql(u8, "13", version)) return error.BadHandshake; - var headers = http.Headers.init(alloc); + var headers = http.Fields.init(alloc); defer headers.deinit(); try headers.put("Upgrade", "websocket"); diff --git a/src/http/test.zig b/src/http/test.zig new file mode 100644 index 0000000..1441ec2 --- /dev/null +++ b/src/http/test.zig @@ -0,0 +1,3 @@ +test { + _ = @import("./request/test_parser.zig"); +} diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 31fbe90..a20e9ee 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -16,7 +16,7 @@ pub const streaming = @import("./controllers/streaming.zig"); pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void { // TODO: hashmaps? - var response = Response{ .headers = http.Headers.init(alloc), .res = res }; + var response = Response{ .headers = http.Fields.init(alloc), .res = res }; defer response.headers.deinit(); const found = routeRequestInternal(api_source, req, &response, alloc); @@ -64,7 +64,7 @@ pub fn Context(comptime Route: type) type { method: http.Method, uri: []const u8, - headers: http.Headers, + headers: http.Fields, args: Args, body: Body, @@ -191,7 +191,7 @@ pub fn Context(comptime Route: type) type { pub const Response = struct { const Self = @This(); - headers: http.Headers, + headers: http.Fields, res: *http.Response, opened: bool = false,