diff --git a/src/http/json.zig b/src/http/json.zig index 21474cc..ee6a852 100644 --- a/src/http/json.zig +++ b/src/http/json.zig @@ -10,10 +10,10 @@ const Token = std.json.Token; const unescapeValidString = std.json.unescapeValidString; const UnescapeValidStringError = std.json.UnescapeValidStringError; -pub fn parse(comptime T: type, body: []const u8, alloc: std.mem.Allocator) !T { +pub fn parse(comptime T: type, allow_unknown_fields: bool, body: []const u8, alloc: std.mem.Allocator) !T { var tokens = TokenStream.init(body); - const options = ParseOptions{ .allocator = alloc }; + const options = ParseOptions{ .allocator = alloc, .ignore_unknown_fields = !allow_unknown_fields }; const token = (try tokens.next()) orelse return error.UnexpectedEndOfJson; const r = try parseInternal(T, token, &tokens, options); diff --git a/src/http/middleware.zig b/src/http/middleware.zig index ce4d307..f4b4630 100644 --- a/src/http/middleware.zig +++ b/src/http/middleware.zig @@ -672,7 +672,13 @@ const BaseContentType = enum { other, }; -fn parseBodyFromRequest(comptime T: type, content_type: ?[]const u8, reader: anytype, alloc: std.mem.Allocator) !T { +fn parseBodyFromRequest( + comptime T: type, + comptime options: ParseBodyOptions, + content_type: ?[]const u8, + reader: anytype, + alloc: std.mem.Allocator, +) !T { // Use json by default for now for testing purposes const eff_type = content_type orelse "application/json"; const parser_type = matchContentType(eff_type); @@ -681,7 +687,7 @@ fn parseBodyFromRequest(comptime T: type, content_type: ?[]const u8, reader: any .octet_stream, .json => { const buf = try reader.readAllAlloc(alloc, 1 << 16); defer alloc.free(buf); - const body = try json_utils.parse(T, buf, alloc); + const body = try json_utils.parse(T, options.allow_unknown_fields, buf, alloc); defer json_utils.parseFree(body, alloc); return try util.deepClone(alloc, body); @@ -689,14 +695,14 @@ fn parseBodyFromRequest(comptime T: type, content_type: ?[]const u8, reader: any .url_encoded => { const buf = try reader.readAllAlloc(alloc, 1 << 16); defer alloc.free(buf); - return urlencode.parse(alloc, T, buf) catch |err| switch (err) { + return urlencode.parse(alloc, options.allow_unknown_fields, T, buf) catch |err| switch (err) { //error.NoQuery => error.NoBody, else => err, }; }, .multipart_formdata => { const boundary = fields.getParam(eff_type, "boundary") orelse return error.MissingBoundary; - return try @import("./multipart.zig").parseFormData(T, boundary, reader, alloc); + return try @import("./multipart.zig").parseFormData(T, options.allow_unknown_fields, boundary, reader, alloc); }, else => return error.UnsupportedMediaType, } @@ -714,12 +720,16 @@ fn matchContentType(hdr: []const u8) BaseContentType { return .other; } +pub const ParseBodyOptions = struct { + allow_unknown_fields: bool = false, +}; + /// Parses a set of body arguments from the request body based on the request's Content-Type /// header. /// /// The exact method for parsing depends partially on the Content-Type. json types are preferred /// TODO: Need tests for this, including various Content-Type values -pub fn ParseBody(comptime Body: type) type { +pub fn ParseBody(comptime Body: type, comptime options: ParseBodyOptions) type { return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { const content_type = req.headers.get("Content-Type"); @@ -731,7 +741,7 @@ pub fn ParseBody(comptime Body: type) type { } var stream = req.body orelse return error.NoBody; - const body = try parseBodyFromRequest(Body, content_type, stream.reader(), ctx.allocator); + const body = try parseBodyFromRequest(Body, options, content_type, stream.reader(), ctx.allocator); defer util.deepFree(ctx.allocator, body); return next.handle( @@ -751,7 +761,7 @@ test "parseBodyFromRequest" { const testCase = struct { fn case(content_type: []const u8, body: []const u8, expected: anytype) !void { var stream = std.io.StreamSource{ .const_buffer = std.io.fixedBufferStream(body) }; - const result = try parseBodyFromRequest(@TypeOf(expected), content_type, stream.reader(), std.testing.allocator); + const result = try parseBodyFromRequest(@TypeOf(expected), .{}, content_type, stream.reader(), std.testing.allocator); defer util.deepFree(std.testing.allocator, result); try util.testing.expectDeepEqual(expected, result); @@ -797,7 +807,7 @@ pub fn ParseQueryParams(comptime QueryParams: type) type { return struct { pub fn handle(_: @This(), req: anytype, res: anytype, ctx: anytype, next: anytype) !void { if (QueryParams == void) return next.handle(req, res, addField(ctx, "query_params", {}), {}); - const query = try urlencode.parse(ctx.allocator, QueryParams, ctx.query_string); + const query = try urlencode.parse(ctx.allocator, true, QueryParams, ctx.query_string); defer util.deepFree(ctx.allocator, query); return next.handle( diff --git a/src/http/multipart.zig b/src/http/multipart.zig index 815711d..e4ccf98 100644 --- a/src/http/multipart.zig +++ b/src/http/multipart.zig @@ -182,7 +182,7 @@ fn Deserializer(comptime Result: type) type { }); } -pub fn parseFormData(comptime T: type, boundary: []const u8, reader: anytype, alloc: std.mem.Allocator) !T { +pub fn parseFormData(comptime T: type, allow_unknown_fields: bool, boundary: []const u8, reader: anytype, alloc: std.mem.Allocator) !T { var form = openForm(try openMultipart(boundary, reader)); var ds = Deserializer(T){}; @@ -196,7 +196,13 @@ pub fn parseFormData(comptime T: type, boundary: []const u8, reader: anytype, al var part = (try form.next(alloc)) orelse break; errdefer util.deepFree(alloc, part); - try ds.setSerializedField(part.name, part); + ds.setSerializedField(part.name, part) catch |err| switch (err) { + error.UnknownField => if (allow_unknown_fields) { + util.deepFree(alloc, part); + continue; + } else return err, + else => |e| return e, + }; } return try ds.finish(alloc); diff --git a/src/http/urlencode.zig b/src/http/urlencode.zig index 3f49423..ee671b7 100644 --- a/src/http/urlencode.zig +++ b/src/http/urlencode.zig @@ -98,13 +98,17 @@ pub const Iter = struct { /// Would be used to parse a query string like /// `?foo.baz=12345` /// -pub fn parse(alloc: std.mem.Allocator, comptime T: type, query: []const u8) !T { +pub fn parse(alloc: std.mem.Allocator, allow_unknown_fields: bool, comptime T: type, query: []const u8) !T { var iter = Iter.from(query); var deserializer = Deserializer(T){}; while (iter.next()) |pair| { try deserializer.setSerializedField(pair.key, pair.value); + deserializer.setSerializedField(pair.key, pair.value) catch |err| switch (err) { + error.UnknownField => if (allow_unknown_fields) continue else return err, + else => |e| return e, + }; } return try deserializer.finish(alloc); diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 398424c..d60e0eb 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -73,6 +73,13 @@ pub fn EndpointRequest(comptime Endpoint: type) type { const Body = if (@hasDecl(Endpoint, "Body")) Endpoint.Body else void; const Query = if (@hasDecl(Endpoint, "Query")) Endpoint.Query else void; + const body_options = .{ + .allow_unknown_fields = if (@hasDecl(Endpoint, "allow_unknown_fields_in_body")) + Endpoint.allow_unknown_fields_in_body + else + false, + }; + allocator: std.mem.Allocator, method: http.Method, @@ -91,7 +98,7 @@ pub fn EndpointRequest(comptime Endpoint: type) type { const body_middleware = //if (Body == void) //mdw.injectContext(.{ .body = {} }) //else - mdw.ParseBody(Body){}; + mdw.ParseBody(Body, body_options){}; const query_middleware = //if (Query == void) //mdw.injectContext(.{ .query_params = {} })