This commit is contained in:
jaina heartles 2022-10-03 19:41:59 -07:00
parent e227a3de0f
commit 955df7b044
20 changed files with 1181 additions and 745 deletions

View File

@ -7,6 +7,12 @@
- System libraries - System libraries
* `sqlite3` * `sqlite3`
NOTE: compilation is broken right now because of:
https://github.com/ziglang/zig/issues/12240
for a temporary fix, rebuild zig after changing `$stdlibdir/crypto/scrypt.zig:465` to use `default_salt_len` instead of `salt_bin.len`
### Commands ### Commands
To build a binary: `zig build` To build a binary: `zig build`

View File

@ -1,11 +0,0 @@
# General overview
- `/controllers/**`
Handles serialization/deserialization of api calls from HTTP requests
- `/api.zig`
Business rules
- `/api/*.zig`
Performs the actual actions in the DB associated with a call
- `/db.zig`
SQL query wrapper

View File

@ -3,10 +3,8 @@ const util = @import("util");
const builtin = @import("builtin"); const builtin = @import("builtin");
const sql = @import("sql"); const sql = @import("sql");
const models = @import("./db/models.zig"); const DateTime = util.DateTime;
const migrations = @import("./migrations.zig"); const Uuid = util.Uuid;
pub const DateTime = util.DateTime;
pub const Uuid = util.Uuid;
const Config = @import("./main.zig").Config; const Config = @import("./main.zig").Config;
const services = struct { const services = struct {
@ -25,21 +23,17 @@ pub const RegistrationRequest = struct {
}; };
pub const InviteRequest = struct { pub const InviteRequest = struct {
pub const Type = services.invites.InviteType; pub const Kind = services.invites.Kind;
name: ?[]const u8 = null, name: ?[]const u8 = null,
expires_at: ?DateTime = null, // TODO: Change this to lifespan lifespan: ?DateTime.Duration = null,
max_uses: ?u16 = null, max_uses: ?u16 = null,
invite_type: Type = .user, // must be user unless the creator is an admin kind: Kind = .user, // must be user unless the creator is an admin
to_community: ?[]const u8 = null, // only valid on admin community to_community: ?[]const u8 = null, // only valid on admin community
}; };
pub const LoginResponse = struct { pub const LoginResponse = services.auth.LoginResult;
token: services.auth.tokens.Token.Value,
user_id: Uuid,
issued_at: DateTime,
};
pub const UserResponse = struct { pub const UserResponse = struct {
id: Uuid, id: Uuid,
@ -103,7 +97,7 @@ pub fn isAdminSetup(db: *sql.Db) !bool {
return true; return true;
} }
pub fn setupAdmin(db: *sql.Db, origin: []const u8, username: []const u8, password: []const u8, allocator: std.mem.Allocator) !void { pub fn setupAdmin(db: *sql.Db, origin: []const u8, username: []const u8, password: []const u8, allocator: std.mem.Allocator) anyerror!void {
const tx = try db.begin(); const tx = try db.begin();
errdefer tx.rollback(); errdefer tx.rollback();
var arena = std.heap.ArenaAllocator.init(allocator); var arena = std.heap.ArenaAllocator.init(allocator);
@ -111,22 +105,22 @@ pub fn setupAdmin(db: *sql.Db, origin: []const u8, username: []const u8, passwor
try tx.setConstraintMode(.deferred); try tx.setConstraintMode(.deferred);
const community = try services.communities.create( const community_id = try services.communities.create(
tx, tx,
origin, origin,
Uuid.nil,
.{ .name = "Cluster Admin", .kind = .admin }, .{ .name = "Cluster Admin", .kind = .admin },
arena.allocator(),
); );
const user = try services.users.create(tx, username, password, community.id, .{ .role = .admin }, arena.allocator()); const user = try services.auth.register(tx, username, password, community_id, .{ .kind = .admin }, arena.allocator());
try services.communities.transferOwnership(tx, community.id, user); try services.communities.transferOwnership(tx, community_id, user);
try tx.commit(); try tx.commit();
std.log.info( std.log.info(
"Created admin user {s} (id {}) with cluster admin origin {s} (id {})", "Created admin user {s} (id {}) with cluster admin origin {s} (id {})",
.{ username, user, origin, community.id }, .{ username, user, origin, community_id },
); );
} }
@ -168,12 +162,18 @@ pub const ApiSource = struct {
const community = try services.communities.getByHost(self.db, host, arena.allocator()); const community = try services.communities.getByHost(self.db, host, arena.allocator());
const token_info = try services.auth.tokens.verify(self.db, token, community.id); const token_info = try services.auth.verifyToken(
self.db,
token,
community.id,
arena.allocator(),
);
return Conn{ return Conn{
.db = self.db, .db = self.db,
.internal_alloc = self.internal_alloc, .internal_alloc = self.internal_alloc,
.user_id = token_info.user_id, .token_info = token_info,
.user_id = token_info.account_id,
.community = community, .community = community,
.arena = arena, .arena = arena,
}; };
@ -186,7 +186,8 @@ fn ApiConn(comptime DbConn: type) type {
db: DbConn, db: DbConn,
internal_alloc: std.mem.Allocator, // used *only* for large, internal buffers internal_alloc: std.mem.Allocator, // used *only* for large, internal buffers
user_id: ?Uuid, token_info: ?services.auth.TokenInfo = null,
user_id: ?Uuid = null,
community: services.communities.Community, community: services.communities.Community,
arena: std.heap.ArenaAllocator, arena: std.heap.ArenaAllocator,
@ -200,32 +201,35 @@ fn ApiConn(comptime DbConn: type) type {
} }
pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResponse { pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResponse {
const user_id = (try services.users.lookupByUsername(self.db, username, self.community.id)) orelse return error.InvalidLogin; return services.auth.login(
try services.auth.passwords.verify(self.db, user_id, password, self.internal_alloc); self.db,
username,
const token = try services.auth.tokens.create(self.db, user_id); self.community.id,
password,
return LoginResponse{ self.arena.allocator(),
.user_id = user_id, );
.token = token.value,
.issued_at = token.info.issued_at,
};
} }
const TokenInfo = struct { pub const AuthorizationInfo = struct {
id: Uuid,
username: []const u8, username: []const u8,
community_id: Uuid,
host: []const u8,
issued_at: DateTime,
}; };
pub fn getTokenInfo(self: *Self) !TokenInfo { pub fn verifyAuthorization(self: *Self) !AuthorizationInfo {
if (self.user_id) |user_id| { if (self.token_info) |info| {
const result = (try self.db.queryRow( const user = try services.users.get(self.db, info.account_id, self.arena.allocator());
std.meta.Tuple(&.{[]const u8}),
"SELECT username FROM user WHERE id = $1", return AuthorizationInfo{
.{user_id}, .id = user.id,
self.arena.allocator(), .username = user.username,
)) orelse { .community_id = self.community.id,
return error.UserNotFound; .host = self.community.host,
.issued_at = info.issued_at,
}; };
return TokenInfo{ .username = result[0] };
} }
return error.Unauthorized; return error.Unauthorized;
@ -236,7 +240,27 @@ fn ApiConn(comptime DbConn: type) type {
return error.PermissionDenied; return error.PermissionDenied;
} }
return services.communities.create(self.db, origin, self.user_id.?, .{}); const tx = try self.db.begin();
errdefer tx.rollback();
const community_id = try services.communities.create(
tx,
origin,
.{},
self.arena.allocator(),
);
const community = services.communities.get(
tx,
community_id,
self.arena.allocator(),
) catch |err| return switch (err) {
error.NotFound => error.DatabaseError,
else => |err2| err2,
};
try tx.commit();
return community;
} }
pub fn createInvite(self: *Self, options: InviteRequest) !services.invites.Invite { pub fn createInvite(self: *Self, options: InviteRequest) !services.invites.Invite {
@ -254,21 +278,23 @@ fn ApiConn(comptime DbConn: type) type {
} else self.community.id; } else self.community.id;
// Users can only make user invites // Users can only make user invites
if (options.invite_type != .user and !self.isAdmin()) return error.PermissionDenied; if (options.kind != .user and !self.isAdmin()) return error.PermissionDenied;
return try services.invites.create(self.db, user_id, community_id, .{ const invite_id = try services.invites.create(self.db, user_id, community_id, .{
.name = options.name, .name = options.name,
.expires_at = options.expires_at, .lifespan = options.lifespan,
.max_uses = options.max_uses, .max_uses = options.max_uses,
.invite_type = options.invite_type, .kind = options.kind,
}, self.arena.allocator()); }, self.arena.allocator());
return try services.invites.get(self.db, invite_id, self.arena.allocator());
} }
pub fn register(self: *Self, request: RegistrationRequest) !UserResponse { pub fn register(self: *Self, request: RegistrationRequest) !UserResponse {
std.log.debug("registering user {s} with code {s}", .{ request.username, request.invite_code }); std.log.debug("registering user {s} with code {s}", .{ request.username, request.invite_code });
const invite = try services.invites.getByCode(self.db, request.invite_code, self.arena.allocator()); const invite = try services.invites.getByCode(self.db, request.invite_code, self.arena.allocator());
if (!Uuid.eql(invite.to_community, self.community.id)) return error.NotFound; if (!Uuid.eql(invite.community_id, self.community.id)) return error.NotFound;
if (invite.max_uses != null and invite.times_used >= invite.max_uses.?) return error.InviteExpired; if (invite.max_uses != null and invite.times_used >= invite.max_uses.?) return error.InviteExpired;
if (invite.expires_at != null and DateTime.now().isAfter(invite.expires_at.?)) return error.InviteExpired; if (invite.expires_at != null and DateTime.now().isAfter(invite.expires_at.?)) return error.InviteExpired;

View File

@ -1,157 +1,272 @@
const std = @import("std"); const std = @import("std");
const util = @import("util"); const util = @import("util");
const users = @import("./users.zig");
const Uuid = util.Uuid; const Uuid = util.Uuid;
const DateTime = util.DateTime; const DateTime = util.DateTime;
pub const passwords = struct { pub const RegistrationError = error{
const PwHash = std.crypto.pwhash.scrypt; PasswordTooShort,
const pw_hash_params = PwHash.Params.interactive; DatabaseFailure,
const pw_hash_encoding = .phc; HashFailure,
const pw_hash_buf_size = 128; OutOfMemory,
} || users.CreateError;
const PwHashBuf = [pw_hash_buf_size]u8; pub const min_password_chars = 12;
pub const RegistrationOptions = struct {
invite_id: ?Uuid = null,
email: ?[]const u8 = null,
kind: users.Kind = .user,
};
pub const VerifyError = error{ /// Creates a local account with the given information and returns the
InvalidLogin, /// account id
DatabaseFailure, pub fn register(
HashFailure, db: anytype,
username: []const u8,
password: []const u8,
community_id: Uuid,
options: RegistrationOptions,
alloc: std.mem.Allocator,
) RegistrationError!Uuid {
if (password.len < min_password_chars) return error.PasswordTooShort;
const hash = try hashPassword(password, alloc);
defer alloc.free(hash);
// transaction may already be running during initial db setup
if (@TypeOf(db).is_transaction) return registerTransaction(
db,
username,
hash,
community_id,
options,
alloc,
);
const tx = try db.begin();
errdefer tx.rollback();
const id = registerTransaction(
tx,
username,
hash,
community_id,
options,
alloc,
);
try tx.commit();
return id;
}
fn registerTransaction(
tx: anytype,
username: []const u8,
password_hash: []const u8,
community_id: Uuid,
options: RegistrationOptions,
alloc: std.mem.Allocator,
) RegistrationError!Uuid {
const id = try users.create(tx, username, community_id, options.kind, alloc);
tx.insert("local_account", .{
.account_id = id,
.invite_id = options.invite_id,
.email = options.email,
}, alloc) catch return error.DatabaseFailure;
tx.insert("password", .{
.account_id = id,
.hash = password_hash,
}, alloc) catch return error.DatabaseFailure;
return id;
}
pub const LoginError = error{
InvalidLogin,
HashFailure,
DatabaseFailure,
OutOfMemory,
};
pub const LoginResult = struct {
token: []const u8,
account_id: Uuid,
};
/// Attempts to login to the account `@username@community` and creates
/// a login token/cookie for the user
pub fn login(
db: anytype,
username: []const u8,
community_id: Uuid,
password: []const u8,
alloc: std.mem.Allocator,
) LoginError!LoginResult {
std.log.debug("user: {s}, community_id: {}", .{ username, community_id });
const info = db.queryRow(
struct { account_id: Uuid, hash: []const u8 },
\\SELECT account.id as account_id, password.hash
\\FROM password JOIN account
\\ ON password.account_id = account.id
\\WHERE account.username = $1
\\ AND account.community_id = $2
\\LIMIT 1
,
.{ username, community_id },
alloc,
) catch |err| return switch (err) {
error.NoRows => error.InvalidLogin,
else => error.DatabaseFailure,
}; };
pub fn verify( errdefer util.deepFree(alloc, info);
db: anytype, std.log.debug("got password", .{});
account_id: Uuid,
password: []const u8, try verifyPassword(info.hash, password, alloc);
alloc: std.mem.Allocator,
) VerifyError!void { const token = try generateToken(alloc);
// TODO: This could be done w/o the dynamically allocated hash buf errdefer util.deepFree(alloc, token);
const hash = db.queryRow( const token_hash = hashToken(token, alloc) catch |err| switch (err) {
std.meta.Tuple(&.{[]const u8}), error.OutOfMemory => return error.OutOfMemory,
\\SELECT hashed_password else => unreachable,
\\FROM account_password };
defer util.deepFree(alloc, token_hash);
const tx = db.begin() catch return error.DatabaseFailure;
errdefer tx.rollback();
// ensure that the password has not changed in the meantime
{
const updated_info = tx.queryRow(
struct { hash: []const u8 },
\\SELECT hash
\\FROM password
\\WHERE account_id = $1 \\WHERE account_id = $1
\\LIMIT 1 \\LIMIT 1
, ,
.{account_id}, .{info.account_id},
alloc,
) catch |err| return switch (err) {
error.NoRows => error.InvalidLogin,
else => error.DatabaseFailure,
};
errdefer alloc.free(hash[0]);
PwHash.strVerify(
hash[0],
password,
.{ .allocator = alloc },
) catch error.HashFailure;
}
pub const CreateError = error{ DatabaseFailure, HashFailure };
pub fn create(
db: anytype,
account_id: Uuid,
password: []const u8,
alloc: std.mem.Allocator,
) CreateError!void {
var buf: PwHashBuf = undefined;
const hash = PwHash.strHash(
password,
.{ .allocator = alloc, .params = pw_hash_params, .encoding = pw_hash_encoding },
&buf,
) catch return error.HashFailure;
db.insert(
"account_password",
.{
.account_id = account_id,
.hashed_password = hash,
},
alloc, alloc,
) catch return error.DatabaseFailure; ) catch return error.DatabaseFailure;
defer util.deepFree(alloc, updated_info);
if (!std.mem.eql(u8, info.hash, updated_info.hash)) return error.InvalidLogin;
} }
};
pub const tokens = struct { tx.insert("token", .{
const token_len = 20; .account_id = info.account_id,
const token_str_len = std.base64.standard.Encoder.calcSize(token_len); .hash = token_hash,
pub const Token = struct { }, alloc) catch return error.DatabaseFailure;
pub const Value = [token_str_len]u8;
pub const Info = struct {
account_id: Uuid,
issued_at: DateTime,
};
value: Value, tx.commit() catch return error.DatabaseFailure;
info: Info,
return LoginResult{
.token = token,
.account_id = info.account_id,
}; };
}
const TokenHash = std.crypto.hash.sha2.Sha256; pub const VerifyTokenError = error{ InvalidToken, DatabaseFailure, OutOfMemory };
const TokenDigestBuf = [TokenHash.digest_length]u8; pub const TokenInfo = struct {
account_id: Uuid,
const DbToken = struct { issued_at: DateTime,
hash: []const u8,
account_id: Uuid,
issued_at: DateTime,
};
pub const CreateError = error{DatabaseFailure};
pub fn create(db: anytype, account_id: Uuid) CreateError!Token {
var token: [token_len]u8 = undefined;
std.crypto.random.bytes(&token);
var hash: TokenDigestBuf = undefined;
TokenHash.hash(&token, &hash, .{});
const issued_at = DateTime.now();
db.insert("token", DbToken{
.hash = &hash,
.account_id = account_id,
.issued_at = issued_at,
}) catch return error.DbError;
var token_enc: [token_str_len]u8 = undefined;
_ = std.base64.standard.Encoder.encode(&token_enc, &token);
return Token{ .value = token_enc, .info = .{
.account_id = account_id,
.issued_at = issued_at,
} };
}
pub const VerifyError = error{ InvalidToken, DatabaseError };
pub fn verify(
db: anytype,
token: []const u8,
community_id: Uuid,
alloc: std.mem.Allocator,
) VerifyError!Token.Info {
const decoded_len = std.base64.standard.Decoder.calcSizeForSlice(
token,
) catch return error.InvalidToken;
if (decoded_len != token_len) return error.InvalidToken;
var decoded: [token_len]u8 = undefined;
std.base64.standard.Decoder.decode(
&decoded,
token,
) catch return error.InvalidToken;
var hash: TokenDigestBuf = undefined;
TokenHash.hash(&decoded, &hash, .{});
return db.queryRow(
Token.Info,
\\SELECT account.id, token.issued_at
\\FROM token JOIN account ON token.account_id = account.id
\\WHERE token.hash = $1 AND account.community_id = $2
\\LIMIT 1
,
.{ hash, community_id },
alloc,
) catch |err| switch (err) {
error.NoRows => error.InvalidToken,
else => error.DatabaseFailure,
};
}
}; };
pub fn verifyToken(
db: anytype,
token: []const u8,
community_id: Uuid,
alloc: std.mem.Allocator,
) VerifyTokenError!TokenInfo {
const hash = try hashToken(token, alloc);
return db.queryRow(
TokenInfo,
\\SELECT token.account_id, token.issued_at
\\FROM token JOIN account
\\ ON token.account_id = account.id
\\WHERE token.hash = $1 AND account.community_id = $2
\\LIMIT 1
,
.{ hash, community_id },
alloc,
) catch |err| switch (err) {
error.NoRows => error.InvalidToken,
else => error.DatabaseFailure,
};
}
// We use scrypt, a password hashing algorithm that attempts to slow down
// GPU-based cracking approaches by using large amounts of memory, for
// password hashing.
// Attempting to calculate/verify a hash will use about 50mb of work space.
const scrypt = std.crypto.pwhash.scrypt;
const password_hash_len = 128;
fn verifyPassword(
hash: []const u8,
password: []const u8,
alloc: std.mem.Allocator,
) LoginError!void {
scrypt.strVerify(
hash,
password,
.{ .allocator = alloc },
) catch |err| return switch (err) {
error.PasswordVerificationFailed => error.InvalidLogin,
else => error.HashFailure,
};
}
fn hashPassword(password: []const u8, alloc: std.mem.Allocator) ![]const u8 {
const buf = try alloc.alloc(u8, password_hash_len);
errdefer alloc.free(buf);
return scrypt.strHash(
password,
.{
.allocator = alloc,
.params = scrypt.Params.interactive,
.encoding = .phc,
},
buf,
) catch error.HashFailure;
}
/// A raw token is a sequence of N random bytes, base64 encoded.
/// When the token is generated:
/// - The hash of the token is calculated by:
/// 1. Decoding the base64 text
/// 2. Calculating the SHA256 hash of this text
/// 3. Encoding the hash back as base64
/// - The b64 encoded hash is stored in the database
/// - The original token is returned to the user
/// * The user will treat it as opaque text
/// When the token is verified:
/// - The hash of the token is taken as shown above
/// - The database is scanned for a token matching this hash
/// - If none can be found, the token is invalid
const Sha256 = std.crypto.hash.sha2.Sha256;
const Base64Encoder = std.base64.standard.Encoder;
const Base64Decoder = std.base64.standard.Decoder;
const token_len = 12;
fn generateToken(alloc: std.mem.Allocator) ![]const u8 {
var token = std.mem.zeroes([token_len]u8);
std.crypto.random.bytes(&token);
const token_b64_len = Base64Encoder.calcSize(token.len);
const token_b64 = try alloc.alloc(u8, token_b64_len);
return Base64Encoder.encode(token_b64, &token);
}
fn hashToken(token_b64: []const u8, alloc: std.mem.Allocator) ![]const u8 {
const decoded_token_len = Base64Decoder.calcSizeForSlice(token_b64) catch return error.InvalidToken;
if (decoded_token_len != token_len) return error.InvalidToken;
var token = std.mem.zeroes([token_len]u8);
Base64Decoder.decode(&token, token_b64) catch return error.InvalidToken;
var hash = std.mem.zeroes([Sha256.digest_length]u8);
Sha256.hash(&token, &hash, .{});
const hash_b64_len = Base64Encoder.calcSize(hash.len);
const hash_b64 = try alloc.alloc(u8, hash_b64_len);
return Base64Encoder.encode(hash_b64, &hash);
}

View File

@ -25,7 +25,7 @@ pub const Kind = enum {
pub const Community = struct { pub const Community = struct {
id: Uuid, id: Uuid,
owner_id: Uuid, owner_id: ?Uuid,
host: []const u8, host: []const u8,
name: []const u8, name: []const u8,
@ -46,7 +46,7 @@ pub const CreateError = error{
CommunityExists, CommunityExists,
}; };
pub fn create(db: anytype, origin: []const u8, owner: Uuid, options: CreateOptions, alloc: std.mem.Allocator) CreateError!Uuid { pub fn create(db: anytype, origin: []const u8, options: CreateOptions, alloc: std.mem.Allocator) CreateError!Uuid {
const scheme_len = std.mem.indexOfScalar(u8, origin, ':') orelse return error.InvalidOrigin; const scheme_len = std.mem.indexOfScalar(u8, origin, ':') orelse return error.InvalidOrigin;
const scheme_str = origin[0..scheme_len]; const scheme_str = origin[0..scheme_len];
const scheme = std.meta.stringToEnum(Scheme, scheme_str) orelse return error.UnsupportedScheme; const scheme = std.meta.stringToEnum(Scheme, scheme_str) orelse return error.UnsupportedScheme;
@ -85,14 +85,14 @@ pub fn create(db: anytype, origin: []const u8, owner: Uuid, options: CreateOptio
else => return error.DatabaseFailure, else => return error.DatabaseFailure,
} }
try db.insert("community", .{ db.insert("community", .{
.id = id, .id = id,
.owner_id = owner, .owner_id = null,
.host = host, .host = host,
.name = options.name orelse host, .name = options.name orelse host,
.scheme = scheme, .scheme = scheme,
.kind = options.kind, .kind = options.kind,
}, alloc); }, alloc) catch return error.DatabaseFailure;
return id; return id;
} }
@ -101,18 +101,27 @@ pub const GetError = error{
NotFound, NotFound,
DatabaseFailure, DatabaseFailure,
}; };
pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) GetError!Community {
fn getWhere(
db: anytype,
comptime where: []const u8,
args: anytype,
alloc: std.mem.Allocator,
) GetError!Community {
return db.queryRow( return db.queryRow(
Community, Community,
std.fmt.comptimePrint( std.fmt.comptimePrint(
\\SELECT {s} \\SELECT {s}
\\FROM community \\FROM community
\\WHERE host = $1 \\WHERE {s}
\\LIMIT 1 \\LIMIT 1
, ,
.{comptime sql.fieldList(Community)}, .{
comptime util.comptimeJoin(",", std.meta.fieldNames(Community)),
where,
},
), ),
.{host}, args,
alloc, alloc,
) catch |err| switch (err) { ) catch |err| switch (err) {
error.NoRows => error.NotFound, error.NoRows => error.NotFound,
@ -120,6 +129,14 @@ pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) GetErr
}; };
} }
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Community {
return getWhere(db, "id = $1", .{id}, alloc);
}
pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) GetError!Community {
return getWhere(db, "host = $1", .{host}, alloc);
}
pub fn transferOwnership(db: anytype, community_id: Uuid, new_owner: Uuid) !void { pub fn transferOwnership(db: anytype, community_id: Uuid, new_owner: Uuid) !void {
// TODO: check that this actually found/updated the row (needs update to sql lib) // TODO: check that this actually found/updated the row (needs update to sql lib)
db.exec( db.exec(

View File

@ -13,7 +13,7 @@ const code_len = 12;
const Encoder = std.base64.url_safe.Encoder; const Encoder = std.base64.url_safe.Encoder;
const Decoder = std.base64.url_safe.Decoder; const Decoder = std.base64.url_safe.Decoder;
pub const InviteKind = enum { pub const Kind = enum {
system, system,
community_owner, community_owner,
user, user,
@ -26,7 +26,7 @@ pub const Invite = struct {
id: Uuid, id: Uuid,
created_by: Uuid, // User ID created_by: Uuid, // User ID
to_community: ?Uuid, community_id: ?Uuid,
name: []const u8, name: []const u8,
code: []const u8, code: []const u8,
@ -36,17 +36,17 @@ pub const Invite = struct {
expires_at: ?DateTime, expires_at: ?DateTime,
max_uses: ?InviteCount, max_uses: ?InviteCount,
invite_kind: InviteKind, kind: Kind,
}; };
pub const InviteOptions = struct { pub const InviteOptions = struct {
name: ?[]const u8 = null, name: ?[]const u8 = null,
max_uses: ?InviteCount = null, max_uses: ?InviteCount = null,
lifespan: ?DateTime.Duration = null, lifespan: ?DateTime.Duration = null,
invite_kind: InviteKind = .user, kind: Kind = .user,
}; };
pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: InviteOptions, alloc: std.mem.Allocator) !Uuid { pub fn create(db: anytype, created_by: Uuid, community_id: ?Uuid, options: InviteOptions, alloc: std.mem.Allocator) !Uuid {
const id = Uuid.randV4(getRandom()); const id = Uuid.randV4(getRandom());
var code_bytes: [rand_len]u8 = undefined; var code_bytes: [rand_len]u8 = undefined;
@ -65,7 +65,7 @@ pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: Invit
.id = id, .id = id,
.created_by = created_by, .created_by = created_by,
.to_community = to_community, .community_id = community_id,
.name = name, .name = name,
.code = code, .code = code,
@ -76,7 +76,7 @@ pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: Invit
else else
null, null,
.invite_kind = options.invite_kind, .kind = options.kind,
}, },
alloc, alloc,
); );
@ -97,17 +97,27 @@ fn doGetQuery(
alloc: std.mem.Allocator, alloc: std.mem.Allocator,
) GetError!Invite { ) GetError!Invite {
// Generate list of fields from struct // Generate list of fields from struct
const field_list = util.comptimeJoinWithPrefix( const field_list = comptime util.comptimeJoinWithPrefix(
",", ",",
"invite.", "invite.",
std.meta.fieldNames(Invite), &.{
"id",
"created_by",
"community_id",
"name",
"code",
"created_at",
"expires_at",
"max_uses",
"kind",
},
); );
// times_used field is not stored directly in the DB, instead // times_used field is not stored directly in the DB, instead
// it is calculated based on the number of accounts that were created // it is calculated based on the number of accounts that were created
// from it // from it
const query = std.fmt.comptimePrint( const query = std.fmt.comptimePrint(
\\SELECT {s}, COUNT(local_account.id) AS times_used \\SELECT {s}, COUNT(local_account.account_id) AS times_used
\\FROM invite LEFT OUTER JOIN local_account \\FROM invite LEFT OUTER JOIN local_account
\\ ON invite.id = local_account.invite_id \\ ON invite.id = local_account.invite_id
\\WHERE {s} \\WHERE {s}

View File

@ -8,10 +8,13 @@ const getRandom = @import("../api.zig").getRandom;
pub const CreateError = error{ pub const CreateError = error{
UsernameTaken, UsernameTaken,
DbError, UsernameContainsInvalidChar,
UsernameTooLong,
UsernameEmpty,
DatabaseFailure,
}; };
pub const Role = enum { pub const Kind = enum {
user, user,
admin, admin,
}; };
@ -19,7 +22,7 @@ pub const Role = enum {
pub const CreateOptions = struct { pub const CreateOptions = struct {
invite_id: ?Uuid = null, invite_id: ?Uuid = null,
email: ?[]const u8 = null, email: ?[]const u8 = null,
role: Role = .user, kind: Kind = .user,
}; };
pub const LookupError = error{ pub const LookupError = error{
@ -48,40 +51,49 @@ pub fn lookupByUsername(
return row[0]; return row[0];
} }
// TODO: This fn sucks. pub const max_username_chars = 32;
// auth.passwords.create requires that the user exists, but we shouldn't pub const UsernameValidationError = error{
// hold onto a transaction for the ~0.5s that it takes to hash the password. UsernameContainsInvalidChar,
// Should probably change this to be specifically about creating the user, UsernameTooLong,
// and then have something in auth responsible for creating local accounts UsernameEmpty,
};
/// Usernames must satisfy:
/// - Be at least 1 character
/// - Be no more than 32 characters
/// - All characters are in [A-Za-z0-9_.]
/// Note that the '.' character is not allowed in all usernames, and
/// is intended for use in federated instance actors (as many instances do)
pub fn validateUsername(username: []const u8) UsernameValidationError!void {
if (username.len == 0) return error.UsernameEmpty;
if (username.len > max_username_chars) return error.UsernameTooLong;
for (username) |ch| {
const valid = std.ascii.isAlNum(ch) or ch == '_';
if (!valid) return error.UsernameContainsInvalidChar;
}
}
pub fn create( pub fn create(
db: anytype, db: anytype,
username: []const u8, username: []const u8,
password: []const u8,
community_id: Uuid, community_id: Uuid,
options: CreateOptions, kind: Kind,
alloc: std.mem.Allocator, alloc: std.mem.Allocator,
) CreateError!Uuid { ) CreateError!Uuid {
const id = Uuid.randV4(getRandom()); const id = Uuid.randV4(getRandom());
const tx = db.begin();
errdefer tx.rollback();
tx.insert("account", .{ try validateUsername(username);
db.insert("account", .{
.id = id, .id = id,
.username = username, .username = username,
.community_id = community_id, .community_id = community_id,
.role = options.role, .kind = kind,
}, alloc) catch |err| return switch (err) { }, alloc) catch |err| return switch (err) {
error.UniqueViolation => error.UsernameTaken, error.UniqueViolation => error.UsernameTaken,
else => error.DatabaseFailure, else => error.DatabaseFailure,
}; };
try auth.passwords.create(tx, id, password, alloc);
tx.insert("local_account", .{
.user_id = id,
.invite_id = options.invite_id,
.email = options.email,
}) catch return error.DatabaseFailure;
try tx.commit();
return id; return id;
} }
@ -93,7 +105,7 @@ pub const User = struct {
host: []const u8, host: []const u8,
community_id: Uuid, community_id: Uuid,
role: Role, kind: Kind,
created_at: DateTime, created_at: DateTime,
}; };
@ -101,9 +113,16 @@ pub const User = struct {
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !User { pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !User {
return db.queryRow( return db.queryRow(
User, User,
\\SELECT user.username, community.host, community.id, user.created_at \\SELECT
\\FROM user JOIN community ON user.community_id = community.id \\ account.id,
\\WHERE user.id = $1 \\ account.username,
\\ community.host,
\\ account.community_id,
\\ account.kind,
\\ account.created_at
\\FROM account JOIN community
\\ ON account.community_id = community.id
\\WHERE account.id = $1
\\LIMIT 1 \\LIMIT 1
, ,
.{id}, .{id},

View File

@ -14,13 +14,13 @@ pub const login = struct {
pub const path = "/auth/login"; pub const path = "/auth/login";
pub const method = .POST; pub const method = .POST;
pub fn handler(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { pub fn handler(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void {
std.debug.print("{s}", .{ctx.request.body.?});
const credentials = try utils.parseRequestBody(struct { username: []const u8, password: []const u8 }, ctx); const credentials = try utils.parseRequestBody(struct { username: []const u8, password: []const u8 }, ctx);
defer utils.freeRequestBody(credentials, ctx.alloc); defer utils.freeRequestBody(credentials, ctx.alloc);
var api = try utils.getApiConn(srv, ctx); var api = try utils.getApiConn(srv, ctx);
defer api.close(); defer api.close();
std.log.debug("connected to api", .{});
const token = try api.login(credentials.username, credentials.password); const token = try api.login(credentials.username, credentials.password);
try utils.respondJson(ctx, .ok, token); try utils.respondJson(ctx, .ok, token);
@ -37,7 +37,7 @@ pub const verify_login = struct {
// The self-hosted compiler doesn't like inferring this error set. // The self-hosted compiler doesn't like inferring this error set.
// do this for now // do this for now
const info = api.getTokenInfo() catch unreachable; const info = try api.verifyAuthorization();
try utils.respondJson(ctx, .ok, info); try utils.respondJson(ctx, .ok, info);
} }

View File

@ -5,7 +5,7 @@ const http = @import("http");
const util = @import("util"); const util = @import("util");
pub const api = @import("./api.zig"); pub const api = @import("./api.zig");
const migrations = @import("./migrations.zig"); pub const migrations = @import("./migrations.zig");
const Uuid = util.Uuid; const Uuid = util.Uuid;
const c = @import("./controllers.zig"); const c = @import("./controllers.zig");
@ -24,10 +24,10 @@ const router = Router{
prepare(c.invites.create), prepare(c.invites.create),
prepare(c.users.create), //prepare(c.users.create),
prepare(c.notes.create), //prepare(c.notes.create),
prepare(c.notes.get), //prepare(c.notes.get),
//Route.new(.GET, "/notes/:id/reacts", &c.notes.reacts.list), //Route.new(.GET, "/notes/:id/reacts", &c.notes.reacts.list),
//Route.new(.POST, "/notes/:id/reacts", &c.notes.reacts.create), //Route.new(.POST, "/notes/:id/reacts", &c.notes.reacts.create),
@ -82,9 +82,7 @@ pub const RequestServer = struct {
}; };
pub const Config = struct { pub const Config = struct {
cluster_host: []const u8,
db: sql.Config, db: sql.Config,
root_password: ?[]const u8 = null,
}; };
fn loadConfig(alloc: std.mem.Allocator) !Config { fn loadConfig(alloc: std.mem.Allocator) !Config {
@ -112,6 +110,7 @@ fn runAdminSetup(db: *sql.Db, alloc: std.mem.Allocator) !void {
fn prepareDb(db: *sql.Db, alloc: std.mem.Allocator) !void { fn prepareDb(db: *sql.Db, alloc: std.mem.Allocator) !void {
try migrations.up(db); try migrations.up(db);
api.initThreadPrng(@bitCast(u64, std.time.milliTimestamp()));
if (!try api.isAdminSetup(db)) { if (!try api.isAdminSetup(db)) {
std.log.info("Performing first-time admin creation...", .{}); std.log.info("Performing first-time admin creation...", .{});
@ -134,14 +133,15 @@ fn prepareDb(db: *sql.Db, alloc: std.mem.Allocator) !void {
} }
} }
pub fn main() anyerror!void { pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}){}; var gpa = std.heap.GeneralPurposeAllocator(.{}){};
var cfg = try loadConfig(gpa.allocator()); var cfg = try loadConfig(gpa.allocator());
var db_conn = try sql.Db.open(cfg.db); var db_conn = try sql.Db.open(cfg.db);
try prepareDb(&db_conn, gpa.allocator()); try prepareDb(&db_conn, gpa.allocator());
//try migrations.up(&db_conn);
//try api.setupAdmin(&db_conn, "http://localhost:8080", "root", "password", gpa.allocator());
var api_src = try api.ApiSource.init(gpa.allocator(), cfg, &db_conn); var api_src = try api.ApiSource.init(gpa.allocator(), cfg, &db_conn);
var srv = try RequestServer.init(gpa.allocator(), &api_src, cfg); var srv = try RequestServer.init(gpa.allocator(), &api_src, cfg);
api.initThreadPrng(@bitCast(u64, std.time.milliTimestamp()));
return srv.listenAndRun(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); return srv.listenAndRun(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable);
} }

View File

@ -38,8 +38,15 @@ fn execScript(db: *sql.Db, script: []const u8, alloc: std.mem.Allocator) !void {
} }
fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !bool { fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !bool {
const row = (try db.queryRow(std.meta.Tuple(&.{i32}), "SELECT COUNT(*) FROM migration WHERE name = $1", .{name}, alloc)) orelse return false; return if (db.queryRow(
return row[0] != 0; std.meta.Tuple(&.{i32}),
"SELECT COUNT(*) FROM migration WHERE name = $1 LIMIT 1",
.{name},
alloc,
)) |row| row[0] != 0 else |err| switch (err) {
error.NoRows => false,
else => error.DatabaseFailure,
};
} }
pub fn up(db: *sql.Db) !void { pub fn up(db: *sql.Db) !void {
@ -70,41 +77,42 @@ const create_migration_table =
// migrations into a single one. this will require db recreation // migrations into a single one. this will require db recreation
const migrations: []const Migration = &.{ const migrations: []const Migration = &.{
.{ .{
.name = "users", .name = "accounts",
.up = .up =
\\CREATE TABLE user( \\CREATE TABLE account(
\\ id TEXT NOT NULL PRIMARY KEY, \\ id UUID NOT NULL PRIMARY KEY,
\\ username TEXT NOT NULL, \\ username TEXT NOT NULL,
\\ \\
\\ kind TEXT NOT NULL CHECK (kind IN ('admin', 'user')),
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\); \\);
\\ \\
\\CREATE TABLE local_user( \\CREATE TABLE local_account(
\\ user_id TEXT NOT NULL PRIMARY KEY REFERENCES user(id), \\ account_id UUID NOT NULL PRIMARY KEY REFERENCES account(id),
\\ \\
\\ email TEXT \\ email TEXT
\\); \\);
\\ \\
\\CREATE TABLE account_password( \\CREATE TABLE password(
\\ user_id TEXT NOT NULL PRIMARY KEY REFERENCES user(id), \\ account_id UUID NOT NULL PRIMARY KEY REFERENCES account(id),
\\ \\
\\ hashed_password BLOB NOT NULL \\ hash BLOB NOT NULL
\\); \\);
, ,
.down = .down =
\\DROP TABLE account_password; \\DROP TABLE password;
\\DROP TABLE local_user; \\DROP TABLE local_account;
\\DROP TABLE user; \\DROP TABLE account;
, ,
}, },
.{ .{
.name = "notes", .name = "notes",
.up = .up =
\\CREATE TABLE note( \\CREATE TABLE note(
\\ id TEXT NOT NULL, \\ id UUID NOT NULL,
\\ \\
\\ content TEXT NOT NULL, \\ content TEXT NOT NULL,
\\ author_id TEXT NOT NULL REFERENCES user(id), \\ author_id UUID NOT NULL REFERENCES account(id),
\\ \\
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\); \\);
@ -115,10 +123,10 @@ const migrations: []const Migration = &.{
.name = "note reactions", .name = "note reactions",
.up = .up =
\\CREATE TABLE reaction( \\CREATE TABLE reaction(
\\ id TEXT NOT NULL PRIMARY KEY, \\ id UUID NOT NULL PRIMARY KEY,
\\ \\
\\ user_id TEXT NOT NULL REFERENCES user(id), \\ account_id UUID NOT NULL REFERENCES account(id),
\\ note_id TEXT NOT NULL REFERENCES note(id), \\ note_id UUID NOT NULL REFERENCES note(id),
\\ \\
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\); \\);
@ -126,11 +134,11 @@ const migrations: []const Migration = &.{
.down = "DROP TABLE reaction;", .down = "DROP TABLE reaction;",
}, },
.{ .{
.name = "user tokens", .name = "account tokens",
.up = .up =
\\CREATE TABLE token( \\CREATE TABLE token(
\\ hash TEXT NOT NULL PRIMARY KEY, \\ hash TEXT NOT NULL PRIMARY KEY,
\\ user_id TEXT NOT NULL REFERENCES local_user(id), \\ account_id UUID NOT NULL REFERENCES local_account(id),
\\ \\
\\ issued_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\ issued_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\); \\);
@ -138,26 +146,26 @@ const migrations: []const Migration = &.{
.down = "DROP TABLE token;", .down = "DROP TABLE token;",
}, },
.{ .{
.name = "user invites", .name = "account invites",
.up = .up =
\\CREATE TABLE invite( \\CREATE TABLE invite(
\\ id TEXT NOT NULL PRIMARY KEY, \\ id UUID NOT NULL PRIMARY KEY,
\\ \\
\\ name TEXT NOT NULL, \\ name TEXT NOT NULL,
\\ code TEXT NOT NULL UNIQUE, \\ code TEXT NOT NULL UNIQUE,
\\ created_by TEXT NOT NULL REFERENCES local_user(id), \\ created_by UUID NOT NULL REFERENCES local_account(id),
\\ \\
\\ max_uses INTEGER, \\ max_uses INTEGER,
\\ \\
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
\\ expires_at TIMESTAMPTZ, \\ expires_at TIMESTAMPTZ,
\\ \\
\\ type TEXT NOT NULL CHECK (type in ('system_user', 'community_owner', 'user')) \\ kind TEXT NOT NULL CHECK (kind in ('system_user', 'community_owner', 'user'))
\\); \\);
\\ALTER TABLE local_user ADD COLUMN invite_id TEXT REFERENCES invite(id); \\ALTER TABLE local_account ADD COLUMN invite_id UUID REFERENCES invite(id);
, ,
.down = .down =
\\ALTER TABLE local_user DROP COLUMN invite_id; \\ALTER TABLE local_account DROP COLUMN invite_id;
\\DROP TABLE invite; \\DROP TABLE invite;
, ,
}, },
@ -165,9 +173,9 @@ const migrations: []const Migration = &.{
.name = "communities", .name = "communities",
.up = .up =
\\CREATE TABLE community( \\CREATE TABLE community(
\\ id TEXT NOT NULL PRIMARY KEY, \\ id UUID NOT NULL PRIMARY KEY,
\\ \\
\\ owner_id TEXT REFERENCES user(id), \\ owner_id UUID REFERENCES account(id),
\\ name TEXT NOT NULL, \\ name TEXT NOT NULL,
\\ host TEXT NOT NULL UNIQUE, \\ host TEXT NOT NULL UNIQUE,
\\ scheme TEXT NOT NULL CHECK (scheme IN ('http', 'https')), \\ scheme TEXT NOT NULL CHECK (scheme IN ('http', 'https')),
@ -175,12 +183,12 @@ const migrations: []const Migration = &.{
\\ \\
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
\\); \\);
\\ALTER TABLE user ADD COLUMN community_id TEXT REFERENCES community(id); \\ALTER TABLE account ADD COLUMN community_id UUID REFERENCES community(id);
\\ALTER TABLE invite ADD COLUMN community_id TEXT REFERENCES community(id); \\ALTER TABLE invite ADD COLUMN community_id UUID REFERENCES community(id);
, ,
.down = .down =
\\ALTER TABLE invite DROP COLUMN community_id; \\ALTER TABLE invite DROP COLUMN community_id;
\\ALTER TABLE user DROP COLUMN community_id; \\ALTER TABLE account DROP COLUMN community_id;
\\DROP TABLE community; \\DROP TABLE community;
, ,
}, },

View File

@ -20,17 +20,17 @@ pub const OpenError = error{BadConnection} || UnexpectedError;
pub const ExecError = error{ pub const ExecError = error{
Cancelled, Cancelled,
ConnectionLost, BadConnection,
InternalException, InternalException,
DatabaseBusy, DatabaseBusy,
PermissionDenied, PermissionDenied,
SqlException, SqlException,
/// Argument could not be marshalled for query /// Argument could not be marshalled for query
InvalidArgument, BindException,
/// An argument was not used by the query (not checked in all DB engines) /// An argument was not used by the query (not checked in all DB engines)
UndefinedParameter, UnusedArgument,
/// Memory error when marshalling argument for query /// Memory error when marshalling argument for query
OutOfMemory, OutOfMemory,
@ -39,7 +39,7 @@ pub const ExecError = error{
pub const RowError = error{ pub const RowError = error{
Cancelled, Cancelled,
ConnectionLost, BadConnection,
InternalException, InternalException,
DatabaseBusy, DatabaseBusy,
PermissionDenied, PermissionDenied,
@ -49,7 +49,7 @@ pub const RowError = error{
pub const GetError = error{ pub const GetError = error{
OutOfMemory, OutOfMemory,
AllocatorRequired, AllocatorRequired,
TypeMismatch, ResultTypeMismatch,
} || UnexpectedError; } || UnexpectedError;
pub const ColumnCountError = error{OutOfRange}; pub const ColumnCountError = error{OutOfRange};
@ -97,13 +97,25 @@ pub fn prepareParamText(arena: *std.heap.ArenaAllocator, val: anytype) !?[:0]con
// Parse a (not-null) value from a string // Parse a (not-null) value from a string
pub fn parseValueNotNull(alloc: ?Allocator, comptime T: type, str: []const u8) !T { pub fn parseValueNotNull(alloc: ?Allocator, comptime T: type, str: []const u8) !T {
return switch (T) { return switch (T) {
Uuid => Uuid.parse(str), Uuid => Uuid.parse(str) catch |err| {
DateTime => DateTime.parse(str), std.log.err("Error {} parsing UUID: '{s}'", .{ err, str });
[]u8, []const u8 => if (alloc) |a| util.deepClone(a, str) else return error.AllocatorRequired, return error.ResultTypeMismatch;
},
DateTime => DateTime.parse(str) catch |err| {
std.log.err("Error {} parsing DateTime: '{s}'", .{ err, str });
return error.ResultTypeMismatch;
},
[]u8, []const u8 => if (alloc) |a| try util.deepClone(a, str) else return error.AllocatorRequired,
else => switch (@typeInfo(T)) { else => switch (@typeInfo(T)) {
.Int => std.fmt.parseInt(T, str, 0), .Int => std.fmt.parseInt(T, str, 0) catch |err| {
.Enum => std.meta.stringToEnum(T, str) orelse return error.InvalidValue, std.log.err("Could not parse int: {}", .{err});
return error.ResultTypeMismatch;
},
.Enum => std.meta.stringToEnum(T, str) orelse {
std.log.err("'{s}' is not a member of enum type {s}", .{ str, @typeName(T) });
return error.ResultTypeMismatch;
},
.Optional => try parseValueNotNull(alloc, std.meta.Child(T), str), .Optional => try parseValueNotNull(alloc, std.meta.Child(T), str),
else => @compileError("Type " ++ @typeName(T) ++ " not supported"), else => @compileError("Type " ++ @typeName(T) ++ " not supported"),

38
src/sql/engines/null.zig Normal file
View File

@ -0,0 +1,38 @@
const std = @import("std");
const common = @import("./common.zig");
const Allocator = std.mem.Allocator;
pub const Results = struct {
pub fn row(_: *Results) common.RowError!?Row {
unreachable;
}
pub fn columnCount(_: Results) common.ColumnCountError!u15 {
unreachable;
}
pub fn columnIndex(_: Results, _: []const u8) common.ColumnIndexError!u15 {
unreachable;
}
pub fn finish(_: Results) void {
unreachable;
}
};
pub const Row = struct {
pub fn get(_: Row, comptime T: type, _: u15, _: ?Allocator) common.GetError!T {
unreachable;
}
};
pub const Db = struct {
pub fn open(_: anytype) common.OpenError!Db {
unreachable;
}
pub fn close(_: Db) void {
unreachable;
}
pub fn exec(_: Db, _: [:0]const u8, _: anytype, _: common.QueryOptions) common.ExecError!Results {
unreachable;
}
};

View File

@ -21,6 +21,10 @@ pub const Results = struct {
}; };
} }
fn rowCount(self: Results) c_int {
return c.PQntuples(self.result);
}
pub fn columnCount(self: Results) common.ColumnCountError!u15 { pub fn columnCount(self: Results) common.ColumnCountError!u15 {
return std.math.cast(u15, c.PQnfields(self.result)) orelse error.OutOfRange; return std.math.cast(u15, c.PQnfields(self.result)) orelse error.OutOfRange;
} }
@ -36,7 +40,7 @@ pub const Results = struct {
} }
}; };
fn handleError(result: *c.PQresult) common.RowError { fn handleError(result: *c.PGresult) common.RowError {
const error_code = c.PQresultErrorField(result, c.PG_DIAG_SQLSTATE); const error_code = c.PQresultErrorField(result, c.PG_DIAG_SQLSTATE);
const state = errors.SqlState.parse(error_code) catch { const state = errors.SqlState.parse(error_code) catch {
std.log.err("Database returned invalid error code {?s}", .{error_code}); std.log.err("Database returned invalid error code {?s}", .{error_code});
@ -126,7 +130,7 @@ pub const Row = struct {
const val = c.PQgetvalue(self.result, self.row_index, idx); const val = c.PQgetvalue(self.result, self.row_index, idx);
const is_null = (c.PQgetisnull(self.result, self.row_index, idx) != 0); const is_null = (c.PQgetisnull(self.result, self.row_index, idx) != 0);
if (is_null) { if (is_null) {
return if (@typeInfo(T) == .Optional) null else error.TypeMismatch; return if (@typeInfo(T) == .Optional) null else error.ResultTypeMismatch;
} }
if (val == null) return error.Unexpected; if (val == null) return error.Unexpected;
@ -175,15 +179,30 @@ pub const Db = struct {
const format_text = 0; const format_text = 0;
const format_binary = 1; const format_binary = 1;
pub fn exec(self: Db, sql: [:0]const u8, args: anytype, alloc: ?Allocator) !Results { pub fn exec(self: Db, sql: [:0]const u8, args: anytype, opt: common.QueryOptions) common.ExecError!Results {
const alloc = opt.prep_allocator;
const result = blk: { const result = blk: {
if (comptime args.len > 0) { if (@TypeOf(args) != void and args.len > 0) {
var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired); var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired);
defer arena.deinit(); defer arena.deinit();
const params = try arena.allocator().alloc(?[*:0]const u8, args.len); const params = try arena.allocator().alloc(?[*:0]const u8, args.len);
inline for (args) |arg, i| params[i] = if (try common.prepareParamText(&arena, arg)) |slice| slice.ptr else null; inline for (args) |arg, i| {
params[i] = if (try common.prepareParamText(&arena, arg)) |slice|
slice.ptr
else
null;
}
break :blk c.PQexecParams(self.conn, sql.ptr, @intCast(c_int, params.len), null, params.ptr, null, null, format_text); break :blk c.PQexecParams(
self.conn,
sql.ptr,
@intCast(c_int, params.len),
null,
params.ptr,
null,
null,
format_text,
);
} else { } else {
break :blk c.PQexecParams(self.conn, sql.ptr, 0, null, null, null, null, format_text); break :blk c.PQexecParams(self.conn, sql.ptr, 0, null, null, null, null, format_text);
} }
@ -199,11 +218,11 @@ pub const Db = struct {
c.PGRES_TUPLES_OK, c.PGRES_TUPLES_OK,
=> return Results{ .result = result }, => return Results{ .result = result },
c.PGRES_EMPTY_QUERY => return error.InvalidSql, c.PGRES_EMPTY_QUERY => return error.SqlException,
c.PGRES_BAD_RESPONSE => { c.PGRES_BAD_RESPONSE => {
std.log.err("Database returned invalid response: {?s}", .{c.PQresultErrorMessage(result)}); std.log.err("Database returned invalid response: {?s}", .{c.PQresultErrorMessage(result)});
return error.Database; return error.InternalException;
}, },
c.PGRES_FATAL_ERROR => return handleError(result), c.PGRES_FATAL_ERROR => return handleError(result),

View File

@ -1,15 +1,20 @@
const std = @import("std"); const std = @import("std");
const c = @import("./c.zig"); const c = @import("./c.zig");
const readIntBig = std.mem.readIntBig; const readVarInt = std.mem.readVarInt;
const Code = u40; // 8 * 5 = 40 const Code = u40; // 8 * 5 = 40
const code_class_mask: Code = 0xFFFF_000000;
pub const SqlStateClass = blk: { pub const SqlStateClass = blk: {
@setEvalBranchQuota(10000);
const info = @typeInfo(SqlState).Enum; const info = @typeInfo(SqlState).Enum;
var fields = &[0]std.builtin.Type.EnumField{}; const EnumField = std.builtin.Type.EnumField;
var fields: []const EnumField = &.{};
for (info.fields) |field| { for (info.fields) |field| {
if (field.value & ~code_class_mask == 0) fields = fields ++ &.{field}; const class_code = toClassCode(field.value);
if (class_code == field.value) {
fields = fields ++ &[_]EnumField{field};
}
} }
break :blk @Type(.{ .Enum = .{ break :blk @Type(.{ .Enum = .{
.layout = info.layout, .layout = info.layout,
@ -20,322 +25,352 @@ pub const SqlStateClass = blk: {
} }); } });
}; };
fn toCodeStr(code: Code) [5]u8 {
var code_str: [5]u8 = undefined;
code_str[0] = @intCast(u8, (code & 0xFF00_000000) >> 32);
code_str[1] = @intCast(u8, (code & 0x00FF_000000) >> 24);
code_str[2] = @intCast(u8, (code & 0x0000_FF0000) >> 16);
code_str[3] = @intCast(u8, (code & 0x0000_00FF00) >> 8);
code_str[4] = @intCast(u8, (code & 0x0000_0000FF) >> 0);
return code_str;
}
fn toClassCode(code: Code) Code {
var code_str = [_]u8{'0'} ** 5;
code_str[0] = @intCast(u8, (code & 0xFF00_000000) >> 32);
code_str[1] = @intCast(u8, (code & 0x00FF_000000) >> 24);
code_str[2] = '0';
code_str[3] = '0';
code_str[4] = '0';
return readVarInt(Code, &code_str, .Big);
}
// SqlState values for Postgres 14 // SqlState values for Postgres 14
pub const SqlState = enum(Code) { pub const SqlState = blk: {
pub const ParseError = error{ InvalidSize, NullPointer }; @setEvalBranchQuota(10000);
pub fn parse(code_str: [*c]const u8) ParseError!SqlState { break :blk enum(Code) {
if (code_str == null) return error.NullPointer; pub const ParseError = error{ InvalidSize, NullPointer };
const slice = std.mem.span(code_str); pub fn parse(code_str: [*c]const u8) ParseError!SqlState {
if (slice.len != @sizeOf(Code)) return error.InvalidSize; if (code_str == null) return error.NullPointer;
return @intToEnum(SqlState, std.mem.readIntSliceBig(Code, slice)); const slice = std.mem.span(code_str);
} if (slice.len != @sizeOf(Code)) return error.InvalidSize;
return @intToEnum(SqlState, readVarInt(Code, slice, .Big));
}
pub fn errorClass(code: SqlState) SqlStateClass { pub fn errorClass(code: SqlState) SqlStateClass {
return @intToEnum(SqlStateClass, @enumToInt(code) & code_class_mask); return @intToEnum(SqlStateClass, toClassCode(@enumToInt(code)));
} }
// Class 00 Successful Completion pub fn errorCodeStr(code: SqlState) [5]u8 {
successful_completion = readIntBig(Code, "00000"), return toCodeStr(@enumToInt(code));
// Class 01 Warning }
warning = readIntBig(Code, "01000"),
dynamic_result_sets_returned = readIntBig(Code, "0100C"),
implicit_zero_bit_padding = readIntBig(Code, "01008"),
null_value_eliminated_in_set_function = readIntBig(Code, "01003"),
privilege_not_granted = readIntBig(Code, "01007"),
privilege_not_revoked = readIntBig(Code, "01006"),
string_data_right_truncation = readIntBig(Code, "01004"),
deprecated_feature = readIntBig(Code, "01P01"),
// Class 02 No Data (this is also a warning class per the SQL standard)
no_data = readIntBig(Code, "02000"),
no_additional_dynamic_result_sets_returned = readIntBig(Code, "02001"),
// Class 03 SQL Statement Not Yet Complete
sql_statement_not_yet_complete = readIntBig(Code, "03000"),
// Class 08 Connection Exception
connection_exception = readIntBig(Code, "08000"),
connection_does_not_exist = readIntBig(Code, "08003"),
connection_failure = readIntBig(Code, "08006"),
sqlclient_unable_to_establish_sqlconnection = readIntBig(Code, "08001"),
sqlserver_rejected_establishment_of_sqlconnection = readIntBig(Code, "08004"),
transaction_resolution_unknown = readIntBig(Code, "08007"),
protocol_violation = readIntBig(Code, "08P01"),
// Class 09 Triggered Action Exception
triggered_action_exception = readIntBig(Code, "09000"),
// Class 0A Feature Not Supported
feature_not_supported = readIntBig(Code, "0A000"),
// Class 0B Invalid Transaction Initiation
invalid_transaction_initiation = readIntBig(Code, "0B000"),
// Class 0F Locator Exception
locator_exception = readIntBig(Code, "0F000"),
invalid_locator_specification = readIntBig(Code, "0F001"),
// Class 0L Invalid Grantor
invalid_grantor = readIntBig(Code, "0L000"),
invalid_grant_operation = readIntBig(Code, "0LP01"),
// Class 0P Invalid Role Specification
invalid_role_specification = readIntBig(Code, "0P000"),
// Class 0Z Diagnostics Exception
diagnostics_exception = readIntBig(Code, "0Z000"),
stacked_diagnostics_accessed_without_active_handler = readIntBig(Code, "0Z002"),
// Class 20 Case Not Found
case_not_found = readIntBig(Code, "20000"),
// Class 21 Cardinality Violation
cardinality_violation = readIntBig(Code, "21000"),
// Class 22 Data Exception
data_exception = readIntBig(Code, "22000"),
array_subscript_error = readIntBig(Code, "2202E"),
character_not_in_repertoire = readIntBig(Code, "22021"),
datetime_field_overflow = readIntBig(Code, "22008"),
division_by_zero = readIntBig(Code, "22012"),
error_in_assignment = readIntBig(Code, "22005"),
escape_character_conflict = readIntBig(Code, "2200B"),
indicator_overflow = readIntBig(Code, "22022"),
interval_field_overflow = readIntBig(Code, "22015"),
invalid_argument_for_logarithm = readIntBig(Code, "2201E"),
invalid_argument_for_ntile_function = readIntBig(Code, "22014"),
invalid_argument_for_nth_value_function = readIntBig(Code, "22016"),
invalid_argument_for_power_function = readIntBig(Code, "2201F"),
invalid_argument_for_width_bucket_function = readIntBig(Code, "2201G"),
invalid_character_value_for_cast = readIntBig(Code, "22018"),
invalid_datetime_format = readIntBig(Code, "22007"),
invalid_escape_character = readIntBig(Code, "22019"),
invalid_escape_octet = readIntBig(Code, "2200D"),
invalid_escape_sequence = readIntBig(Code, "22025"),
nonstandard_use_of_escape_character = readIntBig(Code, "22P06"),
invalid_indicator_parameter_value = readIntBig(Code, "22010"),
invalid_parameter_value = readIntBig(Code, "22023"),
invalid_preceding_or_following_size = readIntBig(Code, "22013"),
invalid_regular_expression = readIntBig(Code, "2201B"),
invalid_row_count_in_limit_clause = readIntBig(Code, "2201W"),
invalid_row_count_in_result_offset_clause = readIntBig(Code, "2201X"),
invalid_tablesample_argument = readIntBig(Code, "2202H"),
invalid_tablesample_repeat = readIntBig(Code, "2202G"),
invalid_time_zone_displacement_value = readIntBig(Code, "22009"),
invalid_use_of_escape_character = readIntBig(Code, "2200C"),
most_specific_type_mismatch = readIntBig(Code, "2200G"),
null_value_not_allowed = readIntBig(Code, "22004"),
null_value_no_indicator_parameter = readIntBig(Code, "22002"),
numeric_value_out_of_range = readIntBig(Code, "22003"),
sequence_generator_limit_exceeded = readIntBig(Code, "2200H"),
string_data_length_mismatch = readIntBig(Code, "22026"),
string_data_right_truncation = readIntBig(Code, "22001"),
substring_error = readIntBig(Code, "22011"),
trim_error = readIntBig(Code, "22027"),
unterminated_c_string = readIntBig(Code, "22024"),
zero_length_character_string = readIntBig(Code, "2200F"),
floating_point_exception = readIntBig(Code, "22P01"),
invalid_text_representation = readIntBig(Code, "22P02"),
invalid_binary_representation = readIntBig(Code, "22P03"),
bad_copy_file_format = readIntBig(Code, "22P04"),
untranslatable_character = readIntBig(Code, "22P05"),
not_an_xml_document = readIntBig(Code, "2200L"),
invalid_xml_document = readIntBig(Code, "2200M"),
invalid_xml_content = readIntBig(Code, "2200N"),
invalid_xml_comment = readIntBig(Code, "2200S"),
invalid_xml_processing_instruction = readIntBig(Code, "2200T"),
duplicate_json_object_key_value = readIntBig(Code, "22030"),
invalid_argument_for_sql_json_datetime_function = readIntBig(Code, "22031"),
invalid_json_text = readIntBig(Code, "22032"),
invalid_sql_json_subscript = readIntBig(Code, "22033"),
more_than_one_sql_json_item = readIntBig(Code, "22034"),
no_sql_json_item = readIntBig(Code, "22035"),
non_numeric_sql_json_item = readIntBig(Code, "22036"),
non_unique_keys_in_a_json_object = readIntBig(Code, "22037"),
singleton_sql_json_item_required = readIntBig(Code, "22038"),
sql_json_array_not_found = readIntBig(Code, "22039"),
sql_json_member_not_found = readIntBig(Code, "2203A"),
sql_json_number_not_found = readIntBig(Code, "2203B"),
sql_json_object_not_found = readIntBig(Code, "2203C"),
too_many_json_array_elements = readIntBig(Code, "2203D"),
too_many_json_object_members = readIntBig(Code, "2203E"),
sql_json_scalar_required = readIntBig(Code, "2203F"),
// Class 23 Integrity Constraint Violation
integrity_constraint_violation = readIntBig(Code, "23000"),
restrict_violation = readIntBig(Code, "23001"),
not_null_violation = readIntBig(Code, "23502"),
foreign_key_violation = readIntBig(Code, "23503"),
unique_violation = readIntBig(Code, "23505"),
check_violation = readIntBig(Code, "23514"),
exclusion_violation = readIntBig(Code, "23P01"),
// Class 24 Invalid Cursor State
invalid_cursor_state = readIntBig(Code, "24000"),
// Class 25 Invalid Transaction State
invalid_transaction_state = readIntBig(Code, "25000"),
active_sql_transaction = readIntBig(Code, "25001"),
branch_transaction_already_active = readIntBig(Code, "25002"),
held_cursor_requires_same_isolation_level = readIntBig(Code, "25008"),
inappropriate_access_mode_for_branch_transaction = readIntBig(Code, "25003"),
inappropriate_isolation_level_for_branch_transaction = readIntBig(Code, "25004"),
no_active_sql_transaction_for_branch_transaction = readIntBig(Code, "25005"),
read_only_sql_transaction = readIntBig(Code, "25006"),
schema_and_data_statement_mixing_not_supported = readIntBig(Code, "25007"),
no_active_sql_transaction = readIntBig(Code, "25P01"),
in_failed_sql_transaction = readIntBig(Code, "25P02"),
idle_in_transaction_session_timeout = readIntBig(Code, "25P03"),
// Class 26 Invalid SQL Statement Name
invalid_sql_statement_name = readIntBig(Code, "26000"),
// Class 27 Triggered Data Change Violation
triggered_data_change_violation = readIntBig(Code, "27000"),
// Class 28 Invalid Authorization Specification
invalid_authorization_specification = readIntBig(Code, "28000"),
invalid_password = readIntBig(Code, "28P01"),
// Class 2B Dependent Privilege Descriptors Still Exist
dependent_privilege_descriptors_still_exist = readIntBig(Code, "2B000"),
dependent_objects_still_exist = readIntBig(Code, "2BP01"),
// Class 2D Invalid Transaction Termination
invalid_transaction_termination = readIntBig(Code, "2D000"),
// Class 2F SQL Routine Exception
sql_routine_exception = readIntBig(Code, "2F000"),
function_executed_no_return_statement = readIntBig(Code, "2F005"),
modifying_sql_data_not_permitted = readIntBig(Code, "2F002"),
prohibited_sql_statement_attempted = readIntBig(Code, "2F003"),
reading_sql_data_not_permitted = readIntBig(Code, "2F004"),
// Class 34 Invalid Cursor Name
invalid_cursor_name = readIntBig(Code, "34000"),
// Class 38 External Routine Exception
external_routine_exception = readIntBig(Code, "38000"),
containing_sql_not_permitted = readIntBig(Code, "38001"),
modifying_sql_data_not_permitted = readIntBig(Code, "38002"),
prohibited_sql_statement_attempted = readIntBig(Code, "38003"),
reading_sql_data_not_permitted = readIntBig(Code, "38004"),
// Class 39 External Routine Invocation Exception
external_routine_invocation_exception = readIntBig(Code, "39000"),
invalid_sqlstate_returned = readIntBig(Code, "39001"),
null_value_not_allowed = readIntBig(Code, "39004"),
trigger_protocol_violated = readIntBig(Code, "39P01"),
srf_protocol_violated = readIntBig(Code, "39P02"),
event_trigger_protocol_violated = readIntBig(Code, "39P03"),
// Class 3B Savepoint Exception
savepoint_exception = readIntBig(Code, "3B000"),
invalid_savepoint_specification = readIntBig(Code, "3B001"),
// Class 3D Invalid Catalog Name
invalid_catalog_name = readIntBig(Code, "3D000"),
// Class 3F Invalid Schema Name
invalid_schema_name = readIntBig(Code, "3F000"),
// Class 40 Transaction Rollback
transaction_rollback = readIntBig(Code, "40000"),
transaction_integrity_constraint_violation = readIntBig(Code, "40002"),
serialization_failure = readIntBig(Code, "40001"),
statement_completion_unknown = readIntBig(Code, "40003"),
deadlock_detected = readIntBig(Code, "40P01"),
// Class 42 Syntax Error or Access Rule Violation
syntax_error_or_access_rule_violation = readIntBig(Code, "42000"),
syntax_error = readIntBig(Code, "42601"),
insufficient_privilege = readIntBig(Code, "42501"),
cannot_coerce = readIntBig(Code, "42846"),
grouping_error = readIntBig(Code, "42803"),
windowing_error = readIntBig(Code, "42P20"),
invalid_recursion = readIntBig(Code, "42P19"),
invalid_foreign_key = readIntBig(Code, "42830"),
invalid_name = readIntBig(Code, "42602"),
name_too_long = readIntBig(Code, "42622"),
reserved_name = readIntBig(Code, "42939"),
datatype_mismatch = readIntBig(Code, "42804"),
indeterminate_datatype = readIntBig(Code, "42P18"),
collation_mismatch = readIntBig(Code, "42P21"),
indeterminate_collation = readIntBig(Code, "42P22"),
wrong_object_type = readIntBig(Code, "42809"),
generated_always = readIntBig(Code, "428C9"),
undefined_column = readIntBig(Code, "42703"),
undefined_function = readIntBig(Code, "42883"),
undefined_table = readIntBig(Code, "42P01"),
undefined_parameter = readIntBig(Code, "42P02"),
undefined_object = readIntBig(Code, "42704"),
duplicate_column = readIntBig(Code, "42701"),
duplicate_cursor = readIntBig(Code, "42P03"),
duplicate_database = readIntBig(Code, "42P04"),
duplicate_function = readIntBig(Code, "42723"),
duplicate_prepared_statement = readIntBig(Code, "42P05"),
duplicate_schema = readIntBig(Code, "42P06"),
duplicate_table = readIntBig(Code, "42P07"),
duplicate_alias = readIntBig(Code, "42712"),
duplicate_object = readIntBig(Code, "42710"),
ambiguous_column = readIntBig(Code, "42702"),
ambiguous_function = readIntBig(Code, "42725"),
ambiguous_parameter = readIntBig(Code, "42P08"),
ambiguous_alias = readIntBig(Code, "42P09"),
invalid_column_reference = readIntBig(Code, "42P10"),
invalid_column_definition = readIntBig(Code, "42611"),
invalid_cursor_definition = readIntBig(Code, "42P11"),
invalid_database_definition = readIntBig(Code, "42P12"),
invalid_function_definition = readIntBig(Code, "42P13"),
invalid_prepared_statement_definition = readIntBig(Code, "42P14"),
invalid_schema_definition = readIntBig(Code, "42P15"),
invalid_table_definition = readIntBig(Code, "42P16"),
invalid_object_definition = readIntBig(Code, "42P17"),
// Class 44 WITH CHECK OPTION Violation
with_check_option_violation = readIntBig(Code, "44000"),
// Class 53 Insufficient Resources
insufficient_resources = readIntBig(Code, "53000"),
disk_full = readIntBig(Code, "53100"),
out_of_memory = readIntBig(Code, "53200"),
too_many_connections = readIntBig(Code, "53300"),
configuration_limit_exceeded = readIntBig(Code, "53400"),
// Class 54 Program Limit Exceeded
program_limit_exceeded = readIntBig(Code, "54000"),
statement_too_complex = readIntBig(Code, "54001"),
too_many_columns = readIntBig(Code, "54011"),
too_many_arguments = readIntBig(Code, "54023"),
// Class 55 Object Not In Prerequisite State
object_not_in_prerequisite_state = readIntBig(Code, "55000"),
object_in_use = readIntBig(Code, "55006"),
cant_change_runtime_param = readIntBig(Code, "55P02"),
lock_not_available = readIntBig(Code, "55P03"),
unsafe_new_enum_value_usage = readIntBig(Code, "55P04"),
// Class 57 Operator Intervention
operator_intervention = readIntBig(Code, "57000"),
query_canceled = readIntBig(Code, "57014"),
admin_shutdown = readIntBig(Code, "57P01"),
crash_shutdown = readIntBig(Code, "57P02"),
cannot_connect_now = readIntBig(Code, "57P03"),
database_dropped = readIntBig(Code, "57P04"),
idle_session_timeout = readIntBig(Code, "57P05"),
// Class 58 System Error (errors external to PostgreSQL itself)
system_error = readIntBig(Code, "58000"),
io_error = readIntBig(Code, "58030"),
undefined_file = readIntBig(Code, "58P01"),
duplicate_file = readIntBig(Code, "58P02"),
// Class 72 Snapshot Failure
snapshot_too_old = readIntBig(Code, "72000"),
// Class F0 Configuration File Error
config_file_error = readIntBig(Code, "F0000"),
lock_file_exists = readIntBig(Code, "F0001"),
// Class HV Foreign Data Wrapper Error (SQL/MED)
fdw_error = readIntBig(Code, "HV000"),
fdw_column_name_not_found = readIntBig(Code, "HV005"),
fdw_dynamic_parameter_value_needed = readIntBig(Code, "HV002"),
fdw_function_sequence_error = readIntBig(Code, "HV010"),
fdw_inconsistent_descriptor_information = readIntBig(Code, "HV021"),
fdw_invalid_attribute_value = readIntBig(Code, "HV024"),
fdw_invalid_column_name = readIntBig(Code, "HV007"),
fdw_invalid_column_number = readIntBig(Code, "HV008"),
fdw_invalid_data_type = readIntBig(Code, "HV004"),
fdw_invalid_data_type_descriptors = readIntBig(Code, "HV006"),
fdw_invalid_descriptor_field_identifier = readIntBig(Code, "HV091"),
fdw_invalid_handle = readIntBig(Code, "HV00B"),
fdw_invalid_option_index = readIntBig(Code, "HV00C"),
fdw_invalid_option_name = readIntBig(Code, "HV00D"),
fdw_invalid_string_length_or_buffer_length = readIntBig(Code, "HV090"),
fdw_invalid_string_format = readIntBig(Code, "HV00A"),
fdw_invalid_use_of_null_pointer = readIntBig(Code, "HV009"),
fdw_too_many_handles = readIntBig(Code, "HV014"),
fdw_out_of_memory = readIntBig(Code, "HV001"),
fdw_no_schemas = readIntBig(Code, "HV00P"),
fdw_option_name_not_found = readIntBig(Code, "HV00J"),
fdw_reply_handle = readIntBig(Code, "HV00K"),
fdw_schema_not_found = readIntBig(Code, "HV00Q"),
fdw_table_not_found = readIntBig(Code, "HV00R"),
fdw_unable_to_create_execution = readIntBig(Code, "HV00L"),
fdw_unable_to_create_reply = readIntBig(Code, "HV00M"),
fdw_unable_to_establish_connection = readIntBig(Code, "HV00N"),
// Class P0 PL/pgSQL Error
plpgsql_error = readIntBig(Code, "P0000"),
raise_exception = readIntBig(Code, "P0001"),
no_data_found = readIntBig(Code, "P0002"),
too_many_rows = readIntBig(Code, "P0003"),
assert_failure = readIntBig(Code, "P0004"),
// Class XX Internal Error
internal_error = readIntBig(Code, "XX000"),
data_corrupted = readIntBig(Code, "XX001"),
index_corrupted = readIntBig(Code, "XX002"),
_, // Class 00 Successful Completion
successful_completion = readVarInt(Code, "00000", .Big),
// Class 01 Warning
warning = readVarInt(Code, "01000", .Big),
dynamic_result_sets_returned = readVarInt(Code, "0100C", .Big),
implicit_zero_bit_padding = readVarInt(Code, "01008", .Big),
null_value_eliminated_in_set_function = readVarInt(Code, "01003", .Big),
privilege_not_granted = readVarInt(Code, "01007", .Big),
privilege_not_revoked = readVarInt(Code, "01006", .Big),
string_data_right_truncation_warning = readVarInt(Code, "01004", .Big),
deprecated_feature = readVarInt(Code, "01P01", .Big),
// Class 02 No Data (this is also a warning class per the SQL standard)
no_data = readVarInt(Code, "02000", .Big),
no_additional_dynamic_result_sets_returned = readVarInt(Code, "02001", .Big),
// Class 03 SQL Statement Not Yet Complete
sql_statement_not_yet_complete = readVarInt(Code, "03000", .Big),
// Class 08 Connection Exception
connection_exception = readVarInt(Code, "08000", .Big),
connection_does_not_exist = readVarInt(Code, "08003", .Big),
connection_failure = readVarInt(Code, "08006", .Big),
sqlclient_unable_to_establish_sqlconnection = readVarInt(Code, "08001", .Big),
sqlserver_rejected_establishment_of_sqlconnection = readVarInt(Code, "08004", .Big),
transaction_resolution_unknown = readVarInt(Code, "08007", .Big),
protocol_violation = readVarInt(Code, "08P01", .Big),
// Class 09 Triggered Action Exception
triggered_action_exception = readVarInt(Code, "09000", .Big),
// Class 0A Feature Not Supported
feature_not_supported = readVarInt(Code, "0A000", .Big),
// Class 0B Invalid Transaction Initiation
invalid_transaction_initiation = readVarInt(Code, "0B000", .Big),
// Class 0F Locator Exception
locator_exception = readVarInt(Code, "0F000", .Big),
invalid_locator_specification = readVarInt(Code, "0F001", .Big),
// Class 0L Invalid Grantor
invalid_grantor = readVarInt(Code, "0L000", .Big),
invalid_grant_operation = readVarInt(Code, "0LP01", .Big),
// Class 0P Invalid Role Specification
invalid_role_specification = readVarInt(Code, "0P000", .Big),
// Class 0Z Diagnostics Exception
diagnostics_exception = readVarInt(Code, "0Z000", .Big),
stacked_diagnostics_accessed_without_active_handler = readVarInt(Code, "0Z002", .Big),
// Class 20 Case Not Found
case_not_found = readVarInt(Code, "20000", .Big),
// Class 21 Cardinality Violation
cardinality_violation = readVarInt(Code, "21000", .Big),
// Class 22 Data Exception
data_exception = readVarInt(Code, "22000", .Big),
array_subscript_error = readVarInt(Code, "2202E", .Big),
character_not_in_repertoire = readVarInt(Code, "22021", .Big),
datetime_field_overflow = readVarInt(Code, "22008", .Big),
division_by_zero = readVarInt(Code, "22012", .Big),
error_in_assignment = readVarInt(Code, "22005", .Big),
escape_character_conflict = readVarInt(Code, "2200B", .Big),
indicator_overflow = readVarInt(Code, "22022", .Big),
interval_field_overflow = readVarInt(Code, "22015", .Big),
invalid_argument_for_logarithm = readVarInt(Code, "2201E", .Big),
invalid_argument_for_ntile_function = readVarInt(Code, "22014", .Big),
invalid_argument_for_nth_value_function = readVarInt(Code, "22016", .Big),
invalid_argument_for_power_function = readVarInt(Code, "2201F", .Big),
invalid_argument_for_width_bucket_function = readVarInt(Code, "2201G", .Big),
invalid_character_value_for_cast = readVarInt(Code, "22018", .Big),
invalid_datetime_format = readVarInt(Code, "22007", .Big),
invalid_escape_character = readVarInt(Code, "22019", .Big),
invalid_escape_octet = readVarInt(Code, "2200D", .Big),
invalid_escape_sequence = readVarInt(Code, "22025", .Big),
nonstandard_use_of_escape_character = readVarInt(Code, "22P06", .Big),
invalid_indicator_parameter_value = readVarInt(Code, "22010", .Big),
invalid_parameter_value = readVarInt(Code, "22023", .Big),
invalid_preceding_or_following_size = readVarInt(Code, "22013", .Big),
invalid_regular_expression = readVarInt(Code, "2201B", .Big),
invalid_row_count_in_limit_clause = readVarInt(Code, "2201W", .Big),
invalid_row_count_in_result_offset_clause = readVarInt(Code, "2201X", .Big),
invalid_tablesample_argument = readVarInt(Code, "2202H", .Big),
invalid_tablesample_repeat = readVarInt(Code, "2202G", .Big),
invalid_time_zone_displacement_value = readVarInt(Code, "22009", .Big),
invalid_use_of_escape_character = readVarInt(Code, "2200C", .Big),
most_specific_type_mismatch = readVarInt(Code, "2200G", .Big),
null_value_not_allowed_data_exception = readVarInt(Code, "22004", .Big),
null_value_no_indicator_parameter = readVarInt(Code, "22002", .Big),
numeric_value_out_of_range = readVarInt(Code, "22003", .Big),
sequence_generator_limit_exceeded = readVarInt(Code, "2200H", .Big),
string_data_length_mismatch = readVarInt(Code, "22026", .Big),
string_data_right_truncation_exception = readVarInt(Code, "22001", .Big),
substring_error = readVarInt(Code, "22011", .Big),
trim_error = readVarInt(Code, "22027", .Big),
unterminated_c_string = readVarInt(Code, "22024", .Big),
zero_length_character_string = readVarInt(Code, "2200F", .Big),
floating_point_exception = readVarInt(Code, "22P01", .Big),
invalid_text_representation = readVarInt(Code, "22P02", .Big),
invalid_binary_representation = readVarInt(Code, "22P03", .Big),
bad_copy_file_format = readVarInt(Code, "22P04", .Big),
untranslatable_character = readVarInt(Code, "22P05", .Big),
not_an_xml_document = readVarInt(Code, "2200L", .Big),
invalid_xml_document = readVarInt(Code, "2200M", .Big),
invalid_xml_content = readVarInt(Code, "2200N", .Big),
invalid_xml_comment = readVarInt(Code, "2200S", .Big),
invalid_xml_processing_instruction = readVarInt(Code, "2200T", .Big),
duplicate_json_object_key_value = readVarInt(Code, "22030", .Big),
invalid_argument_for_sql_json_datetime_function = readVarInt(Code, "22031", .Big),
invalid_json_text = readVarInt(Code, "22032", .Big),
invalid_sql_json_subscript = readVarInt(Code, "22033", .Big),
more_than_one_sql_json_item = readVarInt(Code, "22034", .Big),
no_sql_json_item = readVarInt(Code, "22035", .Big),
non_numeric_sql_json_item = readVarInt(Code, "22036", .Big),
non_unique_keys_in_a_json_object = readVarInt(Code, "22037", .Big),
singleton_sql_json_item_required = readVarInt(Code, "22038", .Big),
sql_json_array_not_found = readVarInt(Code, "22039", .Big),
sql_json_member_not_found = readVarInt(Code, "2203A", .Big),
sql_json_number_not_found = readVarInt(Code, "2203B", .Big),
sql_json_object_not_found = readVarInt(Code, "2203C", .Big),
too_many_json_array_elements = readVarInt(Code, "2203D", .Big),
too_many_json_object_members = readVarInt(Code, "2203E", .Big),
sql_json_scalar_required = readVarInt(Code, "2203F", .Big),
// Class 23 Integrity Constraint Violation
integrity_constraint_violation = readVarInt(Code, "23000", .Big),
restrict_violation = readVarInt(Code, "23001", .Big),
not_null_violation = readVarInt(Code, "23502", .Big),
foreign_key_violation = readVarInt(Code, "23503", .Big),
unique_violation = readVarInt(Code, "23505", .Big),
check_violation = readVarInt(Code, "23514", .Big),
exclusion_violation = readVarInt(Code, "23P01", .Big),
// Class 24 Invalid Cursor State
invalid_cursor_state = readVarInt(Code, "24000", .Big),
// Class 25 Invalid Transaction State
invalid_transaction_state = readVarInt(Code, "25000", .Big),
active_sql_transaction = readVarInt(Code, "25001", .Big),
branch_transaction_already_active = readVarInt(Code, "25002", .Big),
held_cursor_requires_same_isolation_level = readVarInt(Code, "25008", .Big),
inappropriate_access_mode_for_branch_transaction = readVarInt(Code, "25003", .Big),
inappropriate_isolation_level_for_branch_transaction = readVarInt(Code, "25004", .Big),
no_active_sql_transaction_for_branch_transaction = readVarInt(Code, "25005", .Big),
read_only_sql_transaction = readVarInt(Code, "25006", .Big),
schema_and_data_statement_mixing_not_supported = readVarInt(Code, "25007", .Big),
no_active_sql_transaction = readVarInt(Code, "25P01", .Big),
in_failed_sql_transaction = readVarInt(Code, "25P02", .Big),
idle_in_transaction_session_timeout = readVarInt(Code, "25P03", .Big),
// Class 26 Invalid SQL Statement Name
invalid_sql_statement_name = readVarInt(Code, "26000", .Big),
// Class 27 Triggered Data Change Violation
triggered_data_change_violation = readVarInt(Code, "27000", .Big),
// Class 28 Invalid Authorization Specification
invalid_authorization_specification = readVarInt(Code, "28000", .Big),
invalid_password = readVarInt(Code, "28P01", .Big),
// Class 2B Dependent Privilege Descriptors Still Exist
dependent_privilege_descriptors_still_exist = readVarInt(Code, "2B000", .Big),
dependent_objects_still_exist = readVarInt(Code, "2BP01", .Big),
// Class 2D Invalid Transaction Termination
invalid_transaction_termination = readVarInt(Code, "2D000", .Big),
// Class 2F SQL Routine Exception
sql_routine_exception = readVarInt(Code, "2F000", .Big),
function_executed_no_return_statement = readVarInt(Code, "2F005", .Big),
modifying_sql_data_not_permitted_sql_exception = readVarInt(Code, "2F002", .Big),
prohibited_sql_statement_attempted_sql_exception = readVarInt(Code, "2F003", .Big),
reading_sql_data_not_permitted_sql_exception = readVarInt(Code, "2F004", .Big),
// Class 34 Invalid Cursor Name
invalid_cursor_name = readVarInt(Code, "34000", .Big),
// Class 38 External Routine Exception
external_routine_exception = readVarInt(Code, "38000", .Big),
containing_sql_not_permitted = readVarInt(Code, "38001", .Big),
modifying_sql_data_not_permitted_external_exception = readVarInt(Code, "38002", .Big),
prohibited_sql_statement_attempted_external_exception = readVarInt(Code, "38003", .Big),
reading_sql_data_not_permitted_external_exception = readVarInt(Code, "38004", .Big),
// Class 39 External Routine Invocation Exception
external_routine_invocation_exception = readVarInt(Code, "39000", .Big),
invalid_sqlstate_returned = readVarInt(Code, "39001", .Big),
null_value_not_allowed_external_exception = readVarInt(Code, "39004", .Big),
trigger_protocol_violated = readVarInt(Code, "39P01", .Big),
srf_protocol_violated = readVarInt(Code, "39P02", .Big),
event_trigger_protocol_violated = readVarInt(Code, "39P03", .Big),
// Class 3B Savepoint Exception
savepoint_exception = readVarInt(Code, "3B000", .Big),
invalid_savepoint_specification = readVarInt(Code, "3B001", .Big),
// Class 3D Invalid Catalog Name
invalid_catalog_name = readVarInt(Code, "3D000", .Big),
// Class 3F Invalid Schema Name
invalid_schema_name = readVarInt(Code, "3F000", .Big),
// Class 40 Transaction Rollback
transaction_rollback = readVarInt(Code, "40000", .Big),
transaction_integrity_constraint_violation = readVarInt(Code, "40002", .Big),
serialization_failure = readVarInt(Code, "40001", .Big),
statement_completion_unknown = readVarInt(Code, "40003", .Big),
deadlock_detected = readVarInt(Code, "40P01", .Big),
// Class 42 Syntax Error or Access Rule Violation
syntax_error_or_access_rule_violation = readVarInt(Code, "42000", .Big),
syntax_error = readVarInt(Code, "42601", .Big),
insufficient_privilege = readVarInt(Code, "42501", .Big),
cannot_coerce = readVarInt(Code, "42846", .Big),
grouping_error = readVarInt(Code, "42803", .Big),
windowing_error = readVarInt(Code, "42P20", .Big),
invalid_recursion = readVarInt(Code, "42P19", .Big),
invalid_foreign_key = readVarInt(Code, "42830", .Big),
invalid_name = readVarInt(Code, "42602", .Big),
name_too_long = readVarInt(Code, "42622", .Big),
reserved_name = readVarInt(Code, "42939", .Big),
datatype_mismatch = readVarInt(Code, "42804", .Big),
indeterminate_datatype = readVarInt(Code, "42P18", .Big),
collation_mismatch = readVarInt(Code, "42P21", .Big),
indeterminate_collation = readVarInt(Code, "42P22", .Big),
wrong_object_type = readVarInt(Code, "42809", .Big),
generated_always = readVarInt(Code, "428C9", .Big),
undefined_column = readVarInt(Code, "42703", .Big),
undefined_function = readVarInt(Code, "42883", .Big),
undefined_table = readVarInt(Code, "42P01", .Big),
undefined_parameter = readVarInt(Code, "42P02", .Big),
undefined_object = readVarInt(Code, "42704", .Big),
duplicate_column = readVarInt(Code, "42701", .Big),
duplicate_cursor = readVarInt(Code, "42P03", .Big),
duplicate_database = readVarInt(Code, "42P04", .Big),
duplicate_function = readVarInt(Code, "42723", .Big),
duplicate_prepared_statement = readVarInt(Code, "42P05", .Big),
duplicate_schema = readVarInt(Code, "42P06", .Big),
duplicate_table = readVarInt(Code, "42P07", .Big),
duplicate_alias = readVarInt(Code, "42712", .Big),
duplicate_object = readVarInt(Code, "42710", .Big),
ambiguous_column = readVarInt(Code, "42702", .Big),
ambiguous_function = readVarInt(Code, "42725", .Big),
ambiguous_parameter = readVarInt(Code, "42P08", .Big),
ambiguous_alias = readVarInt(Code, "42P09", .Big),
invalid_column_reference = readVarInt(Code, "42P10", .Big),
invalid_column_definition = readVarInt(Code, "42611", .Big),
invalid_cursor_definition = readVarInt(Code, "42P11", .Big),
invalid_database_definition = readVarInt(Code, "42P12", .Big),
invalid_function_definition = readVarInt(Code, "42P13", .Big),
invalid_prepared_statement_definition = readVarInt(Code, "42P14", .Big),
invalid_schema_definition = readVarInt(Code, "42P15", .Big),
invalid_table_definition = readVarInt(Code, "42P16", .Big),
invalid_object_definition = readVarInt(Code, "42P17", .Big),
// Class 44 WITH CHECK OPTION Violation
with_check_option_violation = readVarInt(Code, "44000", .Big),
// Class 53 Insufficient Resources
insufficient_resources = readVarInt(Code, "53000", .Big),
disk_full = readVarInt(Code, "53100", .Big),
out_of_memory = readVarInt(Code, "53200", .Big),
too_many_connections = readVarInt(Code, "53300", .Big),
configuration_limit_exceeded = readVarInt(Code, "53400", .Big),
// Class 54 Program Limit Exceeded
program_limit_exceeded = readVarInt(Code, "54000", .Big),
statement_too_complex = readVarInt(Code, "54001", .Big),
too_many_columns = readVarInt(Code, "54011", .Big),
too_many_arguments = readVarInt(Code, "54023", .Big),
// Class 55 Object Not In Prerequisite State
object_not_in_prerequisite_state = readVarInt(Code, "55000", .Big),
object_in_use = readVarInt(Code, "55006", .Big),
cant_change_runtime_param = readVarInt(Code, "55P02", .Big),
lock_not_available = readVarInt(Code, "55P03", .Big),
unsafe_new_enum_value_usage = readVarInt(Code, "55P04", .Big),
// Class 57 Operator Intervention
operator_intervention = readVarInt(Code, "57000", .Big),
query_canceled = readVarInt(Code, "57014", .Big),
admin_shutdown = readVarInt(Code, "57P01", .Big),
crash_shutdown = readVarInt(Code, "57P02", .Big),
cannot_connect_now = readVarInt(Code, "57P03", .Big),
database_dropped = readVarInt(Code, "57P04", .Big),
idle_session_timeout = readVarInt(Code, "57P05", .Big),
// Class 58 System Error (errors external to PostgreSQL itself)
system_error = readVarInt(Code, "58000", .Big),
io_error = readVarInt(Code, "58030", .Big),
undefined_file = readVarInt(Code, "58P01", .Big),
duplicate_file = readVarInt(Code, "58P02", .Big),
// Class 72 Snapshot Failure
snapshot_too_old = readVarInt(Code, "72000", .Big),
// Class F0 Configuration File Error
config_file_error = readVarInt(Code, "F0000", .Big),
lock_file_exists = readVarInt(Code, "F0001", .Big),
// Class HV Foreign Data Wrapper Error (SQL/MED)
fdw_error = readVarInt(Code, "HV000", .Big),
fdw_column_name_not_found = readVarInt(Code, "HV005", .Big),
fdw_dynamic_parameter_value_needed = readVarInt(Code, "HV002", .Big),
fdw_function_sequence_error = readVarInt(Code, "HV010", .Big),
fdw_inconsistent_descriptor_information = readVarInt(Code, "HV021", .Big),
fdw_invalid_attribute_value = readVarInt(Code, "HV024", .Big),
fdw_invalid_column_name = readVarInt(Code, "HV007", .Big),
fdw_invalid_column_number = readVarInt(Code, "HV008", .Big),
fdw_invalid_data_type = readVarInt(Code, "HV004", .Big),
fdw_invalid_data_type_descriptors = readVarInt(Code, "HV006", .Big),
fdw_invalid_descriptor_field_identifier = readVarInt(Code, "HV091", .Big),
fdw_invalid_handle = readVarInt(Code, "HV00B", .Big),
fdw_invalid_option_index = readVarInt(Code, "HV00C", .Big),
fdw_invalid_option_name = readVarInt(Code, "HV00D", .Big),
fdw_invalid_string_length_or_buffer_length = readVarInt(Code, "HV090", .Big),
fdw_invalid_string_format = readVarInt(Code, "HV00A", .Big),
fdw_invalid_use_of_null_pointer = readVarInt(Code, "HV009", .Big),
fdw_too_many_handles = readVarInt(Code, "HV014", .Big),
fdw_out_of_memory = readVarInt(Code, "HV001", .Big),
fdw_no_schemas = readVarInt(Code, "HV00P", .Big),
fdw_option_name_not_found = readVarInt(Code, "HV00J", .Big),
fdw_reply_handle = readVarInt(Code, "HV00K", .Big),
fdw_schema_not_found = readVarInt(Code, "HV00Q", .Big),
fdw_table_not_found = readVarInt(Code, "HV00R", .Big),
fdw_unable_to_create_execution = readVarInt(Code, "HV00L", .Big),
fdw_unable_to_create_reply = readVarInt(Code, "HV00M", .Big),
fdw_unable_to_establish_connection = readVarInt(Code, "HV00N", .Big),
// Class P0 PL/pgSQL Error
plpgsql_error = readVarInt(Code, "P0000", .Big),
raise_exception = readVarInt(Code, "P0001", .Big),
no_data_found = readVarInt(Code, "P0002", .Big),
too_many_rows = readVarInt(Code, "P0003", .Big),
assert_failure = readVarInt(Code, "P0004", .Big),
// Class XX Internal Error
internal_error = readVarInt(Code, "XX000", .Big),
data_corrupted = readVarInt(Code, "XX001", .Big),
index_corrupted = readVarInt(Code, "XX002", .Big),
_,
};
}; };

View File

@ -28,7 +28,7 @@ fn getCharPos(text: []const u8, offset: c_int) struct { row: usize, col: usize }
return .{ .row = row, .col = col }; return .{ .row = row, .col = col };
} }
fn handleUnexpectedError(db: *c.sqlite3, code: c_int, sql_text: ?[]const u8) anyerror { fn handleUnexpectedError(db: *c.sqlite3, code: c_int, sql_text: ?[]const u8) error{Unexpected} {
std.log.err("Unexpected error in SQLite engine: {s} ({})", .{ c.sqlite3_errstr(code), code }); std.log.err("Unexpected error in SQLite engine: {s} ({})", .{ c.sqlite3_errstr(code), code });
std.log.debug("Additional details:", .{}); std.log.debug("Additional details:", .{});
@ -51,7 +51,7 @@ pub const Db = struct {
pub fn open(path: [:0]const u8) common.OpenError!Db { pub fn open(path: [:0]const u8) common.OpenError!Db {
const flags = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_EXRESCODE; const flags = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_EXRESCODE;
var db: [*c]c.sqlite3 = null; var db: ?*c.sqlite3 = null;
switch (c.sqlite3_open_v2(@ptrCast([*c]const u8, path), &db, flags, null)) { switch (c.sqlite3_open_v2(@ptrCast([*c]const u8, path), &db, flags, null)) {
c.SQLITE_OK => {}, c.SQLITE_OK => {},
else => |code| { else => |code| {
@ -61,7 +61,7 @@ pub const Db = struct {
"Unable to open SQLite DB \"{s}\". Error: {?s} ({})", "Unable to open SQLite DB \"{s}\". Error: {?s} ({})",
.{ path, c.sqlite3_errstr(code), code }, .{ path, c.sqlite3_errstr(code), code },
); );
return error.InternalException; return error.BadConnection;
} }
const ext_code = c.sqlite3_extended_errcode(db); const ext_code = c.sqlite3_extended_errcode(db);
@ -77,7 +77,7 @@ pub const Db = struct {
} }
return Db{ return Db{
.db = db, .db = db.?,
}; };
} }
@ -109,55 +109,115 @@ pub const Db = struct {
}, },
}; };
inline for (args) |arg, i| { if (@TypeOf(args) != void) {
// SQLite treats $NNN args as having the name NNN, not index NNN. inline for (args) |arg, i| {
// As such, if you reference $2 and not $1 in your query (such as // SQLite treats $NNN args as having the name NNN, not index NNN.
// when dynamically constructing queries), it could assign $2 the // As such, if you reference $2 and not $1 in your query (such as
// index 1. So we can't assume the index according to the 1-indexed // when dynamically constructing queries), it could assign $2 the
// arg array is equivalent to the param to bind it to. // index 1. So we can't assume the index according to the 1-indexed
// We can, however, look up the exact index to bind to. // arg array is equivalent to the param to bind it to.
// If the argument is not used in the query, then it will have an "index" // We can, however, look up the exact index to bind to.
// of 0, and we must not bind the argument. // If the argument is not used in the query, then it will have an "index"
const name = std.fmt.comptimePrint("${}", .{i + 1}); // of 0, and we must not bind the argument.
const db_idx = c.sqlite3_bind_parameter_index(stmt.?, name); const name = std.fmt.comptimePrint("${}", .{i + 1});
if (db_idx != 0) { const db_idx = c.sqlite3_bind_parameter_index(stmt.?, name);
switch (bindArg(stmt.?, @intCast(u15, db_idx), arg)) { if (db_idx != 0)
c.SQLITE_OK => {}, try self.bindArgument(stmt.?, @intCast(u15, db_idx), arg)
else => |err| { else if (!opts.ignore_unused_arguments)
return handleUnexpectedError(self.db, err, sql); return error.UnusedArgument;
}, }
}
} else if (!opts.ignore_unknown_parameters) return error.UndefinedParameter;
} }
return Results{ .stmt = stmt.?, .db = self.db }; return Results{ .stmt = stmt.?, .db = self.db };
} }
};
fn bindArg(stmt: *c.sqlite3_stmt, idx: u15, val: anytype) c_int { fn bindArgument(self: Db, stmt: *c.sqlite3_stmt, idx: u15, val: anytype) !void {
if (comptime std.meta.trait.isZigString(@TypeOf(val))) { if (comptime std.meta.trait.isZigString(@TypeOf(val))) {
const slice = @as([]const u8, val); return self.bindString(stmt, idx, val);
return c.sqlite3_bind_text(stmt, idx, slice.ptr, @intCast(c_int, slice.len), c.SQLITE_TRANSIENT); }
const T = @TypeOf(val);
switch (@typeInfo(T)) {
.Struct,
.Union,
.Opaque,
=> {
const arr = if (@hasDecl(T, "toCharArray"))
val.toCharArray()
else if (@hasDecl(T, "toCharArrayZ"))
val.toCharArrayZ()
else
@compileError("SQLite: Could not serialize " ++ @typeName(T) ++ " into staticly sized string");
const len = std.mem.len(&arr);
return self.bindString(stmt, idx, arr[0..len]);
},
.Enum => |info| {
const name = if (info.is_exhaustive)
@tagName(val)
else
@compileError("SQLite: Could not serialize non-exhaustive enum " ++ @typeName(T) ++ " into string");
return self.bindString(stmt, idx, name);
},
.Optional => {
return if (val) |v| self.bindArgument(stmt, idx, v) else self.bindNull(stmt, idx);
},
.Null => return self.bindNull(stmt, idx),
.Int => return self.bindInt(stmt, idx, val),
.Float => return self.bindFloat(stmt, idx, val),
else => @compileError("Unable to serialize type " ++ @typeName(T)),
}
} }
return switch (@TypeOf(val)) { fn bindString(self: Db, stmt: *c.sqlite3_stmt, idx: u15, str: []const u8) !void {
Uuid => blk: { const len = std.math.cast(c_int, str.len) orelse {
const arr = val.toCharArrayZ(); std.log.err("SQLite: string len {} too large", .{str.len});
break :blk bindArg(stmt, idx, &arr); return error.BindException;
}, };
DateTime => blk: {
const arr = val.toCharArrayZ(); switch (c.sqlite3_bind_text(stmt, idx, str.ptr, len, c.SQLITE_TRANSIENT)) {
break :blk bindArg(stmt, idx, &arr); c.SQLITE_OK => {},
}, else => |result| {
@TypeOf(null) => c.sqlite3_bind_null(stmt, idx), std.log.err("SQLite: Unable to bind string to index {}", .{idx});
else => |T| switch (@typeInfo(T)) { std.log.debug("SQLite: {s}", .{str});
.Optional => if (val) |v| bindArg(stmt, idx, v) else bindArg(stmt, idx, null), return handleUnexpectedError(self.db, result, null);
.Enum => bindArg(stmt, idx, @tagName(val)), },
.Int => c.sqlite3_bind_int64(stmt, idx, @intCast(i64, val)), }
else => @compileError("unsupported type " ++ @typeName(T)), }
},
}; fn bindNull(self: Db, stmt: *c.sqlite3_stmt, idx: u15) !void {
} switch (c.sqlite3_bind_null(stmt, idx)) {
c.SQLITE_OK => {},
else => |result| {
std.log.err("SQLite: Unable to bind NULL to index {}", .{idx});
return handleUnexpectedError(self.db, result, null);
},
}
}
fn bindInt(self: Db, stmt: *c.sqlite3_stmt, idx: u15, val: i64) !void {
switch (c.sqlite3_bind_int64(stmt, idx, val)) {
c.SQLITE_OK => {},
else => |result| {
std.log.err("SQLite: Unable to bind int to index {}", .{idx});
std.log.debug("SQLite: {}", .{val});
return handleUnexpectedError(self.db, result, null);
},
}
}
fn bindFloat(self: Db, stmt: *c.sqlite3_stmt, idx: u15, val: f64) !void {
switch (c.sqlite3_bind_double(stmt, idx, val)) {
c.SQLITE_OK => {},
else => |result| {
std.log.err("SQLite: Unable to bind float to index {}", .{idx});
std.log.debug("SQLite: {}", .{val});
return handleUnexpectedError(self.db, result, null);
},
}
}
};
pub const Results = struct { pub const Results = struct {
stmt: *c.sqlite3_stmt, stmt: *c.sqlite3_stmt,
@ -193,12 +253,12 @@ pub const Results = struct {
return if (c.sqlite3_column_name(self.stmt, idx)) |ptr| return if (c.sqlite3_column_name(self.stmt, idx)) |ptr|
ptr[0..std.mem.len(ptr)] ptr[0..std.mem.len(ptr)]
else else
return error.Unexpected; unreachable;
} }
pub fn columnIndex(self: Results, name: []const u8) common.ColumnIndexError!u15 { pub fn columnIndex(self: Results, name: []const u8) common.ColumnIndexError!u15 {
var i: u15 = 0; var i: u15 = 0;
const count = self.columnCount(); const count = try self.columnCount();
while (i < count) : (i += 1) { while (i < count) : (i += 1) {
const column = try self.columnName(i); const column = try self.columnName(i);
if (std.mem.eql(u8, name, column)) return i; if (std.mem.eql(u8, name, column)) return i;
@ -213,33 +273,75 @@ pub const Row = struct {
db: *c.sqlite3, db: *c.sqlite3,
pub fn get(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T { pub fn get(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T {
if (c.sqlite3_column_type(self.stmt, idx) == c.SQLITE_NULL) { return getColumn(self.stmt, T, idx, alloc);
return if (@typeInfo(T) == .Optional) null else error.TypeMismatch;
}
return self.getNotNull(T, idx, alloc);
}
fn getNotNull(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) error.GetError!T {
return switch (T) {
f32, f64 => @floatCast(T, c.sqlite3_column_double(self.stmt, idx)),
else => switch (@typeInfo(T)) {
.Int => |int| if (T == i63 or int.bits < 63)
@intCast(T, c.sqlite3_column_int64(self.stmt, idx))
else
self.getFromString(T, idx, alloc),
.Optional => try self.getNotNull(std.meta.Child(T), idx, alloc),
else => self.getFromString(T, idx, alloc),
},
};
}
fn getFromString(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) error.GetError!T {
const ptr = c.sqlite3_column_text(self.stmt, idx);
const size = @intCast(usize, c.sqlite3_column_bytes(self.stmt, idx));
const str = ptr[0..size];
return common.parseValueNotNull(alloc, T, str);
} }
}; };
fn getColumn(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T {
return switch (c.sqlite3_column_type(stmt, idx)) {
c.SQLITE_INTEGER => getColumnInt(stmt, T, idx),
c.SQLITE_FLOAT => getColumnFloat(stmt, T, idx),
c.SQLITE_TEXT => getColumnText(stmt, T, idx, alloc),
c.SQLITE_NULL => {
if (@typeInfo(T) != .Optional) {
std.log.err("SQLite column {}: Expected value of type {}, got (null)", .{ idx, T });
return error.ResultTypeMismatch;
}
return null;
},
c.SQLITE_BLOB => {
std.log.err("SQLite column {}: SQLite value had unsupported storage class BLOB", .{idx});
return error.ResultTypeMismatch;
},
else => |class| {
std.log.err("SQLite column {}: SQLite value had unknown storage class {}", .{ idx, class });
return error.ResultTypeMismatch;
},
};
}
fn getColumnInt(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15) common.GetError!T {
const val: i64 = c.sqlite3_column_int64(stmt, idx);
switch (T) {
DateTime => return DateTime{ .seconds_since_epoch = val },
else => switch (@typeInfo(T)) {
.Int => if (std.math.cast(T, val)) |v| return v else {
std.log.err("SQLite column {}: Expected value of type {}, got {} (outside of range)", .{ idx, T, val });
return error.ResultTypeMismatch;
},
else => {
std.log.err("SQLite column {}: Storage class INT cannot be parsed into type {}", .{ idx, T });
return error.ResultTypeMismatch;
},
},
}
}
fn getColumnFloat(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15) common.GetError!T {
const val: f64 = c.sqlite3_column_double(stmt, idx);
switch (T) {
// Only support floats that fit in range for now
f16, f32, f64 => return @floatCast(T, val),
DateTime => return DateTime{
.seconds_since_epoch = std.math.lossyCast(i64, val * @intToFloat(f64, std.time.epoch.secs_per_day)),
},
else => {
std.log.err("SQLite column {}: Storage class FLOAT cannot be parsed into type {}", .{ idx, T });
return error.ResultTypeMismatch;
},
}
}
fn getColumnText(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T {
if (c.sqlite3_column_text(stmt, idx)) |ptr| {
const size = @intCast(usize, c.sqlite3_column_bytes(stmt, idx));
const str = std.mem.sliceTo(ptr[0..size], 0);
return common.parseValueNotNull(alloc, T, str);
} else {
std.log.err("SQLite column {}: TEXT value stored but engine returned null pointer (out of memory?)", .{idx});
return error.ResultTypeMismatch;
}
}

View File

@ -17,7 +17,7 @@ const ConnectionError = error{
// - Filesystem full // - Filesystem full
// - Unknown crash // - Unknown crash
// - Filesystem permissions denied // - Filesystem permissions denied
InternalError, InternalException,
}; };
// Errors related to constraint validation // Errors related to constraint validation
@ -41,19 +41,21 @@ const ConstraintError = error{
// Errors related to argument binding // Errors related to argument binding
const ArgumentError = error{ const ArgumentError = error{
// One of the arguments passed could not be marshalled to pass to the SQL engine // One of the arguments passed could not be marshalled to pass to the SQL engine
InvalidArgument, BindException,
// The set of arguments passed did not map to query parameters // The set of arguments passed did not map to query parameters
UndefinedParameter, UnusedArgument,
// The allocator used for staging the query ran out of memory // The allocator used for staging the query ran out of memory
OutOfMemory, OutOfMemory,
AllocatorRequired,
}; };
// Errors related to retrieving query result columns // Errors related to retrieving query result columns
const ResultColumnError = error{ const ResultColumnError = error{
// The allocator used for retrieving the results ran out of memory // The allocator used for retrieving the results ran out of memory
OutOfMemory, OutOfMemory,
AllocatorRequired,
// A type error occurred when parsing results (means invalid data is in the DB) // A type error occurred when parsing results (means invalid data is in the DB)
ResultTypeMismatch, ResultTypeMismatch,
@ -68,10 +70,7 @@ const StartQueryError = error{
PermissionDenied, PermissionDenied,
// The SQL query had invalid syntax or used an invalid identifier // The SQL query had invalid syntax or used an invalid identifier
InvalidSql, SqlException,
// A type error occurred during the query (means query is written wrong)
QueryTypeMismatch,
// The set of columns to parse did not match the columns returned by the query // The set of columns to parse did not match the columns returned by the query
ColumnMismatch, ColumnMismatch,
@ -81,16 +80,6 @@ const StartQueryError = error{
BadTransactionState, BadTransactionState,
}; };
const RowCountError = error{
NoRows,
TooManyRows,
};
pub const OpenError = error{
BadConnection,
InternalError,
};
pub const library_errors = struct { pub const library_errors = struct {
const BaseError = ConnectionError || UnexpectedError; const BaseError = ConnectionError || UnexpectedError;
@ -98,8 +87,8 @@ pub const library_errors = struct {
pub const OpenError = BaseError; pub const OpenError = BaseError;
pub const QueryError = BaseError || ArgumentError || ConstraintError || StartQueryError; pub const QueryError = BaseError || ArgumentError || ConstraintError || StartQueryError;
pub const RowError = BaseError || ResultColumnError || ConstraintError || StartQueryError; pub const RowError = BaseError || ResultColumnError || ConstraintError || StartQueryError;
pub const QueryRowError = QueryError || RowError || RowCountError; pub const QueryRowError = QueryError || RowError || error{ NoRows, TooManyRows };
pub const ExecError = QueryError || RowCountError; pub const ExecError = QueryError;
pub const BeginError = BaseError || StartQueryError; pub const BeginError = BaseError || StartQueryError;
pub const CommitError = BaseError || StartQueryError || ConstraintError; pub const CommitError = BaseError || StartQueryError || ConstraintError;
}; };

View File

@ -2,7 +2,9 @@ const std = @import("std");
const util = @import("util"); const util = @import("util");
const postgres = @import("./engines/postgres.zig"); const postgres = @import("./engines/postgres.zig");
//const postgres = @import("./engines/null.zig");
const sqlite = @import("./engines/sqlite.zig"); const sqlite = @import("./engines/sqlite.zig");
//const sqlite = @import("./engines/null.zig");
const common = @import("./engines/common.zig"); const common = @import("./engines/common.zig");
const Allocator = std.mem.Allocator; const Allocator = std.mem.Allocator;
@ -42,17 +44,20 @@ const RawResults = union(Engine) {
} }
} }
fn columnCount(self: RawResults) u15 { fn columnCount(self: RawResults) !u15 {
return switch (self) { return try switch (self) {
.postgres => |pg| pg.columnCount(), .postgres => |pg| pg.columnCount(),
.sqlite => |lite| lite.columnCount(), .sqlite => |lite| lite.columnCount(),
}; };
} }
fn columnIndex(self: RawResults, name: []const u8) QueryError!u15 { fn columnIndex(self: RawResults, name: []const u8) error{ NotFound, Unexpected }!u15 {
return try switch (self) { return switch (self) {
.postgres => |pg| pg.columnIndex(name), .postgres => |pg| pg.columnIndex(name),
.sqlite => |lite| lite.columnIndex(name), .sqlite => |lite| lite.columnIndex(name),
} catch |err| switch (err) {
error.OutOfRange => error.Unexpected,
error.NotFound => error.NotFound,
}; };
} }
@ -69,7 +74,7 @@ const RawResults = union(Engine) {
// Must be deallocated by a call to finish() // Must be deallocated by a call to finish()
pub fn Results(comptime T: type) type { pub fn Results(comptime T: type) type {
// would normally make this a declaration of the struct, but it causes the compiler to crash // would normally make this a declaration of the struct, but it causes the compiler to crash
const fields = std.meta.fields(T); const fields = if (T == void) .{} else std.meta.fields(T);
return struct { return struct {
const Self = @This(); const Self = @This();
@ -77,15 +82,21 @@ pub fn Results(comptime T: type) type {
column_indices: [fields.len]u15, column_indices: [fields.len]u15,
fn from(underlying: RawResults) QueryError!Self { fn from(underlying: RawResults) QueryError!Self {
if (std.debug.runtime_safety and std.meta.trait.isTuple(T) and fields.len != underlying.columnCount()) { if (std.debug.runtime_safety and fields.len != underlying.columnCount() catch unreachable) {
std.log.err("Expected {} columns in result, got {}", .{ fields.len, underlying.columnCount() }); std.log.err("Expected {} columns in result, got {}", .{ fields.len, underlying.columnCount() catch unreachable });
return error.ColumnMismatch; return error.ColumnMismatch;
} }
return Self{ .underlying = underlying, .column_indices = blk: { return Self{ .underlying = underlying, .column_indices = blk: {
var indices: [fields.len]u15 = undefined; var indices: [fields.len]u15 = undefined;
inline for (fields) |f, i| { inline for (fields) |f, i| {
indices[i] = if (!std.meta.trait.isTuple(T)) try underlying.columnIndex(f.name) else i; indices[i] = if (!std.meta.trait.isTuple(T))
underlying.columnIndex(f.name) catch {
std.log.err("Could not find column index for field {s}", .{f.name});
return error.ColumnMismatch;
}
else
i;
} }
break :blk indices; break :blk indices;
} }; } };
@ -110,7 +121,13 @@ pub fn Results(comptime T: type) type {
}; };
inline for (fields) |f, i| { inline for (fields) |f, i| {
@field(result, f.name) = try row_val.get(f.field_type, self.column_indices[i], alloc); // TODO: Causes compiler segfault. why?
//const F = f.field_type;
const F = @TypeOf(@field(result, f.name));
@field(result, f.name) = row_val.get(F, self.column_indices[i], alloc) catch |err| {
std.log.err("SQL: Error getting column {s} of type {}", .{ f.name, F });
return err;
};
fields_allocated += 1; fields_allocated += 1;
} }
@ -129,7 +146,8 @@ const Row = union(Engine) {
// Not all types require an allocator to be present. If an allocator is needed but // Not all types require an allocator to be present. If an allocator is needed but
// not required, it will return error.AllocatorRequired. // not required, it will return error.AllocatorRequired.
// The caller is responsible for deallocating T, if relevant. // The caller is responsible for deallocating T, if relevant.
fn get(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) !T { fn get(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T {
if (T == void) return;
return switch (self) { return switch (self) {
.postgres => |pg| pg.get(T, idx, alloc), .postgres => |pg| pg.get(T, idx, alloc),
.sqlite => |lite| lite.get(T, idx, alloc), .sqlite => |lite| lite.get(T, idx, alloc),
@ -147,9 +165,9 @@ const QueryHelper = union(Engine) {
sql: [:0]const u8, sql: [:0]const u8,
args: anytype, args: anytype,
opt: QueryOptions, opt: QueryOptions,
) !RawResults { ) QueryError!RawResults {
return switch (self) { return switch (self) {
.postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt.prep_allocator) }, .postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt) },
.sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args, opt) }, .sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args, opt) },
}; };
} }
@ -182,7 +200,12 @@ const QueryHelper = union(Engine) {
args: anytype, args: anytype,
alloc: ?Allocator, alloc: ?Allocator,
) QueryError!void { ) QueryError!void {
try self.queryRow(void, sql, args, alloc); _ = self.queryRow(void, sql, args, alloc) catch |err| return switch (err) {
error.NoRows => {},
error.TooManyRows => error.SqlException,
error.ResultTypeMismatch => unreachable,
else => |err2| err2,
};
} }
// Runs a query and returns a single row // Runs a query and returns a single row
@ -224,7 +247,11 @@ const QueryHelper = union(Engine) {
comptime var table_spec: []const u8 = table ++ "("; comptime var table_spec: []const u8 = table ++ "(";
comptime var value_spec: []const u8 = "("; comptime var value_spec: []const u8 = "(";
inline for (fields) |field, i| { inline for (fields) |field, i| {
types[i] = field.field_type; // This causes a compile error. Why?
//const F = field.field_type;
const F = @TypeOf(@field(std.mem.zeroes(ValueType), field.name));
// causes issues if F is @TypeOf(null), use dummy type
types[i] = if (F == @TypeOf(null)) ?i64 else F;
table_spec = comptime (table_spec ++ field.name ++ ","); table_spec = comptime (table_spec ++ field.name ++ ",");
value_spec = comptime value_spec ++ std.fmt.comptimePrint("${},", .{i + 1}); value_spec = comptime value_spec ++ std.fmt.comptimePrint("${},", .{i + 1});
} }
@ -252,6 +279,8 @@ pub const Db = struct {
tx_open: bool = false, tx_open: bool = false,
engine: QueryHelper, engine: QueryHelper,
pub const is_transaction = false;
pub fn open(cfg: Config) OpenError!Db { pub fn open(cfg: Config) OpenError!Db {
return switch (cfg) { return switch (cfg) {
.postgres => |postgres_cfg| Db{ .postgres => |postgres_cfg| Db{
@ -345,6 +374,8 @@ pub const Db = struct {
pub const Tx = struct { pub const Tx = struct {
db: *Db, db: *Db,
pub const is_transaction = true;
pub fn queryWithOptions( pub fn queryWithOptions(
self: Tx, self: Tx,
comptime RowType: type, comptime RowType: type,
@ -442,12 +473,14 @@ pub const Tx = struct {
pub fn commit(self: Tx) CommitError!void { pub fn commit(self: Tx) CommitError!void {
if (!self.db.tx_open) return error.BadTransactionState; if (!self.db.tx_open) return error.BadTransactionState;
self.exec("COMMIT", {}, null) catch |err| switch (err) { self.exec("COMMIT", {}, null) catch |err| switch (err) {
error.InvalidArgument, error.BindException,
error.OutOfMemory, error.OutOfMemory,
error.UndefinedParameter, error.UnusedArgument,
error.AllocatorRequired,
=> return error.Unexpected, => return error.Unexpected,
else => return err, // use a new capture because it's got a smaller error set
else => |err2| return err2,
}; };
self.db.tx_open = false; self.db.tx_open = false;
} }

View File

@ -9,10 +9,16 @@ pub const Duration = struct {
seconds_since_epoch: i64, seconds_since_epoch: i64,
// Tries the following methods for parsing, in order:
// 1. treats the string as a RFC 3339 DateTime
// 2. treats the string as the number of seconds since epoch
pub fn parse(str: []const u8) !DateTime { pub fn parse(str: []const u8) !DateTime {
// TODO: Try other formats return if (parseRfc3339(str)) |v|
v
return try parseRfc3339(str); else |_| if (std.fmt.parseInt(i64, str, 10)) |v|
DateTime{ .seconds_since_epoch = v }
else |_|
error.UnknownFormat;
} }
pub fn add(self: DateTime, duration: Duration) DateTime { pub fn add(self: DateTime, duration: Duration) DateTime {
@ -36,8 +42,8 @@ pub fn parseRfc3339(str: []const u8) !DateTime {
const month_num = try std.fmt.parseInt(std.meta.Tag(epoch.Month), str[5..7], 10); const month_num = try std.fmt.parseInt(std.meta.Tag(epoch.Month), str[5..7], 10);
const day_num = @as(i64, try std.fmt.parseInt(u9, str[8..10], 10)); const day_num = @as(i64, try std.fmt.parseInt(u9, str[8..10], 10));
const hour_num = @as(i64, try std.fmt.parseInt(u5, str[11..13], 10)); const hour_num = @as(i64, try std.fmt.parseInt(u5, str[11..13], 10));
const minute_num = @as(i64, try std.fmt.parseInt(u6, str[14..15], 10)); const minute_num = @as(i64, try std.fmt.parseInt(u6, str[14..16], 10));
const second_num = @as(i64, try std.fmt.parseInt(u6, str[16..17], 10)); const second_num = @as(i64, try std.fmt.parseInt(u6, str[17..19], 10));
const is_leap_year = epoch.isLeapYear(year_num); const is_leap_year = epoch.isLeapYear(year_num);
const leap_days_preceding_epoch = comptime epoch.epoch_year / 4 - epoch.epoch_year / 100 + epoch.epoch_year / 400; const leap_days_preceding_epoch = comptime epoch.epoch_year / 4 - epoch.epoch_year / 100 + epoch.epoch_year / 400;
@ -97,7 +103,7 @@ pub fn second(value: DateTime) u6 {
const array_len = 20; const array_len = 20;
pub fn toCharArray(value: DateTime) [array_len + 1]u8 { pub fn toCharArray(value: DateTime) [array_len]u8 {
var buf: [array_len]u8 = undefined; var buf: [array_len]u8 = undefined;
_ = std.fmt.bufPrintZ(&buf, "{}", value) catch unreachable; _ = std.fmt.bufPrintZ(&buf, "{}", value) catch unreachable;
return buf; return buf;

View File

@ -60,12 +60,15 @@ pub fn comptimeJoin(
/// to your enum type. /// to your enum type.
pub fn jsonSerializeEnumAsString( pub fn jsonSerializeEnumAsString(
enum_value: anytype, enum_value: anytype,
_: std.json.StringifyOptions, opt: std.json.StringifyOptions,
writer: anytype, writer: anytype,
) !void { ) !void {
switch (@typeInfo(@TypeOf(enum_value))) { switch (@typeInfo(@TypeOf(enum_value))) {
.Enum => |info| if (!info.is_exhaustive) @compileError("Enum must be exhaustive"), .Enum => |info| if (!info.is_exhaustive) @compileError("Enum must be exhaustive"),
else => @compileError("Must be enum type"), .Pointer => |info| if (info.size == .One) {
return jsonSerializeEnumAsString(enum_value.*, opt, writer);
} else @compileError("Must be enum type or pointer to enum, got " ++ @typeName(@TypeOf(enum_value))),
else => @compileError("Must be enum type or pointer to enum, got " ++ @typeName(@TypeOf(enum_value))),
} }
return std.fmt.format(writer, "\"{s}\"", .{@tagName(enum_value)}); return std.fmt.format(writer, "\"{s}\"", .{@tagName(enum_value)});

View File

@ -1,9 +1,8 @@
const std = @import("std"); const std = @import("std");
const main = @import("main"); const main = @import("main");
const sql = @import("sql");
const cluster_host = "test_host";
const test_config = .{ const test_config = .{
.cluster_host = cluster_host,
.db = .{ .db = .{
.sqlite = .{ .sqlite = .{
.db_file = ":memory:", .db_file = ":memory:",
@ -12,13 +11,22 @@ const test_config = .{
}; };
const ApiSource = main.api.ApiSource; const ApiSource = main.api.ApiSource;
const root_password = "password"; const root_user = "root";
const root_password = "password1234";
const admin_host = "example.com";
const admin_origin = "https://" ++ admin_host;
const random_seed = 1234; const random_seed = 1234;
fn makeApi(alloc: std.mem.Allocator) !ApiSource { fn makeDb(alloc: std.mem.Allocator) sql.Db {
var db = try sql.Db.open(test_config.db);
try main.migrations.up(&db);
try main.api.setupAdmin(&db, admin_origin, root_user, root_password, alloc);
}
fn makeApi(alloc: std.mem.Allocator, db: *sql.Db) !ApiSource {
main.api.initThreadPrng(random_seed); main.api.initThreadPrng(random_seed);
const source = try ApiSource.init(alloc, test_config, root_password); const source = try ApiSource.init(alloc, test_config, db);
return source; return source;
} }
@ -26,10 +34,11 @@ test "login as root" {
const alloc = std.testing.allocator; const alloc = std.testing.allocator;
var arena = std.heap.ArenaAllocator.init(alloc); var arena = std.heap.ArenaAllocator.init(alloc);
defer arena.deinit(); defer arena.deinit();
var src = try makeApi(alloc); var db = try makeDb(alloc);
var src = try makeApi(alloc, &db);
std.debug.print("\npassword: {s}\n", .{root_password}); std.debug.print("\npassword: {s}\n", .{root_password});
var api = try src.connectUnauthorized(cluster_host, arena.allocator()); var api = try src.connectUnauthorized(admin_host, arena.allocator());
defer api.close(); defer api.close();
_ = try api.login("root", root_password); _ = try api.login("root", root_password);