From aa632ace8b506811e875ffc6ebeaa7e6ff65560b Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Wed, 30 Nov 2022 19:21:55 -0800 Subject: [PATCH] Work on deserialization refactor --- src/util/deserialize.zig | 376 +++++++++++++++++++++++++++++++++++++++ src/util/lib.zig | 2 + 2 files changed, 378 insertions(+) create mode 100644 src/util/deserialize.zig diff --git a/src/util/deserialize.zig b/src/util/deserialize.zig new file mode 100644 index 0000000..794869a --- /dev/null +++ b/src/util/deserialize.zig @@ -0,0 +1,376 @@ +const std = @import("std"); +const util = @import("./lib.zig"); + +const FieldRef = []const []const u8; + +const QueryStringOptions = struct { + fn isScalar(comptime T: type) bool { + if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true; + if (comptime std.meta.trait.isZigString(T)) return true; + if (comptime std.meta.trait.isIntegral(T)) return true; + if (comptime std.meta.trait.isFloat(T)) return true; + if (comptime std.meta.trait.is(.Enum)(T)) return true; + if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true; + if (comptime std.meta.trait.hasFn("parse")(T)) return true; + if (T == bool) return true; + + return false; + } + + const embed_unions = true; + const ParsedField = ?[]const u8; + + fn deserializeScalar(comptime T: type, maybe_value: ParsedField) !T { + const is_optional = comptime std.meta.trait.is(.Optional)(T); + if (maybe_value) |value| { + const Eff = if (is_optional) std.meta.Child(T) else T; + + // Treat all empty values as nulls if possible + if (value.len == 0 and is_optional) return null; + + // TODO: + //const decoded = try decodeString(alloc, value); + const decoded = value; + + if (comptime std.meta.trait.isZigString(Eff)) return decoded; + + // TOOD: + //defer alloc.free(decoded); + + if (comptime std.meta.trait.isIntegral(Eff)) return try std.fmt.parseInt(Eff, decoded, 0); + if (comptime std.meta.trait.isFloat(Eff)) return try std.fmt.parseFloat(Eff, decoded); + if (comptime std.meta.trait.is(.Enum)(Eff)) { + //_ = std.ascii.lowerString(decoded, decoded); + return std.meta.stringToEnum(Eff, decoded) orelse return error.InvalidEnumValue; + } + if (Eff == bool) { + //_ = std.ascii.lowerString(decoded, decoded); + return bool_map.get(decoded) orelse return error.InvalidBool; + } + if (comptime std.meta.trait.hasFn("parse")(Eff)) return try Eff.parse(decoded); + + @compileError("Invalid type " ++ @typeName(T)); + } else { + // Parameter is present, but no associated value + return if (is_optional) + null + else if (T == bool) + true + else + error.MissingValue; + } + } +}; + +const FormDataField = struct { + filename: ?[]const u8 = null, + charset: ?[]const u8 = null, + value: []const u8, +}; + +const FormFile = struct { + filename: []const u8, + value: []const u8, +}; + +const FormDataOptions = struct { + const embed_unions = true; + + fn isScalar(comptime T: type) bool { + if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true; + if (comptime std.meta.trait.isZigString(T)) return true; + if (comptime std.meta.trait.isIntegral(T)) return true; + if (comptime std.meta.trait.isFloat(T)) return true; + if (comptime std.meta.trait.is(.Enum)(T)) return true; + if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true; + if (comptime std.meta.trait.hasFn("parse")(T)) return true; + if (T == bool) return true; + + return false; + } + const ParsedField = FormDataField; + + fn deserializeScalar(comptime T: type, field: FormDataField) !T { + // TODO: allocation?? + + if (T == FormFile) { + return FormFile{ + .filename = field.filename orelse "untitled", + .data = field.value, + }; + } + + const decoded = field.value; + + if (comptime std.meta.trait.isZigString(T)) return decoded; + + // TOOD: + //defer alloc.free(decoded); + + if (comptime std.meta.trait.isIntegral(T)) return try std.fmt.parseInt(T, decoded, 0); + if (comptime std.meta.trait.isFloat(T)) return try std.fmt.parseFloat(T, decoded); + if (comptime std.meta.trait.is(.Enum)(T)) { + //_ = std.ascii.lowerString(decoded, decoded); + return std.meta.stringToEnum(T, decoded) orelse return error.InvalidEnumValue; + } + if (T == bool) { + //_ = std.ascii.lowerString(decoded, decoded); + return bool_map.get(decoded) orelse return error.InvalidBool; + } + if (comptime std.meta.trait.hasFn("parse")(T)) return try T.parse(decoded); + + @compileError("Invalid type " ++ @typeName(T)); + } +}; + +fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: SerializationOptions) []const FieldRef { + comptime { + if (std.meta.trait.is(.Union)(T) and prefix.len == 0 and options.embed_unions) { + @compileError("Cannot embed a union into nothing"); + } + + if (options.isScalar(T)) return &.{prefix}; + if (std.meta.trait.is(.Optional)(T)) return getRecursiveFieldList(std.meta.Child(T), prefix, options); + + const eff_prefix: FieldRef = if (std.meta.trait.is(.Union)(T) and options.embed_unions) + prefix[0 .. prefix.len - 1] + else + prefix; + + var fields: []const FieldRef = &.{}; + + for (std.meta.fields(T)) |f| { + const new_prefix = eff_prefix ++ &[_][]const u8{f.name}; + const F = f.field_type; + fields = fields ++ getRecursiveFieldList(F, new_prefix, options); + } + + return fields; + } +} + +const SerializationOptions = struct { + embed_unions: bool = true, + isScalar: fn (type) bool = QueryStringOptions.isScalar, +}; + +fn Intermediary(comptime Result: type, comptime From: type, comptime options: SerializationOptions) type { + const field_refs = getRecursiveFieldList(Result, &.{}, options); + + var fields: [field_refs.len]std.builtin.Type.StructField = undefined; + for (field_refs) |ref, i| { + //@compileLog(i, ref, util.comptimeJoin(".", ref)); + fields[i] = .{ + .name = util.comptimeJoin(".", ref), + .field_type = ?From, + .default_value = &@as(?From, null), + .is_comptime = false, + .alignment = @alignOf(?From), + }; + } + + return @Type(.{ .Struct = .{ + .layout = .Auto, + .fields = &fields, + .decls = &.{}, + .is_tuple = false, + } }); +} + +pub fn Deserializer(comptime Result: type, comptime From: type) type { + return DeserializerContext(Result, From, struct { + const options = SerializationOptions{}; + const deserializeScalar = QueryStringOptions.deserializeScalar; + const isScalar = QueryStringOptions.isScalar; + }); +} + +fn DeserializerContext(comptime Result: type, comptime From: type, comptime Context: type) type { + return struct { + const Data = Intermediary(Result, From, Context.options); + + data: Data = .{}, + context: Context = .{}, + + pub fn setSerializedField(self: *@This(), key: []const u8, value: From) !void { + const field = std.meta.stringToEnum(std.meta.FieldEnum(Data), key); + inline for (comptime std.meta.fieldNames(Data)) |field_name| { + @setEvalBranchQuota(10000); + const f = comptime std.meta.stringToEnum(std.meta.FieldEnum(Data), field_name); + if (field == f) { + @field(self.data, field_name) = value; + return; + } + } + + return error.UnknownField; + } + + pub fn finish(self: *@This()) !Result { + return (try self.deserialize(Result, &.{})).?; + } + + fn getSerializedField(self: *@This(), comptime field_ref: FieldRef) ?From { + //inline for (comptime std.meta.fieldNames(Data)) |f| @compileLog(f.ptr); + return @field(self.data, util.comptimeJoin(".", field_ref)); + } + + fn deserialize(self: *@This(), comptime T: type, comptime field_ref: FieldRef) !?T { + if (comptime Context.isScalar(T)) { + return try Context.deserializeScalar(T, self.getSerializedField(field_ref) orelse return null); + } + + switch (@typeInfo(T)) { + // At most one of any union field can be active at a time, and it is embedded + // in its parent container + .Union => |info| { + var result: ?T = null; + // TODO: errdefer cleanup + const union_ref: FieldRef = if (Context.options.embed_unions) field_ref[0 .. field_ref.len - 1] else field_ref; + inline for (info.fields) |field| { + const F = field.field_type; + const new_field_ref = union_ref ++ &[_][]const u8{field.name}; + const maybe_value = try self.deserialize(F, new_field_ref); + if (maybe_value) |value| { + // TODO: errdefer cleanup + if (result != null) return error.DuplicateUnionField; + result = @unionInit(T, field.name, value); + } + } + return result; + }, + + .Struct => |info| { + var result: T = undefined; + + var any_explicit = false; + var any_missing = false; + inline for (info.fields) |field| { + const F = field.field_type; + const new_field_ref = field_ref ++ &[_][]const u8{field.name}; + const maybe_value = try self.deserialize(F, new_field_ref); + if (maybe_value) |v| { + @field(result, field.name) = v; + any_explicit = true; + } else if (field.default_value) |ptr| { + if (@sizeOf(F) != 0) { + @field(result, field.name) = @ptrCast(*const F, @alignCast(field.alignment, ptr)).*; + } + } else { + any_missing = true; + std.debug.print("\nMissing field {s}\n", .{util.comptimeJoin(".", new_field_ref)}); + //return error.MissingStructField; + } + } + if (any_missing) { + return if (any_explicit) error.MissingStructField else null; + } + + return result; + }, + + // Specifically non-scalar optionals + .Optional => |info| return try self.deserialize(info.child, field_ref), + + else => @compileError("Unsupported type"), + } + } + }; +} + +const bool_map = std.ComptimeStringMap(bool, .{ + .{ "true", true }, + .{ "t", true }, + .{ "yes", true }, + .{ "y", true }, + .{ "1", true }, + + .{ "false", false }, + .{ "f", false }, + .{ "no", false }, + .{ "n", false }, + .{ "0", false }, +}); + +test { + const T = struct { + foo: usize, + bar: bool, + }; + + const TDefault = struct { + foo: usize = 0, + bar: bool = false, + }; + _ = TDefault; + + { + var ds = Deserializer(T, []const u8){}; + try ds.setSerializedField("foo", "123"); + try ds.setSerializedField("bar", "true"); + try std.testing.expectEqual(T{ .foo = 123, .bar = true }, try ds.finish()); + } + { + var ds = Deserializer(T, []const u8){}; + try ds.setSerializedField("foo", "123"); + try std.testing.expectError(error.MissingStructField, ds.finish()); + } + + //const int = Intermediary(T, QueryStringOptions){ + //.age = "123", + //.@"sub_struct.foo" = "abc", + //.@"sub_struct.bar" = "abc", + //.@"foo_union" = "abc", + //}; + //std.debug.print("{any}\n", .{int}); + //std.debug.print("{any}\n", .{deserialize(T, QueryStringOptions, &.{}, int)}); + + //const int2 = Intermediary(T, FormDataOptions){ + //.age = .{ .value = "123" }, + //.@"sub_struct.foo" = .{ .value = "123" }, + //.@"sub_struct.bar" = .{ .value = "123" }, + ////.@"foo_union" = .{ + //.value = "abc", + //}, + //}; + //std.debug.print("{any}\n", .{int2}); + // + + // const T = struct { + // age: usize, + // sub_struct: ?struct { + // x: union(enum) { + // foo: []const u8, + // bar: []const u8, + // }, + // }, + // sub_union: union(enum) { + // foo_union: []const u8, + // bar_union: []const u8, + // }, + // time: ?@import("./DateTime.zig"), + // }; + // inline for (comptime getRecursiveFieldList(T, &.{}, .{})) |list| { + // std.debug.print("{s}\n", .{util.comptimeJoin(".", list)}); + // } + // // var ds = Deserializer(struct { + // age: usize, + // sub_struct: ?struct { + // x: union(enum) { + // foo: []const u8, + // bar: []const u8, + // }, + // }, + // sub_union: union(enum) { + // foo_union: []const u8, + // bar_union: []const u8, + // }, + // time: ?@import("./DateTime.zig"), + // }, []const u8){ .context = .{} }; + //try ds.setSerializedField("age", "123"); + //try ds.setSerializedField("foo_union", "123"); + //try ds.setSerializedField("sub_struct.foo", "123"); + //try ds.setSerializedField("sub_struct.bar", "123"); + + //std.debug.print("{any}\n", .{try ds.finish()}); +} diff --git a/src/util/lib.zig b/src/util/lib.zig index 9829ae3..b022210 100644 --- a/src/util/lib.zig +++ b/src/util/lib.zig @@ -8,6 +8,8 @@ pub const Url = @import("./Url.zig"); pub const PathIter = iters.PathIter; pub const QueryIter = iters.QueryIter; pub const SqlStmtIter = iters.Separator(';'); +pub const deserialize = @import("./deserialize.zig"); +pub const Deserializer = deserialize.Deserializer; /// Joins an array of strings, prefixing every entry with `prefix`, /// and putting `separator` in between each pair