Work on deserialization refactor

This commit is contained in:
jaina heartles 2022-11-30 19:21:55 -08:00
parent 96a46a98c9
commit aa632ace8b
2 changed files with 378 additions and 0 deletions

376
src/util/deserialize.zig Normal file
View File

@ -0,0 +1,376 @@
const std = @import("std");
const util = @import("./lib.zig");
const FieldRef = []const []const u8;
const QueryStringOptions = struct {
fn isScalar(comptime T: type) bool {
if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true;
if (comptime std.meta.trait.isZigString(T)) return true;
if (comptime std.meta.trait.isIntegral(T)) return true;
if (comptime std.meta.trait.isFloat(T)) return true;
if (comptime std.meta.trait.is(.Enum)(T)) return true;
if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true;
if (comptime std.meta.trait.hasFn("parse")(T)) return true;
if (T == bool) return true;
return false;
}
const embed_unions = true;
const ParsedField = ?[]const u8;
fn deserializeScalar(comptime T: type, maybe_value: ParsedField) !T {
const is_optional = comptime std.meta.trait.is(.Optional)(T);
if (maybe_value) |value| {
const Eff = if (is_optional) std.meta.Child(T) else T;
// Treat all empty values as nulls if possible
if (value.len == 0 and is_optional) return null;
// TODO:
//const decoded = try decodeString(alloc, value);
const decoded = value;
if (comptime std.meta.trait.isZigString(Eff)) return decoded;
// TOOD:
//defer alloc.free(decoded);
if (comptime std.meta.trait.isIntegral(Eff)) return try std.fmt.parseInt(Eff, decoded, 0);
if (comptime std.meta.trait.isFloat(Eff)) return try std.fmt.parseFloat(Eff, decoded);
if (comptime std.meta.trait.is(.Enum)(Eff)) {
//_ = std.ascii.lowerString(decoded, decoded);
return std.meta.stringToEnum(Eff, decoded) orelse return error.InvalidEnumValue;
}
if (Eff == bool) {
//_ = std.ascii.lowerString(decoded, decoded);
return bool_map.get(decoded) orelse return error.InvalidBool;
}
if (comptime std.meta.trait.hasFn("parse")(Eff)) return try Eff.parse(decoded);
@compileError("Invalid type " ++ @typeName(T));
} else {
// Parameter is present, but no associated value
return if (is_optional)
null
else if (T == bool)
true
else
error.MissingValue;
}
}
};
const FormDataField = struct {
filename: ?[]const u8 = null,
charset: ?[]const u8 = null,
value: []const u8,
};
const FormFile = struct {
filename: []const u8,
value: []const u8,
};
const FormDataOptions = struct {
const embed_unions = true;
fn isScalar(comptime T: type) bool {
if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true;
if (comptime std.meta.trait.isZigString(T)) return true;
if (comptime std.meta.trait.isIntegral(T)) return true;
if (comptime std.meta.trait.isFloat(T)) return true;
if (comptime std.meta.trait.is(.Enum)(T)) return true;
if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true;
if (comptime std.meta.trait.hasFn("parse")(T)) return true;
if (T == bool) return true;
return false;
}
const ParsedField = FormDataField;
fn deserializeScalar(comptime T: type, field: FormDataField) !T {
// TODO: allocation??
if (T == FormFile) {
return FormFile{
.filename = field.filename orelse "untitled",
.data = field.value,
};
}
const decoded = field.value;
if (comptime std.meta.trait.isZigString(T)) return decoded;
// TOOD:
//defer alloc.free(decoded);
if (comptime std.meta.trait.isIntegral(T)) return try std.fmt.parseInt(T, decoded, 0);
if (comptime std.meta.trait.isFloat(T)) return try std.fmt.parseFloat(T, decoded);
if (comptime std.meta.trait.is(.Enum)(T)) {
//_ = std.ascii.lowerString(decoded, decoded);
return std.meta.stringToEnum(T, decoded) orelse return error.InvalidEnumValue;
}
if (T == bool) {
//_ = std.ascii.lowerString(decoded, decoded);
return bool_map.get(decoded) orelse return error.InvalidBool;
}
if (comptime std.meta.trait.hasFn("parse")(T)) return try T.parse(decoded);
@compileError("Invalid type " ++ @typeName(T));
}
};
fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: SerializationOptions) []const FieldRef {
comptime {
if (std.meta.trait.is(.Union)(T) and prefix.len == 0 and options.embed_unions) {
@compileError("Cannot embed a union into nothing");
}
if (options.isScalar(T)) return &.{prefix};
if (std.meta.trait.is(.Optional)(T)) return getRecursiveFieldList(std.meta.Child(T), prefix, options);
const eff_prefix: FieldRef = if (std.meta.trait.is(.Union)(T) and options.embed_unions)
prefix[0 .. prefix.len - 1]
else
prefix;
var fields: []const FieldRef = &.{};
for (std.meta.fields(T)) |f| {
const new_prefix = eff_prefix ++ &[_][]const u8{f.name};
const F = f.field_type;
fields = fields ++ getRecursiveFieldList(F, new_prefix, options);
}
return fields;
}
}
const SerializationOptions = struct {
embed_unions: bool = true,
isScalar: fn (type) bool = QueryStringOptions.isScalar,
};
fn Intermediary(comptime Result: type, comptime From: type, comptime options: SerializationOptions) type {
const field_refs = getRecursiveFieldList(Result, &.{}, options);
var fields: [field_refs.len]std.builtin.Type.StructField = undefined;
for (field_refs) |ref, i| {
//@compileLog(i, ref, util.comptimeJoin(".", ref));
fields[i] = .{
.name = util.comptimeJoin(".", ref),
.field_type = ?From,
.default_value = &@as(?From, null),
.is_comptime = false,
.alignment = @alignOf(?From),
};
}
return @Type(.{ .Struct = .{
.layout = .Auto,
.fields = &fields,
.decls = &.{},
.is_tuple = false,
} });
}
pub fn Deserializer(comptime Result: type, comptime From: type) type {
return DeserializerContext(Result, From, struct {
const options = SerializationOptions{};
const deserializeScalar = QueryStringOptions.deserializeScalar;
const isScalar = QueryStringOptions.isScalar;
});
}
fn DeserializerContext(comptime Result: type, comptime From: type, comptime Context: type) type {
return struct {
const Data = Intermediary(Result, From, Context.options);
data: Data = .{},
context: Context = .{},
pub fn setSerializedField(self: *@This(), key: []const u8, value: From) !void {
const field = std.meta.stringToEnum(std.meta.FieldEnum(Data), key);
inline for (comptime std.meta.fieldNames(Data)) |field_name| {
@setEvalBranchQuota(10000);
const f = comptime std.meta.stringToEnum(std.meta.FieldEnum(Data), field_name);
if (field == f) {
@field(self.data, field_name) = value;
return;
}
}
return error.UnknownField;
}
pub fn finish(self: *@This()) !Result {
return (try self.deserialize(Result, &.{})).?;
}
fn getSerializedField(self: *@This(), comptime field_ref: FieldRef) ?From {
//inline for (comptime std.meta.fieldNames(Data)) |f| @compileLog(f.ptr);
return @field(self.data, util.comptimeJoin(".", field_ref));
}
fn deserialize(self: *@This(), comptime T: type, comptime field_ref: FieldRef) !?T {
if (comptime Context.isScalar(T)) {
return try Context.deserializeScalar(T, self.getSerializedField(field_ref) orelse return null);
}
switch (@typeInfo(T)) {
// At most one of any union field can be active at a time, and it is embedded
// in its parent container
.Union => |info| {
var result: ?T = null;
// TODO: errdefer cleanup
const union_ref: FieldRef = if (Context.options.embed_unions) field_ref[0 .. field_ref.len - 1] else field_ref;
inline for (info.fields) |field| {
const F = field.field_type;
const new_field_ref = union_ref ++ &[_][]const u8{field.name};
const maybe_value = try self.deserialize(F, new_field_ref);
if (maybe_value) |value| {
// TODO: errdefer cleanup
if (result != null) return error.DuplicateUnionField;
result = @unionInit(T, field.name, value);
}
}
return result;
},
.Struct => |info| {
var result: T = undefined;
var any_explicit = false;
var any_missing = false;
inline for (info.fields) |field| {
const F = field.field_type;
const new_field_ref = field_ref ++ &[_][]const u8{field.name};
const maybe_value = try self.deserialize(F, new_field_ref);
if (maybe_value) |v| {
@field(result, field.name) = v;
any_explicit = true;
} else if (field.default_value) |ptr| {
if (@sizeOf(F) != 0) {
@field(result, field.name) = @ptrCast(*const F, @alignCast(field.alignment, ptr)).*;
}
} else {
any_missing = true;
std.debug.print("\nMissing field {s}\n", .{util.comptimeJoin(".", new_field_ref)});
//return error.MissingStructField;
}
}
if (any_missing) {
return if (any_explicit) error.MissingStructField else null;
}
return result;
},
// Specifically non-scalar optionals
.Optional => |info| return try self.deserialize(info.child, field_ref),
else => @compileError("Unsupported type"),
}
}
};
}
const bool_map = std.ComptimeStringMap(bool, .{
.{ "true", true },
.{ "t", true },
.{ "yes", true },
.{ "y", true },
.{ "1", true },
.{ "false", false },
.{ "f", false },
.{ "no", false },
.{ "n", false },
.{ "0", false },
});
test {
const T = struct {
foo: usize,
bar: bool,
};
const TDefault = struct {
foo: usize = 0,
bar: bool = false,
};
_ = TDefault;
{
var ds = Deserializer(T, []const u8){};
try ds.setSerializedField("foo", "123");
try ds.setSerializedField("bar", "true");
try std.testing.expectEqual(T{ .foo = 123, .bar = true }, try ds.finish());
}
{
var ds = Deserializer(T, []const u8){};
try ds.setSerializedField("foo", "123");
try std.testing.expectError(error.MissingStructField, ds.finish());
}
//const int = Intermediary(T, QueryStringOptions){
//.age = "123",
//.@"sub_struct.foo" = "abc",
//.@"sub_struct.bar" = "abc",
//.@"foo_union" = "abc",
//};
//std.debug.print("{any}\n", .{int});
//std.debug.print("{any}\n", .{deserialize(T, QueryStringOptions, &.{}, int)});
//const int2 = Intermediary(T, FormDataOptions){
//.age = .{ .value = "123" },
//.@"sub_struct.foo" = .{ .value = "123" },
//.@"sub_struct.bar" = .{ .value = "123" },
////.@"foo_union" = .{
//.value = "abc",
//},
//};
//std.debug.print("{any}\n", .{int2});
//
// const T = struct {
// age: usize,
// sub_struct: ?struct {
// x: union(enum) {
// foo: []const u8,
// bar: []const u8,
// },
// },
// sub_union: union(enum) {
// foo_union: []const u8,
// bar_union: []const u8,
// },
// time: ?@import("./DateTime.zig"),
// };
// inline for (comptime getRecursiveFieldList(T, &.{}, .{})) |list| {
// std.debug.print("{s}\n", .{util.comptimeJoin(".", list)});
// }
// // var ds = Deserializer(struct {
// age: usize,
// sub_struct: ?struct {
// x: union(enum) {
// foo: []const u8,
// bar: []const u8,
// },
// },
// sub_union: union(enum) {
// foo_union: []const u8,
// bar_union: []const u8,
// },
// time: ?@import("./DateTime.zig"),
// }, []const u8){ .context = .{} };
//try ds.setSerializedField("age", "123");
//try ds.setSerializedField("foo_union", "123");
//try ds.setSerializedField("sub_struct.foo", "123");
//try ds.setSerializedField("sub_struct.bar", "123");
//std.debug.print("{any}\n", .{try ds.finish()});
}

View File

@ -8,6 +8,8 @@ pub const Url = @import("./Url.zig");
pub const PathIter = iters.PathIter;
pub const QueryIter = iters.QueryIter;
pub const SqlStmtIter = iters.Separator(';');
pub const deserialize = @import("./deserialize.zig");
pub const Deserializer = deserialize.Deserializer;
/// Joins an array of strings, prefixing every entry with `prefix`,
/// and putting `separator` in between each pair