More serialization refactor

This commit is contained in:
jaina heartles 2022-11-30 21:11:54 -08:00
parent c7dcded04a
commit 83af6a40e4
2 changed files with 116 additions and 157 deletions

View file

@ -8,8 +8,9 @@ pub const Url = @import("./Url.zig");
pub const PathIter = iters.PathIter; pub const PathIter = iters.PathIter;
pub const QueryIter = iters.QueryIter; pub const QueryIter = iters.QueryIter;
pub const SqlStmtIter = iters.Separator(';'); pub const SqlStmtIter = iters.Separator(';');
pub const deserialize = @import("./deserialize.zig"); pub const serialize = @import("./serialize.zig");
pub const Deserializer = deserialize.Deserializer; pub const Deserializer = serialize.Deserializer;
pub const DeserializerContext = serialize.DeserializerContext;
/// Joins an array of strings, prefixing every entry with `prefix`, /// Joins an array of strings, prefixing every entry with `prefix`,
/// and putting `separator` in between each pair /// and putting `separator` in between each pair

View file

@ -3,9 +3,8 @@ const util = @import("./lib.zig");
const FieldRef = []const []const u8; const FieldRef = []const []const u8;
const QueryStringOptions = struct { fn defaultIsScalar(comptime T: type) bool {
fn isScalar(comptime T: type) bool { if (comptime std.meta.trait.is(.Optional)(T) and defaultIsScalar(std.meta.Child(T))) return true;
if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true;
if (comptime std.meta.trait.isZigString(T)) return true; if (comptime std.meta.trait.isZigString(T)) return true;
if (comptime std.meta.trait.isIntegral(T)) return true; if (comptime std.meta.trait.isIntegral(T)) return true;
if (comptime std.meta.trait.isFloat(T)) return true; if (comptime std.meta.trait.isFloat(T)) return true;
@ -17,111 +16,27 @@ const QueryStringOptions = struct {
return false; return false;
} }
const embed_unions = true; pub fn deserializeString(allocator: std.mem.Allocator, comptime T: type, value: []const u8) !T {
const ParsedField = ?[]const u8; if (comptime std.meta.trait.is(.Optional)(T)) {
if (value.len == 0) return null;
fn deserializeScalar(comptime T: type, maybe_value: ParsedField) !T { return try deserializeString(allocator, std.meta.Child(T), value);
const is_optional = comptime std.meta.trait.is(.Optional)(T);
if (maybe_value) |value| {
const Eff = if (is_optional) std.meta.Child(T) else T;
// Treat all empty values as nulls if possible
if (value.len == 0 and is_optional) return null;
// TODO:
//const decoded = try decodeString(alloc, value);
const decoded = value;
if (comptime std.meta.trait.isZigString(Eff)) return decoded;
// TOOD:
//defer alloc.free(decoded);
if (comptime std.meta.trait.isIntegral(Eff)) return try std.fmt.parseInt(Eff, decoded, 0);
if (comptime std.meta.trait.isFloat(Eff)) return try std.fmt.parseFloat(Eff, decoded);
if (comptime std.meta.trait.is(.Enum)(Eff)) {
//_ = std.ascii.lowerString(decoded, decoded);
return std.meta.stringToEnum(Eff, decoded) orelse return error.InvalidEnumValue;
}
if (Eff == bool) {
//_ = std.ascii.lowerString(decoded, decoded);
return bool_map.get(decoded) orelse return error.InvalidBool;
}
if (comptime std.meta.trait.hasFn("parse")(Eff)) return try Eff.parse(decoded);
@compileError("Invalid type " ++ @typeName(T));
} else {
// Parameter is present, but no associated value
return if (is_optional)
null
else if (T == bool)
true
else
error.MissingValue;
}
}
};
const FormDataField = struct {
filename: ?[]const u8 = null,
charset: ?[]const u8 = null,
value: []const u8,
};
const FormFile = struct {
filename: []const u8,
value: []const u8,
};
const FormDataOptions = struct {
const embed_unions = true;
fn isScalar(comptime T: type) bool {
if (comptime std.meta.trait.is(.Optional)(T) and isScalar(std.meta.Child(T))) return true;
if (comptime std.meta.trait.isZigString(T)) return true;
if (comptime std.meta.trait.isIntegral(T)) return true;
if (comptime std.meta.trait.isFloat(T)) return true;
if (comptime std.meta.trait.is(.Enum)(T)) return true;
if (comptime std.meta.trait.is(.EnumLiteral)(T)) return true;
if (comptime std.meta.trait.hasFn("parse")(T)) return true;
if (T == bool) return true;
return false;
}
const ParsedField = FormDataField;
fn deserializeScalar(comptime T: type, field: FormDataField) !T {
// TODO: allocation??
if (T == FormFile) {
return FormFile{
.filename = field.filename orelse "untitled",
.data = field.value,
};
} }
const decoded = field.value; if (T == []u8 or T == []const u8) return try util.deepClone(allocator, value);
if (comptime std.meta.trait.isIntegral(T)) return try std.fmt.parseInt(T, value, 0);
if (comptime std.meta.trait.isFloat(T)) return try std.fmt.parseFloat(T, value);
if (comptime std.meta.trait.hasFn("parse")(T)) return try T.parse(value);
if (comptime std.meta.trait.isZigString(T)) return decoded; var buf: [64]u8 = undefined;
const lowered = std.ascii.lowerString(&buf, value);
// TOOD: if (T == bool) return bool_map.get(lowered) orelse return error.InvalidBool;
//defer alloc.free(decoded);
if (comptime std.meta.trait.isIntegral(T)) return try std.fmt.parseInt(T, decoded, 0);
if (comptime std.meta.trait.isFloat(T)) return try std.fmt.parseFloat(T, decoded);
if (comptime std.meta.trait.is(.Enum)(T)) { if (comptime std.meta.trait.is(.Enum)(T)) {
//_ = std.ascii.lowerString(decoded, decoded); return std.meta.stringToEnum(T, lowered) orelse return error.InvalidEnumTag;
return std.meta.stringToEnum(T, decoded) orelse return error.InvalidEnumValue;
} }
if (T == bool) {
//_ = std.ascii.lowerString(decoded, decoded);
return bool_map.get(decoded) orelse return error.InvalidBool;
}
if (comptime std.meta.trait.hasFn("parse")(T)) return try T.parse(decoded);
@compileError("Invalid type " ++ @typeName(T)); @compileError("Invalid type " ++ @typeName(T));
} }
};
fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: SerializationOptions) []const FieldRef { fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: SerializationOptions) []const FieldRef {
comptime { comptime {
@ -149,9 +64,14 @@ fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime o
} }
} }
const SerializationOptions = struct { pub const SerializationOptions = struct {
embed_unions: bool = true, embed_unions: bool,
isScalar: fn (type) bool = QueryStringOptions.isScalar, isScalar: fn (type) bool,
};
pub const default_options = SerializationOptions{
.embed_unions = true,
.isScalar = defaultIsScalar,
}; };
fn Intermediary(comptime Result: type, comptime From: type, comptime options: SerializationOptions) type { fn Intermediary(comptime Result: type, comptime From: type, comptime options: SerializationOptions) type {
@ -159,7 +79,6 @@ fn Intermediary(comptime Result: type, comptime From: type, comptime options: Se
var fields: [field_refs.len]std.builtin.Type.StructField = undefined; var fields: [field_refs.len]std.builtin.Type.StructField = undefined;
for (field_refs) |ref, i| { for (field_refs) |ref, i| {
//@compileLog(i, ref, util.comptimeJoin(".", ref));
fields[i] = .{ fields[i] = .{
.name = util.comptimeJoin(".", ref), .name = util.comptimeJoin(".", ref),
.field_type = ?From, .field_type = ?From,
@ -177,15 +96,16 @@ fn Intermediary(comptime Result: type, comptime From: type, comptime options: Se
} }); } });
} }
pub fn Deserializer(comptime Result: type, comptime From: type) type { pub fn Deserializer(comptime Result: type) type {
return DeserializerContext(Result, From, struct { return DeserializerContext(Result, []const u8, struct {
const options = SerializationOptions{}; const options = default_options;
const deserializeScalar = QueryStringOptions.deserializeScalar; fn deserializeScalar(_: @This(), alloc: std.mem.Allocator, comptime T: type, val: []const u8) !T {
const isScalar = QueryStringOptions.isScalar; return try deserializeString(alloc, T, val);
}
}); });
} }
fn DeserializerContext(comptime Result: type, comptime From: type, comptime Context: type) type { pub fn DeserializerContext(comptime Result: type, comptime From: type, comptime Context: type) type {
return struct { return struct {
const Data = Intermediary(Result, From, Context.options); const Data = Intermediary(Result, From, Context.options);
@ -206,8 +126,12 @@ fn DeserializerContext(comptime Result: type, comptime From: type, comptime Cont
return error.UnknownField; return error.UnknownField;
} }
pub fn finish(self: *@This()) !Result { pub fn finishFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void {
return (try self.deserialize(Result, &.{})) orelse error.MissingField; util.deepFree(allocator, val);
}
pub fn finish(self: *@This(), allocator: std.mem.Allocator) !Result {
return (try self.deserialize(allocator, Result, &.{})) orelse error.MissingField;
} }
fn getSerializedField(self: *@This(), comptime field_ref: FieldRef) ?From { fn getSerializedField(self: *@This(), comptime field_ref: FieldRef) ?From {
@ -215,9 +139,13 @@ fn DeserializerContext(comptime Result: type, comptime From: type, comptime Cont
return @field(self.data, util.comptimeJoin(".", field_ref)); return @field(self.data, util.comptimeJoin(".", field_ref));
} }
fn deserialize(self: *@This(), comptime T: type, comptime field_ref: FieldRef) !?T { fn deserializeFree(_: *@This(), allocator: std.mem.Allocator, val: anytype) void {
if (comptime Context.isScalar(T)) { util.deepFree(allocator, val);
return try Context.deserializeScalar(T, self.getSerializedField(field_ref) orelse return null); }
fn deserialize(self: *@This(), allocator: std.mem.Allocator, comptime T: type, comptime field_ref: FieldRef) !?T {
if (comptime Context.options.isScalar(T)) {
return try self.context.deserializeScalar(allocator, T, self.getSerializedField(field_ref) orelse return null);
} }
switch (@typeInfo(T)) { switch (@typeInfo(T)) {
@ -225,14 +153,16 @@ fn DeserializerContext(comptime Result: type, comptime From: type, comptime Cont
// in its parent container // in its parent container
.Union => |info| { .Union => |info| {
var result: ?T = null; var result: ?T = null;
errdefer if (result) |v| self.deserializeFree(allocator, v);
// TODO: errdefer cleanup // TODO: errdefer cleanup
const union_ref: FieldRef = if (Context.options.embed_unions) field_ref[0 .. field_ref.len - 1] else field_ref; const union_ref: FieldRef = if (Context.options.embed_unions) field_ref[0 .. field_ref.len - 1] else field_ref;
inline for (info.fields) |field| { inline for (info.fields) |field| {
const F = field.field_type; const F = field.field_type;
const new_field_ref = union_ref ++ &[_][]const u8{field.name}; const new_field_ref = union_ref ++ &[_][]const u8{field.name};
const maybe_value = try self.deserialize(F, new_field_ref); const maybe_value = try self.deserialize(allocator, F, new_field_ref);
if (maybe_value) |value| { if (maybe_value) |value| {
// TODO: errdefer cleanup // TODO: errdefer cleanup
errdefer self.deserializeFree(allocator, value);
if (result != null) return error.DuplicateUnionMember; if (result != null) return error.DuplicateUnionMember;
result = @unionInit(T, field.name, value); result = @unionInit(T, field.name, value);
} }
@ -245,16 +175,23 @@ fn DeserializerContext(comptime Result: type, comptime From: type, comptime Cont
var any_explicit = false; var any_explicit = false;
var any_missing = false; var any_missing = false;
inline for (info.fields) |field| { var fields_alloced = [1]bool{false} ** info.fields.len;
errdefer inline for (info.fields) |field, i| {
if (fields_alloced[i]) self.deserializeFree(allocator, @field(result, field.name));
};
inline for (info.fields) |field, i| {
const F = field.field_type; const F = field.field_type;
const new_field_ref = field_ref ++ &[_][]const u8{field.name}; const new_field_ref = field_ref ++ &[_][]const u8{field.name};
const maybe_value = try self.deserialize(F, new_field_ref); const maybe_value = try self.deserialize(allocator, F, new_field_ref);
if (maybe_value) |v| { if (maybe_value) |v| {
@field(result, field.name) = v; @field(result, field.name) = v;
fields_alloced[i] = true;
any_explicit = true; any_explicit = true;
} else if (field.default_value) |ptr| { } else if (field.default_value) |ptr| {
if (@sizeOf(F) != 0) { if (@sizeOf(F) != 0) {
@field(result, field.name) = @ptrCast(*const F, @alignCast(field.alignment, ptr)).*; const cast_ptr = @ptrCast(*const F, @alignCast(field.alignment, ptr));
@field(result, field.name) = try util.deepClone(allocator, cast_ptr.*);
fields_alloced[i] = true;
} }
} else { } else {
any_missing = true; any_missing = true;
@ -268,7 +205,7 @@ fn DeserializerContext(comptime Result: type, comptime From: type, comptime Cont
}, },
// Specifically non-scalar optionals // Specifically non-scalar optionals
.Optional => |info| return try self.deserialize(info.child, field_ref), .Optional => |info| return try self.deserialize(allocator, info.child, field_ref),
else => @compileError("Unsupported type"), else => @compileError("Unsupported type"),
} }
@ -294,19 +231,22 @@ test "Deserializer" {
// Happy case - simple // Happy case - simple
{ {
const T = struct { foo: usize, bar: bool }; const T = struct { foo: []const u8, bar: bool };
var ds = Deserializer(T, []const u8){}; var ds = Deserializer(T){};
try ds.setSerializedField("foo", "123"); try ds.setSerializedField("foo", "123");
try ds.setSerializedField("bar", "true"); try ds.setSerializedField("bar", "true");
try std.testing.expectEqual(T{ .foo = 123, .bar = true }, try ds.finish());
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val);
} }
// Returns error if nonexistent field set // Returns error if nonexistent field set
{ {
const T = struct { foo: usize, bar: bool }; const T = struct { foo: []const u8, bar: bool };
var ds = Deserializer(T, []const u8){}; var ds = Deserializer(T){};
try std.testing.expectError(error.UnknownField, ds.setSerializedField("baz", "123")); try std.testing.expectError(error.UnknownField, ds.setSerializedField("baz", "123"));
} }
@ -316,10 +256,13 @@ test "Deserializer" {
foo: struct { bar: bool, baz: bool }, foo: struct { bar: bool, baz: bool },
}; };
var ds = Deserializer(T, []const u8){}; var ds = Deserializer(T){};
try ds.setSerializedField("foo.bar", "true"); try ds.setSerializedField("foo.bar", "true");
try ds.setSerializedField("foo.baz", "true"); try ds.setSerializedField("foo.baz", "true");
try std.testing.expectEqual(T{ .foo = .{ .bar = true, .baz = true } }, try ds.finish());
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true, .baz = true } }, val);
} }
// Union embedding // Union embedding
@ -328,9 +271,12 @@ test "Deserializer" {
foo: union(enum) { bar: bool, baz: bool }, foo: union(enum) { bar: bool, baz: bool },
}; };
var ds = Deserializer(T, []const u8){}; var ds = Deserializer(T){};
try ds.setSerializedField("bar", "true"); try ds.setSerializedField("bar", "true");
try std.testing.expectEqual(T{ .foo = .{ .bar = true } }, try ds.finish());
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = true } }, val);
} }
// Returns error if multiple union fields specified // Returns error if multiple union fields specified
@ -339,27 +285,32 @@ test "Deserializer" {
foo: union(enum) { bar: bool, baz: bool }, foo: union(enum) { bar: bool, baz: bool },
}; };
var ds = Deserializer(T, []const u8){}; var ds = Deserializer(T){};
try ds.setSerializedField("bar", "true"); try ds.setSerializedField("bar", "true");
try ds.setSerializedField("baz", "true"); try ds.setSerializedField("baz", "true");
try std.testing.expectError(error.DuplicateUnionMember, ds.finish());
try std.testing.expectError(error.DuplicateUnionMember, ds.finish(std.testing.allocator));
} }
// Uses default values if fields aren't provided // Uses default values if fields aren't provided
{ {
const T = struct { foo: usize = 123, bar: bool = true }; const T = struct { foo: []const u8 = "123", bar: bool = true };
var ds = Deserializer(T, []const u8){}; var ds = Deserializer(T){};
try std.testing.expectEqual(T{ .foo = 123, .bar = true }, try ds.finish());
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = "123", .bar = true }, val);
} }
// Returns an error if fields aren't provided and no default exists // Returns an error if fields aren't provided and no default exists
{ {
const T = struct { foo: usize, bar: bool }; const T = struct { foo: []const u8, bar: bool };
var ds = Deserializer(T, []const u8){}; var ds = Deserializer(T){};
try ds.setSerializedField("foo", "123"); try ds.setSerializedField("foo", "123");
try std.testing.expectError(error.MissingField, ds.finish());
try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator));
} }
// Handles optional containers // Handles optional containers
@ -369,8 +320,11 @@ test "Deserializer" {
qux: ?union(enum) { quux: usize } = null, qux: ?union(enum) { quux: usize } = null,
}; };
var ds = Deserializer(T, []const u8){}; var ds = Deserializer(T){};
try std.testing.expectEqual(T{ .foo = null, .qux = null }, try ds.finish());
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = null, .qux = null }, val);
} }
{ {
@ -379,10 +333,13 @@ test "Deserializer" {
qux: ?union(enum) { quux: usize } = null, qux: ?union(enum) { quux: usize } = null,
}; };
var ds = Deserializer(T, []const u8){}; var ds = Deserializer(T){};
try ds.setSerializedField("foo.baz", "3"); try ds.setSerializedField("foo.baz", "3");
try ds.setSerializedField("quux", "3"); try ds.setSerializedField("quux", "3");
try std.testing.expectEqual(T{ .foo = .{ .bar = 3, .baz = 3 }, .qux = .{ .quux = 3 } }, try ds.finish());
const val = try ds.finish(std.testing.allocator);
defer ds.finishFree(std.testing.allocator, val);
try util.testing.expectDeepEqual(T{ .foo = .{ .bar = 3, .baz = 3 }, .qux = .{ .quux = 3 } }, val);
} }
{ {
@ -391,9 +348,10 @@ test "Deserializer" {
qux: ?union(enum) { quux: usize } = null, qux: ?union(enum) { quux: usize } = null,
}; };
var ds = Deserializer(T, []const u8){}; var ds = Deserializer(T){};
try ds.setSerializedField("foo.bar", "3"); try ds.setSerializedField("foo.bar", "3");
try ds.setSerializedField("quux", "3"); try ds.setSerializedField("quux", "3");
try std.testing.expectError(error.MissingField, ds.finish());
try std.testing.expectError(error.MissingField, ds.finish(std.testing.allocator));
} }
} }