fediglam/src/util/serialize.zig

556 lines
21 KiB
Zig

const std = @import("std");
const util = @import("./lib.zig");
pub const FieldRef = []const []const u8;
pub fn defaultIsScalar(comptime T: type) bool {
if (comptime std.meta.trait.is(.Optional)(T) and defaultIsScalar(std.meta.Child(T))) return true;
if (comptime std.meta.trait.isZigString(T)) return true;
if (comptime std.meta.trait.isIntegral(T)) return true;
if (comptime std.meta.trait.isFloat(T)) return true;
if (comptime std.meta.trait.is(.Enum)(T)) return true;
if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true;
if (comptime std.meta.trait.hasFn("parse")(T)) return true;
if (T == bool) return true;
return false;
}
pub fn deserializeString(allocator: std.mem.Allocator, comptime T: type, value: []const u8) !T {
if (comptime std.meta.trait.is(.Optional)(T)) {
if (value.len == 0) return null;
return try deserializeString(allocator, std.meta.Child(T), value);
}
if (T == []u8 or T == []const u8) return try util.deepClone(allocator, value);
if (comptime std.meta.trait.isIntegral(T)) return try std.fmt.parseInt(T, value, 0);
if (comptime std.meta.trait.isFloat(T)) return try std.fmt.parseFloat(T, value);
if (comptime std.meta.trait.hasFn("parse")(T)) return try T.parse(value);
var buf: [64]u8 = undefined;
const lowered = std.ascii.lowerString(&buf, value);
if (T == bool) return bool_map.get(lowered) orelse return error.InvalidBool;
if (comptime std.meta.trait.is(.Enum)(T)) {
return std.meta.stringToEnum(T, lowered) orelse return error.InvalidEnumTag;
}
@compileError("Invalid type " ++ @typeName(T));
}
fn getStaticFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: SerializationOptions) []const FieldRef {
comptime {
if (options.isScalar(T)) return &.{prefix};
if (std.meta.trait.is(.Optional)(T)) return getStaticFieldList(std.meta.Child(T), prefix, options);
if (std.meta.trait.isSlice(T) and !std.meta.trait.isZigString(T)) return &.{};
var fields: []const FieldRef = &.{};
for (std.meta.fields(T)) |f| {
const new_prefix = if (std.meta.trait.is(.Union)(T)) prefix else prefix ++ &[_][]const u8{f.name};
const F = f.field_type;
fields = fields ++ getStaticFieldList(F, new_prefix, options);
}
return fields;
}
}
fn getDynamicFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: SerializationOptions) []const DynamicField {
comptime {
if (options.isScalar(T)) return &.{};
if (std.meta.trait.is(.Optional)(T)) return getDynamicFieldList(std.meta.Child(T), prefix, options);
if (std.meta.trait.isSlice(T) and !std.meta.trait.isZigString(T)) return &.{
.{ .ref = prefix, .child_type = std.meta.Child(T) },
};
var fields: []const DynamicField = &.{};
for (std.meta.fields(T)) |f| {
const new_prefix = if (std.meta.trait.is(.Union)(T)) prefix else prefix ++ &[_][]const u8{f.name};
const F = f.field_type;
fields = fields ++ getDynamicFieldList(F, new_prefix, options);
}
return fields;
}
}
const DynamicField = struct {
ref: FieldRef,
child_type: type,
};
pub const SerializationOptions = struct {
isScalar: fn (type) bool,
};
pub const default_options = SerializationOptions{
.isScalar = defaultIsScalar,
};
fn StaticIntermediary(comptime Result: type, comptime From: type, comptime options: SerializationOptions) type {
const field_refs = getStaticFieldList(Result, &.{}, options);
// avert compiler crash by having at least one field
var fields = [_]std.builtin.Type.StructField{.{
.name = "__dummy",
.default_value = &{},
.field_type = void,
.is_comptime = false,
.alignment = 0,
}} ** (field_refs.len + 1);
var count: usize = 1;
outer: for (field_refs) |ref| {
const name = util.comptimeJoin(".", ref);
for (fields[0..count]) |f| if (std.mem.eql(u8, f.name, name)) continue :outer;
fields[count] = .{
.name = name,
.field_type = ?From,
.default_value = &@as(?From, null),
.is_comptime = false,
.alignment = @alignOf(?From),
};
count += 1;
}
return @Type(.{ .Struct = .{
.layout = .Auto,
.fields = fields[0..count],
.decls = &.{},
.is_tuple = false,
} });
}
fn DynamicIntermediary(comptime Result: type, comptime From: type, comptime options: SerializationOptions) type {
const field_refs = getDynamicFieldList(Result, &.{}, options);
var fields = [_]std.builtin.Type.StructField{.{
.name = "__dummy",
.default_value = &{},
.field_type = void,
.is_comptime = false,
.alignment = 0,
}} ** (field_refs.len + 1);
var count: usize = 1;
outer: for (field_refs) |ref| {
const name = util.comptimeJoin(".", ref.ref);
for (fields[0..count]) |f| if (std.mem.eql(u8, f.name, name)) continue :outer;
const T = std.ArrayListUnmanaged(Intermediary(ref.child_type, From, options));
fields[count] = .{
.name = name,
.default_value = &T{},
.field_type = T,
.is_comptime = false,
.alignment = @alignOf(T),
};
count += 1;
}
return @Type(.{ .Struct = .{
.layout = .Auto,
.fields = fields[0..count],
.decls = &.{},
.is_tuple = false,
} });
}
const SerializationInfo = struct {
max_slice_len: usize = 16,
};
fn getSerializationInfo(
comptime info: anytype,
comptime field_name: []const u8,
comptime info_key: std.meta.FieldEnum(SerializationInfo),
) std.meta.fieldInfo(SerializationInfo, info_key).field_type {
if (@hasDecl(info, "serialization_info") and
@hasDecl(info.serialization_info, field_name) and
@hasDecl(@field(info.serialization_info, field_name), @tagName(info_key)))
{
return @field(@field(info.serialization_info, field_name), @tagName(info_key));
} else return switch (info_key) {
.max_slice_len => 16,
};
}
fn Intermediary(comptime Result: type, comptime From: type, comptime options: SerializationOptions) type {
return struct {
const StaticData = StaticIntermediary(Result, From, options);
const DynamicData = DynamicIntermediary(Result, From, options);
static: StaticData = .{},
dynamic: DynamicData = .{},
fn setSerializedField(self: *@This(), allocator: std.mem.Allocator, key: []const u8, value: From) !void {
var split = std.mem.split(u8, key, "[");
const first = split.first();
const rest = split.rest();
if (rest.len == 0) {
const field = std.meta.stringToEnum(std.meta.FieldEnum(StaticData), key) orelse return error.UnknownField;
inline for (comptime std.meta.fieldNames(StaticData)) |field_name| {
@setEvalBranchQuota(10000);
const f = comptime std.meta.stringToEnum(std.meta.FieldEnum(StaticData), field_name).?;
if (f != .__dummy and field == f) {
@field(self.static, field_name) = value;
return;
}
}
unreachable;
} else {
split = std.mem.split(u8, rest, "]");
const idx_str = split.first();
const idx = try std.fmt.parseInt(usize, idx_str, 10);
var next = split.rest();
if (next.len == 0 or next[0] != '.') return error.UnknownField;
next = next[1..];
std.log.debug("{s} {s} {s}", .{ first, idx_str, next });
const field = std.meta.stringToEnum(std.meta.FieldEnum(DynamicData), first) orelse return error.UnknownField;
inline for (comptime std.meta.fieldNames(DynamicData)) |field_name| {
@setEvalBranchQuota(10000);
const f = comptime std.meta.stringToEnum(std.meta.FieldEnum(DynamicData), field_name).?;
if (f != .__dummy and field == f) {
const limits = getSerializationInfo(Result, field_name, .max_slice_len);
if (idx >= limits) return error.SliceTooLong;
const list = &@field(self.dynamic, field_name);
while (idx >= list.items.len) {
try list.append(allocator, .{});
}
try list.items[idx].setSerializedField(allocator, next, value);
return;
}
}
}
}
};
}
pub fn Deserializer(comptime Result: type) type {
return DeserializerContext(Result, []const u8, struct {
const options = default_options;
fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, val: []const u8) !T {
return try deserializeString(alloc, T, val);
}
});
}
pub fn DeserializerContext(comptime Result: type, comptime From: type, comptime Context: type) type {
return struct {
const Data = Intermediary(Result, From, Context.options);
arena: std.heap.ArenaAllocator,
data: Data = .{},
context: Context = .{},
pub fn init(alloc: std.mem.Allocator) @This() {
return .{ .arena = std.heap.ArenaAllocator.init(alloc) };
}
pub fn deinit(self: *@This()) void {
self.arena.deinit();
}
pub fn setSerializedField(self: *@This(), key: []const u8, value: From) !void {
try self.data.setSerializedField(self.arena.allocator(), key, value);
}
pub fn finishFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void {
util.deepFree(allocator, val);
}
pub fn finish(self: *@This(), allocator: std.mem.Allocator) !Result {
return (try self.deserialize(allocator, Result, self.data, &.{}, true)) orelse
if (std.meta.fields(Result).len == 0)
return .{}
else
return error.MissingField;
}
fn getSerializedField(self: *@This(), comptime field_ref: FieldRef) ?From {
//inline for (comptime std.meta.fieldNames(Data)) |f| @compileLog(f.ptr);
return @field(self.data, util.comptimeJoin(".", field_ref));
}
fn deserializeFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void {
util.deepFree(allocator, val);
}
const DeserializeError = error{ ParseFailure, MissingField, DuplicateUnionMember, SparseSlice, OutOfMemory };
fn deserialize(
self: *@This(),
allocator: std.mem.Allocator,
comptime T: type,
intermediary: anytype,
comptime field_ref: FieldRef,
allow_default: bool,
) DeserializeError!?T {
if (comptime Context.options.isScalar(T)) {
const val = @field(intermediary.static, util.comptimeJoin(".", field_ref));
return self.context.deserializeScalar(allocator, T, val orelse return null) catch return error.ParseFailure;
}
switch (@typeInfo(T)) {
// At most one of any union field can be active at a time
.Union => |info| {
var result: ?T = null;
errdefer if (result) |v| self.deserializeFree(allocator, v);
var partial_match_found: bool = false;
inline for (info.fields) |field| {
const F = field.field_type;
const maybe_value = self.deserialize(allocator, F, intermediary, field_ref, false) catch |err| switch (err) {
error.MissingField => blk: {
partial_match_found = true;
break :blk @as(?F, null);
},
else => |e| return e,
};
if (maybe_value) |value| {
errdefer self.deserializeFree(allocator, value);
if (result != null) return error.DuplicateUnionMember;
result = @unionInit(T, field.name, value);
}
}
if (partial_match_found and result == null) return error.MissingField;
return result;
},
.Struct => |info| {
var result: T = undefined;
var any_explicit = false;
var any_missing = false;
var fields_alloced = [1]bool{false} ** info.fields.len;
errdefer inline for (info.fields) |field, i| {
if (fields_alloced[i]) self.deserializeFree(allocator, @field(result, field.name));
};
inline for (info.fields) |field, i| {
const F = field.field_type;
const new_field_ref = field_ref ++ &[_][]const u8{field.name};
const maybe_value = try self.deserialize(allocator, F, intermediary, new_field_ref, false);
if (maybe_value) |v| {
@field(result, field.name) = v;
fields_alloced[i] = true;
any_explicit = true;
} else if (field.default_value) |ptr| {
if (@sizeOf(F) != 0) {
const cast_ptr = @ptrCast(*const F, @alignCast(field.alignment, ptr));
@field(result, field.name) = try util.deepClone(allocator, cast_ptr.*);
fields_alloced[i] = true;
}
} else {
any_missing = true;
}
}
if (any_missing and any_explicit) return error.MissingField;
if (!any_explicit and !allow_default) {
inline for (info.fields) |field, i| {
if (fields_alloced[i]) self.deserializeFree(allocator, @field(result, field.name));
}
return null;
}
return result;
},
.Pointer => |info| switch (info.size) {
.Slice => {
const name = comptime util.comptimeJoin(".", field_ref);
const data = @field(self.data.dynamic, name);
const result = try allocator.alloc(info.child, data.items.len);
errdefer allocator.free(result);
var count: usize = 0;
errdefer for (result[0..count]) |res| util.deepFree(allocator, res);
for (data.items) |sub, i| {
result[i] = (try self.deserialize(allocator, info.child, sub, &.{}, false)) orelse return error.SparseSlice;
}
return result;
},
else => @compileError("Unsupported type"),
},
// Specifically non-scalar optionals
.Optional => |info| return try self.deserialize(allocator, info.child, intermediary, field_ref, allow_default),
else => @compileError("Unsupported type"),
}
}
};
}
pub const bool_map = std.ComptimeStringMap(bool, .{
.{ "true", true },
.{ "t", true },
.{ "yes", true },
.{ "y", true },
.{ "1", true },
.{ "false", false },
.{ "f", false },
.{ "no", false },
.{ "n", false },
.{ "0", false },
});
test "Deserializer" {
// Happy case - simple
{
const T = struct { foo: []const u8, bar: bool };
var ds = Deserializer(T){ .arena = std.heap.ArenaAllocator.init(std.testing.allocator) };
defer ds.deinit();
try ds.setSerializedField("foo", "123");
try ds.setSerializedField("bar", "true");
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val);
}
// Returns error if nonexistent field set
{
const T = struct { foo: []const u8, bar: bool };
var ds = Deserializer(T){ .arena = std.heap.ArenaAllocator.init(std.testing.allocator) };
defer ds.deinit();
try std.testing.expectError(error.UnknownField, ds.setSerializedField("baz", "123"));
}
// Substruct dereferencing
{
const T = struct {
foo: struct { bar: bool, baz: bool },
};
var ds = Deserializer(T){ .arena = std.heap.ArenaAllocator.init(std.testing.allocator) };
defer ds.deinit();
try ds.setSerializedField("foo.bar", "true");
try ds.setSerializedField("foo.baz", "true");
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true, .baz = true } }, val);
}
// Union behavior
{
const T = struct {
foo: union(enum) {
bar: struct {
bar: bool,
},
baz: struct {
baz: bool,
},
},
};
var ds = Deserializer(T){ .arena = std.heap.ArenaAllocator.init(std.testing.allocator) };
defer ds.deinit();
try ds.setSerializedField("foo.bar", "true");
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = .{ .bar = true } } }, val);
}
// Returns error if multiple union fields specified
{
const T = struct {
foo: union(enum) {
bar: struct {
bar: bool,
},
baz: struct {
baz: bool,
},
},
};
var ds = Deserializer(T){ .arena = std.heap.ArenaAllocator.init(std.testing.allocator) };
defer ds.deinit();
try ds.setSerializedField("foo.bar", "true");
try ds.setSerializedField("foo.baz", "true");
try std.testing.expectError(error.DuplicateUnionMember, ds.finish(std.testing.allocator));
}
// Uses default values if fields aren't provided
{
const T = struct { foo: []const u8 = "123", bar: bool = true };
var ds = Deserializer(T){ .arena = std.heap.ArenaAllocator.init(std.testing.allocator) };
defer ds.deinit();
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val);
}
// Returns an error if fields aren't provided and no default exists
{
const T = struct { foo: []const u8, bar: bool };
var ds = Deserializer(T){ .arena = std.heap.ArenaAllocator.init(std.testing.allocator) };
defer ds.deinit();
try ds.setSerializedField("foo", "123");
try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator));
}
// Handles optional containers
{
const T = struct {
foo: ?struct { bar: usize = 3, baz: usize } = null,
qux: ?union(enum) { quux: usize } = null,
};
var ds = Deserializer(T){ .arena = std.heap.ArenaAllocator.init(std.testing.allocator) };
defer ds.deinit();
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = null, .qux = null }, val);
}
{
const T = struct {
foo: ?struct { bar: usize = 3, baz: usize } = null,
qux: ?union(enum) { quux: usize } = null,
};
var ds = Deserializer(T){ .arena = std.heap.ArenaAllocator.init(std.testing.allocator) };
defer ds.deinit();
try ds.setSerializedField("foo.baz", "3");
try ds.setSerializedField("qux", "3");
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = 3, .baz = 3 }, .qux = .{ .quux = 3 } }, val);
}
{
const T = struct {
foo: ?struct { bar: usize = 3, baz: usize } = null,
qux: ?union(enum) { quux: usize } = null,
};
var ds = Deserializer(T){ .arena = std.heap.ArenaAllocator.init(std.testing.allocator) };
defer ds.deinit();
try ds.setSerializedField("foo.bar", "3");
try ds.setSerializedField("qux", "3");
try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator));
}
}