Clean up header handling

This commit is contained in:
jaina heartles 2022-11-05 00:26:53 -07:00
parent 5d94f6874d
commit 4cb574bc91
11 changed files with 417 additions and 341 deletions

View file

@ -53,16 +53,16 @@ pub fn build(b: *std.build.Builder) void {
exe.linkSystemLibrary("pq"); exe.linkSystemLibrary("pq");
exe.linkLibC(); exe.linkLibC();
const util_tests = b.addTest("src/util/lib.zig"); //const util_tests = b.addTest("src/util/lib.zig");
const http_tests = b.addTest("src/http/lib.zig"); const http_tests = b.addTest("src/http/test.zig");
const sql_tests = b.addTest("src/sql/lib.zig"); //const sql_tests = b.addTest("src/sql/lib.zig");
http_tests.addPackage(util_pkg); http_tests.addPackage(util_pkg);
sql_tests.addPackage(util_pkg); //sql_tests.addPackage(util_pkg);
const unit_tests = b.step("unit-tests", "Run tests"); const unit_tests = b.step("unit-tests", "Run tests");
unit_tests.dependOn(&util_tests.step); //unit_tests.dependOn(&util_tests.step);
unit_tests.dependOn(&http_tests.step); unit_tests.dependOn(&http_tests.step);
unit_tests.dependOn(&sql_tests.step); //unit_tests.dependOn(&sql_tests.step);
const api_integration = b.addTest("./tests/api_integration/lib.zig"); const api_integration = b.addTest("./tests/api_integration/lib.zig");
api_integration.addPackage(sql_pkg); api_integration.addPackage(sql_pkg);

82
src/http/headers.zig Normal file
View file

@ -0,0 +1,82 @@
const std = @import("std");
pub const Fields = struct {
const HashContext = struct {
const hash_seed = 1;
pub fn eql(_: @This(), lhs: []const u8, rhs: []const u8) bool {
return std.ascii.eqlIgnoreCase(lhs, rhs);
}
pub fn hash(_: @This(), s: []const u8) u64 {
var h = std.hash.Wyhash.init(hash_seed);
for (s) |ch| {
const c = [1]u8{std.ascii.toLower(ch)};
h.update(&c);
}
return h.final();
}
};
const HashMap = std.HashMapUnmanaged(
[]const u8,
[]const u8,
HashContext,
std.hash_map.default_max_load_percentage,
);
unmanaged: HashMap,
allocator: std.mem.Allocator,
pub fn init(allocator: std.mem.Allocator) Fields {
return Fields{
.unmanaged = .{},
.allocator = allocator,
};
}
pub fn deinit(self: *Fields) void {
var hash_iter = self.unmanaged.iterator();
while (hash_iter.next()) |entry| {
self.allocator.free(entry.key_ptr.*);
self.allocator.free(entry.value_ptr.*);
}
self.unmanaged.deinit(self.allocator);
}
pub fn iterator(self: Fields) HashMap.Iterator {
return self.unmanaged.iterator();
}
pub fn get(self: Fields, key: []const u8) ?[]const u8 {
return self.unmanaged.get(key);
}
pub fn put(self: *Fields, key: []const u8, val: []const u8) !void {
const key_clone = try self.allocator.alloc(u8, key.len);
std.mem.copy(u8, key_clone, key);
errdefer self.allocator.free(key_clone);
const val_clone = try self.allocator.alloc(u8, val.len);
std.mem.copy(u8, val_clone, val);
errdefer self.allocator.free(val_clone);
if (try self.unmanaged.fetchPut(self.allocator, key_clone, val_clone)) |entry| {
self.allocator.free(entry.key);
self.allocator.free(entry.value);
}
}
pub fn append(self: *Fields, key: []const u8, val: []const u8) !void {
if (self.unmanaged.getEntry(key)) |entry| {
const new_val = try std.mem.join(self.allocator, ", ", &.{ entry.value_ptr.*, val });
self.allocator.free(entry.value_ptr.*);
entry.value_ptr.* = new_val;
} else {
try self.put(key, val);
}
}
pub fn count(self: Fields) usize {
return self.unmanaged.count();
}
};

View file

@ -15,6 +15,8 @@ pub const serveConn = server.serveConn;
pub const Response = server.Response; pub const Response = server.Response;
pub const Handler = server.Handler; pub const Handler = server.Handler;
pub const Fields = @import("./headers.zig").Fields;
pub const Headers = std.HashMap([]const u8, []const u8, struct { pub const Headers = std.HashMap([]const u8, []const u8, struct {
pub fn eql(_: @This(), a: []const u8, b: []const u8) bool { pub fn eql(_: @This(), a: []const u8, b: []const u8) bool {
return ciutf8.eql(a, b); return ciutf8.eql(a, b);

View file

@ -7,21 +7,22 @@ pub const Request = struct {
pub const Protocol = enum { pub const Protocol = enum {
http_1_0, http_1_0,
http_1_1, http_1_1,
http_1_x,
}; };
protocol: Protocol, protocol: Protocol,
source_address: ?std.net.Address,
method: http.Method, method: http.Method,
uri: []const u8, uri: []const u8,
headers: http.Headers, headers: http.Fields,
body: ?[]const u8 = null, body: ?[]const u8 = null,
pub fn parse(alloc: std.mem.Allocator, reader: anytype, addr: std.net.Address) !Request { pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request {
return parser.parse(alloc, reader, addr); return parser.parse(alloc, reader);
} }
pub fn parseFree(self: Request, alloc: std.mem.Allocator) void { pub fn parseFree(self: *Request, alloc: std.mem.Allocator) void {
parser.parseFree(alloc, self); parser.parseFree(alloc, self);
} }
}; };

View file

@ -1,9 +1,8 @@
const std = @import("std"); const std = @import("std");
const util = @import("util");
const http = @import("../lib.zig"); const http = @import("../lib.zig");
const Method = http.Method; const Method = http.Method;
const Headers = http.Headers; const Fields = http.Fields;
const Request = @import("../request.zig").Request; const Request = @import("../request.zig").Request;
@ -22,7 +21,7 @@ const Encoding = enum {
chunked, chunked,
}; };
pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address) !Request { pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request {
const method = try parseMethod(reader); const method = try parseMethod(reader);
const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) { const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) {
error.StreamTooLong => return error.RequestUriTooLong, error.StreamTooLong => return error.RequestUriTooLong,
@ -33,11 +32,14 @@ pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address
const proto = try parseProto(reader); const proto = try parseProto(reader);
// discard \r\n // discard \r\n
_ = try reader.readByte(); switch (try reader.readByte()) {
_ = try reader.readByte(); '\r' => if (try reader.readByte() != '\n') return error.BadRequest,
'\n' => {},
else => return error.BadRequest,
}
var headers = try parseHeaders(alloc, reader); var headers = try parseHeaders(alloc, reader);
errdefer freeHeaders(alloc, &headers); errdefer headers.deinit();
const body = if (method.requestHasBody()) const body = if (method.requestHasBody())
try readBody(alloc, headers, reader) try readBody(alloc, headers, reader)
@ -45,16 +47,8 @@ pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address
null; null;
errdefer if (body) |b| alloc.free(b); errdefer if (body) |b| alloc.free(b);
const eff_addr = if (headers.get("X-Real-IP")) |ip|
std.net.Address.parseIp(ip, address.getPort()) catch {
return error.BadRequest;
}
else
address;
return Request{ return Request{
.protocol = proto, .protocol = proto,
.source_address = eff_addr,
.method = method, .method = method,
.uri = uri, .uri = uri,
@ -99,71 +93,49 @@ fn parseProto(reader: anytype) !Request.Protocol {
return switch (buf[2]) { return switch (buf[2]) {
'0' => .http_1_0, '0' => .http_1_0,
'1' => .http_1_1, '1' => .http_1_1,
else => error.HttpVersionNotSupported, else => .http_1_x,
}; };
} }
fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers { fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields {
var map = Headers.init(allocator); var headers = Fields.init(allocator);
errdefer map.deinit();
errdefer {
var iter = map.iterator();
while (iter.next()) |it| {
allocator.free(it.key_ptr.*);
allocator.free(it.value_ptr.*);
}
}
// todo:
//errdefer {
//var iter = map.iterator();
//while (iter.next()) |it| {
//allocator.free(it.key_ptr);
//allocator.free(it.value_ptr);
//}
//}
var buf: [1024]u8 = undefined;
var buf: [4096]u8 = undefined;
while (true) { while (true) {
const line = try reader.readUntilDelimiter(&buf, '\n'); const full_line = reader.readUntilDelimiter(&buf, '\n') catch |err| switch (err) {
if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break; error.StreamTooLong => return error.HeaderLineTooLong,
else => return err,
};
const line = std.mem.trimRight(u8, full_line, "\r");
if (line.len == 0) break;
// TODO: handle multi-line headers const name = std.mem.sliceTo(line, ':');
const name = extractHeaderName(line) orelse continue; if (!isTokenValid(name)) return error.BadRequest;
const value_end = if (line[line.len - 1] == '\r') line.len - 1 else line.len; if (name.len == line.len) return error.BadRequest;
const value = line[name.len + 1 + 1 .. value_end];
if (name.len == 0 or value.len == 0) return error.BadRequest; const value = std.mem.trim(u8, line[name.len + 1 ..], " \t");
const name_alloc = try allocator.alloc(u8, name.len); try headers.put(name, value);
errdefer allocator.free(name_alloc);
const value_alloc = try allocator.alloc(u8, value.len);
errdefer allocator.free(value_alloc);
@memcpy(name_alloc.ptr, name.ptr, name.len);
@memcpy(value_alloc.ptr, value.ptr, value.len);
try map.put(name_alloc, value_alloc);
} }
return map; return headers;
} }
fn extractHeaderName(line: []const u8) ?[]const u8 { fn isTokenValid(token: []const u8) bool {
var index: usize = 0; if (token.len == 0) return false;
for (token) |ch| {
switch (ch) {
'"', '(', ')', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '{', '}' => return false,
// TODO: handle whitespace '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~' => {},
while (index < line.len) : (index += 1) { else => if (!std.ascii.isAlphanumeric(ch)) return false,
if (line[index] == ':') {
if (index == 0) return null;
return line[0..index];
} }
} }
return null; return true;
} }
fn readBody(alloc: std.mem.Allocator, headers: Headers, reader: anytype) !?[]const u8 { fn readBody(alloc: std.mem.Allocator, headers: Fields, reader: anytype) !?[]const u8 {
const xfer_encoding = try parseEncoding(headers.get("Transfer-Encoding")); const xfer_encoding = try parseEncoding(headers.get("Transfer-Encoding"));
if (xfer_encoding != .identity) return error.UnsupportedMediaType; if (xfer_encoding != .identity) return error.UnsupportedMediaType;
const content_encoding = try parseEncoding(headers.get("Content-Encoding")); const content_encoding = try parseEncoding(headers.get("Content-Encoding"));
@ -190,254 +162,6 @@ fn parseEncoding(encoding: ?[]const u8) !Encoding {
pub fn parseFree(allocator: std.mem.Allocator, request: *Request) void { pub fn parseFree(allocator: std.mem.Allocator, request: *Request) void {
allocator.free(request.uri); allocator.free(request.uri);
freeHeaders(allocator, &request.headers); request.headers.deinit();
if (request.body) |body| allocator.free(body); if (request.body) |body| allocator.free(body);
} }
fn freeHeaders(allocator: std.mem.Allocator, headers: *http.Headers) void {
var iter = headers.iterator();
while (iter.next()) |it| {
allocator.free(it.key_ptr.*);
allocator.free(it.value_ptr.*);
}
headers.deinit();
}
const _test = struct {
const expectEqual = std.testing.expectEqual;
const expectEqualStrings = std.testing.expectEqualStrings;
fn toCrlf(comptime str: []const u8) []const u8 {
comptime {
var buf: [str.len * 2]u8 = undefined;
@setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2
var buf_len: usize = 0;
for (str) |ch| {
if (ch == '\n') {
buf[buf_len] = '\r';
buf_len += 1;
}
buf[buf_len] = ch;
buf_len += 1;
}
return buf[0..buf_len];
}
}
fn makeHeaders(alloc: std.mem.Allocator, headers: anytype) !Headers {
var result = Headers.init(alloc);
inline for (headers) |tup| {
try result.put(tup[0], tup[1]);
}
return result;
}
fn areEqualHeaders(lhs: Headers, rhs: Headers) bool {
if (lhs.count() != rhs.count()) return false;
var iter = lhs.iterator();
while (iter.next()) |it| {
const rhs_val = rhs.get(it.key_ptr.*) orelse return false;
if (!std.mem.eql(u8, it.value_ptr.*, rhs_val)) return false;
}
return true;
}
fn printHeaders(headers: Headers) void {
var iter = headers.iterator();
while (iter.next()) |it| {
std.debug.print("{s}: {s}\n", .{ it.key_ptr.*, it.value_ptr.* });
}
}
fn expectEqualHeaders(expected: Headers, actual: Headers) !void {
if (!areEqualHeaders(expected, actual)) {
std.debug.print("\nexpected: \n", .{});
printHeaders(expected);
std.debug.print("\n\nfound: \n", .{});
printHeaders(actual);
std.debug.print("\n\n", .{});
return error.TestExpectedEqual;
}
}
fn parseTestCase(alloc: std.mem.Allocator, comptime request: []const u8, expected: http.Request) !void {
var stream = std.io.fixedBufferStream(toCrlf(request));
const result = try parse(alloc, stream.reader());
try expectEqual(expected.method, result.method);
try expectEqualStrings(expected.path, result.path);
try expectEqualHeaders(expected.headers, result.headers);
if ((expected.body == null) != (result.body == null)) {
const null_str: []const u8 = "(null)";
const exp = expected.body orelse null_str;
const act = result.body orelse null_str;
std.debug.print("\nexpected:\n{s}\n\nfound:\n{s}\n\n", .{ exp, act });
return error.TestExpectedEqual;
}
if (expected.body != null) {
try expectEqualStrings(expected.body.?, result.body.?);
}
}
};
// TOOD: failure test cases
test "parse" {
const testCase = _test.parseTestCase;
var buf = [_]u8{0} ** (1 << 16);
var fba = std.heap.FixedBufferAllocator.init(&buf);
const alloc = fba.allocator();
try testCase(alloc, (
\\GET / HTTP/1.1
\\
\\
), .{
.method = .GET,
.headers = try _test.makeHeaders(alloc, .{}),
.path = "/",
});
fba.reset();
try testCase(alloc, (
\\POST / HTTP/1.1
\\
\\
), .{
.method = .POST,
.headers = try _test.makeHeaders(alloc, .{}),
.path = "/",
});
fba.reset();
try testCase(alloc, (
\\HEAD / HTTP/1.1
\\Authorization: bearer <token>
\\
\\
), .{
.method = .HEAD,
.headers = try _test.makeHeaders(alloc, .{
.{ "Authorization", "bearer <token>" },
}),
.path = "/",
});
fba.reset();
try testCase(alloc, (
\\POST /nonsense HTTP/1.1
\\Authorization: bearer <token>
\\Content-Length: 5
\\
\\12345
), .{
.method = .POST,
.headers = try _test.makeHeaders(alloc, .{
.{ "Authorization", "bearer <token>" },
.{ "Content-Length", "5" },
}),
.path = "/nonsense",
.body = "12345",
});
fba.reset();
try std.testing.expectError(
error.MethodNotImplemented,
testCase(alloc, (
\\FOO /nonsense HTTP/1.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.MethodNotImplemented,
testCase(alloc, (
\\FOOBARBAZ /nonsense HTTP/1.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.RequestUriTooLong,
testCase(alloc, (
\\GET /
++ ("a" ** 2048)), undefined),
);
fba.reset();
try std.testing.expectError(
error.UnknownProtocol,
testCase(alloc, (
\\GET /nonsense SPECIALHTTP/1.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.UnknownProtocol,
testCase(alloc, (
\\GET /nonsense JSON/1.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.HttpVersionNotSupported,
testCase(alloc, (
\\GET /nonsense HTTP/1.9
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.HttpVersionNotSupported,
testCase(alloc, (
\\GET /nonsense HTTP/8.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.BadRequest,
testCase(alloc, (
\\GET /nonsense HTTP/blah blah blah
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.BadRequest,
testCase(alloc, (
\\GET /nonsense HTTP/1/1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.BadRequest,
testCase(alloc, (
\\GET /nonsense HTTP/1/1
\\
\\
), undefined),
);
}

View file

@ -0,0 +1,264 @@
const std = @import("std");
const parser = @import("./parser.zig");
const http = @import("../lib.zig");
const t = std.testing;
const test_case = struct {
fn parse(text: []const u8, expected: struct {
protocol: http.Request.Protocol = .http_1_1,
method: http.Method = .GET,
headers: []const std.meta.Tuple(&.{ []const u8, []const u8 }) = &.{},
uri: []const u8 = "",
}) !void {
var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(text) };
var actual = try parser.parse(t.allocator, stream.reader());
defer actual.parseFree(t.allocator);
try t.expectEqual(expected.protocol, actual.protocol);
try t.expectEqual(expected.method, actual.method);
try t.expectEqualStrings(expected.uri, actual.uri);
try t.expectEqual(expected.headers.len, actual.headers.count());
for (expected.headers) |hdr| {
if (actual.headers.get(hdr[0])) |val| {
try t.expectEqualStrings(hdr[1], val);
} else {
std.debug.print("Error: Header {s} expected to be present, was not.\n", .{hdr[0]});
try t.expect(false);
}
}
}
};
fn toCrlf(comptime str: []const u8) []const u8 {
comptime {
var buf: [str.len * 2]u8 = undefined;
@setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2
var buf_len: usize = 0;
for (str) |ch| {
if (ch == '\n') {
buf[buf_len] = '\r';
buf_len += 1;
}
buf[buf_len] = ch;
buf_len += 1;
}
return buf[0..buf_len];
}
}
test "HTTP/1.x parse - No body" {
try test_case.parse(
toCrlf(
\\GET / HTTP/1.1
\\
\\
),
.{
.protocol = .http_1_1,
.method = .GET,
.uri = "/",
},
);
try test_case.parse(
toCrlf(
\\POST / HTTP/1.1
\\
\\
),
.{
.protocol = .http_1_1,
.method = .POST,
.uri = "/",
},
);
try test_case.parse(
toCrlf(
\\GET /url/abcd HTTP/1.1
\\
\\
),
.{
.protocol = .http_1_1,
.method = .GET,
.uri = "/url/abcd",
},
);
try test_case.parse(
toCrlf(
\\GET / HTTP/1.0
\\
\\
),
.{
.protocol = .http_1_0,
.method = .GET,
.uri = "/",
},
);
try test_case.parse(
toCrlf(
\\GET /url/abcd HTTP/1.1
\\Content-Type: application/json
\\
\\
),
.{
.protocol = .http_1_1,
.method = .GET,
.uri = "/url/abcd",
.headers = &.{.{ "Content-Type", "application/json" }},
},
);
try test_case.parse(
toCrlf(
\\GET /url/abcd HTTP/1.1
\\Content-Type: application/json
\\Authorization: bearer <token>
\\
\\
),
.{
.protocol = .http_1_1,
.method = .GET,
.uri = "/url/abcd",
.headers = &.{
.{ "Content-Type", "application/json" },
.{ "Authorization", "bearer <token>" },
},
},
);
// Test without CRLF
try test_case.parse(
\\GET /url/abcd HTTP/1.1
\\Content-Type: application/json
\\Authorization: bearer <token>
\\
\\
,
.{
.protocol = .http_1_1,
.method = .GET,
.uri = "/url/abcd",
.headers = &.{
.{ "Content-Type", "application/json" },
.{ "Authorization", "bearer <token>" },
},
},
);
try test_case.parse(
\\POST / HTTP/1.1
\\
\\
,
.{
.protocol = .http_1_1,
.method = .POST,
.uri = "/",
},
);
try test_case.parse(
toCrlf(
\\GET / HTTP/1.2
\\
\\
),
.{
.protocol = .http_1_x,
.method = .GET,
.uri = "/",
},
);
}
test "HTTP/1.x parse - unsupported protocol" {
try t.expectError(error.UnknownProtocol, test_case.parse(
\\GET / JSON/1.1
\\
\\
,
.{},
));
try t.expectError(error.UnknownProtocol, test_case.parse(
\\GET / SOMETHINGELSE/3.5
\\
\\
,
.{},
));
try t.expectError(error.UnknownProtocol, test_case.parse(
\\GET / /1.1
\\
\\
,
.{},
));
try t.expectError(error.HttpVersionNotSupported, test_case.parse(
\\GET / HTTP/2.1
\\
\\
,
.{},
));
}
test "HTTP/1.x parse - Unknown method" {
try t.expectError(error.MethodNotImplemented, test_case.parse(
\\ABCD / HTTP/1.1
\\
\\
,
.{},
));
try t.expectError(error.MethodNotImplemented, test_case.parse(
\\PATCHPATCHPATCH / HTTP/1.1
\\
\\
,
.{},
));
}
test "HTTP/1.x parse - Too long" {
try t.expectError(error.RequestUriTooLong, test_case.parse(
std.fmt.comptimePrint("GET {s} HTTP/1.1\n\n", .{"a" ** 8192}),
.{},
));
try t.expectError(error.HeaderLineTooLong, test_case.parse(
std.fmt.comptimePrint("GET / HTTP/1.1\r\n{s}: abcd", .{"a" ** 8192}),
.{},
));
try t.expectError(error.HeaderLineTooLong, test_case.parse(
std.fmt.comptimePrint("GET / HTTP/1.1\r\nabcd: {s}", .{"a" ** 8192}),
.{},
));
}
test "HTTP/1.x parse - bad requests" {
try t.expectError(error.BadRequest, test_case.parse(
\\GET / HTTP/1.1 blah blah
\\
\\
,
.{},
));
try t.expectError(error.BadRequest, test_case.parse(
\\GET / HTTP/1.1
\\abcd : lksjdfkl
\\
,
.{},
));
try t.expectError(error.BadRequest, test_case.parse(
\\GET / HTTP/1.1
\\ lksjfklsjdfklj
\\
,
.{},
));
}

View file

@ -9,7 +9,7 @@ pub const Response = struct {
stream: std.net.Stream, stream: std.net.Stream,
should_close: bool = false, should_close: bool = false,
pub const Stream = response.ResponseStream(std.net.Stream.Writer); pub const Stream = response.ResponseStream(std.net.Stream.Writer);
pub fn open(self: *Response, status: http.Status, headers: *const http.Headers) !Stream { pub fn open(self: *Response, status: http.Status, headers: *const http.Fields) !Stream {
if (headers.get("Connection")) |hdr| { if (headers.get("Connection")) |hdr| {
if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true; if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true;
} }
@ -17,7 +17,7 @@ pub const Response = struct {
return response.open(self.alloc, self.stream.writer(), headers, status); return response.open(self.alloc, self.stream.writer(), headers, status);
} }
pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Headers) !std.net.Stream { pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Fields) !std.net.Stream {
try response.writeRequestHeader(self.stream.writer(), headers, status); try response.writeRequestHeader(self.stream.writer(), headers, status);
return self.stream; return self.stream;
} }
@ -37,7 +37,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a
var arena = std.heap.ArenaAllocator.init(alloc); var arena = std.heap.ArenaAllocator.init(alloc);
defer arena.deinit(); defer arena.deinit();
const req = Request.parse(arena.allocator(), conn.stream.reader(), conn.address) catch |err| { const req = Request.parse(arena.allocator(), conn.stream.reader()) catch |err| {
return handleError(conn.stream.writer(), err) catch {}; return handleError(conn.stream.writer(), err) catch {};
}; };
std.log.debug("done parsing", .{}); std.log.debug("done parsing", .{});

View file

@ -2,20 +2,20 @@ const std = @import("std");
const http = @import("../lib.zig"); const http = @import("../lib.zig");
const Status = http.Status; const Status = http.Status;
const Headers = http.Headers; const Fields = http.Fields;
const chunk_size = 16 * 1024; const chunk_size = 16 * 1024;
pub fn open( pub fn open(
alloc: std.mem.Allocator, alloc: std.mem.Allocator,
writer: anytype, writer: anytype,
headers: *const Headers, headers: *const Fields,
status: Status, status: Status,
) !ResponseStream(@TypeOf(writer)) { ) !ResponseStream(@TypeOf(writer)) {
const buf = try alloc.alloc(u8, chunk_size); const buf = try alloc.alloc(u8, chunk_size);
errdefer alloc.free(buf); errdefer alloc.free(buf);
try writeStatusLine(writer, status); try writeStatusLine(writer, status);
try writeHeaders(writer, headers); try writeFields(writer, headers);
return ResponseStream(@TypeOf(writer)){ return ResponseStream(@TypeOf(writer)){
.allocator = alloc, .allocator = alloc,
@ -25,9 +25,9 @@ pub fn open(
}; };
} }
pub fn writeRequestHeader(writer: anytype, headers: *const Headers, status: Status) !void { pub fn writeRequestHeader(writer: anytype, headers: *const Fields, status: Status) !void {
try writeStatusLine(writer, status); try writeStatusLine(writer, status);
try writeHeaders(writer, headers); try writeFields(writer, headers);
try writer.writeAll("\r\n"); try writer.writeAll("\r\n");
} }
@ -36,7 +36,7 @@ fn writeStatusLine(writer: anytype, status: Status) !void {
try writer.print("HTTP/1.1 {} {s}\r\n", .{ @enumToInt(status), status_text }); try writer.print("HTTP/1.1 {} {s}\r\n", .{ @enumToInt(status), status_text });
} }
fn writeHeaders(writer: anytype, headers: *const Headers) !void { fn writeFields(writer: anytype, headers: *const Fields) !void {
var iter = headers.iterator(); var iter = headers.iterator();
while (iter.next()) |header| { while (iter.next()) |header| {
for (header.value_ptr.*) |ch| { for (header.value_ptr.*) |ch| {
@ -65,7 +65,7 @@ pub fn ResponseStream(comptime BaseWriter: type) type {
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
base_writer: BaseWriter, base_writer: BaseWriter,
headers: *const Headers, headers: *const Fields,
buffer: []u8, buffer: []u8,
buffer_pos: usize = 0, buffer_pos: usize = 0,
chunked: bool = false, chunked: bool = false,
@ -177,7 +177,7 @@ const _tests = struct {
test "ResponseStream no headers empty body" { test "ResponseStream no headers empty body" {
var buffer: [test_buffer_size]u8 = undefined; var buffer: [test_buffer_size]u8 = undefined;
var test_stream = std.io.fixedBufferStream(&buffer); var test_stream = std.io.fixedBufferStream(&buffer);
var headers = Headers.init(std.testing.allocator); var headers = Fields.init(std.testing.allocator);
defer headers.deinit(); defer headers.deinit();
{ {
@ -205,7 +205,7 @@ const _tests = struct {
test "ResponseStream empty body" { test "ResponseStream empty body" {
var buffer: [test_buffer_size]u8 = undefined; var buffer: [test_buffer_size]u8 = undefined;
var test_stream = std.io.fixedBufferStream(&buffer); var test_stream = std.io.fixedBufferStream(&buffer);
var headers = Headers.init(std.testing.allocator); var headers = Fields.init(std.testing.allocator);
defer headers.deinit(); defer headers.deinit();
try headers.put("Content-Type", "text/plain"); try headers.put("Content-Type", "text/plain");
@ -236,7 +236,7 @@ const _tests = struct {
test "ResponseStream not 200 OK" { test "ResponseStream not 200 OK" {
var buffer: [test_buffer_size]u8 = undefined; var buffer: [test_buffer_size]u8 = undefined;
var test_stream = std.io.fixedBufferStream(&buffer); var test_stream = std.io.fixedBufferStream(&buffer);
var headers = Headers.init(std.testing.allocator); var headers = Fields.init(std.testing.allocator);
defer headers.deinit(); defer headers.deinit();
try headers.put("Content-Type", "text/plain"); try headers.put("Content-Type", "text/plain");
@ -266,7 +266,7 @@ const _tests = struct {
test "ResponseStream small body" { test "ResponseStream small body" {
var buffer: [test_buffer_size]u8 = undefined; var buffer: [test_buffer_size]u8 = undefined;
var test_stream = std.io.fixedBufferStream(&buffer); var test_stream = std.io.fixedBufferStream(&buffer);
var headers = Headers.init(std.testing.allocator); var headers = Fields.init(std.testing.allocator);
defer headers.deinit(); defer headers.deinit();
try headers.put("Content-Type", "text/plain"); try headers.put("Content-Type", "text/plain");
@ -300,7 +300,7 @@ const _tests = struct {
test "ResponseStream large body" { test "ResponseStream large body" {
var buffer: [test_buffer_size]u8 = undefined; var buffer: [test_buffer_size]u8 = undefined;
var test_stream = std.io.fixedBufferStream(&buffer); var test_stream = std.io.fixedBufferStream(&buffer);
var headers = Headers.init(std.testing.allocator); var headers = Fields.init(std.testing.allocator);
defer headers.deinit(); defer headers.deinit();
try headers.put("Content-Type", "text/plain"); try headers.put("Content-Type", "text/plain");
@ -341,7 +341,7 @@ const _tests = struct {
test "ResponseStream large body ending on chunk boundary" { test "ResponseStream large body ending on chunk boundary" {
var buffer: [test_buffer_size]u8 = undefined; var buffer: [test_buffer_size]u8 = undefined;
var test_stream = std.io.fixedBufferStream(&buffer); var test_stream = std.io.fixedBufferStream(&buffer);
var headers = Headers.init(std.testing.allocator); var headers = Fields.init(std.testing.allocator);
defer headers.deinit(); defer headers.deinit();
try headers.put("Content-Type", "text/plain"); try headers.put("Content-Type", "text/plain");

View file

@ -37,7 +37,7 @@ pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Respons
const version = req.headers.get("Sec-WebSocket-version") orelse return error.BadHandshake; const version = req.headers.get("Sec-WebSocket-version") orelse return error.BadHandshake;
if (!std.mem.eql(u8, "13", version)) return error.BadHandshake; if (!std.mem.eql(u8, "13", version)) return error.BadHandshake;
var headers = http.Headers.init(alloc); var headers = http.Fields.init(alloc);
defer headers.deinit(); defer headers.deinit();
try headers.put("Upgrade", "websocket"); try headers.put("Upgrade", "websocket");

3
src/http/test.zig Normal file
View file

@ -0,0 +1,3 @@
test {
_ = @import("./request/test_parser.zig");
}

View file

@ -16,7 +16,7 @@ pub const streaming = @import("./controllers/streaming.zig");
pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void { pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void {
// TODO: hashmaps? // TODO: hashmaps?
var response = Response{ .headers = http.Headers.init(alloc), .res = res }; var response = Response{ .headers = http.Fields.init(alloc), .res = res };
defer response.headers.deinit(); defer response.headers.deinit();
const found = routeRequestInternal(api_source, req, &response, alloc); const found = routeRequestInternal(api_source, req, &response, alloc);
@ -64,7 +64,7 @@ pub fn Context(comptime Route: type) type {
method: http.Method, method: http.Method,
uri: []const u8, uri: []const u8,
headers: http.Headers, headers: http.Fields,
args: Args, args: Args,
body: Body, body: Body,
@ -191,7 +191,7 @@ pub fn Context(comptime Route: type) type {
pub const Response = struct { pub const Response = struct {
const Self = @This(); const Self = @This();
headers: http.Headers, headers: http.Fields,
res: *http.Response, res: *http.Response,
opened: bool = false, opened: bool = false,