From 2d464f0820f48508a5fdb2972e96c54e1c451417 Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Sun, 6 Nov 2022 23:38:21 -0800 Subject: [PATCH] Chunked transfer encoding --- src/http/request.zig | 5 +- src/http/request/parser.zig | 106 ++++++++++++++++++++++++++++++------ src/http/server.zig | 4 +- src/http/socket.zig | 2 +- src/main/controllers.zig | 23 +++++--- src/main/main.zig | 2 +- 6 files changed, 112 insertions(+), 30 deletions(-) diff --git a/src/http/request.zig b/src/http/request.zig index f3166a8..b713d16 100644 --- a/src/http/request.zig +++ b/src/http/request.zig @@ -3,7 +3,7 @@ const http = @import("./lib.zig"); const parser = @import("./request/parser.zig"); -pub fn Request(comptime _: type) type { +pub fn Request(comptime Reader: type) type { return struct { protocol: http.Protocol, @@ -11,12 +11,11 @@ pub fn Request(comptime _: type) type { uri: []const u8, headers: http.Fields, - body: ?[]const u8 = null, + body: ?parser.TransferStream(Reader), pub fn parseFree(self: *@This(), allocator: std.mem.Allocator) void { allocator.free(self.uri); self.headers.deinit(); - if (self.body) |body| allocator.free(body); } }; } diff --git a/src/http/request/parser.zig b/src/http/request/parser.zig index 6bdf514..ffac9fd 100644 --- a/src/http/request/parser.zig +++ b/src/http/request/parser.zig @@ -8,7 +8,6 @@ const Request = @import("../request.zig").Request; const request_buf_size = 1 << 16; const max_path_len = 1 << 10; -const max_body_len = 1 << 12; fn ParseError(comptime Reader: type) type { return error{ @@ -41,11 +40,8 @@ pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request(@TypeOf(reader) var headers = try parseHeaders(alloc, reader); errdefer headers.deinit(); - const body = if (method.requestHasBody()) - try readBody(alloc, headers, reader) - else - null; - errdefer if (body) |b| alloc.free(b); + const body = try prepareBody(headers, reader); + if (body != null and !method.requestHasBody()) return error.BadRequest; return Request(@TypeOf(reader)){ .protocol = proto, @@ -135,21 +131,99 @@ fn isTokenValid(token: []const u8) bool { return true; } -fn readBody(alloc: std.mem.Allocator, headers: Fields, reader: anytype) !?[]const u8 { - const xfer_encoding = try parseEncoding(headers.get("Transfer-Encoding")); - if (xfer_encoding != .identity) return error.UnsupportedMediaType; +fn prepareBody(headers: Fields, reader: anytype) !?TransferStream(@TypeOf(reader)) { + const hdr = headers.get("Transfer-Encoding"); + // TODO: + // if (hder != null and protocol == .http_1_0) return error.BadRequest; + const xfer_encoding = try parseEncoding(hdr); const content_encoding = try parseEncoding(headers.get("Content-Encoding")); if (content_encoding != .identity) return error.UnsupportedMediaType; - const len_str = headers.get("Content-Length") orelse return null; - const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest; - if (len > max_body_len) return error.RequestEntityTooLarge; - const body = try alloc.alloc(u8, len); - errdefer alloc.free(body); + switch (xfer_encoding) { + .identity => { + const len_str = headers.get("Content-Length") orelse return null; + const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest; - reader.readNoEof(body) catch return error.BadRequest; + return TransferStream(@TypeOf(reader)){ .underlying = .{ .identity = std.io.limitedReader(reader, len) } }; + }, + .chunked => { + if (headers.get("Content-Length") != null) return error.BadRequest; + return TransferStream(@TypeOf(reader)){ + .underlying = .{ + .chunked = ChunkedStream(@TypeOf(reader)){ + .underlying = reader, + }, + }, + }; + }, + } +} - return body; +fn ChunkedStream(comptime R: type) type { + return struct { + const Self = @This(); + + remaining: ?usize = null, + underlying: R, + + const Error = R.Error || error{ Unexpected, InvalidChunkHeader, StreamTooLong, EndOfStream }; + fn read(self: *Self, buf: []u8) !usize { + if (self.remaining) |*remaining| { + var count: usize = 0; + while (count < buf.len) { + const max_read = std.math.min(buf.len, remaining.*); + const amt = try self.underlying.read(buf[count .. count + max_read]); + if (amt != max_read) return error.EndOfStream; + count += amt; + remaining.* -= amt; + + if (count == buf.len) return count; + + self.remaining = try self.readChunkHeader(); + } + } else { + return 0; + } + + unreachable; + } + + fn readChunkHeader(self: *Self) !?usize { + // TODO: Pick a reasonable limit for this + var buf = std.mem.zeroes([10]u8); + const line = self.underlying.readUntilDelimiter(&buf, '\n') catch |err| { + return if (err == error.StreamTooLong) error.InvalidChunkHeader else err; + }; + if (line.len < 2 or line[line.len - 1] != '\r') return error.InvalidChunkHeader; + + const size = std.fmt.parseInt(usize, line[0 .. line.len - 1], 16) catch return error.InvalidChunkHeader; + + return if (size != 0) size else null; + } + }; +} + +pub fn TransferStream(comptime R: type) type { + return struct { + const Error = R.Error || ChunkedStream(R).Error; + const Reader = std.io.Reader(*@This(), Error, read); + + underlying: union(enum) { + identity: std.io.LimitedReader(R), + chunked: ChunkedStream(R), + }, + + pub fn read(self: *@This(), buf: []u8) Error!usize { + return switch (self.underlying) { + .identity => |*r| try r.read(buf), + .chunked => |*r| try r.read(buf), + }; + } + + pub fn reader(self: *@This()) Reader { + return .{ .context = self }; + } + }; } // TODO: assumes that there's only one encoding, not layered encodings diff --git a/src/http/server.zig b/src/http/server.zig index 91b5747..d81d77a 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -34,7 +34,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a var arena = std.heap.ArenaAllocator.init(alloc); defer arena.deinit(); - const req = request.parse(arena.allocator(), conn.stream.reader()) catch |err| { + var req = request.parse(arena.allocator(), conn.stream.reader()) catch |err| { return handleError(conn.stream.writer(), err) catch {}; }; std.log.debug("done parsing", .{}); @@ -44,7 +44,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a .stream = conn.stream, }; - handler(ctx, req, &res); + handler(ctx, &req, &res); std.log.debug("done handling", .{}); if (req.headers.get("Connection")) |hdr| { diff --git a/src/http/socket.zig b/src/http/socket.zig index faac74f..cae032a 100644 --- a/src/http/socket.zig +++ b/src/http/socket.zig @@ -23,7 +23,7 @@ const Opcode = enum(u4) { } }; -pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Response) !Socket { +pub fn handshake(alloc: std.mem.Allocator, req: *http.Request, res: *http.Response) !Socket { const upgrade = req.headers.get("Upgrade") orelse return error.BadHandshake; const connection = req.headers.get("Connection") orelse return error.BadHandshake; if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake; diff --git a/src/main/controllers.zig b/src/main/controllers.zig index a20e9ee..810a16c 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -14,7 +14,7 @@ pub const users = @import("./controllers/users.zig"); pub const notes = @import("./controllers/notes.zig"); pub const streaming = @import("./controllers/streaming.zig"); -pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void { +pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response, alloc: std.mem.Allocator) void { // TODO: hashmaps? var response = Response{ .headers = http.Fields.init(alloc), .res = res }; defer response.headers.deinit(); @@ -24,7 +24,7 @@ pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, if (!found) response.status(.not_found) catch {}; } -fn routeRequestInternal(api_source: anytype, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool { +fn routeRequestInternal(api_source: anytype, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool { inline for (routes) |route| { if (Context(route).matchAndHandle(api_source, req, res, alloc)) return true; } @@ -58,7 +58,7 @@ pub fn Context(comptime Route: type) type { // leave it as a simple string instead of void pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void; - base_request: http.Request, + base_request: *http.Request, allocator: std.mem.Allocator, @@ -70,6 +70,9 @@ pub fn Context(comptime Route: type) type { body: Body, query: Query, + // TODO + body_buf: ?[]const u8 = null, + fn parseArgs(path: []const u8) ?Args { var args: Args = undefined; var path_iter = util.PathIter.from(path); @@ -94,7 +97,7 @@ pub fn Context(comptime Route: type) type { @compileError("Unsupported Type " ++ @typeName(T)); } - pub fn matchAndHandle(api_source: *api.ApiSource, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool { + pub fn matchAndHandle(api_source: *api.ApiSource, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool { if (req.method != Route.method) return false; var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?'); var args: Args = parseArgs(path) orelse return false; @@ -112,6 +115,8 @@ pub fn Context(comptime Route: type) type { .query = undefined, }; + std.log.debug("Matched route {s}", .{path}); + self.prepareAndHandle(api_source, req, res); return true; @@ -129,7 +134,7 @@ pub fn Context(comptime Route: type) type { }; } - fn prepareAndHandle(self: *Self, api_source: anytype, req: http.Request, response: *Response) void { + fn prepareAndHandle(self: *Self, api_source: anytype, req: *http.Request, response: *Response) void { self.parseBody(req) catch |err| return errorHandler(response, .bad_request, err); defer self.freeBody(); @@ -141,16 +146,20 @@ pub fn Context(comptime Route: type) type { self.handle(response, &api_conn); } - fn parseBody(self: *Self, req: http.Request) !void { + fn parseBody(self: *Self, req: *http.Request) !void { if (Body != void) { - const body = req.body orelse return error.NoBody; + var stream = req.body orelse return error.NoBody; + const body = try stream.reader().readAllAlloc(self.allocator, 1 << 16); + errdefer self.allocator.free(body); self.body = try json_utils.parse(Body, body, self.allocator); + self.body_buf = body; } } fn freeBody(self: *Self) void { if (Body != void) { json_utils.parseFree(self.body, self.allocator); + self.allocator.free(self.body_buf.?); } } diff --git a/src/main/main.zig b/src/main/main.zig index 5870bc2..9ff01fe 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -87,7 +87,7 @@ fn thread_main(src: *api.ApiSource, srv: *std.net.StreamServer) void { } } -fn handle(ctx: anytype, req: http.Request, res: *http.Response) void { +fn handle(ctx: anytype, req: *http.Request, res: *http.Response) void { c.routeRequest(ctx.src, req, res, ctx.allocator); }