const std = @import("std"); const util = @import("util"); const http = @import("./lib.zig"); const max_boundary = 70; const read_ahead = max_boundary + 4; pub fn MultipartStream(comptime ReaderType: type) type { return struct { const Multipart = @This(); pub const PartReader = std.io.Reader(*Part, ReaderType.Error, Part.read); stream: std.io.PeekStream(.{ .Static = read_ahead }, ReaderType), boundary: []const u8, pub fn next(self: *Multipart, alloc: std.mem.Allocator) !?Part { const reader = self.stream.reader(); while (true) { try reader.skipUntilDelimiterOrEof('\r'); var line_buf: [read_ahead]u8 = undefined; const len = try reader.readAll(line_buf[0 .. self.boundary.len + 3]); const line = line_buf[0..len]; if (line.len == 0) return null; if (std.mem.startsWith(u8, line, "\n--") and std.mem.endsWith(u8, line, self.boundary)) { // match, check for end thing var more_buf: [2]u8 = undefined; if (try reader.readAll(&more_buf) != 2) return error.EndOfStream; const more = !(more_buf[0] == '-' and more_buf[1] == '-'); try self.stream.putBack(&more_buf); try reader.skipUntilDelimiterOrEof('\n'); if (more) return try Part.open(self, alloc) else return null; } } } pub const Part = struct { base: ?*Multipart, fields: http.Fields, pub fn open(base: *Multipart, alloc: std.mem.Allocator) !Part { var fields = try @import("./request/parser.zig").parseHeaders(alloc, base.stream.reader()); return .{ .base = base, .fields = fields }; } pub fn reader(self: *Part) PartReader { return .{ .context = self }; } pub fn close(self: *Part) void { self.fields.deinit(); } pub fn read(self: *Part, buf: []u8) ReaderType.Error!usize { const base = self.base orelse return 0; const r = base.stream.reader(); var count: usize = 0; while (count < buf.len) { const byte = r.readByte() catch |err| switch (err) { error.EndOfStream => { self.base = null; return count; }, else => |e| return e, }; buf[count] = byte; count += 1; if (byte != '\r') continue; var line_buf: [read_ahead]u8 = undefined; const line = line_buf[0..try r.readAll(line_buf[0 .. base.boundary.len + 3])]; if (!std.mem.startsWith(u8, line, "\n--") or !std.mem.endsWith(u8, line, base.boundary)) { base.stream.putBack(line) catch unreachable; continue; } else { base.stream.putBack(line) catch unreachable; base.stream.putBackByte('\r') catch unreachable; self.base = null; return count - 1; } } return count; } }; }; } pub fn openMultipart(boundary: []const u8, reader: anytype) !MultipartStream(@TypeOf(reader)) { if (boundary.len > max_boundary) return error.BoundaryTooLarge; var stream = .{ .stream = std.io.peekStream(read_ahead, reader), .boundary = boundary, }; stream.stream.putBack("\r\n") catch unreachable; return stream; } const ParamIter = struct { str: []const u8, index: usize = 0, const Param = struct { name: []const u8, value: []const u8, }; pub fn from(str: []const u8) ParamIter { return .{ .str = str, .index = std.mem.indexOfScalar(u8, str, ';') orelse str.len }; } pub fn fieldValue(self: *ParamIter) []const u8 { return std.mem.sliceTo(self.str, ';'); } pub fn next(self: *ParamIter) ?Param { if (self.index >= self.str.len) return null; const start = self.index + 1; const new_start = std.mem.indexOfScalarPos(u8, self.str, start, ';') orelse self.str.len; self.index = new_start; const param = std.mem.trim(u8, self.str[start..new_start], " \t"); var split = std.mem.split(u8, param, "="); const name = split.first(); const value = std.mem.trimLeft(u8, split.rest(), " \t"); // TODO: handle quoted values // TODO: handle parse errors return Param{ .name = name, .value = value, }; } }; pub fn parseFormData(comptime T: type, boundary: []const u8, reader: anytype, alloc: std.mem.Allocator) !T { var multipart = try openMultipart(boundary, reader); var ds = util.Deserializer(T){}; while (true) { var part = (try multipart.next(alloc)) orelse break; defer part.close(); const disposition = part.fields.get("Content-Disposition") orelse return error.InvalidForm; var iter = ParamIter.from(disposition); if (!std.ascii.eqlIgnoreCase("form-data", iter.fieldValue())) return error.InvalidForm; const name = while (iter.next()) |param| { if (!std.ascii.eqlIgnoreCase("name", param.name)) @panic("Not implemented"); break param.value; } else return error.InvalidForm; const value = try part.reader().readAllAlloc(alloc, 1 << 32); try ds.setSerializedField(name, value); } return try ds.finish(alloc); } // TODO: Fix these tests test "MultipartStream" { const body = util.comptimeToCrlf( \\--abcd \\Content-Disposition: form-data; name=first; charset=utf8 \\ \\content \\--abcd \\content-Disposition: form-data; name=second \\ \\no content \\--abcd \\content-disposition: form-data; name=third \\ \\ \\--abcd-- \\ ); var src = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) }; var stream = try openMultipart("abcd", src.reader()); while (try stream.next(std.testing.allocator)) |p| { var part = p; defer part.close(); std.debug.print("\n{?s}\n", .{part.fields.get("content-disposition")}); var buf: [64]u8 = undefined; std.debug.print("\"{s}\"\n", .{buf[0..try part.reader().readAll(&buf)]}); } } test "parseFormData" { const body = util.comptimeToCrlf( \\--abcd \\Content-Disposition: form-data; name=foo \\ \\content \\--abcd-- \\ ); var src = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) }; const val = try parseFormData(struct { foo: []const u8, }, "abcd", src.reader(), std.testing.allocator); std.debug.print("\n\n\n\"{any}\"\n\n\n", .{val}); }