create vtables
This commit is contained in:
parent
c25433ed01
commit
eb0d8b4ca0
1 changed files with 81 additions and 704 deletions
785
src/sql/lib.zig
785
src/sql/lib.zig
|
@ -1,726 +1,103 @@
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const util = @import("util");
|
|
||||||
const build_options = @import("build_options");
|
|
||||||
|
|
||||||
const postgres = if (build_options.enable_postgres)
|
|
||||||
@import("./engines/postgres.zig")
|
|
||||||
else
|
|
||||||
@import("./engines/null.zig");
|
|
||||||
|
|
||||||
const sqlite = if (build_options.enable_sqlite)
|
|
||||||
@import("./engines/sqlite.zig")
|
|
||||||
else
|
|
||||||
@import("./engines/null.zig");
|
|
||||||
const common = @import("./engines/common.zig");
|
|
||||||
const Allocator = std.mem.Allocator;
|
const Allocator = std.mem.Allocator;
|
||||||
|
|
||||||
const errors = @import("./errors.zig").library_errors;
|
pub const SqlValue = union(enum) {
|
||||||
|
int: i64,
|
||||||
pub const AcquireError = OpenError || error{NoConnectionsLeft};
|
uint: u64,
|
||||||
pub const OpenError = errors.OpenError;
|
str: []const u8,
|
||||||
pub const QueryError = errors.QueryError;
|
@"null": void,
|
||||||
pub const RowError = errors.RowError;
|
float: f64,
|
||||||
pub const QueryRowError = errors.QueryRowError;
|
|
||||||
pub const BeginError = errors.BeginError;
|
|
||||||
pub const CommitError = errors.CommitError;
|
|
||||||
|
|
||||||
pub const DatabaseError = QueryError || RowError || QueryRowError || BeginError || CommitError;
|
|
||||||
|
|
||||||
pub const QueryOptions = common.QueryOptions;
|
|
||||||
|
|
||||||
pub const Engine = enum {
|
|
||||||
postgres,
|
|
||||||
sqlite,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Helper for building queries at runtime. All constituent parts of the
|
pub const QueryOptions = struct {
|
||||||
/// query should be defined at comptime, however the choice of whether
|
// If true, then it will not return an error on the SQLite backend
|
||||||
/// or not to include them can occur at runtime.
|
// if an argument passed does not map to a parameter in the query.
|
||||||
pub const QueryBuilder = struct {
|
// Has no effect on the postgres backend.
|
||||||
array: std.ArrayList(u8),
|
ignore_unused_arguments: bool = false,
|
||||||
where_clauses_appended: usize = 0,
|
|
||||||
set_statements_appended: usize = 0,
|
|
||||||
|
|
||||||
pub fn init(alloc: std.mem.Allocator) QueryBuilder {
|
|
||||||
return QueryBuilder{ .array = std.ArrayList(u8).init(alloc) };
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deinit(self: *const QueryBuilder) void {
|
|
||||||
self.array.deinit();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a chunk of sql to the query without processing
|
|
||||||
pub fn appendSlice(self: *QueryBuilder, comptime sql: []const u8) !void {
|
|
||||||
try self.array.appendSlice(sql);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a where clause to the query. Clauses are assumed to be components
|
|
||||||
/// in an overall expression in Conjunctive Normal Form (AND of OR's).
|
|
||||||
/// https://en.wikipedia.org/wiki/Conjunctive_normal_form
|
|
||||||
/// All calls to andWhere must be contiguous, that is, they cannot be
|
|
||||||
/// interspersed with calls to appendSlice
|
|
||||||
pub fn andWhere(self: *QueryBuilder, comptime clause: []const u8) !void {
|
|
||||||
if (self.where_clauses_appended == 0) {
|
|
||||||
try self.array.appendSlice("\nWHERE ");
|
|
||||||
} else {
|
|
||||||
try self.array.appendSlice(" AND ");
|
|
||||||
}
|
|
||||||
|
|
||||||
try self.array.appendSlice(clause);
|
|
||||||
self.where_clauses_appended += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set(self: *QueryBuilder, comptime col: []const u8, comptime val: []const u8) !void {
|
|
||||||
if (self.set_statements_appended == 0) {
|
|
||||||
try self.array.appendSlice("\nSET ");
|
|
||||||
} else {
|
|
||||||
try self.array.appendSlice(", ");
|
|
||||||
}
|
|
||||||
|
|
||||||
try self.array.appendSlice(col ++ " = " ++ val);
|
|
||||||
self.set_statements_appended += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn str(self: *const QueryBuilder) []const u8 {
|
|
||||||
return self.array.items;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn terminate(self: *QueryBuilder) ![:0]const u8 {
|
|
||||||
std.debug.assert(self.array.items.len != 0);
|
|
||||||
if (self.array.items[self.array.items.len - 1] != 0) try self.array.append(0);
|
|
||||||
|
|
||||||
return std.meta.assumeSentinel(self.array.items, 0);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: make this suck less
|
pub const UnexpectedError = error{Unexpected};
|
||||||
pub const Config = union(Engine) {
|
pub const ConstraintError = error{
|
||||||
postgres: struct {
|
NotNullViolation,
|
||||||
pg_conn_str: [:0]const u8,
|
ForeignKeyViolation,
|
||||||
},
|
UniqueViolation,
|
||||||
sqlite: struct {
|
CheckViolation,
|
||||||
sqlite_file_path: [:0]const u8,
|
|
||||||
sqlite_is_uri: bool = false,
|
/// Catchall for miscellaneous types of constraints
|
||||||
},
|
ConstraintViolation,
|
||||||
};
|
};
|
||||||
|
|
||||||
const RawResults = union(Engine) {
|
pub const OpenError = error{BadConnection} || UnexpectedError;
|
||||||
postgres: postgres.Results,
|
|
||||||
sqlite: sqlite.Results,
|
|
||||||
|
|
||||||
fn finish(self: RawResults) void {
|
pub const ExecError = error{
|
||||||
switch (self) {
|
Cancelled,
|
||||||
.postgres => |pg| pg.finish(),
|
BadConnection,
|
||||||
.sqlite => |lite| lite.finish(),
|
InternalException,
|
||||||
}
|
DatabaseBusy,
|
||||||
}
|
PermissionDenied,
|
||||||
|
SqlException,
|
||||||
|
|
||||||
fn columnCount(self: RawResults) !u15 {
|
/// Argument could not be marshalled for query
|
||||||
return try switch (self) {
|
BindException,
|
||||||
.postgres => |pg| pg.columnCount(),
|
|
||||||
.sqlite => |lite| lite.columnCount(),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
fn columnIndex(self: RawResults, name: []const u8) error{ NotFound, Unexpected }!u15 {
|
/// An argument was not used by the query (not checked in all DB engines)
|
||||||
return switch (self) {
|
UnusedArgument,
|
||||||
.postgres => |pg| pg.columnIndex(name),
|
|
||||||
.sqlite => |lite| lite.columnIndex(name),
|
|
||||||
} catch |err| switch (err) {
|
|
||||||
error.OutOfRange => error.Unexpected,
|
|
||||||
error.NotFound => error.NotFound,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
fn row(self: *RawResults) RowError!?Row {
|
/// Memory error when marshalling argument for query
|
||||||
return switch (self.*) {
|
OutOfMemory,
|
||||||
.postgres => |*pg| if (try pg.row()) |r| Row{ .postgres = r } else null,
|
AllocatorRequired,
|
||||||
.sqlite => |*lite| if (try lite.row()) |r| Row{ .sqlite = r } else null,
|
} || ConstraintError || UnexpectedError;
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const FieldRef = []const []const u8;
|
pub const Db = struct {
|
||||||
fn FieldPtr(comptime Ptr: type, comptime names: FieldRef) type {
|
const VTable = struct {
|
||||||
if (names.len == 0) return Ptr;
|
/// Closes the database connection
|
||||||
|
close: *const fn (ctx: *anyopaque) void,
|
||||||
|
|
||||||
const T = std.meta.Child(Ptr);
|
/// Executes a SQL query.
|
||||||
|
/// All memory allocated with this allocator must be freed before this function returns.
|
||||||
const field = for (@typeInfo(T).Struct.fields) |f| {
|
exec: *const fn (ctx: *anyopaque, sql: []const u8, args: []const SqlValue, opt: QueryOptions, allocator: Allocator) ExecError!Results,
|
||||||
if (std.mem.eql(u8, f.name, names[0])) break f;
|
|
||||||
} else @compileError("Unknown field " ++ names[0] ++ " in type " ++ @typeName(T));
|
|
||||||
|
|
||||||
return FieldPtr(*field.field_type, names[1..]);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn fieldPtr(ptr: anytype, comptime names: FieldRef) FieldPtr(@TypeOf(ptr), names) {
|
|
||||||
if (names.len == 0) return ptr;
|
|
||||||
|
|
||||||
return fieldPtr(&@field(ptr.*, names[0]), names[1..]);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn getRecursiveFieldList(comptime T: type, comptime prefix: FieldRef, comptime options: anytype) []const FieldRef {
|
|
||||||
comptime {
|
|
||||||
if (std.meta.trait.is(.Union)(T) and prefix.len == 0 and options.embed_unions) {
|
|
||||||
@compileError("Cannot embed a union into nothing");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (options.isScalar(T)) return &.{prefix};
|
|
||||||
if (std.meta.trait.is(.Optional)(T)) return getRecursiveFieldList(std.meta.Child(T), prefix, options);
|
|
||||||
|
|
||||||
const eff_prefix: FieldRef = if (std.meta.trait.is(.Union)(T) and options.embed_unions)
|
|
||||||
prefix[0 .. prefix.len - 1]
|
|
||||||
else
|
|
||||||
prefix;
|
|
||||||
|
|
||||||
var fields: []const FieldRef = &.{};
|
|
||||||
|
|
||||||
for (std.meta.fields(T)) |f| {
|
|
||||||
const new_prefix = eff_prefix ++ &[_][]const u8{f.name};
|
|
||||||
if (@hasDecl(T, "sql_serialize") and @hasDecl(T.sql_serialize, f.name) and @field(T.sql_serialize, f.name) == .json) {
|
|
||||||
fields = fields ++ &[_]FieldRef{new_prefix};
|
|
||||||
} else {
|
|
||||||
const F = f.field_type;
|
|
||||||
fields = fields ++ getRecursiveFieldList(F, new_prefix, options);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fields;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Represents a set of results.
|
|
||||||
// row() must be called until it returns null, or the query may not complete
|
|
||||||
// Must be deallocated by a call to finish()
|
|
||||||
pub fn Results(comptime T: type) type {
|
|
||||||
// would normally make this a declaration of the struct, but it causes the compiler to crash
|
|
||||||
const fields = if (T == void) .{} else getRecursiveFieldList(
|
|
||||||
T,
|
|
||||||
&.{},
|
|
||||||
util.serialize.default_options,
|
|
||||||
);
|
|
||||||
return struct {
|
|
||||||
const Self = @This();
|
|
||||||
|
|
||||||
underlying: RawResults,
|
|
||||||
column_indices: [fields.len]u15,
|
|
||||||
|
|
||||||
fn from(underlying: RawResults) QueryError!Self {
|
|
||||||
if (std.debug.runtime_safety and fields.len != underlying.columnCount() catch unreachable) {
|
|
||||||
std.log.err("Expected {} columns in result, got {}", .{ fields.len, underlying.columnCount() catch unreachable });
|
|
||||||
return error.ColumnMismatch;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Self{ .underlying = underlying, .column_indices = blk: {
|
|
||||||
var indices: [fields.len]u15 = undefined;
|
|
||||||
inline for (fields) |f, i| {
|
|
||||||
if (comptime std.meta.trait.isTuple(T)) {
|
|
||||||
indices[i] = i;
|
|
||||||
} else {
|
|
||||||
const name = util.comptimeJoin(".", f);
|
|
||||||
indices[i] =
|
|
||||||
underlying.columnIndex(name) catch {
|
|
||||||
std.log.err("Could not find column index for field {s}", .{name});
|
|
||||||
return error.ColumnMismatch;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break :blk indices;
|
|
||||||
} };
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn finish(self: Self) void {
|
|
||||||
self.underlying.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the next row of results, or null if there are no more rows.
|
|
||||||
// Caller owns all memory allocated. The entire object can be deallocated with a
|
|
||||||
// call to util.deepFree
|
|
||||||
pub fn row(self: *Self, alloc: ?Allocator) RowError!?T {
|
|
||||||
if (try self.underlying.row()) |row_val| {
|
|
||||||
var result: T = undefined;
|
|
||||||
var fields_allocated: usize = 0;
|
|
||||||
errdefer inline for (fields) |f, i| {
|
|
||||||
// Iteration bounds must be defined at comptime (inline for) but the number of fields we could
|
|
||||||
// successfully allocate is defined at runtime. So we iterate over the entire field array and
|
|
||||||
// conditionally deallocate fields in the loop.
|
|
||||||
const ptr = fieldPtr(&result, f);
|
|
||||||
if (i < fields_allocated) util.deepFree(alloc, ptr.*);
|
|
||||||
};
|
|
||||||
|
|
||||||
inline for (fields) |f, i| {
|
|
||||||
// TODO: Causes compiler segfault. why?
|
|
||||||
//const F = f.field_type;
|
|
||||||
//const F = @TypeOf(@field(result, f.name));
|
|
||||||
const F = std.meta.Child(FieldPtr(*@TypeOf(result), f));
|
|
||||||
const ptr = fieldPtr(&result, f);
|
|
||||||
const name = comptime util.comptimeJoin(".", f);
|
|
||||||
|
|
||||||
const mode = comptime if (@hasDecl(T, "sql_serialize")) blk: {
|
|
||||||
if (@hasDecl(T.sql_serialize, name)) {
|
|
||||||
break :blk @field(T.sql_serialize, name);
|
|
||||||
}
|
|
||||||
break :blk .default;
|
|
||||||
} else .default;
|
|
||||||
switch (mode) {
|
|
||||||
.default => ptr.* = row_val.get(F, self.column_indices[i], alloc) catch |err| {
|
|
||||||
std.log.err("SQL: Error getting column {s} of type {}", .{ name, F });
|
|
||||||
return err;
|
|
||||||
},
|
|
||||||
.json => {
|
|
||||||
const str = row_val.get([]const u8, self.column_indices[i], alloc) catch |err| {
|
|
||||||
std.log.err("SQL: Error getting column {s} of type {}", .{ name, F });
|
|
||||||
return err;
|
|
||||||
};
|
|
||||||
const a = alloc orelse return error.AllocatorRequired;
|
|
||||||
defer a.free(str);
|
|
||||||
|
|
||||||
var ts = std.json.TokenStream.init(str);
|
|
||||||
ptr.* = std.json.parse(F, &ts, .{ .allocator = a }) catch |err| {
|
|
||||||
std.log.err("SQL: Error parsing columns {s} of type {}: {}", .{ name, F, err });
|
|
||||||
return error.ResultTypeMismatch;
|
|
||||||
};
|
|
||||||
},
|
|
||||||
else => @compileError("unknown mode"),
|
|
||||||
}
|
|
||||||
fields_allocated += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
} else return null;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Row is invalidated by the next call to result.row()
|
|
||||||
const Row = union(Engine) {
|
|
||||||
postgres: postgres.Row,
|
|
||||||
sqlite: sqlite.Row,
|
|
||||||
|
|
||||||
// Returns a value of type T from the zero-indexed column given by idx.
|
|
||||||
// Not all types require an allocator to be present. If an allocator is needed but
|
|
||||||
// not required, it will return error.AllocatorRequired.
|
|
||||||
// The caller is responsible for deallocating T, if relevant.
|
|
||||||
fn get(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T {
|
|
||||||
if (T == void) return;
|
|
||||||
return switch (self) {
|
|
||||||
.postgres => |pg| try pg.get(T, idx, alloc),
|
|
||||||
.sqlite => |lite| try lite.get(T, idx, alloc),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const ConstraintMode = enum {
|
|
||||||
deferred,
|
|
||||||
immediate,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const ConnPool = struct {
|
|
||||||
const max_conns = 4;
|
|
||||||
const Conn = struct {
|
|
||||||
engine: union(Engine) {
|
|
||||||
postgres: postgres.Db,
|
|
||||||
sqlite: sqlite.Db,
|
|
||||||
},
|
|
||||||
in_use: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false),
|
|
||||||
current_tx_level: u8 = 0,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
config: Config,
|
vtable: VTable,
|
||||||
connections: [max_conns]Conn,
|
|
||||||
|
|
||||||
pub fn init(cfg: Config) OpenError!ConnPool {
|
|
||||||
var self = ConnPool{
|
|
||||||
.config = cfg,
|
|
||||||
.connections = undefined,
|
|
||||||
};
|
|
||||||
var count: usize = 0;
|
|
||||||
errdefer for (self.connections[0..count]) |*c| closeConn(c);
|
|
||||||
for (self.connections) |*c| {
|
|
||||||
c.* = try self.createConn();
|
|
||||||
count += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return self;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deinit(self: *ConnPool) void {
|
|
||||||
for (self.connections) |*c| closeConn(c);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn acquire(self: *ConnPool) AcquireError!Db {
|
|
||||||
for (self.connections) |*c| {
|
|
||||||
if (tryAcquire(c)) return Db{ .conn = c };
|
|
||||||
}
|
|
||||||
return error.NoConnectionsLeft;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn tryAcquire(conn: *Conn) bool {
|
|
||||||
const acquired = !conn.in_use.swap(true, .AcqRel);
|
|
||||||
if (acquired) {
|
|
||||||
if (conn.current_tx_level != 0) @panic("Transaction still open on unused db connection");
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn createConn(self: *ConnPool) OpenError!Conn {
|
|
||||||
return switch (self.config) {
|
|
||||||
.postgres => |postgres_cfg| Conn{
|
|
||||||
.engine = .{
|
|
||||||
.postgres = try postgres.Db.open(postgres_cfg.pg_conn_str),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
.sqlite => |lite_cfg| Conn{
|
|
||||||
.engine = .{
|
|
||||||
.sqlite = if (lite_cfg.sqlite_is_uri)
|
|
||||||
try sqlite.Db.openUri(lite_cfg.sqlite_file_path)
|
|
||||||
else
|
|
||||||
try sqlite.Db.open(lite_cfg.sqlite_file_path),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
fn closeConn(conn: *Conn) void {
|
|
||||||
if (conn.in_use.loadUnchecked()) @panic("DB Conn still open");
|
|
||||||
switch (conn.engine) {
|
|
||||||
.postgres => |pg| pg.close(),
|
|
||||||
.sqlite => |lite| lite.close(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Db = Tx(0);
|
pub const ColumnCountError = error{OutOfRange};
|
||||||
|
pub const ColumnIndexError = error{ NotFound, OutOfRange };
|
||||||
|
pub const ColumnIndex = u32;
|
||||||
|
|
||||||
/// When tx_level == 0, the DB is operating in "implied transaction" mode where
|
pub const RowError = error{
|
||||||
/// every command is its own transaction
|
Cancelled,
|
||||||
/// When tx_level >= 1, the DB has an explicit transaction open
|
BadConnection,
|
||||||
/// When tx_level >= 2, the DB has (tx_level - 1) levels of transaction savepoints open
|
InternalException,
|
||||||
/// (like nested transactions)
|
DatabaseBusy,
|
||||||
fn Tx(comptime tx_level: u8) type {
|
PermissionDenied,
|
||||||
return struct {
|
SqlException,
|
||||||
const Self = @This();
|
} || ConstraintError || UnexpectedError;
|
||||||
const savepoint_name = if (tx_level == 0)
|
|
||||||
@compileError("Transaction not started")
|
|
||||||
else
|
|
||||||
std.fmt.comptimePrint("save_{}", .{tx_level});
|
|
||||||
const next_savepoint_name = Tx(tx_level + 1).savepoint_name;
|
|
||||||
|
|
||||||
conn: *ConnPool.Conn,
|
pub const Results = struct {
|
||||||
|
const VTable = struct {
|
||||||
/// The type of SQL engine being used. Use of this function should be discouraged
|
columnCount: *const fn (ctx: *anyopaque) ColumnCountError!ColumnIndex,
|
||||||
pub fn sqlEngine(self: Self) Engine {
|
columnIndex: *const fn (ctx: *anyopaque) ColumnIndexError!ColumnIndex,
|
||||||
return self.conn.engine;
|
row: *const fn (ctx: *anyopaque) RowError!?Row,
|
||||||
}
|
finish: *const fn (ctx: *anyopaque) void,
|
||||||
|
|
||||||
/// Return the connection to the pool
|
|
||||||
pub fn releaseConnection(self: Self) void {
|
|
||||||
if (tx_level != 0) @compileError("close must be called on root db");
|
|
||||||
if (self.conn.current_tx_level != 0) {
|
|
||||||
std.log.warn("Database released while transaction in progress!", .{});
|
|
||||||
self.rollbackUnchecked() catch {}; // TODO: Burn database connection
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!self.conn.in_use.swap(false, .AcqRel)) @panic("Double close on db conection");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ********* Transaction management functions **********
|
|
||||||
|
|
||||||
/// Start an explicit transaction
|
|
||||||
pub fn begin(self: Self) !Tx(1) {
|
|
||||||
if (tx_level != 0) @compileError("Transaction already started");
|
|
||||||
if (self.conn.current_tx_level != 0) return error.BadTransactionState;
|
|
||||||
|
|
||||||
try self.exec("BEGIN", {}, null);
|
|
||||||
|
|
||||||
self.conn.current_tx_level = 1;
|
|
||||||
|
|
||||||
return Tx(1){ .conn = self.conn };
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a savepoint (nested transaction)
|
|
||||||
pub fn savepoint(self: Self) !Tx(tx_level + 1) {
|
|
||||||
if (tx_level == 0) @compileError("Cannot place a savepoint on an implicit transaction");
|
|
||||||
if (self.conn.current_tx_level != tx_level) return error.BadTransactionState;
|
|
||||||
|
|
||||||
try self.exec("SAVEPOINT " ++ next_savepoint_name, {}, null);
|
|
||||||
|
|
||||||
self.conn.current_tx_level = tx_level + 1;
|
|
||||||
|
|
||||||
return Tx(tx_level + 1){ .conn = self.conn };
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Commit the entire transaction
|
|
||||||
pub fn commit(self: Self) !void {
|
|
||||||
if (tx_level == 0) @compileError("Transaction not started");
|
|
||||||
if (tx_level >= 2) @compileError("Cannot commit a savepoint");
|
|
||||||
if (self.conn.current_tx_level == 0) return error.BadTransactionState;
|
|
||||||
|
|
||||||
try self.exec("COMMIT", {}, null);
|
|
||||||
|
|
||||||
self.conn.current_tx_level = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Release the current savepoint and all savepoints created from it.
|
|
||||||
pub fn release(self: Self) !void {
|
|
||||||
if (tx_level == 0) @compileError("Transaction not started");
|
|
||||||
if (tx_level == 1) @compileError("Cannot release a transaction");
|
|
||||||
if (self.conn.current_tx_level < tx_level) return error.BadTransactionState;
|
|
||||||
|
|
||||||
try self.exec("RELEASE SAVEPOINT " ++ savepoint_name, {}, null);
|
|
||||||
|
|
||||||
self.conn.current_tx_level = tx_level - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Rolls back the entire transaction
|
|
||||||
pub fn rollbackTx(self: Self) !void {
|
|
||||||
if (tx_level == 0) @compileError("Transaction not started");
|
|
||||||
if (tx_level >= 2) @compileError("Cannot rollback a transaction using a savepoint");
|
|
||||||
if (self.conn.current_tx_level == 0) return error.BadTransactionState;
|
|
||||||
|
|
||||||
try self.rollbackUnchecked();
|
|
||||||
|
|
||||||
self.conn.current_tx_level = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Attempts to roll back to a savepoint
|
|
||||||
pub fn rollbackSavepoint(self: Self) !void {
|
|
||||||
if (tx_level == 0) @compileError("Transaction not started");
|
|
||||||
if (tx_level == 1) @compileError("Cannot rollback a savepoint on the entire transaction");
|
|
||||||
if (self.conn.current_tx_level < tx_level) return error.BadTransactionState;
|
|
||||||
|
|
||||||
try self.exec("ROLLBACK TO " ++ savepoint_name, {}, null);
|
|
||||||
|
|
||||||
self.conn.current_tx_level = tx_level - 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Perform whichever rollback is appropriate for the situation
|
|
||||||
pub fn rollback(self: Self) void {
|
|
||||||
(if (tx_level < 2) self.rollbackTx() else self.rollbackSavepoint()) catch |err| {
|
|
||||||
std.log.err("Failed to rollback transaction: {}", .{err});
|
|
||||||
std.log.err("{any}", .{@errorReturnTrace()});
|
|
||||||
@panic("TODO: more gracefully handle rollback failures");
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const BeginOrSavepoint = Tx(tx_level + 1);
|
|
||||||
pub const beginOrSavepoint = if (tx_level == 0) begin else savepoint;
|
|
||||||
pub const commitOrRelease = if (tx_level < 2) commit else release;
|
|
||||||
|
|
||||||
// Allows relaxing *some* constraints for the lifetime of the transaction.
|
|
||||||
// You should generally not do this, but it's useful when bootstrapping
|
|
||||||
// the initial admin community and cluster operator user.
|
|
||||||
pub fn setConstraintMode(self: Self, mode: ConstraintMode) !void {
|
|
||||||
if (tx_level == 0) @compileError("Transaction not started");
|
|
||||||
if (tx_level >= 2) @compileError("Cannot set constraint mode on a savepoint");
|
|
||||||
switch (self.sqlEngine()) {
|
|
||||||
.sqlite => try self.exec(
|
|
||||||
switch (mode) {
|
|
||||||
.immediate => "PRAGMA defer_foreign_keys = FALSE",
|
|
||||||
.deferred => "PRAGMA defer_foreign_keys = TRUE",
|
|
||||||
},
|
|
||||||
{},
|
|
||||||
null,
|
|
||||||
),
|
|
||||||
.postgres => try self.exec(
|
|
||||||
switch (mode) {
|
|
||||||
.immediate => "SET CONSTRAINTS ALL IMMEDIATE",
|
|
||||||
.deferred => "SET CONSTRAINTS ALL DEFERRED",
|
|
||||||
},
|
|
||||||
{},
|
|
||||||
null,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ********** Query Helpers ************
|
|
||||||
|
|
||||||
/// Runs a command without returning results
|
|
||||||
pub fn exec(
|
|
||||||
self: Self,
|
|
||||||
sql: [:0]const u8,
|
|
||||||
args: anytype,
|
|
||||||
alloc: ?std.mem.Allocator,
|
|
||||||
) !void {
|
|
||||||
try self.execInternal(sql, args, .{ .allocator = alloc }, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn execWithOptions(
|
|
||||||
self: Self,
|
|
||||||
sql: [:0]const u8,
|
|
||||||
args: anytype,
|
|
||||||
options: QueryOptions,
|
|
||||||
) !void {
|
|
||||||
try self.execInternal(sql, args, options, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn queryWithOptions(
|
|
||||||
self: Self,
|
|
||||||
comptime RowType: type,
|
|
||||||
sql: [:0]const u8,
|
|
||||||
args: anytype,
|
|
||||||
options: QueryOptions,
|
|
||||||
) QueryError!Results(RowType) {
|
|
||||||
return Results(RowType).from(try self.runSql(sql, args, options, true));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn query(
|
|
||||||
self: Self,
|
|
||||||
comptime RowType: type,
|
|
||||||
sql: [:0]const u8,
|
|
||||||
args: anytype,
|
|
||||||
alloc: ?Allocator,
|
|
||||||
) QueryError!Results(RowType) {
|
|
||||||
return self.queryWithOptions(RowType, sql, args, .{ .allocator = alloc });
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Runs a query to completion and returns a row of results, unless the query
|
|
||||||
/// returned a different number of rows.
|
|
||||||
pub fn queryRow(
|
|
||||||
self: Self,
|
|
||||||
comptime RowType: type,
|
|
||||||
q: [:0]const u8,
|
|
||||||
args: anytype,
|
|
||||||
alloc: ?Allocator,
|
|
||||||
) QueryRowError!RowType {
|
|
||||||
var results = try self.query(RowType, q, args, alloc);
|
|
||||||
defer results.finish();
|
|
||||||
|
|
||||||
const row = (try results.row(alloc)) orelse return error.NoRows;
|
|
||||||
errdefer util.deepFree(alloc, row);
|
|
||||||
|
|
||||||
// execute query to completion
|
|
||||||
var more_rows = false;
|
|
||||||
while (try results.row(alloc)) |r| {
|
|
||||||
util.deepFree(alloc, r);
|
|
||||||
more_rows = true;
|
|
||||||
}
|
|
||||||
if (more_rows) return error.TooManyRows;
|
|
||||||
|
|
||||||
return row;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn queryRows(
|
|
||||||
self: Self,
|
|
||||||
comptime RowType: type,
|
|
||||||
q: [:0]const u8,
|
|
||||||
args: anytype,
|
|
||||||
max_items: ?usize,
|
|
||||||
alloc: std.mem.Allocator,
|
|
||||||
) QueryRowError![]RowType {
|
|
||||||
return try self.queryRowsWithOptions(RowType, q, args, max_items, .{ .allocator = alloc });
|
|
||||||
}
|
|
||||||
|
|
||||||
// Runs a query to completion and returns the results as a slice
|
|
||||||
pub fn queryRowsWithOptions(
|
|
||||||
self: Self,
|
|
||||||
comptime RowType: type,
|
|
||||||
q: [:0]const u8,
|
|
||||||
args: anytype,
|
|
||||||
max_items: ?usize,
|
|
||||||
options: QueryOptions,
|
|
||||||
) QueryRowError![]RowType {
|
|
||||||
var results = try self.queryWithOptions(RowType, q, args, options);
|
|
||||||
defer results.finish();
|
|
||||||
|
|
||||||
const alloc = options.allocator orelse return error.AllocatorRequired;
|
|
||||||
|
|
||||||
var result_array = std.ArrayList(RowType).init(alloc);
|
|
||||||
errdefer result_array.deinit();
|
|
||||||
if (max_items) |max| try result_array.ensureTotalCapacity(max);
|
|
||||||
|
|
||||||
errdefer for (result_array.items) |r| util.deepFree(alloc, r);
|
|
||||||
|
|
||||||
var too_many: bool = false;
|
|
||||||
while (try results.row(alloc)) |row| {
|
|
||||||
errdefer util.deepFree(alloc, row);
|
|
||||||
if (max_items) |max| {
|
|
||||||
if (result_array.items.len >= max) {
|
|
||||||
util.deepFree(alloc, row);
|
|
||||||
too_many = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
try result_array.append(row);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (too_many) return error.TooManyRows;
|
|
||||||
|
|
||||||
return result_array.toOwnedSlice();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Inserts a single value into a table
|
|
||||||
pub fn insert(
|
|
||||||
self: Self,
|
|
||||||
comptime table: []const u8,
|
|
||||||
value: anytype,
|
|
||||||
alloc: ?std.mem.Allocator,
|
|
||||||
) !void {
|
|
||||||
const ValueType = comptime @TypeOf(value);
|
|
||||||
|
|
||||||
const fields = std.meta.fields(ValueType);
|
|
||||||
comptime var types: [fields.len]type = undefined;
|
|
||||||
comptime var table_spec: []const u8 = table ++ "(";
|
|
||||||
comptime var value_spec: []const u8 = "(";
|
|
||||||
inline for (fields) |field, i| {
|
|
||||||
// This causes a compiler crash. Why?
|
|
||||||
//const F = field.field_type;
|
|
||||||
const F = @TypeOf(@field(value, field.name));
|
|
||||||
// causes issues if F is @TypeOf(null), use dummy type
|
|
||||||
types[i] = if (F == @TypeOf(null)) ?i64 else F;
|
|
||||||
table_spec = comptime (table_spec ++ field.name ++ ",");
|
|
||||||
value_spec = comptime value_spec ++ std.fmt.comptimePrint("${},", .{i + 1});
|
|
||||||
}
|
|
||||||
table_spec = comptime table_spec[0 .. table_spec.len - 1] ++ ")";
|
|
||||||
value_spec = comptime value_spec[0 .. value_spec.len - 1] ++ ")";
|
|
||||||
const q = comptime std.fmt.comptimePrint(
|
|
||||||
"INSERT INTO {s} VALUES {s}",
|
|
||||||
.{ table_spec, value_spec },
|
|
||||||
);
|
|
||||||
|
|
||||||
var args_tuple: std.meta.Tuple(&types) = undefined;
|
|
||||||
inline for (fields) |field, i| {
|
|
||||||
args_tuple[i] = @field(value, field.name);
|
|
||||||
}
|
|
||||||
try self.exec(q, args_tuple, alloc);
|
|
||||||
}
|
|
||||||
|
|
||||||
// internal helpers
|
|
||||||
|
|
||||||
fn runSql(
|
|
||||||
self: Self,
|
|
||||||
sql: [:0]const u8,
|
|
||||||
args: anytype,
|
|
||||||
opt: QueryOptions,
|
|
||||||
comptime check_tx: bool,
|
|
||||||
) !RawResults {
|
|
||||||
if (check_tx and self.conn.current_tx_level != tx_level) return error.BadTransactionState;
|
|
||||||
|
|
||||||
return switch (self.conn.engine) {
|
|
||||||
.postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt) },
|
|
||||||
.sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args, opt) },
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
fn execInternal(
|
|
||||||
self: Self,
|
|
||||||
sql: [:0]const u8,
|
|
||||||
args: anytype,
|
|
||||||
options: QueryOptions,
|
|
||||||
comptime check_tx: bool,
|
|
||||||
) !void {
|
|
||||||
var results = try self.runSql(sql, args, options, check_tx);
|
|
||||||
defer results.finish();
|
|
||||||
|
|
||||||
while (try results.row()) |_| {}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rollbackUnchecked(self: Self) !void {
|
|
||||||
try self.execInternal("ROLLBACK", {}, .{}, false);
|
|
||||||
self.conn.current_tx_level = 0;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
}
|
|
||||||
|
vtable: VTable,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const GetError = error{
|
||||||
|
TypeMismatch,
|
||||||
|
InvalidIndex,
|
||||||
|
} || UnexpectedError;
|
||||||
|
|
||||||
|
pub const Row = struct {
|
||||||
|
const VTable = struct {
|
||||||
|
isNull: *const fn (ctx: *anyopaque, idx: ColumnIndex) GetError!bool,
|
||||||
|
getStr: *const fn (ctx: *anyopaque, idx: ColumnIndex) GetError![]const u8,
|
||||||
|
getInt: *const fn (ctx: *anyopaque, idx: ColumnIndex) GetError!i64,
|
||||||
|
getUint: *const fn (ctx: *anyopaque, idx: ColumnIndex) GetError!u64,
|
||||||
|
getFloat: *const fn (ctx: *anyopaque, idx: ColumnIndex) GetError!f64,
|
||||||
|
};
|
||||||
|
|
||||||
|
vtable: VTable,
|
||||||
|
};
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue