From 5e796d3b3de257217fdf4570910b317d78a7516c Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Thu, 23 Jun 2022 00:25:50 -0700 Subject: [PATCH] Add request parsing to http lib --- src/http/lib.zig | 9 + src/http/server.zig | 435 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 444 insertions(+) create mode 100644 src/http/server.zig diff --git a/src/http/lib.zig b/src/http/lib.zig index e2d0052..bb7279e 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -2,6 +2,7 @@ const std = @import("std"); const ciutf8 = @import("util").ciutf8; const routing = @import("./routing.zig"); +const server = @import("./server.zig"); const response_stream = @import("./response_stream.zig"); pub const Status = std.http.Status; @@ -19,7 +20,15 @@ pub const Headers = std.HashMap([]const u8, []const u8, struct { } }, std.hash_map.default_max_load_percentage); +pub const Request = struct { + method: Method, + path: []const u8, + headers: Headers, + body: ?[]const u8 = null, +}; + test { _ = ResponseStream; _ = routing; + _ = server; } diff --git a/src/http/server.zig b/src/http/server.zig new file mode 100644 index 0000000..2f9515d --- /dev/null +++ b/src/http/server.zig @@ -0,0 +1,435 @@ +const std = @import("std"); +const util = @import("util"); +const http = @import("./lib.zig"); + +const Address = std.net.Address; +const Method = http.Method; +const Request = http.Request; +const Headers = http.Headers; +const Id = u64; + +const Connection = struct { + id: Id, + address: Address, + stream: std.net.Stream, +}; + +const ConnectionServer = struct { + alloc: std.mem.Allocator, + next_conn_id: std.atomic.Atomic(Id) = std.atomic.Atomic(Id).init(1), + + // todo accept is a bad name + fn accept(self: *ConnectionServer, stream: std.net.Stream, address: Address) void { + const conn = Connection{ + .id = self.next_conn_id.fetchAdd(1, .SeqCst), + .address = address, + .stream = stream, + }; + + defer conn.stream.close(); + std.log.debug("new connection conn_id={}", .{conn.id}); + + async handleConnection(conn); + + std.log.debug("terminating connection conn_id={}", .{conn.id}); + } +}; + +const request_buf_size = 1 << 16; +const max_path_len = 1 << 10; +const max_body_len = 1 << 12; + +fn ParseError(comptime Reader: type) type { + return error{ + MethodNotImplemented, + } | Reader.ReadError; +} + +const Encoding = enum { + identity, + chunked, +}; + +fn handleConnection(conn: Connection) void { + var request_buf: [request_buf_size]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&request_buf); + + const request = try parseRequest(fba.allocator(), conn.stream.reader()); + _ = request; +} + +fn parseRequest(alloc: std.mem.Allocator, reader: anytype) !Request { + var request: Request = undefined; + + try parseRequestLine(alloc, &request, reader); + request.headers = try parseHeaders(alloc, reader); + + if (request.method.requestHasBody()) { + request.body = try readBody(alloc, request.headers, reader); + } else { + request.body = null; + } + + return request; +} + +fn parseRequestLine(alloc: std.mem.Allocator, request: *Request, reader: anytype) !void { + request.method = try parseMethod(reader); + request.path = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) { + error.StreamTooLong => return error.RequestUriTooLong, + else => return err, + }; + + try checkProto(reader); + + // discard \r\n + _ = try reader.readByte(); + _ = try reader.readByte(); +} + +fn parseMethod(reader: anytype) !Method { + var buf: [8]u8 = undefined; + const str = reader.readUntilDelimiter(&buf, ' ') catch |err| switch (err) { + error.StreamTooLong => return error.MethodNotImplemented, + else => return err, + }; + + inline for (@typeInfo(Method).Enum.fields) |method| { + if (std.mem.eql(u8, method.name, str)) { + return @intToEnum(Method, method.value); + } + } + + return error.MethodNotImplemented; +} + +fn checkProto(reader: anytype) !void { + var buf: [8]u8 = undefined; + const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { + error.StreamTooLong => return error.UnknownProtocol, + else => return err, + }; + + if (!std.mem.eql(u8, proto, "HTTP")) { + return error.UnknownProtocol; + } + + const count = try reader.read(buf[0..3]); + if (count != 3 or buf[1] != '.') { + return error.BadRequest; + } + + if (buf[0] != '1' or buf[2] != '1') { + return error.HttpVersionNotSupported; + } +} + +fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers { + var map = Headers.init(allocator); + errdefer map.deinit(); + // TODO: free map keys/values + + var buf: [1024]u8 = undefined; + + while (true) { + const line = try reader.readUntilDelimiter(&buf, '\n'); + if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break; + + // TODO: handle multi-line headers + const name = extractHeaderName(line) orelse continue; + const value_end = if (line[line.len - 1] == '\r') line.len - 1 else line.len; + const value = line[name.len + 1 + 1 .. value_end]; + + if (name.len == 0 or value.len == 0) return error.BadRequest; + + const name_alloc = try allocator.alloc(u8, name.len); + errdefer allocator.free(name_alloc); + const value_alloc = try allocator.alloc(u8, value.len); + errdefer allocator.free(value_alloc); + + @memcpy(name_alloc.ptr, name.ptr, name.len); + @memcpy(value_alloc.ptr, value.ptr, value.len); + + try map.put(name_alloc, value_alloc); + } + + return map; +} + +fn extractHeaderName(line: []const u8) ?[]const u8 { + var index: usize = 0; + + // TODO: handle whitespace + while (index < line.len) : (index += 1) { + if (line[index] == ':') { + if (index == 0) return null; + return line[0..index]; + } + } + + return null; +} + +fn readBody(alloc: std.mem.Allocator, headers: Headers, reader: anytype) !?[]const u8 { + const xfer_encoding = try parseEncoding(headers.get("Transfer-Encoding")); + if (xfer_encoding != .identity) return error.UnsupportedMediaType; + const content_encoding = try parseEncoding(headers.get("Content-Encoding")); + if (content_encoding != .identity) return error.UnsupportedMediaType; + + const len_str = headers.get("Content-Length") orelse return null; + const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest; + if (len > max_body_len) return error.RequestEntityTooLarge; + const body = try alloc.alloc(u8, len); + errdefer alloc.free(body); + + reader.readNoEof(body) catch return error.BadRequest; + + return body; +} + +// TODO: assumes that there's only one encoding, not layered encodings +fn parseEncoding(encoding: ?[]const u8) !Encoding { + if (encoding == null) return .identity; + if (std.mem.eql(u8, encoding.?, "identity")) return .identity; + if (std.mem.eql(u8, encoding.?, "chunked")) return .chunked; + return error.UnsupportedMediaType; +} + +const _test = struct { + const expectEqual = std.testing.expectEqual; + const expectEqualStrings = std.testing.expectEqualStrings; + + fn toCrlf(comptime str: []const u8) []const u8 { + comptime { + var buf: [str.len * 2]u8 = undefined; + + @setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2 + + var buf_len: usize = 0; + for (str) |ch| { + if (ch == '\n') { + buf[buf_len] = '\r'; + buf_len += 1; + } + + buf[buf_len] = ch; + buf_len += 1; + } + + return buf[0..buf_len]; + } + } + + fn makeHeaders(alloc: std.mem.Allocator, headers: anytype) !Headers { + var result = Headers.init(alloc); + inline for (headers) |tup| { + try result.put(tup[0], tup[1]); + } + return result; + } + + fn areEqualHeaders(lhs: Headers, rhs: Headers) bool { + if (lhs.count() != rhs.count()) return false; + var iter = lhs.iterator(); + while (iter.next()) |it| { + const rhs_val = rhs.get(it.key_ptr.*) orelse return false; + if (!std.mem.eql(u8, it.value_ptr.*, rhs_val)) return false; + } + return true; + } + + fn printHeaders(headers: Headers) void { + var iter = headers.iterator(); + while (iter.next()) |it| { + std.debug.print("{s}: {s}\n", .{ it.key_ptr.*, it.value_ptr.* }); + } + } + + fn expectEqualHeaders(expected: Headers, actual: Headers) !void { + if (!areEqualHeaders(expected, actual)) { + std.debug.print("\nexpected: \n", .{}); + printHeaders(expected); + std.debug.print("\n\nfound: \n", .{}); + printHeaders(actual); + std.debug.print("\n\n", .{}); + return error.TestExpectedEqual; + } + } + + fn parseRequestTestCase(alloc: std.mem.Allocator, comptime request: []const u8, expected: http.Request) !void { + var stream = std.io.fixedBufferStream(toCrlf(request)); + + const result = try parseRequest(alloc, stream.reader()); + + try expectEqual(expected.method, result.method); + try expectEqualStrings(expected.path, result.path); + try expectEqualHeaders(expected.headers, result.headers); + if ((expected.body == null) != (result.body == null)) { + const null_str: []const u8 = "(null)"; + const exp = expected.body orelse null_str; + const act = result.body orelse null_str; + std.debug.print("\nexpected:\n{s}\n\nfound:\n{s}\n\n", .{ exp, act }); + return error.TestExpectedEqual; + } + if (expected.body != null) { + try expectEqualStrings(expected.body.?, result.body.?); + } + } +}; + +// TOOD: failure test cases +test "parseRequest" { + const testCase = _test.parseRequestTestCase; + var buf = [_]u8{0} ** (1 << 16); + var fba = std.heap.FixedBufferAllocator.init(&buf); + const alloc = fba.allocator(); + try testCase(alloc, ( + \\GET / HTTP/1.1 + \\ + \\ + ), .{ + .method = .GET, + .headers = try _test.makeHeaders(alloc, .{}), + .path = "/", + }); + + fba.reset(); + try testCase(alloc, ( + \\POST / HTTP/1.1 + \\ + \\ + ), .{ + .method = .POST, + .headers = try _test.makeHeaders(alloc, .{}), + .path = "/", + }); + + fba.reset(); + try testCase(alloc, ( + \\HEAD / HTTP/1.1 + \\Authorization: bearer + \\ + \\ + ), .{ + .method = .HEAD, + .headers = try _test.makeHeaders(alloc, .{ + .{ "Authorization", "bearer " }, + }), + .path = "/", + }); + + fba.reset(); + try testCase(alloc, ( + \\POST /nonsense HTTP/1.1 + \\Authorization: bearer + \\Content-Length: 5 + \\ + \\12345 + ), .{ + .method = .POST, + .headers = try _test.makeHeaders(alloc, .{ + .{ "Authorization", "bearer " }, + .{ "Content-Length", "5" }, + }), + .path = "/nonsense", + .body = "12345", + }); + + fba.reset(); + try std.testing.expectError( + error.MethodNotImplemented, + testCase(alloc, ( + \\FOO /nonsense HTTP/1.1 + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.MethodNotImplemented, + testCase(alloc, ( + \\FOOBARBAZ /nonsense HTTP/1.1 + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.RequestUriTooLong, + testCase(alloc, ( + \\GET / + ++ ("a" ** 2048)), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.UnknownProtocol, + testCase(alloc, ( + \\GET /nonsense SPECIALHTTP/1.1 + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.UnknownProtocol, + testCase(alloc, ( + \\GET /nonsense JSON/1.1 + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.HttpVersionNotSupported, + testCase(alloc, ( + \\GET /nonsense HTTP/1.9 + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.HttpVersionNotSupported, + testCase(alloc, ( + \\GET /nonsense HTTP/8.1 + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.BadRequest, + testCase(alloc, ( + \\GET /nonsense HTTP/blah blah blah + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.BadRequest, + testCase(alloc, ( + \\GET /nonsense HTTP/1/1 + \\ + \\ + ), undefined), + ); + + fba.reset(); + try std.testing.expectError( + error.BadRequest, + testCase(alloc, ( + \\GET /nonsense HTTP/1/1 + \\ + \\ + ), undefined), + ); +}