const std = @import("std"); const http = @import("../lib.zig"); const Method = http.Method; const Fields = http.Fields; const Request = @import("../request.zig").Request; const request_buf_size = 1 << 16; const max_path_len = 1 << 10; fn ParseError(comptime Reader: type) type { return error{ MethodNotImplemented, } | Reader.ReadError; } const Encoding = enum { identity, chunked, }; pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request(@TypeOf(reader)) { const method = try parseMethod(reader); const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) { error.StreamTooLong => return error.RequestUriTooLong, else => return err, }; errdefer alloc.free(uri); const proto = try parseProto(reader); // discard \r\n switch (try reader.readByte()) { '\r' => if ((try reader.readByte()) != '\n') return error.BadRequest, '\n' => {}, else => return error.BadRequest, } var headers = try parseHeaders(alloc, reader); errdefer headers.deinit(); const body = try prepareBody(headers, reader); if (body != null and !method.requestHasBody()) return error.BadRequest; return Request(@TypeOf(reader)){ .protocol = proto, .method = method, .uri = uri, .headers = headers, .body = body, }; } fn parseMethod(reader: anytype) !Method { var buf: [8]u8 = undefined; const str = reader.readUntilDelimiter(&buf, ' ') catch |err| switch (err) { error.StreamTooLong => return error.MethodNotImplemented, else => return err, }; inline for (@typeInfo(Method).Enum.fields) |method| { if (std.mem.eql(u8, method.name, str)) { return @intToEnum(Method, method.value); } } return error.MethodNotImplemented; } fn parseProto(reader: anytype) !http.Protocol { var buf: [8]u8 = undefined; const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { error.StreamTooLong => return error.UnknownProtocol, else => return err, }; if (!std.mem.eql(u8, proto, "HTTP")) { return error.UnknownProtocol; } const count = try reader.read(buf[0..3]); if (count != 3 or buf[1] != '.') { return error.BadRequest; } if (buf[0] != '1') return error.HttpVersionNotSupported; return switch (buf[2]) { '0' => .http_1_0, '1' => .http_1_1, else => .http_1_x, }; } pub fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Fields { var headers = Fields.init(allocator); var buf: [4096]u8 = undefined; while (true) { const full_line = reader.readUntilDelimiter(&buf, '\n') catch |err| switch (err) { error.StreamTooLong => return error.HeaderLineTooLong, else => return err, }; const line = if (full_line.len != 0 and full_line[full_line.len - 1] == '\r') full_line[0 .. full_line.len - 1] else full_line; if (line.len == 0) break; const name = std.mem.sliceTo(line, ':'); if (!isTokenValid(name)) return error.BadRequest; if (name.len == line.len) return error.BadRequest; const encoded_value = line[name.len + 1 ..]; const decoded_value = blk: { var ii: usize = 0; var io: usize = 0; while (ii < encoded_value.len) : ({ ii += 1; io += 1; }) { switch (encoded_value[ii]) { '\r', '\n', 0 => return error.BadRequest, else => {}, } if (encoded_value[ii] == '%') { if (encoded_value.len < ii + 2) return error.BadRequest; const ch_buf = [2]u8{ encoded_value[ii + 1], encoded_value[ii + 2] }; encoded_value[io] = try std.fmt.parseInt(u8, &ch_buf, 16); ii += 2; } else { encoded_value[io] = encoded_value[ii]; } } break :blk encoded_value[0..io]; }; const val = std.mem.trim(u8, decoded_value, " \t"); try headers.append(name, val); } return headers; } fn isTokenValid(token: []const u8) bool { if (token.len == 0) return false; for (token) |ch| { switch (ch) { '"', '(', ')', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '{', '}' => return false, '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~' => {}, else => if (!std.ascii.isAlphanumeric(ch)) return false, } } return true; } fn prepareBody(headers: Fields, reader: anytype) !?TransferStream(@TypeOf(reader)) { const hdr = headers.get("Transfer-Encoding"); // TODO: // if (hder != null and protocol == .http_1_0) return error.BadRequest; const xfer_encoding = try parseEncoding(hdr); const content_encoding = try parseEncoding(headers.get("Content-Encoding")); if (content_encoding != .identity) return error.UnsupportedMediaType; switch (xfer_encoding) { .identity => { const len_str = headers.get("Content-Length") orelse return null; const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest; return TransferStream(@TypeOf(reader)){ .underlying = .{ .identity = std.io.limitedReader(reader, len) } }; }, .chunked => { if (headers.get("Content-Length") != null) return error.BadRequest; return TransferStream(@TypeOf(reader)){ .underlying = .{ .chunked = try ChunkedStream(@TypeOf(reader)).init(reader), }, }; }, } } fn ChunkedStream(comptime R: type) type { return struct { const Self = @This(); remaining: ?usize = 0, underlying: R, const Error = R.Error || error{ Unexpected, InvalidChunkHeader, StreamTooLong, EndOfStream }; fn init(reader: R) !Self { var self: Self = .{ .underlying = reader }; return self; } fn read(self: *Self, buf: []u8) !usize { var count: usize = 0; while (true) { if (count == buf.len) return count; if (self.remaining == null) return count; if (self.remaining.? == 0) self.remaining = try self.readChunkHeader(); const max_read = std.math.min(buf.len, self.remaining.?); const amt = try self.underlying.read(buf[count .. count + max_read]); if (amt != max_read) return error.EndOfStream; count += amt; self.remaining.? -= amt; if (self.remaining.? == 0) { var crlf: [2]u8 = undefined; _ = try self.underlying.readUntilDelimiter(&crlf, '\n'); self.remaining = try self.readChunkHeader(); } if (count == buf.len) return count; } } fn readChunkHeader(self: *Self) !?usize { // TODO: Pick a reasonable limit for this var buf = std.mem.zeroes([10]u8); const line = self.underlying.readUntilDelimiter(&buf, '\n') catch |err| { return if (err == error.StreamTooLong) error.InvalidChunkHeader else err; }; if (line.len < 2 or line[line.len - 1] != '\r') return error.InvalidChunkHeader; const size = std.fmt.parseInt(usize, line[0 .. line.len - 1], 16) catch return error.InvalidChunkHeader; return if (size != 0) size else null; } }; } pub fn TransferStream(comptime R: type) type { return struct { const Error = R.Error || ChunkedStream(R).Error; const Reader = std.io.Reader(*@This(), Error, read); underlying: union(enum) { identity: std.io.LimitedReader(R), chunked: ChunkedStream(R), }, pub fn read(self: *@This(), buf: []u8) Error!usize { return switch (self.underlying) { .identity => |*r| try r.read(buf), .chunked => |*r| try r.read(buf), }; } pub fn reader(self: *@This()) Reader { return .{ .context = self }; } }; } // TODO: assumes that there's only one encoding, not layered encodings fn parseEncoding(encoding: ?[]const u8) !Encoding { if (encoding == null) return .identity; if (std.mem.eql(u8, encoding.?, "identity")) return .identity; if (std.mem.eql(u8, encoding.?, "chunked")) return .chunked; return error.UnsupportedMediaType; } fn isTokenChar(ch: u8) bool { switch (ch) { '"', '(', ')', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '{', '}' => return false, '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~' => return true, else => return std.ascii.isAlphanumeric(ch), } } fn parseToken(alloc: std.mem.Allocator, peek_stream: anytype) ![]const u8 { var data = std.ArrayList(u8).init(alloc); errdefer data.deinit(); const reader = peek_stream.reader(); while (reader.readByte()) |ch| { if (!isTokenChar(ch)) { try peek_stream.putBackByte(ch); break; } try data.append(ch); } else |err| if (err != error.EndOfStream) return err; return data.toOwnedSlice(); } test "parseToken" { const testCase = struct { fn func(data: []const u8, err: ?anyerror, expected: anyerror![]const u8, remaining: []const u8) !void { var fbs = std.io.fixedBufferStream(data); var stream = errorReader(err orelse error.EndOfStream, fbs.reader()); var peeker = std.io.peekStream(1, stream.reader()); const result = parseToken(std.testing.allocator, &peeker); defer if (result) |v| std.testing.allocator.free(v) else |_| {}; if (expected) |val| try std.testing.expectEqualStrings(val, try result) else |expected_err| try std.testing.expectError(expected_err, result); try std.testing.expect(try peeker.reader().isBytes(remaining)); try std.testing.expectError(err orelse error.EndOfStream, peeker.reader().readByte()); } }.func; try testCase("abcdefg", null, "abcdefg", ""); try testCase("abc defg", null, "abc", " defg"); try testCase("abc;defg", null, "abc", ";defg"); try testCase("abc%defg$; ", null, "abc%defg$", "; "); try testCase(" ", null, "", " "); try testCase(";", null, "", ";"); try testCase("abcdefg", error.ClosedPipe, error.ClosedPipe, ""); } fn ErrorReader(comptime E: type, comptime ReaderType: type) type { return struct { inner_reader: ReaderType, err: E, pub const Error = ReaderType.Error || E; pub const Reader = std.io.Reader(*@This(), Error, read); pub fn read(self: *@This(), dest: []u8) Error!usize { const count = try self.inner_reader.readAll(dest); if (count == 0) return self.err; return dest.len; } pub fn reader(self: *@This()) Reader { return .{ .context = self }; } }; } /// Returns the given error after the underlying stream is finished fn errorReader(err: anytype, reader: anytype) ErrorReader(@TypeOf(err), @TypeOf(reader)) { return .{ .inner_reader = reader, .err = err, }; }