fediglam/src/http/multipart.zig

212 lines
7.1 KiB
Zig

const std = @import("std");
const util = @import("util");
const http = @import("./lib.zig");
const max_boundary = 70;
const read_ahead = max_boundary + 4;
pub fn MultipartStream(comptime ReaderType: type) type {
return struct {
const Multipart = @This();
pub const PartReader = std.io.Reader(*Part, ReaderType.Error, Part.read);
stream: std.io.PeekStream(.{ .Static = read_ahead }, ReaderType),
boundary: []const u8,
pub fn next(self: *Multipart, alloc: std.mem.Allocator) !?Part {
const reader = self.stream.reader();
while (true) {
try reader.skipUntilDelimiterOrEof('\r');
var line_buf: [read_ahead]u8 = undefined;
const len = try reader.readAll(line_buf[0 .. self.boundary.len + 3]);
const line = line_buf[0..len];
if (line.len == 0) return null;
if (std.mem.startsWith(u8, line, "\n--") and std.mem.endsWith(u8, line, self.boundary)) {
// match, check for end thing
var more_buf: [2]u8 = undefined;
if (try reader.readAll(&more_buf) != 2) return error.EndOfStream;
const more = !(more_buf[0] == '-' and more_buf[1] == '-');
try self.stream.putBack(&more_buf);
try reader.skipUntilDelimiterOrEof('\n');
if (more) return try Part.open(self, alloc) else return null;
}
}
}
pub const Part = struct {
base: ?*Multipart,
fields: http.Fields,
pub fn open(base: *Multipart, alloc: std.mem.Allocator) !Part {
var fields = try @import("./request/parser.zig").parseHeaders(alloc, base.stream.reader());
return .{ .base = base, .fields = fields };
}
pub fn reader(self: *Part) PartReader {
return .{ .context = self };
}
pub fn close(self: *Part) void {
self.fields.deinit();
}
pub fn read(self: *Part, buf: []u8) ReaderType.Error!usize {
const base = self.base orelse return 0;
const r = base.stream.reader();
var count: usize = 0;
while (count < buf.len) {
const byte = r.readByte() catch |err| switch (err) {
error.EndOfStream => {
self.base = null;
return count;
},
else => |e| return e,
};
buf[count] = byte;
count += 1;
if (byte != '\r') continue;
var line_buf: [read_ahead]u8 = undefined;
const line = line_buf[0..try r.readAll(line_buf[0 .. base.boundary.len + 3])];
if (!std.mem.startsWith(u8, line, "\n--") or !std.mem.endsWith(u8, line, base.boundary)) {
base.stream.putBack(line) catch unreachable;
continue;
} else {
base.stream.putBack(line) catch unreachable;
base.stream.putBackByte('\r') catch unreachable;
self.base = null;
return count - 1;
}
}
return count;
}
};
};
}
pub fn openMultipart(boundary: []const u8, reader: anytype) !MultipartStream(@TypeOf(reader)) {
if (boundary.len > max_boundary) return error.BoundaryTooLarge;
var stream = .{
.stream = std.io.peekStream(read_ahead, reader),
.boundary = boundary,
};
stream.stream.putBack("\r\n") catch unreachable;
return stream;
}
const ParamIter = struct {
str: []const u8,
index: usize = 0,
const Param = struct {
name: []const u8,
value: []const u8,
};
pub fn from(str: []const u8) ParamIter {
return .{ .str = str, .index = std.mem.indexOfScalar(u8, str, ';') orelse str.len };
}
pub fn fieldValue(self: *ParamIter) []const u8 {
return std.mem.sliceTo(self.str, ';');
}
pub fn next(self: *ParamIter) ?Param {
if (self.index >= self.str.len) return null;
const start = self.index + 1;
const new_start = std.mem.indexOfScalarPos(u8, self.str, start, ';') orelse self.str.len;
self.index = new_start;
const param = std.mem.trim(u8, self.str[start..new_start], " \t");
var split = std.mem.split(u8, param, "=");
const name = split.first();
const value = std.mem.trimLeft(u8, split.rest(), " \t");
// TODO: handle quoted values
// TODO: handle parse errors
return Param{
.name = name,
.value = value,
};
}
};
pub fn parseFormData(comptime T: type, boundary: []const u8, reader: anytype, alloc: std.mem.Allocator) !T {
var multipart = try openMultipart(boundary, reader);
var ds = util.Deserializer(T){};
while (true) {
var part = (try multipart.next(alloc)) orelse break;
defer part.close();
const disposition = part.fields.get("Content-Disposition") orelse return error.InvalidForm;
var iter = ParamIter.from(disposition);
if (!std.ascii.eqlIgnoreCase("form-data", iter.fieldValue())) return error.InvalidForm;
const name = while (iter.next()) |param| {
if (!std.ascii.eqlIgnoreCase("name", param.name)) @panic("Not implemented");
break param.value;
} else return error.InvalidForm;
const value = try part.reader().readAllAlloc(alloc, 1 << 32);
try ds.setSerializedField(name, value);
}
return try ds.finish(alloc);
}
// TODO: Fix these tests
test "MultipartStream" {
const body = util.comptimeToCrlf(
\\--abcd
\\Content-Disposition: form-data; name=first; charset=utf8
\\
\\content
\\--abcd
\\content-Disposition: form-data; name=second
\\
\\no content
\\--abcd
\\content-disposition: form-data; name=third
\\
\\
\\--abcd--
\\
);
var src = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) };
var stream = try openMultipart("abcd", src.reader());
while (try stream.next(std.testing.allocator)) |p| {
var part = p;
defer part.close();
std.debug.print("\n{?s}\n", .{part.fields.get("content-disposition")});
var buf: [64]u8 = undefined;
std.debug.print("\"{s}\"\n", .{buf[0..try part.reader().readAll(&buf)]});
}
}
test "parseFormData" {
const body = util.comptimeToCrlf(
\\--abcd
\\Content-Disposition: form-data; name=foo
\\
\\content
\\--abcd--
\\
);
var src = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) };
const val = try parseFormData(struct {
foo: []const u8,
}, "abcd", src.reader(), std.testing.allocator);
std.debug.print("\n\n\n\"{any}\"\n\n\n", .{val});
}