Add nested transaction logic

This commit is contained in:
jaina heartles 2022-10-03 22:41:22 -07:00
parent 955df7b044
commit 753ae2729e
5 changed files with 287 additions and 345 deletions

View file

@ -88,7 +88,7 @@ pub fn getRandom() std.rand.Random {
return prng.random(); return prng.random();
} }
pub fn isAdminSetup(db: *sql.Db) !bool { pub fn isAdminSetup(db: sql.Db) !bool {
_ = services.communities.adminCommunityId(db) catch |err| switch (err) { _ = services.communities.adminCommunityId(db) catch |err| switch (err) {
error.NotFound => return false, error.NotFound => return false,
else => return err, else => return err,
@ -97,7 +97,7 @@ pub fn isAdminSetup(db: *sql.Db) !bool {
return true; return true;
} }
pub fn setupAdmin(db: *sql.Db, origin: []const u8, username: []const u8, password: []const u8, allocator: std.mem.Allocator) anyerror!void { pub fn setupAdmin(db: sql.Db, origin: []const u8, username: []const u8, password: []const u8, allocator: std.mem.Allocator) anyerror!void {
const tx = try db.begin(); const tx = try db.begin();
errdefer tx.rollback(); errdefer tx.rollback();
var arena = std.heap.ArenaAllocator.init(allocator); var arena = std.heap.ArenaAllocator.init(allocator);
@ -125,17 +125,17 @@ pub fn setupAdmin(db: *sql.Db, origin: []const u8, username: []const u8, passwor
} }
pub const ApiSource = struct { pub const ApiSource = struct {
db: *sql.Db, db_conn: *sql.Conn,
internal_alloc: std.mem.Allocator, internal_alloc: std.mem.Allocator,
config: Config, config: Config,
pub const Conn = ApiConn(*sql.Db); pub const Conn = ApiConn(sql.Db);
const root_username = "root"; const root_username = "root";
pub fn init(alloc: std.mem.Allocator, cfg: Config, db_conn: *sql.Db) !ApiSource { pub fn init(alloc: std.mem.Allocator, cfg: Config, db_conn: *sql.Conn) !ApiSource {
return ApiSource{ return ApiSource{
.db = db_conn, .db_conn = db_conn,
.internal_alloc = alloc, .internal_alloc = alloc,
.config = cfg, .config = cfg,
}; };
@ -145,10 +145,11 @@ pub const ApiSource = struct {
var arena = std.heap.ArenaAllocator.init(alloc); var arena = std.heap.ArenaAllocator.init(alloc);
errdefer arena.deinit(); errdefer arena.deinit();
const community = try services.communities.getByHost(self.db, host, arena.allocator()); const db = try self.db_conn.acquire();
const community = try services.communities.getByHost(db, host, arena.allocator());
return Conn{ return Conn{
.db = self.db, .db = db,
.internal_alloc = self.internal_alloc, .internal_alloc = self.internal_alloc,
.user_id = null, .user_id = null,
.community = community, .community = community,
@ -160,17 +161,18 @@ pub const ApiSource = struct {
var arena = std.heap.ArenaAllocator.init(alloc); var arena = std.heap.ArenaAllocator.init(alloc);
errdefer arena.deinit(); errdefer arena.deinit();
const community = try services.communities.getByHost(self.db, host, arena.allocator()); const db = try self.db_conn.acquire();
const community = try services.communities.getByHost(db, host, arena.allocator());
const token_info = try services.auth.verifyToken( const token_info = try services.auth.verifyToken(
self.db, db,
token, token,
community.id, community.id,
arena.allocator(), arena.allocator(),
); );
return Conn{ return Conn{
.db = self.db, .db = db,
.internal_alloc = self.internal_alloc, .internal_alloc = self.internal_alloc,
.token_info = token_info, .token_info = token_info,
.user_id = token_info.account_id, .user_id = token_info.account_id,

View file

@ -34,41 +34,9 @@ pub fn register(
const hash = try hashPassword(password, alloc); const hash = try hashPassword(password, alloc);
defer alloc.free(hash); defer alloc.free(hash);
// transaction may already be running during initial db setup const tx = db.beginOrSavepoint() catch return error.DatabaseFailure;
if (@TypeOf(db).is_transaction) return registerTransaction(
db,
username,
hash,
community_id,
options,
alloc,
);
const tx = try db.begin();
errdefer tx.rollback(); errdefer tx.rollback();
const id = registerTransaction(
tx,
username,
hash,
community_id,
options,
alloc,
);
try tx.commit();
return id;
}
fn registerTransaction(
tx: anytype,
username: []const u8,
password_hash: []const u8,
community_id: Uuid,
options: RegistrationOptions,
alloc: std.mem.Allocator,
) RegistrationError!Uuid {
const id = try users.create(tx, username, community_id, options.kind, alloc); const id = try users.create(tx, username, community_id, options.kind, alloc);
tx.insert("local_account", .{ tx.insert("local_account", .{
.account_id = id, .account_id = id,
@ -77,9 +45,11 @@ fn registerTransaction(
}, alloc) catch return error.DatabaseFailure; }, alloc) catch return error.DatabaseFailure;
tx.insert("password", .{ tx.insert("password", .{
.account_id = id, .account_id = id,
.hash = password_hash, .hash = hash,
}, alloc) catch return error.DatabaseFailure; }, alloc) catch return error.DatabaseFailure;
tx.commitOrRelease() catch return error.DatabaseFailure;
return id; return id;
} }

View file

@ -100,7 +100,7 @@ const admin_origin_envvar = "CLUSTER_ADMIN_ORIGIN";
const admin_username_envvar = "CLUSTER_ADMIN_USERNAME"; const admin_username_envvar = "CLUSTER_ADMIN_USERNAME";
const admin_password_envvar = "CLUSTER_ADMIN_PASSWORD"; const admin_password_envvar = "CLUSTER_ADMIN_PASSWORD";
fn runAdminSetup(db: *sql.Db, alloc: std.mem.Allocator) !void { fn runAdminSetup(db: sql.Db, alloc: std.mem.Allocator) !void {
const origin = std.os.getenv(admin_origin_envvar) orelse return error.MissingArgument; const origin = std.os.getenv(admin_origin_envvar) orelse return error.MissingArgument;
const username = std.os.getenv(admin_username_envvar) orelse return error.MissingArgument; const username = std.os.getenv(admin_username_envvar) orelse return error.MissingArgument;
const password = std.os.getenv(admin_password_envvar) orelse return error.MissingArgument; const password = std.os.getenv(admin_password_envvar) orelse return error.MissingArgument;
@ -108,7 +108,7 @@ fn runAdminSetup(db: *sql.Db, alloc: std.mem.Allocator) !void {
try api.setupAdmin(db, origin, username, password, alloc); try api.setupAdmin(db, origin, username, password, alloc);
} }
fn prepareDb(db: *sql.Db, alloc: std.mem.Allocator) !void { fn prepareDb(db: sql.Db, alloc: std.mem.Allocator) !void {
try migrations.up(db); try migrations.up(db);
api.initThreadPrng(@bitCast(u64, std.time.milliTimestamp())); api.initThreadPrng(@bitCast(u64, std.time.milliTimestamp()));
@ -136,8 +136,8 @@ fn prepareDb(db: *sql.Db, alloc: std.mem.Allocator) !void {
pub fn main() !void { pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){}; var gpa = std.heap.GeneralPurposeAllocator(.{}){};
var cfg = try loadConfig(gpa.allocator()); var cfg = try loadConfig(gpa.allocator());
var db_conn = try sql.Db.open(cfg.db); var db_conn = try sql.Conn.open(cfg.db);
try prepareDb(&db_conn, gpa.allocator()); try prepareDb(try db_conn.acquire(), gpa.allocator());
//try migrations.up(&db_conn); //try migrations.up(&db_conn);
//try api.setupAdmin(&db_conn, "http://localhost:8080", "root", "password", gpa.allocator()); //try api.setupAdmin(&db_conn, "http://localhost:8080", "root", "password", gpa.allocator());

View file

@ -16,13 +16,13 @@ fn firstIndexOf(str: []const u8, char: u8) ?usize {
return null; return null;
} }
fn execStmt(tx: sql.Tx, stmt: []const u8, alloc: std.mem.Allocator) !void { fn execStmt(tx: anytype, stmt: []const u8, alloc: std.mem.Allocator) !void {
const stmt_null = try std.cstr.addNullByte(alloc, stmt); const stmt_null = try std.cstr.addNullByte(alloc, stmt);
defer alloc.free(stmt_null); defer alloc.free(stmt_null);
try tx.exec(stmt_null, {}, null); try tx.exec(stmt_null, {}, null);
} }
fn execScript(db: *sql.Db, script: []const u8, alloc: std.mem.Allocator) !void { fn execScript(db: anytype, script: []const u8, alloc: std.mem.Allocator) !void {
const tx = try db.begin(); const tx = try db.begin();
errdefer tx.rollback(); errdefer tx.rollback();
@ -49,7 +49,7 @@ fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !boo
}; };
} }
pub fn up(db: *sql.Db) !void { pub fn up(db: anytype) !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){}; var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit(); defer _ = gpa.deinit();
std.log.info("Running migrations...", .{}); std.log.info("Running migrations...", .{});

View file

@ -155,36 +155,205 @@ const Row = union(Engine) {
} }
}; };
const QueryHelper = union(Engine) { pub const ConstraintMode = enum {
deferred,
immediate,
};
pub const Conn = struct {
engine: union(Engine) {
postgres: postgres.Db, postgres: postgres.Db,
sqlite: sqlite.Db, sqlite: sqlite.Db,
},
current_tx_level: u8 = 0,
is_tx_failed: bool = false,
// internal helper fn pub fn open(cfg: Config) OpenError!Conn {
fn queryInternal( return switch (cfg) {
self: QueryHelper, .postgres => |postgres_cfg| Conn{
sql: [:0]const u8, .engine = .{
args: anytype, .postgres = try postgres.Db.open(postgres_cfg.pg_conn_str),
opt: QueryOptions, },
) QueryError!RawResults { },
return switch (self) { .sqlite => |lite_cfg| Conn{
.postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt) }, .engine = .{
.sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args, opt) }, .sqlite = try sqlite.Db.open(lite_cfg.sqlite_file_path),
},
},
}; };
} }
fn queryWithOptions( pub fn close(self: *Conn) void {
self: QueryHelper, switch (self.engine) {
.postgres => |pg| pg.close(),
.sqlite => |lite| lite.close(),
}
}
pub fn acquire(self: *Conn) !Db {
if (self.current_tx_level != 0) return error.BadTransactionState;
return Db{ .conn = self };
}
pub fn sqlEngine(self: *Conn) Engine {
return self.engine;
}
};
pub const Db = Tx(0);
/// When tx_level == 0, the DB is operating in "implied transaction" mode where
/// every command is its own transaction
/// When tx_level >= 1, the DB has an explicit transaction open
/// When tx_level >= 2, the DB has (tx_level - 1) levels of transaction savepoints open
/// (like nested transactions)
fn Tx(comptime tx_level: u8) type {
return struct {
const Self = @This();
const savepoint_name = if (tx_level == 0)
@compileError("Transaction not started")
else
std.fmt.comptimePrint("save_{}", .{tx_level});
const next_savepoint_name = Tx(tx_level + 1).savepoint_name;
conn: *Conn,
/// The type of SQL engine being used. Use of this function should be discouraged
pub fn sqlEngine(self: Self) Engine {
return self.conn.sqlEngine();
}
// ********* Transaction management functions **********
/// Start an explicit transaction
pub fn begin(self: Self) !Tx(1) {
if (tx_level != 0) @compileError("Transaction already started");
if (self.conn.current_tx_level != 0) return error.BadTransactionState;
try self.exec("BEGIN", {}, null);
self.conn.current_tx_level = 1;
return Tx(1){ .conn = self.conn };
}
/// Create a savepoint (nested transaction)
pub fn savepoint(self: Self) !Tx(tx_level + 1) {
if (tx_level == 0) @compileError("Cannot place a savepoint on an implicit transaction");
if (self.conn.current_tx_level != tx_level) return error.BadTransactionState;
try self.exec("SAVEPOINT " ++ next_savepoint_name, {}, null);
self.conn.current_tx_level = tx_level + 1;
return Tx(tx_level + 1){ .conn = self.conn };
}
/// Commit the entire transaction
pub fn commit(self: Self) !void {
if (tx_level == 0) @compileError("Transaction not started");
if (tx_level >= 2) @compileError("Cannot commit a savepoint");
if (self.conn.current_tx_level == 0) return error.BadTransactionState;
try self.exec("COMMIT", {}, null);
self.conn.current_tx_level = 0;
}
/// Release the current savepoint and all savepoints created from it.
pub fn release(self: Self) !void {
if (tx_level == 0) @compileError("Transaction not started");
if (tx_level == 1) @compileError("Cannot release a transaction");
if (self.conn.current_tx_level < tx_level) return error.BadTransactionState;
try self.exec("RELEASE SAVEPOINT " ++ savepoint_name, {}, null);
self.conn.current_tx_level = tx_level - 1;
}
/// Rolls back the entire transaction
pub fn rollbackTx(self: Self) !void {
if (tx_level == 0) @compileError("Transaction not started");
if (tx_level >= 2) @compileError("Cannot rollback a transaction using a savepoint");
if (self.conn.current_tx_level == 0) return error.BadTransactionState;
try self.exec("ROLLBACK", {}, null);
self.conn.current_tx_level = 0;
}
/// Attempts to roll back to a savepoint
pub fn rollbackSavepoint(self: Self) !void {
if (tx_level == 0) @compileError("Transaction not started");
if (tx_level == 1) @compileError("Cannot rollback a savepoint on the entire transaction");
if (self.conn.current_tx_level < tx_level) return error.BadTransactionState;
try self.exec("ROLLBACK TO " ++ savepoint_name, {}, null);
self.conn.current_tx_level = tx_level - 1;
}
/// Perform whichever rollback is appropriate for the situation
pub fn rollback(self: Self) void {
(if (tx_level < 2) self.rollbackTx() else self.rollbackSavepoint()) catch |err| {
std.log.err("Failed to rollback transaction: {}", .{err});
@panic("TODO: more gracefully handle rollback failures");
};
}
pub const beginOrSavepoint = if (tx_level == 0) begin else savepoint;
pub const commitOrRelease = if (tx_level < 2) commit else release;
// Allows relaxing *some* constraints for the lifetime of the transaction.
// You should generally not do this, but it's useful when bootstrapping
// the initial admin community and cluster operator user.
pub fn setConstraintMode(self: Self, mode: ConstraintMode) !void {
if (tx_level == 0) @compileError("Transaction not started");
if (tx_level >= 2) @compileError("Cannot set constraint mode on a savepoint");
switch (self.sqlEngine()) {
.sqlite => try self.exec(
switch (mode) {
.immediate => "PRAGMA defer_foreign_keys = FALSE",
.deferred => "PRAGMA defer_foreign_keys = TRUE",
},
{},
null,
),
.postgres => try self.exec(
switch (mode) {
.immediate => "SET CONSTRAINTS ALL IMMEDIATE",
.deferred => "SET CONSTRAINTS ALL DEFERRED",
},
{},
null,
),
}
}
// ********** Query Helpers ************
/// Runs a command without returning results
pub fn exec(
self: Self,
sql: [:0]const u8,
args: anytype,
alloc: ?std.mem.Allocator,
) !void {
try self.execInternal(sql, args, alloc, true);
}
pub fn queryWithOptions(
self: Self,
comptime RowType: type, comptime RowType: type,
sql: [:0]const u8, sql: [:0]const u8,
args: anytype, args: anytype,
options: QueryOptions, options: QueryOptions,
) QueryError!Results(RowType) { ) QueryError!Results(RowType) {
return Results(RowType).from(try self.queryInternal(sql, args, options)); return Results(RowType).from(try self.runSql(sql, args, options, true));
} }
// Executes a query and returns the result set pub fn query(
fn query( self: Self,
self: QueryHelper,
comptime RowType: type, comptime RowType: type,
sql: [:0]const u8, sql: [:0]const u8,
args: anytype, args: anytype,
@ -193,24 +362,10 @@ const QueryHelper = union(Engine) {
return self.queryWithOptions(RowType, sql, args, .{ .prep_allocator = alloc }); return self.queryWithOptions(RowType, sql, args, .{ .prep_allocator = alloc });
} }
// Executes a query without returning results /// Runs a query to completion and returns a row of results, unless the query
fn exec( /// returned a different number of rows.
self: QueryHelper, pub fn queryRow(
sql: [:0]const u8, self: Self,
args: anytype,
alloc: ?Allocator,
) QueryError!void {
_ = self.queryRow(void, sql, args, alloc) catch |err| return switch (err) {
error.NoRows => {},
error.TooManyRows => error.SqlException,
error.ResultTypeMismatch => unreachable,
else => |err2| err2,
};
}
// Runs a query and returns a single row
fn queryRow(
self: QueryHelper,
comptime RowType: type, comptime RowType: type,
q: [:0]const u8, q: [:0]const u8,
args: anytype, args: anytype,
@ -234,8 +389,8 @@ const QueryHelper = union(Engine) {
} }
// Inserts a single value into a table // Inserts a single value into a table
fn insert( pub fn insert(
self: QueryHelper, self: Self,
comptime table: []const u8, comptime table: []const u8,
value: anytype, value: anytype,
alloc: ?std.mem.Allocator, alloc: ?std.mem.Allocator,
@ -268,220 +423,35 @@ const QueryHelper = union(Engine) {
} }
try self.exec(q, args_tuple, alloc); try self.exec(q, args_tuple, alloc);
} }
};
pub const ConstraintMode = enum { // internal helpers
deferred,
immediate,
};
pub const Db = struct { fn runSql(
tx_open: bool = false, self: Self,
engine: QueryHelper,
pub const is_transaction = false;
pub fn open(cfg: Config) OpenError!Db {
return switch (cfg) {
.postgres => |postgres_cfg| Db{
.engine = .{
.postgres = try postgres.Db.open(postgres_cfg.pg_conn_str),
},
},
.sqlite => |lite_cfg| Db{
.engine = .{
.sqlite = try sqlite.Db.open(lite_cfg.sqlite_file_path),
},
},
};
}
pub fn close(self: *Db) void {
switch (self.engine) {
.postgres => |pg| pg.close(),
.sqlite => |lite| lite.close(),
}
}
pub fn queryWithOptions(
self: *Db,
comptime RowType: type,
sql: [:0]const u8, sql: [:0]const u8,
args: anytype, args: anytype,
opt: QueryOptions, opt: QueryOptions,
) QueryError!Results(RowType) { comptime check_tx: bool,
if (self.tx_open) return error.BadTransactionState; ) !RawResults {
return self.engine.queryWithOptions(RowType, sql, args, opt); if (check_tx and self.conn.current_tx_level != tx_level) return error.BadTransactionState;
return switch (self.conn.engine) {
.postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt) },
.sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args, opt) },
};
} }
pub fn query( fn execInternal(
self: *Db, self: Self,
comptime RowType: type,
sql: [:0]const u8, sql: [:0]const u8,
args: anytype, args: anytype,
alloc: ?Allocator,
) QueryError!Results(RowType) {
if (self.tx_open) return error.BadTransactionState;
return self.engine.query(RowType, sql, args, alloc);
}
pub fn exec(
self: *Db,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryError!void {
if (self.tx_open) return error.BadTransactionState;
return self.engine.exec(sql, args, alloc);
}
pub fn queryRow(
self: *Db,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryRowError!RowType {
if (self.tx_open) return error.BadTransactionState;
return self.engine.queryRow(RowType, sql, args, alloc);
}
pub fn insert(
self: *Db,
comptime table: []const u8,
value: anytype,
alloc: ?std.mem.Allocator, alloc: ?std.mem.Allocator,
comptime check_tx: bool,
) !void { ) !void {
if (self.tx_open) return error.BadTransactionState; var results = try self.runSql(sql, args, .{ .prep_allocator = alloc }, check_tx);
return self.engine.insert(table, value, alloc); defer results.finish();
}
pub fn sqlEngine(self: *Db) Engine { while (try results.row()) |_| {}
return self.engine;
}
// Begins a transaction
pub fn begin(self: *Db) !Tx {
if (self.tx_open) return error.BadTransactionState;
try self.exec("BEGIN", {}, null);
self.tx_open = true;
return Tx{ .db = self };
} }
}; };
pub const Tx = struct {
db: *Db,
pub const is_transaction = true;
pub fn queryWithOptions(
self: Tx,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
options: QueryOptions,
) QueryError!Results(RowType) {
if (!self.db.tx_open) return error.BadTransactionState;
return self.db.engine.queryWithOptions(RowType, sql, args, options);
} }
// Executes a query and returns the result set
pub fn query(
self: Tx,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryError!Results(RowType) {
if (!self.db.tx_open) return error.BadTransactionState;
return self.db.engine.query(RowType, sql, args, alloc);
}
// Executes a query without returning results
pub fn exec(
self: Tx,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryError!void {
if (!self.db.tx_open) return error.BadTransactionState;
return self.db.engine.exec(sql, args, alloc);
}
// Runs a query and returns a single row
pub fn queryRow(
self: Tx,
comptime RowType: type,
sql: [:0]const u8,
args: anytype,
alloc: ?Allocator,
) QueryRowError!RowType {
if (!self.db.tx_open) return error.BadTransactionState;
return self.db.engine.queryRow(RowType, sql, args, alloc);
}
// Inserts a single value into a table
pub fn insert(
self: Tx,
comptime table: []const u8,
value: anytype,
alloc: ?std.mem.Allocator,
) !void {
if (!self.db.tx_open) return error.BadTransactionState;
return self.db.engine.insert(table, value, alloc);
}
pub fn sqlEngine(self: Tx) Engine {
return self.db.engine;
}
// Allows relaxing *some* constraints for the lifetime of the transaction.
// You should generally not do this, but it's useful when bootstrapping
// the initial admin community and cluster operator user.
pub fn setConstraintMode(self: Tx, mode: ConstraintMode) QueryError!void {
if (!self.db.tx_open) return error.BadTransactionState;
switch (self.db.engine) {
.sqlite => try self.exec(
switch (mode) {
.immediate => "PRAGMA defer_foreign_keys = FALSE",
.deferred => "PRAGMA defer_foreign_keys = TRUE",
},
{},
null,
),
.postgres => try self.exec(
switch (mode) {
.immediate => "SET CONSTRAINTS ALL IMMEDIATE",
.deferred => "SET CONSTRAINTS ALL DEFERRED",
},
{},
null,
),
}
}
pub fn rollback(self: Tx) void {
if (!self.db.tx_open) @panic("Transaction not open");
self.exec("ROLLBACK", {}, null) catch |err| {
std.log.err("Error occured during rollback operation: {}", .{err});
};
self.db.tx_open = false;
}
pub fn commit(self: Tx) CommitError!void {
if (!self.db.tx_open) return error.BadTransactionState;
self.exec("COMMIT", {}, null) catch |err| switch (err) {
error.BindException,
error.OutOfMemory,
error.UnusedArgument,
error.AllocatorRequired,
=> return error.Unexpected,
// use a new capture because it's got a smaller error set
else => |err2| return err2,
};
self.db.tx_open = false;
}
};