Add nested transaction logic

This commit is contained in:
jaina heartles 2022-10-03 22:41:22 -07:00
parent 955df7b044
commit 753ae2729e
5 changed files with 287 additions and 345 deletions

View File

@ -88,7 +88,7 @@ pub fn getRandom() std.rand.Random {
return prng.random();
}
pub fn isAdminSetup(db: *sql.Db) !bool {
pub fn isAdminSetup(db: sql.Db) !bool {
_ = services.communities.adminCommunityId(db) catch |err| switch (err) {
error.NotFound => return false,
else => return err,
@ -97,7 +97,7 @@ pub fn isAdminSetup(db: *sql.Db) !bool {
return true;
}
pub fn setupAdmin(db: *sql.Db, origin: []const u8, username: []const u8, password: []const u8, allocator: std.mem.Allocator) anyerror!void {
pub fn setupAdmin(db: sql.Db, origin: []const u8, username: []const u8, password: []const u8, allocator: std.mem.Allocator) anyerror!void {
const tx = try db.begin();
errdefer tx.rollback();
var arena = std.heap.ArenaAllocator.init(allocator);
@ -125,17 +125,17 @@ pub fn setupAdmin(db: *sql.Db, origin: []const u8, username: []const u8, passwor
}
pub const ApiSource = struct {
db: *sql.Db,
db_conn: *sql.Conn,
internal_alloc: std.mem.Allocator,
config: Config,
pub const Conn = ApiConn(*sql.Db);
pub const Conn = ApiConn(sql.Db);
const root_username = "root";
pub fn init(alloc: std.mem.Allocator, cfg: Config, db_conn: *sql.Db) !ApiSource {
pub fn init(alloc: std.mem.Allocator, cfg: Config, db_conn: *sql.Conn) !ApiSource {
return ApiSource{
.db = db_conn,
.db_conn = db_conn,
.internal_alloc = alloc,
.config = cfg,
};
@ -145,10 +145,11 @@ pub const ApiSource = struct {
var arena = std.heap.ArenaAllocator.init(alloc);
errdefer arena.deinit();
const community = try services.communities.getByHost(self.db, host, arena.allocator());
const db = try self.db_conn.acquire();
const community = try services.communities.getByHost(db, host, arena.allocator());
return Conn{
.db = self.db,
.db = db,
.internal_alloc = self.internal_alloc,
.user_id = null,
.community = community,
@ -160,17 +161,18 @@ pub const ApiSource = struct {
var arena = std.heap.ArenaAllocator.init(alloc);
errdefer arena.deinit();
const community = try services.communities.getByHost(self.db, host, arena.allocator());
const db = try self.db_conn.acquire();
const community = try services.communities.getByHost(db, host, arena.allocator());
const token_info = try services.auth.verifyToken(
self.db,
db,
token,
community.id,
arena.allocator(),
);
return Conn{
.db = self.db,
.db = db,
.internal_alloc = self.internal_alloc,
.token_info = token_info,
.user_id = token_info.account_id,

View File

@ -34,41 +34,9 @@ pub fn register(
const hash = try hashPassword(password, alloc);
defer alloc.free(hash);
// transaction may already be running during initial db setup
if (@TypeOf(db).is_transaction) return registerTransaction(
db,
username,
hash,
community_id,
options,
alloc,
);
const tx = try db.begin();
const tx = db.beginOrSavepoint() catch return error.DatabaseFailure;
errdefer tx.rollback();
const id = registerTransaction(
tx,
username,
hash,
community_id,
options,
alloc,
);
try tx.commit();
return id;
}
fn registerTransaction(
tx: anytype,
username: []const u8,
password_hash: []const u8,
community_id: Uuid,
options: RegistrationOptions,
alloc: std.mem.Allocator,
) RegistrationError!Uuid {
const id = try users.create(tx, username, community_id, options.kind, alloc);
tx.insert("local_account", .{
.account_id = id,
@ -77,9 +45,11 @@ fn registerTransaction(
}, alloc) catch return error.DatabaseFailure;
tx.insert("password", .{
.account_id = id,
.hash = password_hash,
.hash = hash,
}, alloc) catch return error.DatabaseFailure;
tx.commitOrRelease() catch return error.DatabaseFailure;
return id;
}

View File

@ -100,7 +100,7 @@ const admin_origin_envvar = "CLUSTER_ADMIN_ORIGIN";
const admin_username_envvar = "CLUSTER_ADMIN_USERNAME";
const admin_password_envvar = "CLUSTER_ADMIN_PASSWORD";
fn runAdminSetup(db: *sql.Db, alloc: std.mem.Allocator) !void {
fn runAdminSetup(db: sql.Db, alloc: std.mem.Allocator) !void {
const origin = std.os.getenv(admin_origin_envvar) orelse return error.MissingArgument;
const username = std.os.getenv(admin_username_envvar) orelse return error.MissingArgument;
const password = std.os.getenv(admin_password_envvar) orelse return error.MissingArgument;
@ -108,7 +108,7 @@ fn runAdminSetup(db: *sql.Db, alloc: std.mem.Allocator) !void {
try api.setupAdmin(db, origin, username, password, alloc);
}
fn prepareDb(db: *sql.Db, alloc: std.mem.Allocator) !void {
fn prepareDb(db: sql.Db, alloc: std.mem.Allocator) !void {
try migrations.up(db);
api.initThreadPrng(@bitCast(u64, std.time.milliTimestamp()));
@ -136,8 +136,8 @@ fn prepareDb(db: *sql.Db, alloc: std.mem.Allocator) !void {
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
var cfg = try loadConfig(gpa.allocator());
var db_conn = try sql.Db.open(cfg.db);
try prepareDb(&db_conn, gpa.allocator());
var db_conn = try sql.Conn.open(cfg.db);
try prepareDb(try db_conn.acquire(), gpa.allocator());
//try migrations.up(&db_conn);
//try api.setupAdmin(&db_conn, "http://localhost:8080", "root", "password", gpa.allocator());

View File

@ -16,13 +16,13 @@ fn firstIndexOf(str: []const u8, char: u8) ?usize {
return null;
}
fn execStmt(tx: sql.Tx, stmt: []const u8, alloc: std.mem.Allocator) !void {
fn execStmt(tx: anytype, stmt: []const u8, alloc: std.mem.Allocator) !void {
const stmt_null = try std.cstr.addNullByte(alloc, stmt);
defer alloc.free(stmt_null);
try tx.exec(stmt_null, {}, null);
}
fn execScript(db: *sql.Db, script: []const u8, alloc: std.mem.Allocator) !void {
fn execScript(db: anytype, script: []const u8, alloc: std.mem.Allocator) !void {
const tx = try db.begin();
errdefer tx.rollback();
@ -49,7 +49,7 @@ fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !boo
};
}
pub fn up(db: *sql.Db) !void {
pub fn up(db: anytype) !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
std.log.info("Running migrations...", .{});

View File

@ -155,140 +155,27 @@ const Row = union(Engine) {
}
};
const QueryHelper = union(Engine) {
postgres: postgres.Db,
sqlite: sqlite.Db,
// internal helper fn
fn queryInternal(
self: QueryHelper,
sql: [:0]const u8,
args: anytype,
opt: QueryOptions,
) QueryError!RawResults {
return switch (self) {
.postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt) },
.sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args, opt) },
};
}
fn queryWithOptions(
self: QueryHelper,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
options: QueryOptions,
) QueryError!Results(RowType) {
return Results(RowType).from(try self.queryInternal(sql, args, options));
}
// Executes a query and returns the result set
fn query(
self: QueryHelper,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryError!Results(RowType) {
return self.queryWithOptions(RowType, sql, args, .{ .prep_allocator = alloc });
}
// Executes a query without returning results
fn exec(
self: QueryHelper,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryError!void {
_ = self.queryRow(void, sql, args, alloc) catch |err| return switch (err) {
error.NoRows => {},
error.TooManyRows => error.SqlException,
error.ResultTypeMismatch => unreachable,
else => |err2| err2,
};
}
// Runs a query and returns a single row
fn queryRow(
self: QueryHelper,
comptime RowType: type,
q: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryRowError!RowType {
var results = try self.query(RowType, q, args, alloc);
defer results.finish();
const row = (try results.row(alloc)) orelse return error.NoRows;
errdefer util.deepFree(alloc, row);
// execute query to completion
var more_rows = false;
while (try results.row(alloc)) |r| {
util.deepFree(alloc, r);
more_rows = true;
}
if (more_rows) return error.TooManyRows;
return row;
}
// Inserts a single value into a table
fn insert(
self: QueryHelper,
comptime table: []const u8,
value: anytype,
alloc: ?std.mem.Allocator,
) !void {
const ValueType = comptime @TypeOf(value);
const fields = std.meta.fields(ValueType);
comptime var types: [fields.len]type = undefined;
comptime var table_spec: []const u8 = table ++ "(";
comptime var value_spec: []const u8 = "(";
inline for (fields) |field, i| {
// This causes a compile error. Why?
//const F = field.field_type;
const F = @TypeOf(@field(std.mem.zeroes(ValueType), field.name));
// causes issues if F is @TypeOf(null), use dummy type
types[i] = if (F == @TypeOf(null)) ?i64 else F;
table_spec = comptime (table_spec ++ field.name ++ ",");
value_spec = comptime value_spec ++ std.fmt.comptimePrint("${},", .{i + 1});
}
table_spec = comptime table_spec[0 .. table_spec.len - 1] ++ ")";
value_spec = comptime value_spec[0 .. value_spec.len - 1] ++ ")";
const q = comptime std.fmt.comptimePrint(
"INSERT INTO {s} VALUES {s}",
.{ table_spec, value_spec },
);
var args_tuple: std.meta.Tuple(&types) = undefined;
inline for (fields) |field, i| {
args_tuple[i] = @field(value, field.name);
}
try self.exec(q, args_tuple, alloc);
}
};
pub const ConstraintMode = enum {
deferred,
immediate,
};
pub const Db = struct {
tx_open: bool = false,
engine: QueryHelper,
pub const Conn = struct {
engine: union(Engine) {
postgres: postgres.Db,
sqlite: sqlite.Db,
},
current_tx_level: u8 = 0,
is_tx_failed: bool = false,
pub const is_transaction = false;
pub fn open(cfg: Config) OpenError!Db {
pub fn open(cfg: Config) OpenError!Conn {
return switch (cfg) {
.postgres => |postgres_cfg| Db{
.postgres => |postgres_cfg| Conn{
.engine = .{
.postgres = try postgres.Db.open(postgres_cfg.pg_conn_str),
},
},
.sqlite => |lite_cfg| Db{
.sqlite => |lite_cfg| Conn{
.engine = .{
.sqlite = try sqlite.Db.open(lite_cfg.sqlite_file_path),
},
@ -296,192 +183,275 @@ pub const Db = struct {
};
}
pub fn close(self: *Db) void {
pub fn close(self: *Conn) void {
switch (self.engine) {
.postgres => |pg| pg.close(),
.sqlite => |lite| lite.close(),
}
}
pub fn queryWithOptions(
self: *Db,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
opt: QueryOptions,
) QueryError!Results(RowType) {
if (self.tx_open) return error.BadTransactionState;
return self.engine.queryWithOptions(RowType, sql, args, opt);
pub fn acquire(self: *Conn) !Db {
if (self.current_tx_level != 0) return error.BadTransactionState;
return Db{ .conn = self };
}
pub fn query(
self: *Db,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryError!Results(RowType) {
if (self.tx_open) return error.BadTransactionState;
return self.engine.query(RowType, sql, args, alloc);
}
pub fn exec(
self: *Db,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryError!void {
if (self.tx_open) return error.BadTransactionState;
return self.engine.exec(sql, args, alloc);
}
pub fn queryRow(
self: *Db,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryRowError!RowType {
if (self.tx_open) return error.BadTransactionState;
return self.engine.queryRow(RowType, sql, args, alloc);
}
pub fn insert(
self: *Db,
comptime table: []const u8,
value: anytype,
alloc: ?std.mem.Allocator,
) !void {
if (self.tx_open) return error.BadTransactionState;
return self.engine.insert(table, value, alloc);
}
pub fn sqlEngine(self: *Db) Engine {
pub fn sqlEngine(self: *Conn) Engine {
return self.engine;
}
// Begins a transaction
pub fn begin(self: *Db) !Tx {
if (self.tx_open) return error.BadTransactionState;
try self.exec("BEGIN", {}, null);
self.tx_open = true;
return Tx{ .db = self };
}
};
pub const Tx = struct {
db: *Db,
pub const Db = Tx(0);
pub const is_transaction = true;
/// When tx_level == 0, the DB is operating in "implied transaction" mode where
/// every command is its own transaction
/// When tx_level >= 1, the DB has an explicit transaction open
/// When tx_level >= 2, the DB has (tx_level - 1) levels of transaction savepoints open
/// (like nested transactions)
fn Tx(comptime tx_level: u8) type {
return struct {
const Self = @This();
const savepoint_name = if (tx_level == 0)
@compileError("Transaction not started")
else
std.fmt.comptimePrint("save_{}", .{tx_level});
const next_savepoint_name = Tx(tx_level + 1).savepoint_name;
pub fn queryWithOptions(
self: Tx,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
options: QueryOptions,
) QueryError!Results(RowType) {
if (!self.db.tx_open) return error.BadTransactionState;
return self.db.engine.queryWithOptions(RowType, sql, args, options);
}
conn: *Conn,
// Executes a query and returns the result set
pub fn query(
self: Tx,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryError!Results(RowType) {
if (!self.db.tx_open) return error.BadTransactionState;
return self.db.engine.query(RowType, sql, args, alloc);
}
// Executes a query without returning results
pub fn exec(
self: Tx,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryError!void {
if (!self.db.tx_open) return error.BadTransactionState;
return self.db.engine.exec(sql, args, alloc);
}
// Runs a query and returns a single row
pub fn queryRow(
self: Tx,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryRowError!RowType {
if (!self.db.tx_open) return error.BadTransactionState;
return self.db.engine.queryRow(RowType, sql, args, alloc);
}
// Inserts a single value into a table
pub fn insert(
self: Tx,
comptime table: []const u8,
value: anytype,
alloc: ?std.mem.Allocator,
) !void {
if (!self.db.tx_open) return error.BadTransactionState;
return self.db.engine.insert(table, value, alloc);
}
pub fn sqlEngine(self: Tx) Engine {
return self.db.engine;
}
// Allows relaxing *some* constraints for the lifetime of the transaction.
// You should generally not do this, but it's useful when bootstrapping
// the initial admin community and cluster operator user.
pub fn setConstraintMode(self: Tx, mode: ConstraintMode) QueryError!void {
if (!self.db.tx_open) return error.BadTransactionState;
switch (self.db.engine) {
.sqlite => try self.exec(
switch (mode) {
.immediate => "PRAGMA defer_foreign_keys = FALSE",
.deferred => "PRAGMA defer_foreign_keys = TRUE",
},
{},
null,
),
.postgres => try self.exec(
switch (mode) {
.immediate => "SET CONSTRAINTS ALL IMMEDIATE",
.deferred => "SET CONSTRAINTS ALL DEFERRED",
},
{},
null,
),
/// The type of SQL engine being used. Use of this function should be discouraged
pub fn sqlEngine(self: Self) Engine {
return self.conn.sqlEngine();
}
}
pub fn rollback(self: Tx) void {
if (!self.db.tx_open) @panic("Transaction not open");
self.exec("ROLLBACK", {}, null) catch |err| {
std.log.err("Error occured during rollback operation: {}", .{err});
};
self.db.tx_open = false;
}
// ********* Transaction management functions **********
pub fn commit(self: Tx) CommitError!void {
if (!self.db.tx_open) return error.BadTransactionState;
self.exec("COMMIT", {}, null) catch |err| switch (err) {
error.BindException,
error.OutOfMemory,
error.UnusedArgument,
error.AllocatorRequired,
=> return error.Unexpected,
/// Start an explicit transaction
pub fn begin(self: Self) !Tx(1) {
if (tx_level != 0) @compileError("Transaction already started");
if (self.conn.current_tx_level != 0) return error.BadTransactionState;
// use a new capture because it's got a smaller error set
else => |err2| return err2,
};
self.db.tx_open = false;
}
};
try self.exec("BEGIN", {}, null);
self.conn.current_tx_level = 1;
return Tx(1){ .conn = self.conn };
}
/// Create a savepoint (nested transaction)
pub fn savepoint(self: Self) !Tx(tx_level + 1) {
if (tx_level == 0) @compileError("Cannot place a savepoint on an implicit transaction");
if (self.conn.current_tx_level != tx_level) return error.BadTransactionState;
try self.exec("SAVEPOINT " ++ next_savepoint_name, {}, null);
self.conn.current_tx_level = tx_level + 1;
return Tx(tx_level + 1){ .conn = self.conn };
}
/// Commit the entire transaction
pub fn commit(self: Self) !void {
if (tx_level == 0) @compileError("Transaction not started");
if (tx_level >= 2) @compileError("Cannot commit a savepoint");
if (self.conn.current_tx_level == 0) return error.BadTransactionState;
try self.exec("COMMIT", {}, null);
self.conn.current_tx_level = 0;
}
/// Release the current savepoint and all savepoints created from it.
pub fn release(self: Self) !void {
if (tx_level == 0) @compileError("Transaction not started");
if (tx_level == 1) @compileError("Cannot release a transaction");
if (self.conn.current_tx_level < tx_level) return error.BadTransactionState;
try self.exec("RELEASE SAVEPOINT " ++ savepoint_name, {}, null);
self.conn.current_tx_level = tx_level - 1;
}
/// Rolls back the entire transaction
pub fn rollbackTx(self: Self) !void {
if (tx_level == 0) @compileError("Transaction not started");
if (tx_level >= 2) @compileError("Cannot rollback a transaction using a savepoint");
if (self.conn.current_tx_level == 0) return error.BadTransactionState;
try self.exec("ROLLBACK", {}, null);
self.conn.current_tx_level = 0;
}
/// Attempts to roll back to a savepoint
pub fn rollbackSavepoint(self: Self) !void {
if (tx_level == 0) @compileError("Transaction not started");
if (tx_level == 1) @compileError("Cannot rollback a savepoint on the entire transaction");
if (self.conn.current_tx_level < tx_level) return error.BadTransactionState;
try self.exec("ROLLBACK TO " ++ savepoint_name, {}, null);
self.conn.current_tx_level = tx_level - 1;
}
/// Perform whichever rollback is appropriate for the situation
pub fn rollback(self: Self) void {
(if (tx_level < 2) self.rollbackTx() else self.rollbackSavepoint()) catch |err| {
std.log.err("Failed to rollback transaction: {}", .{err});
@panic("TODO: more gracefully handle rollback failures");
};
}
pub const beginOrSavepoint = if (tx_level == 0) begin else savepoint;
pub const commitOrRelease = if (tx_level < 2) commit else release;
// Allows relaxing *some* constraints for the lifetime of the transaction.
// You should generally not do this, but it's useful when bootstrapping
// the initial admin community and cluster operator user.
pub fn setConstraintMode(self: Self, mode: ConstraintMode) !void {
if (tx_level == 0) @compileError("Transaction not started");
if (tx_level >= 2) @compileError("Cannot set constraint mode on a savepoint");
switch (self.sqlEngine()) {
.sqlite => try self.exec(
switch (mode) {
.immediate => "PRAGMA defer_foreign_keys = FALSE",
.deferred => "PRAGMA defer_foreign_keys = TRUE",
},
{},
null,
),
.postgres => try self.exec(
switch (mode) {
.immediate => "SET CONSTRAINTS ALL IMMEDIATE",
.deferred => "SET CONSTRAINTS ALL DEFERRED",
},
{},
null,
),
}
}
// ********** Query Helpers ************
/// Runs a command without returning results
pub fn exec(
self: Self,
sql: [:0]const u8,
args: anytype,
alloc: ?std.mem.Allocator,
) !void {
try self.execInternal(sql, args, alloc, true);
}
pub fn queryWithOptions(
self: Self,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
options: QueryOptions,
) QueryError!Results(RowType) {
return Results(RowType).from(try self.runSql(sql, args, options, true));
}
pub fn query(
self: Self,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryError!Results(RowType) {
return self.queryWithOptions(RowType, sql, args, .{ .prep_allocator = alloc });
}
/// Runs a query to completion and returns a row of results, unless the query
/// returned a different number of rows.
pub fn queryRow(
self: Self,
comptime RowType: type,
q: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryRowError!RowType {
var results = try self.query(RowType, q, args, alloc);
defer results.finish();
const row = (try results.row(alloc)) orelse return error.NoRows;
errdefer util.deepFree(alloc, row);
// execute query to completion
var more_rows = false;
while (try results.row(alloc)) |r| {
util.deepFree(alloc, r);
more_rows = true;
}
if (more_rows) return error.TooManyRows;
return row;
}
// Inserts a single value into a table
pub fn insert(
self: Self,
comptime table: []const u8,
value: anytype,
alloc: ?std.mem.Allocator,
) !void {
const ValueType = comptime @TypeOf(value);
const fields = std.meta.fields(ValueType);
comptime var types: [fields.len]type = undefined;
comptime var table_spec: []const u8 = table ++ "(";
comptime var value_spec: []const u8 = "(";
inline for (fields) |field, i| {
// This causes a compile error. Why?
//const F = field.field_type;
const F = @TypeOf(@field(std.mem.zeroes(ValueType), field.name));
// causes issues if F is @TypeOf(null), use dummy type
types[i] = if (F == @TypeOf(null)) ?i64 else F;
table_spec = comptime (table_spec ++ field.name ++ ",");
value_spec = comptime value_spec ++ std.fmt.comptimePrint("${},", .{i + 1});
}
table_spec = comptime table_spec[0 .. table_spec.len - 1] ++ ")";
value_spec = comptime value_spec[0 .. value_spec.len - 1] ++ ")";
const q = comptime std.fmt.comptimePrint(
"INSERT INTO {s} VALUES {s}",
.{ table_spec, value_spec },
);
var args_tuple: std.meta.Tuple(&types) = undefined;
inline for (fields) |field, i| {
args_tuple[i] = @field(value, field.name);
}
try self.exec(q, args_tuple, alloc);
}
// internal helpers
fn runSql(
self: Self,
sql: [:0]const u8,
args: anytype,
opt: QueryOptions,
comptime check_tx: bool,
) !RawResults {
if (check_tx and self.conn.current_tx_level != tx_level) return error.BadTransactionState;
return switch (self.conn.engine) {
.postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt) },
.sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args, opt) },
};
}
fn execInternal(
self: Self,
sql: [:0]const u8,
args: anytype,
alloc: ?std.mem.Allocator,
comptime check_tx: bool,
) !void {
var results = try self.runSql(sql, args, .{ .prep_allocator = alloc }, check_tx);
defer results.finish();
while (try results.row()) |_| {}
}
};
}