diff --git a/src/util/lib.zig b/src/util/lib.zig index b022210..1958922 100644 --- a/src/util/lib.zig +++ b/src/util/lib.zig @@ -8,8 +8,9 @@ pub const Url = @import("./Url.zig"); pub const PathIter = iters.PathIter; pub const QueryIter = iters.QueryIter; pub const SqlStmtIter = iters.Separator(';'); -pub const deserialize = @import("./deserialize.zig"); -pub const Deserializer = deserialize.Deserializer; +pub const serialize = @import("./serialize.zig"); +pub const Deserializer = serialize.Deserializer; +pub const DeserializerContext = serialize.DeserializerContext; /// Joins an array of strings, prefixing every entry with `prefix`, /// and putting `separator` in between each pair diff --git a/src/util/deserialize.zig b/src/util/serialize.zig similarity index 52% rename from src/util/deserialize.zig rename to src/util/serialize.zig index af68b35..2c26500 100644 --- a/src/util/deserialize.zig +++ b/src/util/serialize.zig @@ -3,125 +3,40 @@ const util = @import("./lib.zig"); const FieldRef = []const []const u8; -const QueryStringOptions = struct { - fn isScalar(comptime T: type) bool { - if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true; - if (comptime std.meta.trait.isZigString(T)) return true; - if (comptime std.meta.trait.isIntegral(T)) return true; - if (comptime std.meta.trait.isFloat(T)) return true; - if (comptime std.meta.trait.is(.Enum)(T)) return true; - if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true; - if (comptime std.meta.trait.hasFn("parse")(T)) return true; - if (T == bool) return true; +fn defaultIsScalar(comptime T: type) bool { + if (comptime std.meta.trait.is(.Optional)(T) and defaultIsScalar(std.meta.Child(T))) return true; + if (comptime std.meta.trait.isZigString(T)) return true; + if (comptime std.meta.trait.isIntegral(T)) return true; + if (comptime std.meta.trait.isFloat(T)) return true; + if (comptime std.meta.trait.is(.Enum)(T)) return true; + if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true; + if (comptime std.meta.trait.hasFn("parse")(T)) return true; + if (T == bool) return true; - return false; + return false; +} + +pub fn deserializeString(allocator: std.mem.Allocator, comptime T: type, value: []const u8) !T { + if (comptime std.meta.trait.is(.Optional)(T)) { + if (value.len == 0) return null; + return try deserializeString(allocator, std.meta.Child(T), value); } - const embed_unions = true; - const ParsedField = ?[]const u8; + if (T == []u8 or T == []const u8) return try util.deepClone(allocator, value); + if (comptime std.meta.trait.isIntegral(T)) return try std.fmt.parseInt(T, value, 0); + if (comptime std.meta.trait.isFloat(T)) return try std.fmt.parseFloat(T, value); + if (comptime std.meta.trait.hasFn("parse")(T)) return try T.parse(value); - fn deserializeScalar(comptime T: type, maybe_value: ParsedField) !T { - const is_optional = comptime std.meta.trait.is(.Optional)(T); - if (maybe_value) |value| { - const Eff = if (is_optional) std.meta.Child(T) else T; + var buf: [64]u8 = undefined; + const lowered = std.ascii.lowerString(&buf, value); - // Treat all empty values as nulls if possible - if (value.len == 0 and is_optional) return null; - - // TODO: - //const decoded = try decodeString(alloc, value); - const decoded = value; - - if (comptime std.meta.trait.isZigString(Eff)) return decoded; - - // TOOD: - //defer alloc.free(decoded); - - if (comptime std.meta.trait.isIntegral(Eff)) return try std.fmt.parseInt(Eff, decoded, 0); - if (comptime std.meta.trait.isFloat(Eff)) return try std.fmt.parseFloat(Eff, decoded); - if (comptime std.meta.trait.is(.Enum)(Eff)) { - //_ = std.ascii.lowerString(decoded, decoded); - return std.meta.stringToEnum(Eff, decoded) orelse return error.InvalidEnumValue; - } - if (Eff == bool) { - //_ = std.ascii.lowerString(decoded, decoded); - return bool_map.get(decoded) orelse return error.InvalidBool; - } - if (comptime std.meta.trait.hasFn("parse")(Eff)) return try Eff.parse(decoded); - - @compileError("Invalid type " ++ @typeName(T)); - } else { - // Parameter is present, but no associated value - return if (is_optional) - null - else if (T == bool) - true - else - error.MissingValue; - } + if (T == bool) return bool_map.get(lowered) orelse return error.InvalidBool; + if (comptime std.meta.trait.is(.Enum)(T)) { + return std.meta.stringToEnum(T, lowered) orelse return error.InvalidEnumTag; } -}; -const FormDataField = struct { - filename: ?[]const u8 = null, - charset: ?[]const u8 = null, - value: []const u8, -}; - -const FormFile = struct { - filename: []const u8, - value: []const u8, -}; - -const FormDataOptions = struct { - const embed_unions = true; - - fn isScalar(comptime T: type) bool { - if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true; - if (comptime std.meta.trait.isZigString(T)) return true; - if (comptime std.meta.trait.isIntegral(T)) return true; - if (comptime std.meta.trait.isFloat(T)) return true; - if (comptime std.meta.trait.is(.Enum)(T)) return true; - if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true; - if (comptime std.meta.trait.hasFn("parse")(T)) return true; - if (T == bool) return true; - - return false; - } - const ParsedField = FormDataField; - - fn deserializeScalar(comptime T: type, field: FormDataField) !T { - // TODO: allocation?? - - if (T == FormFile) { - return FormFile{ - .filename = field.filename orelse "untitled", - .data = field.value, - }; - } - - const decoded = field.value; - - if (comptime std.meta.trait.isZigString(T)) return decoded; - - // TOOD: - //defer alloc.free(decoded); - - if (comptime std.meta.trait.isIntegral(T)) return try std.fmt.parseInt(T, decoded, 0); - if (comptime std.meta.trait.isFloat(T)) return try std.fmt.parseFloat(T, decoded); - if (comptime std.meta.trait.is(.Enum)(T)) { - //_ = std.ascii.lowerString(decoded, decoded); - return std.meta.stringToEnum(T, decoded) orelse return error.InvalidEnumValue; - } - if (T == bool) { - //_ = std.ascii.lowerString(decoded, decoded); - return bool_map.get(decoded) orelse return error.InvalidBool; - } - if (comptime std.meta.trait.hasFn("parse")(T)) return try T.parse(decoded); - - @compileError("Invalid type " ++ @typeName(T)); - } -}; + @compileError("Invalid type " ++ @typeName(T)); +} fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: SerializationOptions) []const FieldRef { comptime { @@ -149,9 +64,14 @@ fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime o } } -const SerializationOptions = struct { - embed_unions: bool = true, - isScalar: fn (type) bool = QueryStringOptions.isScalar, +pub const SerializationOptions = struct { + embed_unions: bool, + isScalar: fn (type) bool, +}; + +pub const default_options = SerializationOptions{ + .embed_unions = true, + .isScalar = defaultIsScalar, }; fn Intermediary(comptime Result: type, comptime From: type, comptime options: SerializationOptions) type { @@ -159,7 +79,6 @@ fn Intermediary(comptime Result: type, comptime From: type, comptime options: Se var fields: [field_refs.len]std.builtin.Type.StructField = undefined; for (field_refs) |ref, i| { - //@compileLog(i, ref, util.comptimeJoin(".", ref)); fields[i] = .{ .name = util.comptimeJoin(".", ref), .field_type = ?From, @@ -177,15 +96,16 @@ fn Intermediary(comptime Result: type, comptime From: type, comptime options: Se } }); } -pub fn Deserializer(comptime Result: type, comptime From: type) type { - return DeserializerContext(Result, From, struct { - const options = SerializationOptions{}; - const deserializeScalar = QueryStringOptions.deserializeScalar; - const isScalar = QueryStringOptions.isScalar; +pub fn Deserializer(comptime Result: type) type { + return DeserializerContext(Result, []const u8, struct { + const options = default_options; + fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, val: []const u8) !T { + return try deserializeString(alloc, T, val); + } }); } -fn DeserializerContext(comptime Result: type, comptime From: type, comptime Context: type) type { +pub fn DeserializerContext(comptime Result: type, comptime From: type, comptime Context: type) type { return struct { const Data = Intermediary(Result, From, Context.options); @@ -206,8 +126,12 @@ fn DeserializerContext(comptime Result: type, comptime From: type, comptime Cont return error.UnknownField; } - pub fn finish(self: *@This()) !Result { - return (try self.deserialize(Result, &.{})) orelse error.MissingField; + pub fn finishFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void { + util.deepFree(allocator, val); + } + + pub fn finish(self: *@This(), allocator: std.mem.Allocator) !Result { + return (try self.deserialize(allocator, Result, &.{})) orelse error.MissingField; } fn getSerializedField(self: *@This(), comptime field_ref: FieldRef) ?From { @@ -215,9 +139,13 @@ fn DeserializerContext(comptime Result: type, comptime From: type, comptime Cont return @field(self.data, util.comptimeJoin(".", field_ref)); } - fn deserialize(self: *@This(), comptime T: type, comptime field_ref: FieldRef) !?T { - if (comptime Context.isScalar(T)) { - return try Context.deserializeScalar(T, self.getSerializedField(field_ref) orelse return null); + fn deserializeFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void { + util.deepFree(allocator, val); + } + + fn deserialize(self: *@This(), allocator: std.mem.Allocator, comptime T: type, comptime field_ref: FieldRef) !?T { + if (comptime Context.options.isScalar(T)) { + return try self.context.deserializeScalar(allocator, T, self.getSerializedField(field_ref) orelse return null); } switch (@typeInfo(T)) { @@ -225,14 +153,16 @@ fn DeserializerContext(comptime Result: type, comptime From: type, comptime Cont // in its parent container .Union => |info| { var result: ?T = null; + errdefer if (result) |v| self.deserializeFree(allocator, v); // TODO: errdefer cleanup const union_ref: FieldRef = if (Context.options.embed_unions) field_ref[0 .. field_ref.len - 1] else field_ref; inline for (info.fields) |field| { const F = field.field_type; const new_field_ref = union_ref ++ &[_][]const u8{field.name}; - const maybe_value = try self.deserialize(F, new_field_ref); + const maybe_value = try self.deserialize(allocator, F, new_field_ref); if (maybe_value) |value| { // TODO: errdefer cleanup + errdefer self.deserializeFree(allocator, value); if (result != null) return error.DuplicateUnionMember; result = @unionInit(T, field.name, value); } @@ -245,16 +175,23 @@ fn DeserializerContext(comptime Result: type, comptime From: type, comptime Cont var any_explicit = false; var any_missing = false; - inline for (info.fields) |field| { + var fields_alloced = [1]bool{false} ** info.fields.len; + errdefer inline for (info.fields) |field, i| { + if (fields_alloced[i]) self.deserializeFree(allocator, @field(result, field.name)); + }; + inline for (info.fields) |field, i| { const F = field.field_type; const new_field_ref = field_ref ++ &[_][]const u8{field.name}; - const maybe_value = try self.deserialize(F, new_field_ref); + const maybe_value = try self.deserialize(allocator, F, new_field_ref); if (maybe_value) |v| { @field(result, field.name) = v; + fields_alloced[i] = true; any_explicit = true; } else if (field.default_value) |ptr| { if (@sizeOf(F) != 0) { - @field(result, field.name) = @ptrCast(*const F, @alignCast(field.alignment, ptr)).*; + const cast_ptr = @ptrCast(*const F, @alignCast(field.alignment, ptr)); + @field(result, field.name) = try util.deepClone(allocator, cast_ptr.*); + fields_alloced[i] = true; } } else { any_missing = true; @@ -268,7 +205,7 @@ fn DeserializerContext(comptime Result: type, comptime From: type, comptime Cont }, // Specifically non-scalar optionals - .Optional => |info| return try self.deserialize(info.child, field_ref), + .Optional => |info| return try self.deserialize(allocator, info.child, field_ref), else => @compileError("Unsupported type"), } @@ -294,19 +231,22 @@ test "Deserializer" { // Happy case - simple { - const T = struct { foo: usize, bar: bool }; + const T = struct { foo: []const u8, bar: bool }; - var ds = Deserializer(T, []const u8){}; + var ds = Deserializer(T){}; try ds.setSerializedField("foo", "123"); try ds.setSerializedField("bar", "true"); - try std.testing.expectEqual(T{ .foo = 123, .bar = true }, try ds.finish()); + + const val = try ds.finish(std.testing.allocator); + defer ds.finishFree(std.testing.allocator, val); + try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val); } // Returns error if nonexistent field set { - const T = struct { foo: usize, bar: bool }; + const T = struct { foo: []const u8, bar: bool }; - var ds = Deserializer(T, []const u8){}; + var ds = Deserializer(T){}; try std.testing.expectError(error.UnknownField, ds.setSerializedField("baz", "123")); } @@ -316,10 +256,13 @@ test "Deserializer" { foo: struct { bar: bool, baz: bool }, }; - var ds = Deserializer(T, []const u8){}; + var ds = Deserializer(T){}; try ds.setSerializedField("foo.bar", "true"); try ds.setSerializedField("foo.baz", "true"); - try std.testing.expectEqual(T{ .foo = .{ .bar = true, .baz = true } }, try ds.finish()); + + const val = try ds.finish(std.testing.allocator); + defer ds.finishFree(std.testing.allocator, val); + try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true, .baz = true } }, val); } // Union embedding @@ -328,9 +271,12 @@ test "Deserializer" { foo: union(enum) { bar: bool, baz: bool }, }; - var ds = Deserializer(T, []const u8){}; + var ds = Deserializer(T){}; try ds.setSerializedField("bar", "true"); - try std.testing.expectEqual(T{ .foo = .{ .bar = true } }, try ds.finish()); + + const val = try ds.finish(std.testing.allocator); + defer ds.finishFree(std.testing.allocator, val); + try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true } }, val); } // Returns error if multiple union fields specified @@ -339,27 +285,32 @@ test "Deserializer" { foo: union(enum) { bar: bool, baz: bool }, }; - var ds = Deserializer(T, []const u8){}; + var ds = Deserializer(T){}; try ds.setSerializedField("bar", "true"); try ds.setSerializedField("baz", "true"); - try std.testing.expectError(error.DuplicateUnionMember, ds.finish()); + + try std.testing.expectError(error.DuplicateUnionMember, ds.finish(std.testing.allocator)); } // Uses default values if fields aren't provided { - const T = struct { foo: usize = 123, bar: bool = true }; + const T = struct { foo: []const u8 = "123", bar: bool = true }; - var ds = Deserializer(T, []const u8){}; - try std.testing.expectEqual(T{ .foo = 123, .bar = true }, try ds.finish()); + var ds = Deserializer(T){}; + + const val = try ds.finish(std.testing.allocator); + defer ds.finishFree(std.testing.allocator, val); + try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val); } // Returns an error if fields aren't provided and no default exists { - const T = struct { foo: usize, bar: bool }; + const T = struct { foo: []const u8, bar: bool }; - var ds = Deserializer(T, []const u8){}; + var ds = Deserializer(T){}; try ds.setSerializedField("foo", "123"); - try std.testing.expectError(error.MissingField, ds.finish()); + + try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator)); } // Handles optional containers @@ -369,8 +320,11 @@ test "Deserializer" { qux: ?union(enum) { quux: usize } = null, }; - var ds = Deserializer(T, []const u8){}; - try std.testing.expectEqual(T{ .foo = null, .qux = null }, try ds.finish()); + var ds = Deserializer(T){}; + + const val = try ds.finish(std.testing.allocator); + defer ds.finishFree(std.testing.allocator, val); + try util.testing.expectDeepEqual(T{ .foo = null, .qux = null }, val); } { @@ -379,10 +333,13 @@ test "Deserializer" { qux: ?union(enum) { quux: usize } = null, }; - var ds = Deserializer(T, []const u8){}; + var ds = Deserializer(T){}; try ds.setSerializedField("foo.baz", "3"); try ds.setSerializedField("quux", "3"); - try std.testing.expectEqual(T{ .foo = .{ .bar = 3, .baz = 3 }, .qux = .{ .quux = 3 } }, try ds.finish()); + + const val = try ds.finish(std.testing.allocator); + defer ds.finishFree(std.testing.allocator, val); + try util.testing.expectDeepEqual(T{ .foo = .{ .bar = 3, .baz = 3 }, .qux = .{ .quux = 3 } }, val); } { @@ -391,9 +348,10 @@ test "Deserializer" { qux: ?union(enum) { quux: usize } = null, }; - var ds = Deserializer(T, []const u8){}; + var ds = Deserializer(T){}; try ds.setSerializedField("foo.bar", "3"); try ds.setSerializedField("quux", "3"); - try std.testing.expectError(error.MissingField, ds.finish()); + + try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator)); } }