fediglam/src/util/serialize.zig

387 lines
14 KiB
Zig

const std = @import("std");
const util = @import("./lib.zig");
const FieldRef = []const []const u8;
pub fn defaultIsScalar(comptime T: type) bool {
if (comptime std.meta.trait.is(.Optional)(T) and defaultIsScalar(std.meta.Child(T))) return true;
if (comptime std.meta.trait.isZigString(T)) return true;
if (comptime std.meta.trait.isIntegral(T)) return true;
if (comptime std.meta.trait.isFloat(T)) return true;
if (comptime std.meta.trait.is(.Enum)(T)) return true;
if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true;
if (comptime std.meta.trait.hasFn("parse")(T)) return true;
if (T == bool) return true;
return false;
}
pub fn deserializeString(allocator: std.mem.Allocator, comptime T: type, value: []const u8) !T {
if (comptime std.meta.trait.is(.Optional)(T)) {
if (value.len == 0) return null;
return try deserializeString(allocator, std.meta.Child(T), value);
}
if (T == []u8 or T == []const u8) return try util.deepClone(allocator, value);
if (comptime std.meta.trait.isIntegral(T)) return try std.fmt.parseInt(T, value, 0);
if (comptime std.meta.trait.isFloat(T)) return try std.fmt.parseFloat(T, value);
if (comptime std.meta.trait.hasFn("parse")(T)) return try T.parse(value);
var buf: [64]u8 = undefined;
const lowered = std.ascii.lowerString(&buf, value);
if (T == bool) return bool_map.get(lowered) orelse return error.InvalidBool;
if (comptime std.meta.trait.is(.Enum)(T)) {
return std.meta.stringToEnum(T, lowered) orelse return error.InvalidEnumTag;
}
@compileError("Invalid type " ++ @typeName(T));
}
pub fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: SerializationOptions) []const FieldRef {
comptime {
if (std.meta.trait.is(.Union)(T) and prefix.len == 0 and options.embed_unions) {
@compileError("Cannot embed a union into nothing");
}
if (options.isScalar(T)) return &.{prefix};
if (std.meta.trait.is(.Optional)(T)) return getRecursiveFieldList(std.meta.Child(T), prefix, options);
const eff_prefix: FieldRef = if (std.meta.trait.is(.Union)(T) and options.embed_unions)
prefix[0 .. prefix.len - 1]
else
prefix;
var fields: []const FieldRef = &.{};
for (std.meta.fields(T)) |f| {
const new_prefix = eff_prefix ++ &[_][]const u8{f.name};
const F = f.field_type;
fields = fields ++ getRecursiveFieldList(F, new_prefix, options);
}
return fields;
}
}
pub const SerializationOptions = struct {
embed_unions: bool,
isScalar: fn (type) bool,
};
pub const default_options = SerializationOptions{
.embed_unions = true,
.isScalar = defaultIsScalar,
};
fn Intermediary(comptime Result: type, comptime From: type, comptime options: SerializationOptions) type {
const field_refs = getRecursiveFieldList(Result, &.{}, options);
var fields: [field_refs.len]std.builtin.Type.StructField = undefined;
for (field_refs) |ref, i| {
fields[i] = .{
.name = util.comptimeJoin(".", ref),
.field_type = ?From,
.default_value = &@as(?From, null),
.is_comptime = false,
.alignment = @alignOf(?From),
};
}
return @Type(.{ .Struct = .{
.layout = .Auto,
.fields = &fields,
.decls = &.{},
.is_tuple = false,
} });
}
pub fn Deserializer(comptime Result: type) type {
return DeserializerContext(Result, []const u8, struct {
const options = default_options;
fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, val: []const u8) !T {
return try deserializeString(alloc, T, val);
}
});
}
pub fn DeserializerContext(comptime Result: type, comptime From: type, comptime Context: type) type {
return struct {
const Data = Intermediary(Result, From, Context.options);
data: Data = .{},
context: Context = .{},
pub fn setSerializedField(self: *@This(), key: []const u8, value: From) !void {
const field = std.meta.stringToEnum(std.meta.FieldEnum(Data), key) orelse return error.UnknownField;
inline for (comptime std.meta.fieldNames(Data)) |field_name| {
@setEvalBranchQuota(10000);
const f = comptime std.meta.stringToEnum(std.meta.FieldEnum(Data), field_name);
if (field == f) {
@field(self.data, field_name) = value;
return;
}
}
unreachable;
}
pub const Iter = struct {
data: *const Data,
field_index: usize,
const Item = struct {
key: []const u8,
value: From,
};
pub fn next(self: *Iter) ?Item {
while (self.field_index < std.meta.fields(Data).len) {
const idx = self.field_index;
self.field_index += 1;
inline for (comptime std.meta.fieldNames(Data)) |field, i| {
if (i == idx) {
const maybe_value = @field(self.data.*, field);
if (maybe_value) |value| return Item{ .key = field, .value = value };
}
}
}
return null;
}
};
pub fn iterator(self: *const @This()) Iter {
return .{ .data = &self.data, .field_index = 0 };
}
pub fn finishFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void {
util.deepFree(allocator, val);
}
pub fn finish(self: *@This(), allocator: std.mem.Allocator) !Result {
return (try self.deserialize(allocator, Result, &.{})) orelse error.MissingField;
}
fn getSerializedField(self: *@This(), comptime field_ref: FieldRef) ?From {
//inline for (comptime std.meta.fieldNames(Data)) |f| @compileLog(f.ptr);
return @field(self.data, util.comptimeJoin(".", field_ref));
}
fn deserializeFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void {
util.deepFree(allocator, val);
}
fn deserialize(self: *@This(), allocator: std.mem.Allocator, comptime T: type, comptime field_ref: FieldRef) !?T {
if (comptime Context.options.isScalar(T)) {
return try self.context.deserializeScalar(allocator, T, self.getSerializedField(field_ref) orelse return null);
}
switch (@typeInfo(T)) {
// At most one of any union field can be active at a time, and it is embedded
// in its parent container
.Union => |info| {
var result: ?T = null;
errdefer if (result) |v| self.deserializeFree(allocator, v);
// TODO: errdefer cleanup
const union_ref: FieldRef = if (Context.options.embed_unions) field_ref[0 .. field_ref.len - 1] else field_ref;
inline for (info.fields) |field| {
const F = field.field_type;
const new_field_ref = union_ref ++ &[_][]const u8{field.name};
const maybe_value = try self.deserialize(allocator, F, new_field_ref);
if (maybe_value) |value| {
// TODO: errdefer cleanup
errdefer self.deserializeFree(allocator, value);
if (result != null) return error.DuplicateUnionMember;
result = @unionInit(T, field.name, value);
}
}
return result;
},
.Struct => |info| {
var result: T = undefined;
var any_explicit = false;
var any_missing = false;
var fields_alloced = [1]bool{false} ** info.fields.len;
errdefer inline for (info.fields) |field, i| {
if (fields_alloced[i]) self.deserializeFree(allocator, @field(result, field.name));
};
inline for (info.fields) |field, i| {
const F = field.field_type;
const new_field_ref = field_ref ++ &[_][]const u8{field.name};
const maybe_value = try self.deserialize(allocator, F, new_field_ref);
if (maybe_value) |v| {
@field(result, field.name) = v;
fields_alloced[i] = true;
any_explicit = true;
} else if (field.default_value) |ptr| {
if (@sizeOf(F) != 0) {
const cast_ptr = @ptrCast(*const F, @alignCast(field.alignment, ptr));
@field(result, field.name) = try util.deepClone(allocator, cast_ptr.*);
fields_alloced[i] = true;
}
} else {
any_missing = true;
}
}
if (any_missing) {
return if (any_explicit) error.MissingField else null;
}
return result;
},
// Specifically non-scalar optionals
.Optional => |info| return try self.deserialize(allocator, info.child, field_ref),
else => @compileError("Unsupported type"),
}
}
};
}
const bool_map = std.ComptimeStringMap(bool, .{
.{ "true", true },
.{ "t", true },
.{ "yes", true },
.{ "y", true },
.{ "1", true },
.{ "false", false },
.{ "f", false },
.{ "no", false },
.{ "n", false },
.{ "0", false },
});
test "Deserializer" {
// Happy case - simple
{
const T = struct { foo: []const u8, bar: bool };
var ds = Deserializer(T){};
try ds.setSerializedField("foo", "123");
try ds.setSerializedField("bar", "true");
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val);
}
// Returns error if nonexistent field set
{
const T = struct { foo: []const u8, bar: bool };
var ds = Deserializer(T){};
try std.testing.expectError(error.UnknownField, ds.setSerializedField("baz", "123"));
}
// Substruct dereferencing
{
const T = struct {
foo: struct { bar: bool, baz: bool },
};
var ds = Deserializer(T){};
try ds.setSerializedField("foo.bar", "true");
try ds.setSerializedField("foo.baz", "true");
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true, .baz = true } }, val);
}
// Union embedding
{
const T = struct {
foo: union(enum) { bar: bool, baz: bool },
};
var ds = Deserializer(T){};
try ds.setSerializedField("bar", "true");
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true } }, val);
}
// Returns error if multiple union fields specified
{
const T = struct {
foo: union(enum) { bar: bool, baz: bool },
};
var ds = Deserializer(T){};
try ds.setSerializedField("bar", "true");
try ds.setSerializedField("baz", "true");
try std.testing.expectError(error.DuplicateUnionMember, ds.finish(std.testing.allocator));
}
// Uses default values if fields aren't provided
{
const T = struct { foo: []const u8 = "123", bar: bool = true };
var ds = Deserializer(T){};
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val);
}
// Returns an error if fields aren't provided and no default exists
{
const T = struct { foo: []const u8, bar: bool };
var ds = Deserializer(T){};
try ds.setSerializedField("foo", "123");
try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator));
}
// Handles optional containers
{
const T = struct {
foo: ?struct { bar: usize = 3, baz: usize } = null,
qux: ?union(enum) { quux: usize } = null,
};
var ds = Deserializer(T){};
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = null, .qux = null }, val);
}
{
const T = struct {
foo: ?struct { bar: usize = 3, baz: usize } = null,
qux: ?union(enum) { quux: usize } = null,
};
var ds = Deserializer(T){};
try ds.setSerializedField("foo.baz", "3");
try ds.setSerializedField("quux", "3");
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = 3, .baz = 3 }, .qux = .{ .quux = 3 } }, val);
}
{
const T = struct {
foo: ?struct { bar: usize = 3, baz: usize } = null,
qux: ?union(enum) { quux: usize } = null,
};
var ds = Deserializer(T){};
try ds.setSerializedField("foo.bar", "3");
try ds.setSerializedField("quux", "3");
try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator));
}
}