Chunked transfer encoding

This commit is contained in:
jaina heartles 2022-11-06 23:38:21 -08:00
parent 438c72b7e9
commit 2d464f0820
6 changed files with 112 additions and 30 deletions

View file

@ -3,7 +3,7 @@ const http = @import("./lib.zig");
const parser = @import("./request/parser.zig"); const parser = @import("./request/parser.zig");
pub fn Request(comptime _: type) type { pub fn Request(comptime Reader: type) type {
return struct { return struct {
protocol: http.Protocol, protocol: http.Protocol,
@ -11,12 +11,11 @@ pub fn Request(comptime _: type) type {
uri: []const u8, uri: []const u8,
headers: http.Fields, headers: http.Fields,
body: ?[]const u8 = null, body: ?parser.TransferStream(Reader),
pub fn parseFree(self: *@This(), allocator: std.mem.Allocator) void { pub fn parseFree(self: *@This(), allocator: std.mem.Allocator) void {
allocator.free(self.uri); allocator.free(self.uri);
self.headers.deinit(); self.headers.deinit();
if (self.body) |body| allocator.free(body);
} }
}; };
} }

View file

@ -8,7 +8,6 @@ const Request = @import("../request.zig").Request;
const request_buf_size = 1 << 16; const request_buf_size = 1 << 16;
const max_path_len = 1 << 10; const max_path_len = 1 << 10;
const max_body_len = 1 << 12;
fn ParseError(comptime Reader: type) type { fn ParseError(comptime Reader: type) type {
return error{ return error{
@ -41,11 +40,8 @@ pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request(@TypeOf(reader)
var headers = try parseHeaders(alloc, reader); var headers = try parseHeaders(alloc, reader);
errdefer headers.deinit(); errdefer headers.deinit();
const body = if (method.requestHasBody()) const body = try prepareBody(headers, reader);
try readBody(alloc, headers, reader) if (body != null and !method.requestHasBody()) return error.BadRequest;
else
null;
errdefer if (body) |b| alloc.free(b);
return Request(@TypeOf(reader)){ return Request(@TypeOf(reader)){
.protocol = proto, .protocol = proto,
@ -135,21 +131,99 @@ fn isTokenValid(token: []const u8) bool {
return true; return true;
} }
fn readBody(alloc: std.mem.Allocator, headers: Fields, reader: anytype) !?[]const u8 { fn prepareBody(headers: Fields, reader: anytype) !?TransferStream(@TypeOf(reader)) {
const xfer_encoding = try parseEncoding(headers.get("Transfer-Encoding")); const hdr = headers.get("Transfer-Encoding");
if (xfer_encoding != .identity) return error.UnsupportedMediaType; // TODO:
// if (hder != null and protocol == .http_1_0) return error.BadRequest;
const xfer_encoding = try parseEncoding(hdr);
const content_encoding = try parseEncoding(headers.get("Content-Encoding")); const content_encoding = try parseEncoding(headers.get("Content-Encoding"));
if (content_encoding != .identity) return error.UnsupportedMediaType; if (content_encoding != .identity) return error.UnsupportedMediaType;
switch (xfer_encoding) {
.identity => {
const len_str = headers.get("Content-Length") orelse return null; const len_str = headers.get("Content-Length") orelse return null;
const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest; const len = std.fmt.parseInt(usize, len_str, 10) catch return error.BadRequest;
if (len > max_body_len) return error.RequestEntityTooLarge;
const body = try alloc.alloc(u8, len);
errdefer alloc.free(body);
reader.readNoEof(body) catch return error.BadRequest; return TransferStream(@TypeOf(reader)){ .underlying = .{ .identity = std.io.limitedReader(reader, len) } };
},
.chunked => {
if (headers.get("Content-Length") != null) return error.BadRequest;
return TransferStream(@TypeOf(reader)){
.underlying = .{
.chunked = ChunkedStream(@TypeOf(reader)){
.underlying = reader,
},
},
};
},
}
}
return body; fn ChunkedStream(comptime R: type) type {
return struct {
const Self = @This();
remaining: ?usize = null,
underlying: R,
const Error = R.Error || error{ Unexpected, InvalidChunkHeader, StreamTooLong, EndOfStream };
fn read(self: *Self, buf: []u8) !usize {
if (self.remaining) |*remaining| {
var count: usize = 0;
while (count < buf.len) {
const max_read = std.math.min(buf.len, remaining.*);
const amt = try self.underlying.read(buf[count .. count + max_read]);
if (amt != max_read) return error.EndOfStream;
count += amt;
remaining.* -= amt;
if (count == buf.len) return count;
self.remaining = try self.readChunkHeader();
}
} else {
return 0;
}
unreachable;
}
fn readChunkHeader(self: *Self) !?usize {
// TODO: Pick a reasonable limit for this
var buf = std.mem.zeroes([10]u8);
const line = self.underlying.readUntilDelimiter(&buf, '\n') catch |err| {
return if (err == error.StreamTooLong) error.InvalidChunkHeader else err;
};
if (line.len < 2 or line[line.len - 1] != '\r') return error.InvalidChunkHeader;
const size = std.fmt.parseInt(usize, line[0 .. line.len - 1], 16) catch return error.InvalidChunkHeader;
return if (size != 0) size else null;
}
};
}
pub fn TransferStream(comptime R: type) type {
return struct {
const Error = R.Error || ChunkedStream(R).Error;
const Reader = std.io.Reader(*@This(), Error, read);
underlying: union(enum) {
identity: std.io.LimitedReader(R),
chunked: ChunkedStream(R),
},
pub fn read(self: *@This(), buf: []u8) Error!usize {
return switch (self.underlying) {
.identity => |*r| try r.read(buf),
.chunked => |*r| try r.read(buf),
};
}
pub fn reader(self: *@This()) Reader {
return .{ .context = self };
}
};
} }
// TODO: assumes that there's only one encoding, not layered encodings // TODO: assumes that there's only one encoding, not layered encodings

View file

@ -34,7 +34,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a
var arena = std.heap.ArenaAllocator.init(alloc); var arena = std.heap.ArenaAllocator.init(alloc);
defer arena.deinit(); defer arena.deinit();
const req = request.parse(arena.allocator(), conn.stream.reader()) catch |err| { var req = request.parse(arena.allocator(), conn.stream.reader()) catch |err| {
return handleError(conn.stream.writer(), err) catch {}; return handleError(conn.stream.writer(), err) catch {};
}; };
std.log.debug("done parsing", .{}); std.log.debug("done parsing", .{});
@ -44,7 +44,7 @@ pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: a
.stream = conn.stream, .stream = conn.stream,
}; };
handler(ctx, req, &res); handler(ctx, &req, &res);
std.log.debug("done handling", .{}); std.log.debug("done handling", .{});
if (req.headers.get("Connection")) |hdr| { if (req.headers.get("Connection")) |hdr| {

View file

@ -23,7 +23,7 @@ const Opcode = enum(u4) {
} }
}; };
pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Response) !Socket { pub fn handshake(alloc: std.mem.Allocator, req: *http.Request, res: *http.Response) !Socket {
const upgrade = req.headers.get("Upgrade") orelse return error.BadHandshake; const upgrade = req.headers.get("Upgrade") orelse return error.BadHandshake;
const connection = req.headers.get("Connection") orelse return error.BadHandshake; const connection = req.headers.get("Connection") orelse return error.BadHandshake;
if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake; if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake;

View file

@ -14,7 +14,7 @@ pub const users = @import("./controllers/users.zig");
pub const notes = @import("./controllers/notes.zig"); pub const notes = @import("./controllers/notes.zig");
pub const streaming = @import("./controllers/streaming.zig"); pub const streaming = @import("./controllers/streaming.zig");
pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void { pub fn routeRequest(api_source: anytype, req: *http.Request, res: *http.Response, alloc: std.mem.Allocator) void {
// TODO: hashmaps? // TODO: hashmaps?
var response = Response{ .headers = http.Fields.init(alloc), .res = res }; var response = Response{ .headers = http.Fields.init(alloc), .res = res };
defer response.headers.deinit(); defer response.headers.deinit();
@ -24,7 +24,7 @@ pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response,
if (!found) response.status(.not_found) catch {}; if (!found) response.status(.not_found) catch {};
} }
fn routeRequestInternal(api_source: anytype, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool { fn routeRequestInternal(api_source: anytype, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool {
inline for (routes) |route| { inline for (routes) |route| {
if (Context(route).matchAndHandle(api_source, req, res, alloc)) return true; if (Context(route).matchAndHandle(api_source, req, res, alloc)) return true;
} }
@ -58,7 +58,7 @@ pub fn Context(comptime Route: type) type {
// leave it as a simple string instead of void // leave it as a simple string instead of void
pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void; pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void;
base_request: http.Request, base_request: *http.Request,
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
@ -70,6 +70,9 @@ pub fn Context(comptime Route: type) type {
body: Body, body: Body,
query: Query, query: Query,
// TODO
body_buf: ?[]const u8 = null,
fn parseArgs(path: []const u8) ?Args { fn parseArgs(path: []const u8) ?Args {
var args: Args = undefined; var args: Args = undefined;
var path_iter = util.PathIter.from(path); var path_iter = util.PathIter.from(path);
@ -94,7 +97,7 @@ pub fn Context(comptime Route: type) type {
@compileError("Unsupported Type " ++ @typeName(T)); @compileError("Unsupported Type " ++ @typeName(T));
} }
pub fn matchAndHandle(api_source: *api.ApiSource, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool { pub fn matchAndHandle(api_source: *api.ApiSource, req: *http.Request, res: *Response, alloc: std.mem.Allocator) bool {
if (req.method != Route.method) return false; if (req.method != Route.method) return false;
var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?'); var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?');
var args: Args = parseArgs(path) orelse return false; var args: Args = parseArgs(path) orelse return false;
@ -112,6 +115,8 @@ pub fn Context(comptime Route: type) type {
.query = undefined, .query = undefined,
}; };
std.log.debug("Matched route {s}", .{path});
self.prepareAndHandle(api_source, req, res); self.prepareAndHandle(api_source, req, res);
return true; return true;
@ -129,7 +134,7 @@ pub fn Context(comptime Route: type) type {
}; };
} }
fn prepareAndHandle(self: *Self, api_source: anytype, req: http.Request, response: *Response) void { fn prepareAndHandle(self: *Self, api_source: anytype, req: *http.Request, response: *Response) void {
self.parseBody(req) catch |err| return errorHandler(response, .bad_request, err); self.parseBody(req) catch |err| return errorHandler(response, .bad_request, err);
defer self.freeBody(); defer self.freeBody();
@ -141,16 +146,20 @@ pub fn Context(comptime Route: type) type {
self.handle(response, &api_conn); self.handle(response, &api_conn);
} }
fn parseBody(self: *Self, req: http.Request) !void { fn parseBody(self: *Self, req: *http.Request) !void {
if (Body != void) { if (Body != void) {
const body = req.body orelse return error.NoBody; var stream = req.body orelse return error.NoBody;
const body = try stream.reader().readAllAlloc(self.allocator, 1 << 16);
errdefer self.allocator.free(body);
self.body = try json_utils.parse(Body, body, self.allocator); self.body = try json_utils.parse(Body, body, self.allocator);
self.body_buf = body;
} }
} }
fn freeBody(self: *Self) void { fn freeBody(self: *Self) void {
if (Body != void) { if (Body != void) {
json_utils.parseFree(self.body, self.allocator); json_utils.parseFree(self.body, self.allocator);
self.allocator.free(self.body_buf.?);
} }
} }

View file

@ -87,7 +87,7 @@ fn thread_main(src: *api.ApiSource, srv: *std.net.StreamServer) void {
} }
} }
fn handle(ctx: anytype, req: http.Request, res: *http.Response) void { fn handle(ctx: anytype, req: *http.Request, res: *http.Response) void {
c.routeRequest(ctx.src, req, res, ctx.allocator); c.routeRequest(ctx.src, req, res, ctx.allocator);
} }