diff --git a/src/http/request/parser.zig b/src/http/request/parser.zig index a11c315..54ca6bd 100644 --- a/src/http/request/parser.zig +++ b/src/http/request/parser.zig @@ -265,3 +265,87 @@ fn parseEncoding(encoding: ?[]const u8) !Encoding { if (std.mem.eql(u8, encoding.?, "chunked")) return .chunked; return error.UnsupportedMediaType; } + +fn isTokenChar(ch: u8) bool { + switch (ch) { + '"', '(', ')', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '{', '}' => return false, + + '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~' => return true, + else => return std.ascii.isAlphanumeric(ch), + } +} + +fn parseToken(alloc: std.mem.Allocator, peek_stream: anytype) ![]const u8 { + var data = std.ArrayList(u8).init(alloc); + errdefer data.deinit(); + + const reader = peek_stream.reader(); + while (reader.readByte()) |ch| { + if (!isTokenChar(ch)) { + try peek_stream.putBackByte(ch); + break; + } + try data.append(ch); + } else |err| if (err != error.EndOfStream) return err; + + return data.toOwnedSlice(); +} + +test "parseToken" { + const testCase = struct { + fn func(data: []const u8, err: ?anyerror, expected: anyerror![]const u8, remaining: []const u8) !void { + var fbs = std.io.fixedBufferStream(data); + var stream = errorReader(err orelse error.EndOfStream, fbs.reader()); + var peeker = std.io.peekStream(1, stream.reader()); + + const result = parseToken(std.testing.allocator, &peeker); + defer if (result) |v| std.testing.allocator.free(v) else |_| {}; + + if (expected) |val| + try std.testing.expectEqualStrings(val, try result) + else |expected_err| + try std.testing.expectError(expected_err, result); + + try std.testing.expect(try peeker.reader().isBytes(remaining)); + try std.testing.expectError(err orelse error.EndOfStream, peeker.reader().readByte()); + } + }.func; + + try testCase("abcdefg", null, "abcdefg", ""); + try testCase("abc defg", null, "abc", " defg"); + try testCase("abc;defg", null, "abc", ";defg"); + try testCase("abc%defg$; ", null, "abc%defg$", "; "); + + try testCase(" ", null, "", " "); + try testCase(";", null, "", ";"); + + try testCase("abcdefg", error.ClosedPipe, error.ClosedPipe, ""); +} + +fn ErrorReader(comptime E: type, comptime ReaderType: type) type { + return struct { + inner_reader: ReaderType, + err: E, + + pub const Error = ReaderType.Error || E; + pub const Reader = std.io.Reader(*@This(), Error, read); + + pub fn read(self: *@This(), dest: []u8) Error!usize { + const count = try self.inner_reader.readAll(dest); + if (count == 0) return self.err; + return dest.len; + } + + pub fn reader(self: *@This()) Reader { + return .{ .context = self }; + } + }; +} + +/// Returns the given error after the underlying stream is finished +fn errorReader(err: anytype, reader: anytype) ErrorReader(@TypeOf(err), @TypeOf(reader)) { + return .{ + .inner_reader = reader, + .err = err, + }; +}