const std = @import("std"); const util = @import("util"); const http = @import("../lib.zig"); const Method = http.Method; const Headers = http.Headers; const Request = @import("../request.zig").Request; const request_buf_size = 1 << 16; const max_path_len = 1 << 10; const max_body_len = 1 << 12; fn ParseError(comptime Reader: type) type { return error{ MethodNotImplemented, } | Reader.ReadError; } const Encoding = enum { identity, chunked, }; pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address) !Request { const method = try parseMethod(reader); const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) { error.StreamTooLong => return error.RequestUriTooLong, else => return err, }; errdefer alloc.free(uri); const proto = try parseProto(reader); // discard \r\n _ = try reader.readByte(); _ = try reader.readByte(); var headers = try parseHeaders(alloc, reader); errdefer freeHeaders(alloc, &headers); const body = if (method.requestHasBody()) try readBody(alloc, headers, reader) else null; errdefer if (body) |b| alloc.free(b); const eff_addr = if (headers.get("X-Real-IP")) |ip| std.net.Address.parseIp(ip, address.getPort()) catch { return error.BadRequest; } else address; return Request{ .protocol = proto, .source_address = eff_addr, .method = method, .uri = uri, .headers = headers, .body = body, }; } fn parseMethod(reader: anytype) !Method { var buf: [8]u8 = undefined; const str = reader.readUntilDelimiter(&buf, ' ') catch |err| switch (err) { error.StreamTooLong => return error.MethodNotImplemented, else => return err, }; inline for (@typeInfo(Method).Enum.fields) |method| { if (std.mem.eql(u8, method.name, str)) { return @intToEnum(Method, method.value); } } return error.MethodNotImplemented; } fn parseProto(reader: anytype) !Request.Protocol { var buf: [8]u8 = undefined; const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { error.StreamTooLong => return error.UnknownProtocol, else => return err, }; if (!std.mem.eql(u8, proto, "HTTP")) { return error.UnknownProtocol; } const count = try reader.read(buf[0..3]); if (count != 3 or buf[1] != '.') { return error.BadRequest; } if (buf[0] != '1') return error.HttpVersionNotSupported; return switch (buf[2]) { '0' => .http_1_0, '1' => .http_1_1, else => error.HttpVersionNotSupported, }; } fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers { var map = Headers.init(allocator); errdefer map.deinit(); errdefer { var iter = map.iterator(); while (iter.next()) |it| { allocator.free(it.key_ptr.*); allocator.free(it.value_ptr.*); } } // todo: //errdefer { //var iter = map.iterator(); //while (iter.next()) |it| { //allocator.free(it.key_ptr); //allocator.free(it.value_ptr); //} //} var buf: [1024]u8 = undefined; while (true) { const line = try reader.readUntilDelimiter(&buf, '\n'); if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break; // TODO: handle multi-line headers const name = extractHeaderName(line) orelse continue; const value_end = if (line[line.len - 1] == '\r') line.len - 1 else line.len; const value = line[name.len + 1 + 1 .. value_end]; if (name.len == 0 or value.len == 0) return error.BadRequest; const name_alloc = try allocator.alloc(u8, name.len); errdefer allocator.free(name_alloc); const value_alloc = try allocator.alloc(u8, value.len); errdefer allocator.free(value_alloc); @memcpy(name_alloc.ptr, name.ptr, name.len); @memcpy(value_alloc.ptr, value.ptr, value.len); try map.put(name_alloc, value_alloc); } return map; } fn extractHeaderName(line: []const u8) ?[]const u8 { var index: usize = 0; // TODO: handle whitespace while (index < line.len) : (index += 1) { if (line[index] == ':') { if (index == 0) return null; return line[0..index]; } } return null; } fn readBody(alloc: std.mem.Allocator, headers: Headers, reader: anytype) !?[]const u8 { const xfer_encoding = try parseEncoding(headers.get("Transfer-Encoding")); if (xfer_encoding != .identity) return error.UnsupportedMediaType; const content_encoding = try parseEncoding(headers.get("Content-Encoding")); if (content_encoding != .identity) return error.UnsupportedMediaType; const len_str = headers.get("Content-Length") orelse return null; const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest; if (len > max_body_len) return error.RequestEntityTooLarge; const body = try alloc.alloc(u8, len); errdefer alloc.free(body); reader.readNoEof(body) catch return error.BadRequest; return body; } // TODO: assumes that there's only one encoding, not layered encodings fn parseEncoding(encoding: ?[]const u8) !Encoding { if (encoding == null) return .identity; if (std.mem.eql(u8, encoding.?, "identity")) return .identity; if (std.mem.eql(u8, encoding.?, "chunked")) return .chunked; return error.UnsupportedMediaType; } pub fn parseFree(allocator: std.mem.Allocator, request: *Request) void { allocator.free(request.uri); freeHeaders(allocator, &request.headers); if (request.body) |body| allocator.free(body); } fn freeHeaders(allocator: std.mem.Allocator, headers: *http.Headers) void { var iter = headers.iterator(); while (iter.next()) |it| { allocator.free(it.key_ptr.*); allocator.free(it.value_ptr.*); } headers.deinit(); } const _test = struct { const expectEqual = std.testing.expectEqual; const expectEqualStrings = std.testing.expectEqualStrings; fn toCrlf(comptime str: []const u8) []const u8 { comptime { var buf: [str.len * 2]u8 = undefined; @setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2 var buf_len: usize = 0; for (str) |ch| { if (ch == '\n') { buf[buf_len] = '\r'; buf_len += 1; } buf[buf_len] = ch; buf_len += 1; } return buf[0..buf_len]; } } fn makeHeaders(alloc: std.mem.Allocator, headers: anytype) !Headers { var result = Headers.init(alloc); inline for (headers) |tup| { try result.put(tup[0], tup[1]); } return result; } fn areEqualHeaders(lhs: Headers, rhs: Headers) bool { if (lhs.count() != rhs.count()) return false; var iter = lhs.iterator(); while (iter.next()) |it| { const rhs_val = rhs.get(it.key_ptr.*) orelse return false; if (!std.mem.eql(u8, it.value_ptr.*, rhs_val)) return false; } return true; } fn printHeaders(headers: Headers) void { var iter = headers.iterator(); while (iter.next()) |it| { std.debug.print("{s}: {s}\n", .{ it.key_ptr.*, it.value_ptr.* }); } } fn expectEqualHeaders(expected: Headers, actual: Headers) !void { if (!areEqualHeaders(expected, actual)) { std.debug.print("\nexpected: \n", .{}); printHeaders(expected); std.debug.print("\n\nfound: \n", .{}); printHeaders(actual); std.debug.print("\n\n", .{}); return error.TestExpectedEqual; } } fn parseTestCase(alloc: std.mem.Allocator, comptime request: []const u8, expected: http.Request) !void { var stream = std.io.fixedBufferStream(toCrlf(request)); const result = try parse(alloc, stream.reader()); try expectEqual(expected.method, result.method); try expectEqualStrings(expected.path, result.path); try expectEqualHeaders(expected.headers, result.headers); if ((expected.body == null) != (result.body == null)) { const null_str: []const u8 = "(null)"; const exp = expected.body orelse null_str; const act = result.body orelse null_str; std.debug.print("\nexpected:\n{s}\n\nfound:\n{s}\n\n", .{ exp, act }); return error.TestExpectedEqual; } if (expected.body != null) { try expectEqualStrings(expected.body.?, result.body.?); } } }; // TOOD: failure test cases test "parse" { const testCase = _test.parseTestCase; var buf = [_]u8{0} ** (1 << 16); var fba = std.heap.FixedBufferAllocator.init(&buf); const alloc = fba.allocator(); try testCase(alloc, ( \\GET / HTTP/1.1 \\ \\ ), .{ .method = .GET, .headers = try _test.makeHeaders(alloc, .{}), .path = "/", }); fba.reset(); try testCase(alloc, ( \\POST / HTTP/1.1 \\ \\ ), .{ .method = .POST, .headers = try _test.makeHeaders(alloc, .{}), .path = "/", }); fba.reset(); try testCase(alloc, ( \\HEAD / HTTP/1.1 \\Authorization: bearer \\ \\ ), .{ .method = .HEAD, .headers = try _test.makeHeaders(alloc, .{ .{ "Authorization", "bearer " }, }), .path = "/", }); fba.reset(); try testCase(alloc, ( \\POST /nonsense HTTP/1.1 \\Authorization: bearer \\Content-Length: 5 \\ \\12345 ), .{ .method = .POST, .headers = try _test.makeHeaders(alloc, .{ .{ "Authorization", "bearer " }, .{ "Content-Length", "5" }, }), .path = "/nonsense", .body = "12345", }); fba.reset(); try std.testing.expectError( error.MethodNotImplemented, testCase(alloc, ( \\FOO /nonsense HTTP/1.1 \\ \\ ), undefined), ); fba.reset(); try std.testing.expectError( error.MethodNotImplemented, testCase(alloc, ( \\FOOBARBAZ /nonsense HTTP/1.1 \\ \\ ), undefined), ); fba.reset(); try std.testing.expectError( error.RequestUriTooLong, testCase(alloc, ( \\GET / ++ ("a" ** 2048)), undefined), ); fba.reset(); try std.testing.expectError( error.UnknownProtocol, testCase(alloc, ( \\GET /nonsense SPECIALHTTP/1.1 \\ \\ ), undefined), ); fba.reset(); try std.testing.expectError( error.UnknownProtocol, testCase(alloc, ( \\GET /nonsense JSON/1.1 \\ \\ ), undefined), ); fba.reset(); try std.testing.expectError( error.HttpVersionNotSupported, testCase(alloc, ( \\GET /nonsense HTTP/1.9 \\ \\ ), undefined), ); fba.reset(); try std.testing.expectError( error.HttpVersionNotSupported, testCase(alloc, ( \\GET /nonsense HTTP/8.1 \\ \\ ), undefined), ); fba.reset(); try std.testing.expectError( error.BadRequest, testCase(alloc, ( \\GET /nonsense HTTP/blah blah blah \\ \\ ), undefined), ); fba.reset(); try std.testing.expectError( error.BadRequest, testCase(alloc, ( \\GET /nonsense HTTP/1/1 \\ \\ ), undefined), ); fba.reset(); try std.testing.expectError( error.BadRequest, testCase(alloc, ( \\GET /nonsense HTTP/1/1 \\ \\ ), undefined), ); }