fediglam/src/http/request/parser.zig

443 lines
12 KiB
Zig

const std = @import("std");
const util = @import("util");
const http = @import("../lib.zig");
const Method = http.Method;
const Headers = http.Headers;
const Request = @import("../request.zig").Request;
const request_buf_size = 1 << 16;
const max_path_len = 1 << 10;
const max_body_len = 1 << 12;
fn ParseError(comptime Reader: type) type {
return error{
MethodNotImplemented,
} | Reader.ReadError;
}
const Encoding = enum {
identity,
chunked,
};
pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address) !Request {
const method = try parseMethod(reader);
const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) {
error.StreamTooLong => return error.RequestUriTooLong,
else => return err,
};
errdefer alloc.free(uri);
const proto = try parseProto(reader);
// discard \r\n
_ = try reader.readByte();
_ = try reader.readByte();
var headers = try parseHeaders(alloc, reader);
errdefer freeHeaders(alloc, &headers);
const body = if (method.requestHasBody())
try readBody(alloc, headers, reader)
else
null;
errdefer if (body) |b| alloc.free(b);
const eff_addr = if (headers.get("X-Real-IP")) |ip|
std.net.Address.parseIp(ip, address.getPort()) catch {
return error.BadRequest;
}
else
address;
return Request{
.protocol = proto,
.source_address = eff_addr,
.method = method,
.uri = uri,
.headers = headers,
.body = body,
};
}
fn parseMethod(reader: anytype) !Method {
var buf: [8]u8 = undefined;
const str = reader.readUntilDelimiter(&buf, ' ') catch |err| switch (err) {
error.StreamTooLong => return error.MethodNotImplemented,
else => return err,
};
inline for (@typeInfo(Method).Enum.fields) |method| {
if (std.mem.eql(u8, method.name, str)) {
return @intToEnum(Method, method.value);
}
}
return error.MethodNotImplemented;
}
fn parseProto(reader: anytype) !Request.Protocol {
var buf: [8]u8 = undefined;
const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) {
error.StreamTooLong => return error.UnknownProtocol,
else => return err,
};
if (!std.mem.eql(u8, proto, "HTTP")) {
return error.UnknownProtocol;
}
const count = try reader.read(buf[0..3]);
if (count != 3 or buf[1] != '.') {
return error.BadRequest;
}
if (buf[0] != '1') return error.HttpVersionNotSupported;
return switch (buf[2]) {
'0' => .http_1_0,
'1' => .http_1_1,
else => error.HttpVersionNotSupported,
};
}
fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers {
var map = Headers.init(allocator);
errdefer map.deinit();
errdefer {
var iter = map.iterator();
while (iter.next()) |it| {
allocator.free(it.key_ptr.*);
allocator.free(it.value_ptr.*);
}
}
// todo:
//errdefer {
//var iter = map.iterator();
//while (iter.next()) |it| {
//allocator.free(it.key_ptr);
//allocator.free(it.value_ptr);
//}
//}
var buf: [1024]u8 = undefined;
while (true) {
const line = try reader.readUntilDelimiter(&buf, '\n');
if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break;
// TODO: handle multi-line headers
const name = extractHeaderName(line) orelse continue;
const value_end = if (line[line.len - 1] == '\r') line.len - 1 else line.len;
const value = line[name.len + 1 + 1 .. value_end];
if (name.len == 0 or value.len == 0) return error.BadRequest;
const name_alloc = try allocator.alloc(u8, name.len);
errdefer allocator.free(name_alloc);
const value_alloc = try allocator.alloc(u8, value.len);
errdefer allocator.free(value_alloc);
@memcpy(name_alloc.ptr, name.ptr, name.len);
@memcpy(value_alloc.ptr, value.ptr, value.len);
try map.put(name_alloc, value_alloc);
}
return map;
}
fn extractHeaderName(line: []const u8) ?[]const u8 {
var index: usize = 0;
// TODO: handle whitespace
while (index < line.len) : (index += 1) {
if (line[index] == ':') {
if (index == 0) return null;
return line[0..index];
}
}
return null;
}
fn readBody(alloc: std.mem.Allocator, headers: Headers, reader: anytype) !?[]const u8 {
const xfer_encoding = try parseEncoding(headers.get("Transfer-Encoding"));
if (xfer_encoding != .identity) return error.UnsupportedMediaType;
const content_encoding = try parseEncoding(headers.get("Content-Encoding"));
if (content_encoding != .identity) return error.UnsupportedMediaType;
const len_str = headers.get("Content-Length") orelse return null;
const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest;
if (len > max_body_len) return error.RequestEntityTooLarge;
const body = try alloc.alloc(u8, len);
errdefer alloc.free(body);
reader.readNoEof(body) catch return error.BadRequest;
return body;
}
// TODO: assumes that there's only one encoding, not layered encodings
fn parseEncoding(encoding: ?[]const u8) !Encoding {
if (encoding == null) return .identity;
if (std.mem.eql(u8, encoding.?, "identity")) return .identity;
if (std.mem.eql(u8, encoding.?, "chunked")) return .chunked;
return error.UnsupportedMediaType;
}
pub fn parseFree(allocator: std.mem.Allocator, request: *Request) void {
allocator.free(request.uri);
freeHeaders(allocator, &request.headers);
if (request.body) |body| allocator.free(body);
}
fn freeHeaders(allocator: std.mem.Allocator, headers: *http.Headers) void {
var iter = headers.iterator();
while (iter.next()) |it| {
allocator.free(it.key_ptr.*);
allocator.free(it.value_ptr.*);
}
headers.deinit();
}
const _test = struct {
const expectEqual = std.testing.expectEqual;
const expectEqualStrings = std.testing.expectEqualStrings;
fn toCrlf(comptime str: []const u8) []const u8 {
comptime {
var buf: [str.len * 2]u8 = undefined;
@setEvalBranchQuota(@intCast(u32, str.len * 2)); // TODO: why does this need to be *2
var buf_len: usize = 0;
for (str) |ch| {
if (ch == '\n') {
buf[buf_len] = '\r';
buf_len += 1;
}
buf[buf_len] = ch;
buf_len += 1;
}
return buf[0..buf_len];
}
}
fn makeHeaders(alloc: std.mem.Allocator, headers: anytype) !Headers {
var result = Headers.init(alloc);
inline for (headers) |tup| {
try result.put(tup[0], tup[1]);
}
return result;
}
fn areEqualHeaders(lhs: Headers, rhs: Headers) bool {
if (lhs.count() != rhs.count()) return false;
var iter = lhs.iterator();
while (iter.next()) |it| {
const rhs_val = rhs.get(it.key_ptr.*) orelse return false;
if (!std.mem.eql(u8, it.value_ptr.*, rhs_val)) return false;
}
return true;
}
fn printHeaders(headers: Headers) void {
var iter = headers.iterator();
while (iter.next()) |it| {
std.debug.print("{s}: {s}\n", .{ it.key_ptr.*, it.value_ptr.* });
}
}
fn expectEqualHeaders(expected: Headers, actual: Headers) !void {
if (!areEqualHeaders(expected, actual)) {
std.debug.print("\nexpected: \n", .{});
printHeaders(expected);
std.debug.print("\n\nfound: \n", .{});
printHeaders(actual);
std.debug.print("\n\n", .{});
return error.TestExpectedEqual;
}
}
fn parseTestCase(alloc: std.mem.Allocator, comptime request: []const u8, expected: http.Request) !void {
var stream = std.io.fixedBufferStream(toCrlf(request));
const result = try parse(alloc, stream.reader());
try expectEqual(expected.method, result.method);
try expectEqualStrings(expected.path, result.path);
try expectEqualHeaders(expected.headers, result.headers);
if ((expected.body == null) != (result.body == null)) {
const null_str: []const u8 = "(null)";
const exp = expected.body orelse null_str;
const act = result.body orelse null_str;
std.debug.print("\nexpected:\n{s}\n\nfound:\n{s}\n\n", .{ exp, act });
return error.TestExpectedEqual;
}
if (expected.body != null) {
try expectEqualStrings(expected.body.?, result.body.?);
}
}
};
// TOOD: failure test cases
test "parse" {
const testCase = _test.parseTestCase;
var buf = [_]u8{0} ** (1 << 16);
var fba = std.heap.FixedBufferAllocator.init(&buf);
const alloc = fba.allocator();
try testCase(alloc, (
\\GET / HTTP/1.1
\\
\\
), .{
.method = .GET,
.headers = try _test.makeHeaders(alloc, .{}),
.path = "/",
});
fba.reset();
try testCase(alloc, (
\\POST / HTTP/1.1
\\
\\
), .{
.method = .POST,
.headers = try _test.makeHeaders(alloc, .{}),
.path = "/",
});
fba.reset();
try testCase(alloc, (
\\HEAD / HTTP/1.1
\\Authorization: bearer <token>
\\
\\
), .{
.method = .HEAD,
.headers = try _test.makeHeaders(alloc, .{
.{ "Authorization", "bearer <token>" },
}),
.path = "/",
});
fba.reset();
try testCase(alloc, (
\\POST /nonsense HTTP/1.1
\\Authorization: bearer <token>
\\Content-Length: 5
\\
\\12345
), .{
.method = .POST,
.headers = try _test.makeHeaders(alloc, .{
.{ "Authorization", "bearer <token>" },
.{ "Content-Length", "5" },
}),
.path = "/nonsense",
.body = "12345",
});
fba.reset();
try std.testing.expectError(
error.MethodNotImplemented,
testCase(alloc, (
\\FOO /nonsense HTTP/1.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.MethodNotImplemented,
testCase(alloc, (
\\FOOBARBAZ /nonsense HTTP/1.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.RequestUriTooLong,
testCase(alloc, (
\\GET /
++ ("a" ** 2048)), undefined),
);
fba.reset();
try std.testing.expectError(
error.UnknownProtocol,
testCase(alloc, (
\\GET /nonsense SPECIALHTTP/1.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.UnknownProtocol,
testCase(alloc, (
\\GET /nonsense JSON/1.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.HttpVersionNotSupported,
testCase(alloc, (
\\GET /nonsense HTTP/1.9
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.HttpVersionNotSupported,
testCase(alloc, (
\\GET /nonsense HTTP/8.1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.BadRequest,
testCase(alloc, (
\\GET /nonsense HTTP/blah blah blah
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.BadRequest,
testCase(alloc, (
\\GET /nonsense HTTP/1/1
\\
\\
), undefined),
);
fba.reset();
try std.testing.expectError(
error.BadRequest,
testCase(alloc, (
\\GET /nonsense HTTP/1/1
\\
\\
), undefined),
);
}