From 955df7b0447c161d2854c6b0836bdf468cb01cfb Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Mon, 3 Oct 2022 19:41:59 -0700 Subject: [PATCH] refactor --- README.md | 6 + src/main/README.md | 11 - src/main/api.zig | 122 +++-- src/main/api/auth.zig | 389 ++++++++++------ src/main/api/communities.zig | 35 +- src/main/api/invites.zig | 30 +- src/main/api/users.zig | 71 +-- src/main/controllers/auth.zig | 4 +- src/main/main.zig | 16 +- src/main/migrations.zig | 70 +-- src/sql/engines/common.zig | 32 +- src/sql/engines/null.zig | 38 ++ src/sql/engines/postgres.zig | 35 +- src/sql/engines/postgres/errors.zig | 671 +++++++++++++++------------- src/sql/engines/sqlite.zig | 254 +++++++---- src/sql/errors.zig | 27 +- src/sql/lib.zig | 67 ++- src/util/DateTime.zig | 18 +- src/util/lib.zig | 7 +- tests/api_integration/lib.zig | 23 +- 20 files changed, 1181 insertions(+), 745 deletions(-) delete mode 100644 src/main/README.md create mode 100644 src/sql/engines/null.zig diff --git a/README.md b/README.md index 0c7b7cb..6991978 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,12 @@ - System libraries * `sqlite3` +NOTE: compilation is broken right now because of: + +https://github.com/ziglang/zig/issues/12240 + +for a temporary fix, rebuild zig after changing `$stdlibdir/crypto/scrypt.zig:465` to use `default_salt_len` instead of `salt_bin.len` + ### Commands To build a binary: `zig build` diff --git a/src/main/README.md b/src/main/README.md deleted file mode 100644 index a9335b5..0000000 --- a/src/main/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# General overview - -- `/controllers/**` - Handles serialization/deserialization of api calls from HTTP requests -- `/api.zig` - Business rules -- `/api/*.zig` - Performs the actual actions in the DB associated with a call -- `/db.zig` - SQL query wrapper - diff --git a/src/main/api.zig b/src/main/api.zig index 2166683..f7473b2 100644 --- a/src/main/api.zig +++ b/src/main/api.zig @@ -3,10 +3,8 @@ const util = @import("util"); const builtin = @import("builtin"); const sql = @import("sql"); -const models = @import("./db/models.zig"); -const migrations = @import("./migrations.zig"); -pub const DateTime = util.DateTime; -pub const Uuid = util.Uuid; +const DateTime = util.DateTime; +const Uuid = util.Uuid; const Config = @import("./main.zig").Config; const services = struct { @@ -25,21 +23,17 @@ pub const RegistrationRequest = struct { }; pub const InviteRequest = struct { - pub const Type = services.invites.InviteType; + pub const Kind = services.invites.Kind; name: ?[]const u8 = null, - expires_at: ?DateTime = null, // TODO: Change this to lifespan + lifespan: ?DateTime.Duration = null, max_uses: ?u16 = null, - invite_type: Type = .user, // must be user unless the creator is an admin + kind: Kind = .user, // must be user unless the creator is an admin to_community: ?[]const u8 = null, // only valid on admin community }; -pub const LoginResponse = struct { - token: services.auth.tokens.Token.Value, - user_id: Uuid, - issued_at: DateTime, -}; +pub const LoginResponse = services.auth.LoginResult; pub const UserResponse = struct { id: Uuid, @@ -103,7 +97,7 @@ pub fn isAdminSetup(db: *sql.Db) !bool { return true; } -pub fn setupAdmin(db: *sql.Db, origin: []const u8, username: []const u8, password: []const u8, allocator: std.mem.Allocator) !void { +pub fn setupAdmin(db: *sql.Db, origin: []const u8, username: []const u8, password: []const u8, allocator: std.mem.Allocator) anyerror!void { const tx = try db.begin(); errdefer tx.rollback(); var arena = std.heap.ArenaAllocator.init(allocator); @@ -111,22 +105,22 @@ pub fn setupAdmin(db: *sql.Db, origin: []const u8, username: []const u8, passwor try tx.setConstraintMode(.deferred); - const community = try services.communities.create( + const community_id = try services.communities.create( tx, origin, - Uuid.nil, .{ .name = "Cluster Admin", .kind = .admin }, + arena.allocator(), ); - const user = try services.users.create(tx, username, password, community.id, .{ .role = .admin }, arena.allocator()); + const user = try services.auth.register(tx, username, password, community_id, .{ .kind = .admin }, arena.allocator()); - try services.communities.transferOwnership(tx, community.id, user); + try services.communities.transferOwnership(tx, community_id, user); try tx.commit(); std.log.info( "Created admin user {s} (id {}) with cluster admin origin {s} (id {})", - .{ username, user, origin, community.id }, + .{ username, user, origin, community_id }, ); } @@ -168,12 +162,18 @@ pub const ApiSource = struct { const community = try services.communities.getByHost(self.db, host, arena.allocator()); - const token_info = try services.auth.tokens.verify(self.db, token, community.id); + const token_info = try services.auth.verifyToken( + self.db, + token, + community.id, + arena.allocator(), + ); return Conn{ .db = self.db, .internal_alloc = self.internal_alloc, - .user_id = token_info.user_id, + .token_info = token_info, + .user_id = token_info.account_id, .community = community, .arena = arena, }; @@ -186,7 +186,8 @@ fn ApiConn(comptime DbConn: type) type { db: DbConn, internal_alloc: std.mem.Allocator, // used *only* for large, internal buffers - user_id: ?Uuid, + token_info: ?services.auth.TokenInfo = null, + user_id: ?Uuid = null, community: services.communities.Community, arena: std.heap.ArenaAllocator, @@ -200,32 +201,35 @@ fn ApiConn(comptime DbConn: type) type { } pub fn login(self: *Self, username: []const u8, password: []const u8) !LoginResponse { - const user_id = (try services.users.lookupByUsername(self.db, username, self.community.id)) orelse return error.InvalidLogin; - try services.auth.passwords.verify(self.db, user_id, password, self.internal_alloc); - - const token = try services.auth.tokens.create(self.db, user_id); - - return LoginResponse{ - .user_id = user_id, - .token = token.value, - .issued_at = token.info.issued_at, - }; + return services.auth.login( + self.db, + username, + self.community.id, + password, + self.arena.allocator(), + ); } - const TokenInfo = struct { + pub const AuthorizationInfo = struct { + id: Uuid, username: []const u8, + community_id: Uuid, + host: []const u8, + + issued_at: DateTime, }; - pub fn getTokenInfo(self: *Self) !TokenInfo { - if (self.user_id) |user_id| { - const result = (try self.db.queryRow( - std.meta.Tuple(&.{[]const u8}), - "SELECT username FROM user WHERE id = $1", - .{user_id}, - self.arena.allocator(), - )) orelse { - return error.UserNotFound; + pub fn verifyAuthorization(self: *Self) !AuthorizationInfo { + if (self.token_info) |info| { + const user = try services.users.get(self.db, info.account_id, self.arena.allocator()); + + return AuthorizationInfo{ + .id = user.id, + .username = user.username, + .community_id = self.community.id, + .host = self.community.host, + + .issued_at = info.issued_at, }; - return TokenInfo{ .username = result[0] }; } return error.Unauthorized; @@ -236,7 +240,27 @@ fn ApiConn(comptime DbConn: type) type { return error.PermissionDenied; } - return services.communities.create(self.db, origin, self.user_id.?, .{}); + const tx = try self.db.begin(); + errdefer tx.rollback(); + const community_id = try services.communities.create( + tx, + origin, + .{}, + self.arena.allocator(), + ); + + const community = services.communities.get( + tx, + community_id, + self.arena.allocator(), + ) catch |err| return switch (err) { + error.NotFound => error.DatabaseError, + else => |err2| err2, + }; + + try tx.commit(); + + return community; } pub fn createInvite(self: *Self, options: InviteRequest) !services.invites.Invite { @@ -254,21 +278,23 @@ fn ApiConn(comptime DbConn: type) type { } else self.community.id; // Users can only make user invites - if (options.invite_type != .user and !self.isAdmin()) return error.PermissionDenied; + if (options.kind != .user and !self.isAdmin()) return error.PermissionDenied; - return try services.invites.create(self.db, user_id, community_id, .{ + const invite_id = try services.invites.create(self.db, user_id, community_id, .{ .name = options.name, - .expires_at = options.expires_at, + .lifespan = options.lifespan, .max_uses = options.max_uses, - .invite_type = options.invite_type, + .kind = options.kind, }, self.arena.allocator()); + + return try services.invites.get(self.db, invite_id, self.arena.allocator()); } pub fn register(self: *Self, request: RegistrationRequest) !UserResponse { std.log.debug("registering user {s} with code {s}", .{ request.username, request.invite_code }); const invite = try services.invites.getByCode(self.db, request.invite_code, self.arena.allocator()); - if (!Uuid.eql(invite.to_community, self.community.id)) return error.NotFound; + if (!Uuid.eql(invite.community_id, self.community.id)) return error.NotFound; if (invite.max_uses != null and invite.times_used >= invite.max_uses.?) return error.InviteExpired; if (invite.expires_at != null and DateTime.now().isAfter(invite.expires_at.?)) return error.InviteExpired; diff --git a/src/main/api/auth.zig b/src/main/api/auth.zig index 5ad7e91..d0c9772 100644 --- a/src/main/api/auth.zig +++ b/src/main/api/auth.zig @@ -1,157 +1,272 @@ const std = @import("std"); const util = @import("util"); +const users = @import("./users.zig"); const Uuid = util.Uuid; const DateTime = util.DateTime; -pub const passwords = struct { - const PwHash = std.crypto.pwhash.scrypt; - const pw_hash_params = PwHash.Params.interactive; - const pw_hash_encoding = .phc; - const pw_hash_buf_size = 128; +pub const RegistrationError = error{ + PasswordTooShort, + DatabaseFailure, + HashFailure, + OutOfMemory, +} || users.CreateError; - const PwHashBuf = [pw_hash_buf_size]u8; +pub const min_password_chars = 12; +pub const RegistrationOptions = struct { + invite_id: ?Uuid = null, + email: ?[]const u8 = null, + kind: users.Kind = .user, +}; - pub const VerifyError = error{ - InvalidLogin, - DatabaseFailure, - HashFailure, +/// Creates a local account with the given information and returns the +/// account id +pub fn register( + db: anytype, + username: []const u8, + password: []const u8, + community_id: Uuid, + options: RegistrationOptions, + alloc: std.mem.Allocator, +) RegistrationError!Uuid { + if (password.len < min_password_chars) return error.PasswordTooShort; + + const hash = try hashPassword(password, alloc); + defer alloc.free(hash); + + // transaction may already be running during initial db setup + if (@TypeOf(db).is_transaction) return registerTransaction( + db, + username, + hash, + community_id, + options, + alloc, + ); + + const tx = try db.begin(); + errdefer tx.rollback(); + + const id = registerTransaction( + tx, + username, + hash, + community_id, + options, + alloc, + ); + + try tx.commit(); + + return id; +} + +fn registerTransaction( + tx: anytype, + username: []const u8, + password_hash: []const u8, + community_id: Uuid, + options: RegistrationOptions, + alloc: std.mem.Allocator, +) RegistrationError!Uuid { + const id = try users.create(tx, username, community_id, options.kind, alloc); + tx.insert("local_account", .{ + .account_id = id, + .invite_id = options.invite_id, + .email = options.email, + }, alloc) catch return error.DatabaseFailure; + tx.insert("password", .{ + .account_id = id, + .hash = password_hash, + }, alloc) catch return error.DatabaseFailure; + + return id; +} + +pub const LoginError = error{ + InvalidLogin, + HashFailure, + DatabaseFailure, + OutOfMemory, +}; + +pub const LoginResult = struct { + token: []const u8, + account_id: Uuid, +}; + +/// Attempts to login to the account `@username@community` and creates +/// a login token/cookie for the user +pub fn login( + db: anytype, + username: []const u8, + community_id: Uuid, + password: []const u8, + alloc: std.mem.Allocator, +) LoginError!LoginResult { + std.log.debug("user: {s}, community_id: {}", .{ username, community_id }); + const info = db.queryRow( + struct { account_id: Uuid, hash: []const u8 }, + \\SELECT account.id as account_id, password.hash + \\FROM password JOIN account + \\ ON password.account_id = account.id + \\WHERE account.username = $1 + \\ AND account.community_id = $2 + \\LIMIT 1 + , + .{ username, community_id }, + alloc, + ) catch |err| return switch (err) { + error.NoRows => error.InvalidLogin, + else => error.DatabaseFailure, }; - pub fn verify( - db: anytype, - account_id: Uuid, - password: []const u8, - alloc: std.mem.Allocator, - ) VerifyError!void { - // TODO: This could be done w/o the dynamically allocated hash buf - const hash = db.queryRow( - std.meta.Tuple(&.{[]const u8}), - \\SELECT hashed_password - \\FROM account_password + errdefer util.deepFree(alloc, info); + std.log.debug("got password", .{}); + + try verifyPassword(info.hash, password, alloc); + + const token = try generateToken(alloc); + errdefer util.deepFree(alloc, token); + const token_hash = hashToken(token, alloc) catch |err| switch (err) { + error.OutOfMemory => return error.OutOfMemory, + else => unreachable, + }; + defer util.deepFree(alloc, token_hash); + + const tx = db.begin() catch return error.DatabaseFailure; + errdefer tx.rollback(); + + // ensure that the password has not changed in the meantime + { + const updated_info = tx.queryRow( + struct { hash: []const u8 }, + \\SELECT hash + \\FROM password \\WHERE account_id = $1 \\LIMIT 1 , - .{account_id}, - alloc, - ) catch |err| return switch (err) { - error.NoRows => error.InvalidLogin, - else => error.DatabaseFailure, - }; - errdefer alloc.free(hash[0]); - - PwHash.strVerify( - hash[0], - password, - .{ .allocator = alloc }, - ) catch error.HashFailure; - } - - pub const CreateError = error{ DatabaseFailure, HashFailure }; - pub fn create( - db: anytype, - account_id: Uuid, - password: []const u8, - alloc: std.mem.Allocator, - ) CreateError!void { - var buf: PwHashBuf = undefined; - const hash = PwHash.strHash( - password, - .{ .allocator = alloc, .params = pw_hash_params, .encoding = pw_hash_encoding }, - &buf, - ) catch return error.HashFailure; - - db.insert( - "account_password", - .{ - .account_id = account_id, - .hashed_password = hash, - }, + .{info.account_id}, alloc, ) catch return error.DatabaseFailure; + defer util.deepFree(alloc, updated_info); + + if (!std.mem.eql(u8, info.hash, updated_info.hash)) return error.InvalidLogin; } -}; -pub const tokens = struct { - const token_len = 20; - const token_str_len = std.base64.standard.Encoder.calcSize(token_len); - pub const Token = struct { - pub const Value = [token_str_len]u8; - pub const Info = struct { - account_id: Uuid, - issued_at: DateTime, - }; + tx.insert("token", .{ + .account_id = info.account_id, + .hash = token_hash, + }, alloc) catch return error.DatabaseFailure; - value: Value, - info: Info, + tx.commit() catch return error.DatabaseFailure; + + return LoginResult{ + .token = token, + .account_id = info.account_id, }; +} - const TokenHash = std.crypto.hash.sha2.Sha256; - const TokenDigestBuf = [TokenHash.digest_length]u8; - - const DbToken = struct { - hash: []const u8, - account_id: Uuid, - issued_at: DateTime, - }; - - pub const CreateError = error{DatabaseFailure}; - pub fn create(db: anytype, account_id: Uuid) CreateError!Token { - var token: [token_len]u8 = undefined; - std.crypto.random.bytes(&token); - - var hash: TokenDigestBuf = undefined; - TokenHash.hash(&token, &hash, .{}); - - const issued_at = DateTime.now(); - - db.insert("token", DbToken{ - .hash = &hash, - .account_id = account_id, - .issued_at = issued_at, - }) catch return error.DbError; - - var token_enc: [token_str_len]u8 = undefined; - _ = std.base64.standard.Encoder.encode(&token_enc, &token); - - return Token{ .value = token_enc, .info = .{ - .account_id = account_id, - .issued_at = issued_at, - } }; - } - - pub const VerifyError = error{ InvalidToken, DatabaseError }; - pub fn verify( - db: anytype, - token: []const u8, - community_id: Uuid, - alloc: std.mem.Allocator, - ) VerifyError!Token.Info { - const decoded_len = std.base64.standard.Decoder.calcSizeForSlice( - token, - ) catch return error.InvalidToken; - if (decoded_len != token_len) return error.InvalidToken; - - var decoded: [token_len]u8 = undefined; - std.base64.standard.Decoder.decode( - &decoded, - token, - ) catch return error.InvalidToken; - - var hash: TokenDigestBuf = undefined; - TokenHash.hash(&decoded, &hash, .{}); - - return db.queryRow( - Token.Info, - \\SELECT account.id, token.issued_at - \\FROM token JOIN account ON token.account_id = account.id - \\WHERE token.hash = $1 AND account.community_id = $2 - \\LIMIT 1 - , - .{ hash, community_id }, - alloc, - ) catch |err| switch (err) { - error.NoRows => error.InvalidToken, - else => error.DatabaseFailure, - }; - } +pub const VerifyTokenError = error{ InvalidToken, DatabaseFailure, OutOfMemory }; +pub const TokenInfo = struct { + account_id: Uuid, + issued_at: DateTime, }; +pub fn verifyToken( + db: anytype, + token: []const u8, + community_id: Uuid, + alloc: std.mem.Allocator, +) VerifyTokenError!TokenInfo { + const hash = try hashToken(token, alloc); + + return db.queryRow( + TokenInfo, + \\SELECT token.account_id, token.issued_at + \\FROM token JOIN account + \\ ON token.account_id = account.id + \\WHERE token.hash = $1 AND account.community_id = $2 + \\LIMIT 1 + , + .{ hash, community_id }, + alloc, + ) catch |err| switch (err) { + error.NoRows => error.InvalidToken, + else => error.DatabaseFailure, + }; +} + +// We use scrypt, a password hashing algorithm that attempts to slow down +// GPU-based cracking approaches by using large amounts of memory, for +// password hashing. +// Attempting to calculate/verify a hash will use about 50mb of work space. +const scrypt = std.crypto.pwhash.scrypt; +const password_hash_len = 128; +fn verifyPassword( + hash: []const u8, + password: []const u8, + alloc: std.mem.Allocator, +) LoginError!void { + scrypt.strVerify( + hash, + password, + .{ .allocator = alloc }, + ) catch |err| return switch (err) { + error.PasswordVerificationFailed => error.InvalidLogin, + else => error.HashFailure, + }; +} + +fn hashPassword(password: []const u8, alloc: std.mem.Allocator) ![]const u8 { + const buf = try alloc.alloc(u8, password_hash_len); + errdefer alloc.free(buf); + return scrypt.strHash( + password, + .{ + .allocator = alloc, + .params = scrypt.Params.interactive, + .encoding = .phc, + }, + buf, + ) catch error.HashFailure; +} + +/// A raw token is a sequence of N random bytes, base64 encoded. +/// When the token is generated: +/// - The hash of the token is calculated by: +/// 1. Decoding the base64 text +/// 2. Calculating the SHA256 hash of this text +/// 3. Encoding the hash back as base64 +/// - The b64 encoded hash is stored in the database +/// - The original token is returned to the user +/// * The user will treat it as opaque text +/// When the token is verified: +/// - The hash of the token is taken as shown above +/// - The database is scanned for a token matching this hash +/// - If none can be found, the token is invalid +const Sha256 = std.crypto.hash.sha2.Sha256; +const Base64Encoder = std.base64.standard.Encoder; +const Base64Decoder = std.base64.standard.Decoder; +const token_len = 12; +fn generateToken(alloc: std.mem.Allocator) ![]const u8 { + var token = std.mem.zeroes([token_len]u8); + std.crypto.random.bytes(&token); + + const token_b64_len = Base64Encoder.calcSize(token.len); + const token_b64 = try alloc.alloc(u8, token_b64_len); + return Base64Encoder.encode(token_b64, &token); +} + +fn hashToken(token_b64: []const u8, alloc: std.mem.Allocator) ![]const u8 { + const decoded_token_len = Base64Decoder.calcSizeForSlice(token_b64) catch return error.InvalidToken; + if (decoded_token_len != token_len) return error.InvalidToken; + + var token = std.mem.zeroes([token_len]u8); + Base64Decoder.decode(&token, token_b64) catch return error.InvalidToken; + + var hash = std.mem.zeroes([Sha256.digest_length]u8); + Sha256.hash(&token, &hash, .{}); + + const hash_b64_len = Base64Encoder.calcSize(hash.len); + const hash_b64 = try alloc.alloc(u8, hash_b64_len); + return Base64Encoder.encode(hash_b64, &hash); +} diff --git a/src/main/api/communities.zig b/src/main/api/communities.zig index 12b0d3b..bec3a58 100644 --- a/src/main/api/communities.zig +++ b/src/main/api/communities.zig @@ -25,7 +25,7 @@ pub const Kind = enum { pub const Community = struct { id: Uuid, - owner_id: Uuid, + owner_id: ?Uuid, host: []const u8, name: []const u8, @@ -46,7 +46,7 @@ pub const CreateError = error{ CommunityExists, }; -pub fn create(db: anytype, origin: []const u8, owner: Uuid, options: CreateOptions, alloc: std.mem.Allocator) CreateError!Uuid { +pub fn create(db: anytype, origin: []const u8, options: CreateOptions, alloc: std.mem.Allocator) CreateError!Uuid { const scheme_len = std.mem.indexOfScalar(u8, origin, ':') orelse return error.InvalidOrigin; const scheme_str = origin[0..scheme_len]; const scheme = std.meta.stringToEnum(Scheme, scheme_str) orelse return error.UnsupportedScheme; @@ -85,14 +85,14 @@ pub fn create(db: anytype, origin: []const u8, owner: Uuid, options: CreateOptio else => return error.DatabaseFailure, } - try db.insert("community", .{ + db.insert("community", .{ .id = id, - .owner_id = owner, + .owner_id = null, .host = host, .name = options.name orelse host, .scheme = scheme, .kind = options.kind, - }, alloc); + }, alloc) catch return error.DatabaseFailure; return id; } @@ -101,18 +101,27 @@ pub const GetError = error{ NotFound, DatabaseFailure, }; -pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) GetError!Community { + +fn getWhere( + db: anytype, + comptime where: []const u8, + args: anytype, + alloc: std.mem.Allocator, +) GetError!Community { return db.queryRow( Community, std.fmt.comptimePrint( \\SELECT {s} \\FROM community - \\WHERE host = $1 + \\WHERE {s} \\LIMIT 1 , - .{comptime sql.fieldList(Community)}, + .{ + comptime util.comptimeJoin(",", std.meta.fieldNames(Community)), + where, + }, ), - .{host}, + args, alloc, ) catch |err| switch (err) { error.NoRows => error.NotFound, @@ -120,6 +129,14 @@ pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) GetErr }; } +pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) GetError!Community { + return getWhere(db, "id = $1", .{id}, alloc); +} + +pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) GetError!Community { + return getWhere(db, "host = $1", .{host}, alloc); +} + pub fn transferOwnership(db: anytype, community_id: Uuid, new_owner: Uuid) !void { // TODO: check that this actually found/updated the row (needs update to sql lib) db.exec( diff --git a/src/main/api/invites.zig b/src/main/api/invites.zig index b10bd2f..54631e3 100644 --- a/src/main/api/invites.zig +++ b/src/main/api/invites.zig @@ -13,7 +13,7 @@ const code_len = 12; const Encoder = std.base64.url_safe.Encoder; const Decoder = std.base64.url_safe.Decoder; -pub const InviteKind = enum { +pub const Kind = enum { system, community_owner, user, @@ -26,7 +26,7 @@ pub const Invite = struct { id: Uuid, created_by: Uuid, // User ID - to_community: ?Uuid, + community_id: ?Uuid, name: []const u8, code: []const u8, @@ -36,17 +36,17 @@ pub const Invite = struct { expires_at: ?DateTime, max_uses: ?InviteCount, - invite_kind: InviteKind, + kind: Kind, }; pub const InviteOptions = struct { name: ?[]const u8 = null, max_uses: ?InviteCount = null, lifespan: ?DateTime.Duration = null, - invite_kind: InviteKind = .user, + kind: Kind = .user, }; -pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: InviteOptions, alloc: std.mem.Allocator) !Uuid { +pub fn create(db: anytype, created_by: Uuid, community_id: ?Uuid, options: InviteOptions, alloc: std.mem.Allocator) !Uuid { const id = Uuid.randV4(getRandom()); var code_bytes: [rand_len]u8 = undefined; @@ -65,7 +65,7 @@ pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: Invit .id = id, .created_by = created_by, - .to_community = to_community, + .community_id = community_id, .name = name, .code = code, @@ -76,7 +76,7 @@ pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: Invit else null, - .invite_kind = options.invite_kind, + .kind = options.kind, }, alloc, ); @@ -97,17 +97,27 @@ fn doGetQuery( alloc: std.mem.Allocator, ) GetError!Invite { // Generate list of fields from struct - const field_list = util.comptimeJoinWithPrefix( + const field_list = comptime util.comptimeJoinWithPrefix( ",", "invite.", - std.meta.fieldNames(Invite), + &.{ + "id", + "created_by", + "community_id", + "name", + "code", + "created_at", + "expires_at", + "max_uses", + "kind", + }, ); // times_used field is not stored directly in the DB, instead // it is calculated based on the number of accounts that were created // from it const query = std.fmt.comptimePrint( - \\SELECT {s}, COUNT(local_account.id) AS times_used + \\SELECT {s}, COUNT(local_account.account_id) AS times_used \\FROM invite LEFT OUTER JOIN local_account \\ ON invite.id = local_account.invite_id \\WHERE {s} diff --git a/src/main/api/users.zig b/src/main/api/users.zig index 4709973..7891a1d 100644 --- a/src/main/api/users.zig +++ b/src/main/api/users.zig @@ -8,10 +8,13 @@ const getRandom = @import("../api.zig").getRandom; pub const CreateError = error{ UsernameTaken, - DbError, + UsernameContainsInvalidChar, + UsernameTooLong, + UsernameEmpty, + DatabaseFailure, }; -pub const Role = enum { +pub const Kind = enum { user, admin, }; @@ -19,7 +22,7 @@ pub const Role = enum { pub const CreateOptions = struct { invite_id: ?Uuid = null, email: ?[]const u8 = null, - role: Role = .user, + kind: Kind = .user, }; pub const LookupError = error{ @@ -48,40 +51,49 @@ pub fn lookupByUsername( return row[0]; } -// TODO: This fn sucks. -// auth.passwords.create requires that the user exists, but we shouldn't -// hold onto a transaction for the ~0.5s that it takes to hash the password. -// Should probably change this to be specifically about creating the user, -// and then have something in auth responsible for creating local accounts +pub const max_username_chars = 32; +pub const UsernameValidationError = error{ + UsernameContainsInvalidChar, + UsernameTooLong, + UsernameEmpty, +}; + +/// Usernames must satisfy: +/// - Be at least 1 character +/// - Be no more than 32 characters +/// - All characters are in [A-Za-z0-9_.] +/// Note that the '.' character is not allowed in all usernames, and +/// is intended for use in federated instance actors (as many instances do) +pub fn validateUsername(username: []const u8) UsernameValidationError!void { + if (username.len == 0) return error.UsernameEmpty; + if (username.len > max_username_chars) return error.UsernameTooLong; + + for (username) |ch| { + const valid = std.ascii.isAlNum(ch) or ch == '_'; + if (!valid) return error.UsernameContainsInvalidChar; + } +} + pub fn create( db: anytype, username: []const u8, - password: []const u8, community_id: Uuid, - options: CreateOptions, + kind: Kind, alloc: std.mem.Allocator, ) CreateError!Uuid { const id = Uuid.randV4(getRandom()); - const tx = db.begin(); - errdefer tx.rollback(); - tx.insert("account", .{ + try validateUsername(username); + + db.insert("account", .{ .id = id, .username = username, .community_id = community_id, - .role = options.role, + .kind = kind, }, alloc) catch |err| return switch (err) { error.UniqueViolation => error.UsernameTaken, else => error.DatabaseFailure, }; - try auth.passwords.create(tx, id, password, alloc); - tx.insert("local_account", .{ - .user_id = id, - .invite_id = options.invite_id, - .email = options.email, - }) catch return error.DatabaseFailure; - - try tx.commit(); return id; } @@ -93,7 +105,7 @@ pub const User = struct { host: []const u8, community_id: Uuid, - role: Role, + kind: Kind, created_at: DateTime, }; @@ -101,9 +113,16 @@ pub const User = struct { pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !User { return db.queryRow( User, - \\SELECT user.username, community.host, community.id, user.created_at - \\FROM user JOIN community ON user.community_id = community.id - \\WHERE user.id = $1 + \\SELECT + \\ account.id, + \\ account.username, + \\ community.host, + \\ account.community_id, + \\ account.kind, + \\ account.created_at + \\FROM account JOIN community + \\ ON account.community_id = community.id + \\WHERE account.id = $1 \\LIMIT 1 , .{id}, diff --git a/src/main/controllers/auth.zig b/src/main/controllers/auth.zig index c2b356a..3c4b1e1 100644 --- a/src/main/controllers/auth.zig +++ b/src/main/controllers/auth.zig @@ -14,13 +14,13 @@ pub const login = struct { pub const path = "/auth/login"; pub const method = .POST; pub fn handler(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void { - std.debug.print("{s}", .{ctx.request.body.?}); const credentials = try utils.parseRequestBody(struct { username: []const u8, password: []const u8 }, ctx); defer utils.freeRequestBody(credentials, ctx.alloc); var api = try utils.getApiConn(srv, ctx); defer api.close(); + std.log.debug("connected to api", .{}); const token = try api.login(credentials.username, credentials.password); try utils.respondJson(ctx, .ok, token); @@ -37,7 +37,7 @@ pub const verify_login = struct { // The self-hosted compiler doesn't like inferring this error set. // do this for now - const info = api.getTokenInfo() catch unreachable; + const info = try api.verifyAuthorization(); try utils.respondJson(ctx, .ok, info); } diff --git a/src/main/main.zig b/src/main/main.zig index 85d6ce1..9d561d1 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -5,7 +5,7 @@ const http = @import("http"); const util = @import("util"); pub const api = @import("./api.zig"); -const migrations = @import("./migrations.zig"); +pub const migrations = @import("./migrations.zig"); const Uuid = util.Uuid; const c = @import("./controllers.zig"); @@ -24,10 +24,10 @@ const router = Router{ prepare(c.invites.create), - prepare(c.users.create), + //prepare(c.users.create), - prepare(c.notes.create), - prepare(c.notes.get), + //prepare(c.notes.create), + //prepare(c.notes.get), //Route.new(.GET, "/notes/:id/reacts", &c.notes.reacts.list), //Route.new(.POST, "/notes/:id/reacts", &c.notes.reacts.create), @@ -82,9 +82,7 @@ pub const RequestServer = struct { }; pub const Config = struct { - cluster_host: []const u8, db: sql.Config, - root_password: ?[]const u8 = null, }; fn loadConfig(alloc: std.mem.Allocator) !Config { @@ -112,6 +110,7 @@ fn runAdminSetup(db: *sql.Db, alloc: std.mem.Allocator) !void { fn prepareDb(db: *sql.Db, alloc: std.mem.Allocator) !void { try migrations.up(db); + api.initThreadPrng(@bitCast(u64, std.time.milliTimestamp())); if (!try api.isAdminSetup(db)) { std.log.info("Performing first-time admin creation...", .{}); @@ -134,14 +133,15 @@ fn prepareDb(db: *sql.Db, alloc: std.mem.Allocator) !void { } } -pub fn main() anyerror!void { +pub fn main() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; var cfg = try loadConfig(gpa.allocator()); var db_conn = try sql.Db.open(cfg.db); try prepareDb(&db_conn, gpa.allocator()); + //try migrations.up(&db_conn); + //try api.setupAdmin(&db_conn, "http://localhost:8080", "root", "password", gpa.allocator()); var api_src = try api.ApiSource.init(gpa.allocator(), cfg, &db_conn); var srv = try RequestServer.init(gpa.allocator(), &api_src, cfg); - api.initThreadPrng(@bitCast(u64, std.time.milliTimestamp())); return srv.listenAndRun(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); } diff --git a/src/main/migrations.zig b/src/main/migrations.zig index 9b35e75..b39cbf4 100644 --- a/src/main/migrations.zig +++ b/src/main/migrations.zig @@ -38,8 +38,15 @@ fn execScript(db: *sql.Db, script: []const u8, alloc: std.mem.Allocator) !void { } fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !bool { - const row = (try db.queryRow(std.meta.Tuple(&.{i32}), "SELECT COUNT(*) FROM migration WHERE name = $1", .{name}, alloc)) orelse return false; - return row[0] != 0; + return if (db.queryRow( + std.meta.Tuple(&.{i32}), + "SELECT COUNT(*) FROM migration WHERE name = $1 LIMIT 1", + .{name}, + alloc, + )) |row| row[0] != 0 else |err| switch (err) { + error.NoRows => false, + else => error.DatabaseFailure, + }; } pub fn up(db: *sql.Db) !void { @@ -70,41 +77,42 @@ const create_migration_table = // migrations into a single one. this will require db recreation const migrations: []const Migration = &.{ .{ - .name = "users", + .name = "accounts", .up = - \\CREATE TABLE user( - \\ id TEXT NOT NULL PRIMARY KEY, + \\CREATE TABLE account( + \\ id UUID NOT NULL PRIMARY KEY, \\ username TEXT NOT NULL, \\ + \\ kind TEXT NOT NULL CHECK (kind IN ('admin', 'user')), \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\); \\ - \\CREATE TABLE local_user( - \\ user_id TEXT NOT NULL PRIMARY KEY REFERENCES user(id), + \\CREATE TABLE local_account( + \\ account_id UUID NOT NULL PRIMARY KEY REFERENCES account(id), \\ \\ email TEXT \\); \\ - \\CREATE TABLE account_password( - \\ user_id TEXT NOT NULL PRIMARY KEY REFERENCES user(id), + \\CREATE TABLE password( + \\ account_id UUID NOT NULL PRIMARY KEY REFERENCES account(id), \\ - \\ hashed_password BLOB NOT NULL + \\ hash BLOB NOT NULL \\); , .down = - \\DROP TABLE account_password; - \\DROP TABLE local_user; - \\DROP TABLE user; + \\DROP TABLE password; + \\DROP TABLE local_account; + \\DROP TABLE account; , }, .{ .name = "notes", .up = \\CREATE TABLE note( - \\ id TEXT NOT NULL, + \\ id UUID NOT NULL, \\ \\ content TEXT NOT NULL, - \\ author_id TEXT NOT NULL REFERENCES user(id), + \\ author_id UUID NOT NULL REFERENCES account(id), \\ \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\); @@ -115,10 +123,10 @@ const migrations: []const Migration = &.{ .name = "note reactions", .up = \\CREATE TABLE reaction( - \\ id TEXT NOT NULL PRIMARY KEY, + \\ id UUID NOT NULL PRIMARY KEY, \\ - \\ user_id TEXT NOT NULL REFERENCES user(id), - \\ note_id TEXT NOT NULL REFERENCES note(id), + \\ account_id UUID NOT NULL REFERENCES account(id), + \\ note_id UUID NOT NULL REFERENCES note(id), \\ \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\); @@ -126,11 +134,11 @@ const migrations: []const Migration = &.{ .down = "DROP TABLE reaction;", }, .{ - .name = "user tokens", + .name = "account tokens", .up = \\CREATE TABLE token( \\ hash TEXT NOT NULL PRIMARY KEY, - \\ user_id TEXT NOT NULL REFERENCES local_user(id), + \\ account_id UUID NOT NULL REFERENCES local_account(id), \\ \\ issued_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\); @@ -138,26 +146,26 @@ const migrations: []const Migration = &.{ .down = "DROP TABLE token;", }, .{ - .name = "user invites", + .name = "account invites", .up = \\CREATE TABLE invite( - \\ id TEXT NOT NULL PRIMARY KEY, + \\ id UUID NOT NULL PRIMARY KEY, \\ \\ name TEXT NOT NULL, \\ code TEXT NOT NULL UNIQUE, - \\ created_by TEXT NOT NULL REFERENCES local_user(id), + \\ created_by UUID NOT NULL REFERENCES local_account(id), \\ \\ max_uses INTEGER, \\ \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, \\ expires_at TIMESTAMPTZ, \\ - \\ type TEXT NOT NULL CHECK (type in ('system_user', 'community_owner', 'user')) + \\ kind TEXT NOT NULL CHECK (kind in ('system_user', 'community_owner', 'user')) \\); - \\ALTER TABLE local_user ADD COLUMN invite_id TEXT REFERENCES invite(id); + \\ALTER TABLE local_account ADD COLUMN invite_id UUID REFERENCES invite(id); , .down = - \\ALTER TABLE local_user DROP COLUMN invite_id; + \\ALTER TABLE local_account DROP COLUMN invite_id; \\DROP TABLE invite; , }, @@ -165,9 +173,9 @@ const migrations: []const Migration = &.{ .name = "communities", .up = \\CREATE TABLE community( - \\ id TEXT NOT NULL PRIMARY KEY, + \\ id UUID NOT NULL PRIMARY KEY, \\ - \\ owner_id TEXT REFERENCES user(id), + \\ owner_id UUID REFERENCES account(id), \\ name TEXT NOT NULL, \\ host TEXT NOT NULL UNIQUE, \\ scheme TEXT NOT NULL CHECK (scheme IN ('http', 'https')), @@ -175,12 +183,12 @@ const migrations: []const Migration = &.{ \\ \\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP \\); - \\ALTER TABLE user ADD COLUMN community_id TEXT REFERENCES community(id); - \\ALTER TABLE invite ADD COLUMN community_id TEXT REFERENCES community(id); + \\ALTER TABLE account ADD COLUMN community_id UUID REFERENCES community(id); + \\ALTER TABLE invite ADD COLUMN community_id UUID REFERENCES community(id); , .down = \\ALTER TABLE invite DROP COLUMN community_id; - \\ALTER TABLE user DROP COLUMN community_id; + \\ALTER TABLE account DROP COLUMN community_id; \\DROP TABLE community; , }, diff --git a/src/sql/engines/common.zig b/src/sql/engines/common.zig index d7fe906..849e470 100644 --- a/src/sql/engines/common.zig +++ b/src/sql/engines/common.zig @@ -20,17 +20,17 @@ pub const OpenError = error{BadConnection} || UnexpectedError; pub const ExecError = error{ Cancelled, - ConnectionLost, + BadConnection, InternalException, DatabaseBusy, PermissionDenied, SqlException, /// Argument could not be marshalled for query - InvalidArgument, + BindException, /// An argument was not used by the query (not checked in all DB engines) - UndefinedParameter, + UnusedArgument, /// Memory error when marshalling argument for query OutOfMemory, @@ -39,7 +39,7 @@ pub const ExecError = error{ pub const RowError = error{ Cancelled, - ConnectionLost, + BadConnection, InternalException, DatabaseBusy, PermissionDenied, @@ -49,7 +49,7 @@ pub const RowError = error{ pub const GetError = error{ OutOfMemory, AllocatorRequired, - TypeMismatch, + ResultTypeMismatch, } || UnexpectedError; pub const ColumnCountError = error{OutOfRange}; @@ -97,13 +97,25 @@ pub fn prepareParamText(arena: *std.heap.ArenaAllocator, val: anytype) !?[:0]con // Parse a (not-null) value from a string pub fn parseValueNotNull(alloc: ?Allocator, comptime T: type, str: []const u8) !T { return switch (T) { - Uuid => Uuid.parse(str), - DateTime => DateTime.parse(str), - []u8, []const u8 => if (alloc) |a| util.deepClone(a, str) else return error.AllocatorRequired, + Uuid => Uuid.parse(str) catch |err| { + std.log.err("Error {} parsing UUID: '{s}'", .{ err, str }); + return error.ResultTypeMismatch; + }, + DateTime => DateTime.parse(str) catch |err| { + std.log.err("Error {} parsing DateTime: '{s}'", .{ err, str }); + return error.ResultTypeMismatch; + }, + []u8, []const u8 => if (alloc) |a| try util.deepClone(a, str) else return error.AllocatorRequired, else => switch (@typeInfo(T)) { - .Int => std.fmt.parseInt(T, str, 0), - .Enum => std.meta.stringToEnum(T, str) orelse return error.InvalidValue, + .Int => std.fmt.parseInt(T, str, 0) catch |err| { + std.log.err("Could not parse int: {}", .{err}); + return error.ResultTypeMismatch; + }, + .Enum => std.meta.stringToEnum(T, str) orelse { + std.log.err("'{s}' is not a member of enum type {s}", .{ str, @typeName(T) }); + return error.ResultTypeMismatch; + }, .Optional => try parseValueNotNull(alloc, std.meta.Child(T), str), else => @compileError("Type " ++ @typeName(T) ++ " not supported"), diff --git a/src/sql/engines/null.zig b/src/sql/engines/null.zig new file mode 100644 index 0000000..f6b2d9d --- /dev/null +++ b/src/sql/engines/null.zig @@ -0,0 +1,38 @@ +const std = @import("std"); +const common = @import("./common.zig"); +const Allocator = std.mem.Allocator; + +pub const Results = struct { + pub fn row(_: *Results) common.RowError!?Row { + unreachable; + } + pub fn columnCount(_: Results) common.ColumnCountError!u15 { + unreachable; + } + pub fn columnIndex(_: Results, _: []const u8) common.ColumnIndexError!u15 { + unreachable; + } + pub fn finish(_: Results) void { + unreachable; + } +}; + +pub const Row = struct { + pub fn get(_: Row, comptime T: type, _: u15, _: ?Allocator) common.GetError!T { + unreachable; + } +}; + +pub const Db = struct { + pub fn open(_: anytype) common.OpenError!Db { + unreachable; + } + + pub fn close(_: Db) void { + unreachable; + } + + pub fn exec(_: Db, _: [:0]const u8, _: anytype, _: common.QueryOptions) common.ExecError!Results { + unreachable; + } +}; diff --git a/src/sql/engines/postgres.zig b/src/sql/engines/postgres.zig index e54f462..94ef674 100644 --- a/src/sql/engines/postgres.zig +++ b/src/sql/engines/postgres.zig @@ -21,6 +21,10 @@ pub const Results = struct { }; } + fn rowCount(self: Results) c_int { + return c.PQntuples(self.result); + } + pub fn columnCount(self: Results) common.ColumnCountError!u15 { return std.math.cast(u15, c.PQnfields(self.result)) orelse error.OutOfRange; } @@ -36,7 +40,7 @@ pub const Results = struct { } }; -fn handleError(result: *c.PQresult) common.RowError { +fn handleError(result: *c.PGresult) common.RowError { const error_code = c.PQresultErrorField(result, c.PG_DIAG_SQLSTATE); const state = errors.SqlState.parse(error_code) catch { std.log.err("Database returned invalid error code {?s}", .{error_code}); @@ -126,7 +130,7 @@ pub const Row = struct { const val = c.PQgetvalue(self.result, self.row_index, idx); const is_null = (c.PQgetisnull(self.result, self.row_index, idx) != 0); if (is_null) { - return if (@typeInfo(T) == .Optional) null else error.TypeMismatch; + return if (@typeInfo(T) == .Optional) null else error.ResultTypeMismatch; } if (val == null) return error.Unexpected; @@ -175,15 +179,30 @@ pub const Db = struct { const format_text = 0; const format_binary = 1; - pub fn exec(self: Db, sql: [:0]const u8, args: anytype, alloc: ?Allocator) !Results { + pub fn exec(self: Db, sql: [:0]const u8, args: anytype, opt: common.QueryOptions) common.ExecError!Results { + const alloc = opt.prep_allocator; const result = blk: { - if (comptime args.len > 0) { + if (@TypeOf(args) != void and args.len > 0) { var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired); defer arena.deinit(); const params = try arena.allocator().alloc(?[*:0]const u8, args.len); - inline for (args) |arg, i| params[i] = if (try common.prepareParamText(&arena, arg)) |slice| slice.ptr else null; + inline for (args) |arg, i| { + params[i] = if (try common.prepareParamText(&arena, arg)) |slice| + slice.ptr + else + null; + } - break :blk c.PQexecParams(self.conn, sql.ptr, @intCast(c_int, params.len), null, params.ptr, null, null, format_text); + break :blk c.PQexecParams( + self.conn, + sql.ptr, + @intCast(c_int, params.len), + null, + params.ptr, + null, + null, + format_text, + ); } else { break :blk c.PQexecParams(self.conn, sql.ptr, 0, null, null, null, null, format_text); } @@ -199,11 +218,11 @@ pub const Db = struct { c.PGRES_TUPLES_OK, => return Results{ .result = result }, - c.PGRES_EMPTY_QUERY => return error.InvalidSql, + c.PGRES_EMPTY_QUERY => return error.SqlException, c.PGRES_BAD_RESPONSE => { std.log.err("Database returned invalid response: {?s}", .{c.PQresultErrorMessage(result)}); - return error.Database; + return error.InternalException; }, c.PGRES_FATAL_ERROR => return handleError(result), diff --git a/src/sql/engines/postgres/errors.zig b/src/sql/engines/postgres/errors.zig index 1d91a17..33b064b 100644 --- a/src/sql/engines/postgres/errors.zig +++ b/src/sql/engines/postgres/errors.zig @@ -1,15 +1,20 @@ const std = @import("std"); const c = @import("./c.zig"); -const readIntBig = std.mem.readIntBig; +const readVarInt = std.mem.readVarInt; const Code = u40; // 8 * 5 = 40 -const code_class_mask: Code = 0xFFFF_000000; pub const SqlStateClass = blk: { + @setEvalBranchQuota(10000); const info = @typeInfo(SqlState).Enum; - var fields = &[0]std.builtin.Type.EnumField{}; + const EnumField = std.builtin.Type.EnumField; + + var fields: []const EnumField = &.{}; for (info.fields) |field| { - if (field.value & ~code_class_mask == 0) fields = fields ++ &.{field}; + const class_code = toClassCode(field.value); + if (class_code == field.value) { + fields = fields ++ &[_]EnumField{field}; + } } break :blk @Type(.{ .Enum = .{ .layout = info.layout, @@ -20,322 +25,352 @@ pub const SqlStateClass = blk: { } }); }; +fn toCodeStr(code: Code) [5]u8 { + var code_str: [5]u8 = undefined; + code_str[0] = @intCast(u8, (code & 0xFF00_000000) >> 32); + code_str[1] = @intCast(u8, (code & 0x00FF_000000) >> 24); + code_str[2] = @intCast(u8, (code & 0x0000_FF0000) >> 16); + code_str[3] = @intCast(u8, (code & 0x0000_00FF00) >> 8); + code_str[4] = @intCast(u8, (code & 0x0000_0000FF) >> 0); + + return code_str; +} + +fn toClassCode(code: Code) Code { + var code_str = [_]u8{'0'} ** 5; + code_str[0] = @intCast(u8, (code & 0xFF00_000000) >> 32); + code_str[1] = @intCast(u8, (code & 0x00FF_000000) >> 24); + + code_str[2] = '0'; + code_str[3] = '0'; + code_str[4] = '0'; + + return readVarInt(Code, &code_str, .Big); +} + // SqlState values for Postgres 14 -pub const SqlState = enum(Code) { - pub const ParseError = error{ InvalidSize, NullPointer }; - pub fn parse(code_str: [*c]const u8) ParseError!SqlState { - if (code_str == null) return error.NullPointer; - const slice = std.mem.span(code_str); - if (slice.len != @sizeOf(Code)) return error.InvalidSize; - return @intToEnum(SqlState, std.mem.readIntSliceBig(Code, slice)); - } +pub const SqlState = blk: { + @setEvalBranchQuota(10000); + break :blk enum(Code) { + pub const ParseError = error{ InvalidSize, NullPointer }; + pub fn parse(code_str: [*c]const u8) ParseError!SqlState { + if (code_str == null) return error.NullPointer; + const slice = std.mem.span(code_str); + if (slice.len != @sizeOf(Code)) return error.InvalidSize; + return @intToEnum(SqlState, readVarInt(Code, slice, .Big)); + } - pub fn errorClass(code: SqlState) SqlStateClass { - return @intToEnum(SqlStateClass, @enumToInt(code) & code_class_mask); - } + pub fn errorClass(code: SqlState) SqlStateClass { + return @intToEnum(SqlStateClass, toClassCode(@enumToInt(code))); + } - // Class 00 — Successful Completion - successful_completion = readIntBig(Code, "00000"), - // Class 01 — Warning - warning = readIntBig(Code, "01000"), - dynamic_result_sets_returned = readIntBig(Code, "0100C"), - implicit_zero_bit_padding = readIntBig(Code, "01008"), - null_value_eliminated_in_set_function = readIntBig(Code, "01003"), - privilege_not_granted = readIntBig(Code, "01007"), - privilege_not_revoked = readIntBig(Code, "01006"), - string_data_right_truncation = readIntBig(Code, "01004"), - deprecated_feature = readIntBig(Code, "01P01"), - // Class 02 — No Data (this is also a warning class per the SQL standard) - no_data = readIntBig(Code, "02000"), - no_additional_dynamic_result_sets_returned = readIntBig(Code, "02001"), - // Class 03 — SQL Statement Not Yet Complete - sql_statement_not_yet_complete = readIntBig(Code, "03000"), - // Class 08 — Connection Exception - connection_exception = readIntBig(Code, "08000"), - connection_does_not_exist = readIntBig(Code, "08003"), - connection_failure = readIntBig(Code, "08006"), - sqlclient_unable_to_establish_sqlconnection = readIntBig(Code, "08001"), - sqlserver_rejected_establishment_of_sqlconnection = readIntBig(Code, "08004"), - transaction_resolution_unknown = readIntBig(Code, "08007"), - protocol_violation = readIntBig(Code, "08P01"), - // Class 09 — Triggered Action Exception - triggered_action_exception = readIntBig(Code, "09000"), - // Class 0A — Feature Not Supported - feature_not_supported = readIntBig(Code, "0A000"), - // Class 0B — Invalid Transaction Initiation - invalid_transaction_initiation = readIntBig(Code, "0B000"), - // Class 0F — Locator Exception - locator_exception = readIntBig(Code, "0F000"), - invalid_locator_specification = readIntBig(Code, "0F001"), - // Class 0L — Invalid Grantor - invalid_grantor = readIntBig(Code, "0L000"), - invalid_grant_operation = readIntBig(Code, "0LP01"), - // Class 0P — Invalid Role Specification - invalid_role_specification = readIntBig(Code, "0P000"), - // Class 0Z — Diagnostics Exception - diagnostics_exception = readIntBig(Code, "0Z000"), - stacked_diagnostics_accessed_without_active_handler = readIntBig(Code, "0Z002"), - // Class 20 — Case Not Found - case_not_found = readIntBig(Code, "20000"), - // Class 21 — Cardinality Violation - cardinality_violation = readIntBig(Code, "21000"), - // Class 22 — Data Exception - data_exception = readIntBig(Code, "22000"), - array_subscript_error = readIntBig(Code, "2202E"), - character_not_in_repertoire = readIntBig(Code, "22021"), - datetime_field_overflow = readIntBig(Code, "22008"), - division_by_zero = readIntBig(Code, "22012"), - error_in_assignment = readIntBig(Code, "22005"), - escape_character_conflict = readIntBig(Code, "2200B"), - indicator_overflow = readIntBig(Code, "22022"), - interval_field_overflow = readIntBig(Code, "22015"), - invalid_argument_for_logarithm = readIntBig(Code, "2201E"), - invalid_argument_for_ntile_function = readIntBig(Code, "22014"), - invalid_argument_for_nth_value_function = readIntBig(Code, "22016"), - invalid_argument_for_power_function = readIntBig(Code, "2201F"), - invalid_argument_for_width_bucket_function = readIntBig(Code, "2201G"), - invalid_character_value_for_cast = readIntBig(Code, "22018"), - invalid_datetime_format = readIntBig(Code, "22007"), - invalid_escape_character = readIntBig(Code, "22019"), - invalid_escape_octet = readIntBig(Code, "2200D"), - invalid_escape_sequence = readIntBig(Code, "22025"), - nonstandard_use_of_escape_character = readIntBig(Code, "22P06"), - invalid_indicator_parameter_value = readIntBig(Code, "22010"), - invalid_parameter_value = readIntBig(Code, "22023"), - invalid_preceding_or_following_size = readIntBig(Code, "22013"), - invalid_regular_expression = readIntBig(Code, "2201B"), - invalid_row_count_in_limit_clause = readIntBig(Code, "2201W"), - invalid_row_count_in_result_offset_clause = readIntBig(Code, "2201X"), - invalid_tablesample_argument = readIntBig(Code, "2202H"), - invalid_tablesample_repeat = readIntBig(Code, "2202G"), - invalid_time_zone_displacement_value = readIntBig(Code, "22009"), - invalid_use_of_escape_character = readIntBig(Code, "2200C"), - most_specific_type_mismatch = readIntBig(Code, "2200G"), - null_value_not_allowed = readIntBig(Code, "22004"), - null_value_no_indicator_parameter = readIntBig(Code, "22002"), - numeric_value_out_of_range = readIntBig(Code, "22003"), - sequence_generator_limit_exceeded = readIntBig(Code, "2200H"), - string_data_length_mismatch = readIntBig(Code, "22026"), - string_data_right_truncation = readIntBig(Code, "22001"), - substring_error = readIntBig(Code, "22011"), - trim_error = readIntBig(Code, "22027"), - unterminated_c_string = readIntBig(Code, "22024"), - zero_length_character_string = readIntBig(Code, "2200F"), - floating_point_exception = readIntBig(Code, "22P01"), - invalid_text_representation = readIntBig(Code, "22P02"), - invalid_binary_representation = readIntBig(Code, "22P03"), - bad_copy_file_format = readIntBig(Code, "22P04"), - untranslatable_character = readIntBig(Code, "22P05"), - not_an_xml_document = readIntBig(Code, "2200L"), - invalid_xml_document = readIntBig(Code, "2200M"), - invalid_xml_content = readIntBig(Code, "2200N"), - invalid_xml_comment = readIntBig(Code, "2200S"), - invalid_xml_processing_instruction = readIntBig(Code, "2200T"), - duplicate_json_object_key_value = readIntBig(Code, "22030"), - invalid_argument_for_sql_json_datetime_function = readIntBig(Code, "22031"), - invalid_json_text = readIntBig(Code, "22032"), - invalid_sql_json_subscript = readIntBig(Code, "22033"), - more_than_one_sql_json_item = readIntBig(Code, "22034"), - no_sql_json_item = readIntBig(Code, "22035"), - non_numeric_sql_json_item = readIntBig(Code, "22036"), - non_unique_keys_in_a_json_object = readIntBig(Code, "22037"), - singleton_sql_json_item_required = readIntBig(Code, "22038"), - sql_json_array_not_found = readIntBig(Code, "22039"), - sql_json_member_not_found = readIntBig(Code, "2203A"), - sql_json_number_not_found = readIntBig(Code, "2203B"), - sql_json_object_not_found = readIntBig(Code, "2203C"), - too_many_json_array_elements = readIntBig(Code, "2203D"), - too_many_json_object_members = readIntBig(Code, "2203E"), - sql_json_scalar_required = readIntBig(Code, "2203F"), - // Class 23 — Integrity Constraint Violation - integrity_constraint_violation = readIntBig(Code, "23000"), - restrict_violation = readIntBig(Code, "23001"), - not_null_violation = readIntBig(Code, "23502"), - foreign_key_violation = readIntBig(Code, "23503"), - unique_violation = readIntBig(Code, "23505"), - check_violation = readIntBig(Code, "23514"), - exclusion_violation = readIntBig(Code, "23P01"), - // Class 24 — Invalid Cursor State - invalid_cursor_state = readIntBig(Code, "24000"), - // Class 25 — Invalid Transaction State - invalid_transaction_state = readIntBig(Code, "25000"), - active_sql_transaction = readIntBig(Code, "25001"), - branch_transaction_already_active = readIntBig(Code, "25002"), - held_cursor_requires_same_isolation_level = readIntBig(Code, "25008"), - inappropriate_access_mode_for_branch_transaction = readIntBig(Code, "25003"), - inappropriate_isolation_level_for_branch_transaction = readIntBig(Code, "25004"), - no_active_sql_transaction_for_branch_transaction = readIntBig(Code, "25005"), - read_only_sql_transaction = readIntBig(Code, "25006"), - schema_and_data_statement_mixing_not_supported = readIntBig(Code, "25007"), - no_active_sql_transaction = readIntBig(Code, "25P01"), - in_failed_sql_transaction = readIntBig(Code, "25P02"), - idle_in_transaction_session_timeout = readIntBig(Code, "25P03"), - // Class 26 — Invalid SQL Statement Name - invalid_sql_statement_name = readIntBig(Code, "26000"), - // Class 27 — Triggered Data Change Violation - triggered_data_change_violation = readIntBig(Code, "27000"), - // Class 28 — Invalid Authorization Specification - invalid_authorization_specification = readIntBig(Code, "28000"), - invalid_password = readIntBig(Code, "28P01"), - // Class 2B — Dependent Privilege Descriptors Still Exist - dependent_privilege_descriptors_still_exist = readIntBig(Code, "2B000"), - dependent_objects_still_exist = readIntBig(Code, "2BP01"), - // Class 2D — Invalid Transaction Termination - invalid_transaction_termination = readIntBig(Code, "2D000"), - // Class 2F — SQL Routine Exception - sql_routine_exception = readIntBig(Code, "2F000"), - function_executed_no_return_statement = readIntBig(Code, "2F005"), - modifying_sql_data_not_permitted = readIntBig(Code, "2F002"), - prohibited_sql_statement_attempted = readIntBig(Code, "2F003"), - reading_sql_data_not_permitted = readIntBig(Code, "2F004"), - // Class 34 — Invalid Cursor Name - invalid_cursor_name = readIntBig(Code, "34000"), - // Class 38 — External Routine Exception - external_routine_exception = readIntBig(Code, "38000"), - containing_sql_not_permitted = readIntBig(Code, "38001"), - modifying_sql_data_not_permitted = readIntBig(Code, "38002"), - prohibited_sql_statement_attempted = readIntBig(Code, "38003"), - reading_sql_data_not_permitted = readIntBig(Code, "38004"), - // Class 39 — External Routine Invocation Exception - external_routine_invocation_exception = readIntBig(Code, "39000"), - invalid_sqlstate_returned = readIntBig(Code, "39001"), - null_value_not_allowed = readIntBig(Code, "39004"), - trigger_protocol_violated = readIntBig(Code, "39P01"), - srf_protocol_violated = readIntBig(Code, "39P02"), - event_trigger_protocol_violated = readIntBig(Code, "39P03"), - // Class 3B — Savepoint Exception - savepoint_exception = readIntBig(Code, "3B000"), - invalid_savepoint_specification = readIntBig(Code, "3B001"), - // Class 3D — Invalid Catalog Name - invalid_catalog_name = readIntBig(Code, "3D000"), - // Class 3F — Invalid Schema Name - invalid_schema_name = readIntBig(Code, "3F000"), - // Class 40 — Transaction Rollback - transaction_rollback = readIntBig(Code, "40000"), - transaction_integrity_constraint_violation = readIntBig(Code, "40002"), - serialization_failure = readIntBig(Code, "40001"), - statement_completion_unknown = readIntBig(Code, "40003"), - deadlock_detected = readIntBig(Code, "40P01"), - // Class 42 — Syntax Error or Access Rule Violation - syntax_error_or_access_rule_violation = readIntBig(Code, "42000"), - syntax_error = readIntBig(Code, "42601"), - insufficient_privilege = readIntBig(Code, "42501"), - cannot_coerce = readIntBig(Code, "42846"), - grouping_error = readIntBig(Code, "42803"), - windowing_error = readIntBig(Code, "42P20"), - invalid_recursion = readIntBig(Code, "42P19"), - invalid_foreign_key = readIntBig(Code, "42830"), - invalid_name = readIntBig(Code, "42602"), - name_too_long = readIntBig(Code, "42622"), - reserved_name = readIntBig(Code, "42939"), - datatype_mismatch = readIntBig(Code, "42804"), - indeterminate_datatype = readIntBig(Code, "42P18"), - collation_mismatch = readIntBig(Code, "42P21"), - indeterminate_collation = readIntBig(Code, "42P22"), - wrong_object_type = readIntBig(Code, "42809"), - generated_always = readIntBig(Code, "428C9"), - undefined_column = readIntBig(Code, "42703"), - undefined_function = readIntBig(Code, "42883"), - undefined_table = readIntBig(Code, "42P01"), - undefined_parameter = readIntBig(Code, "42P02"), - undefined_object = readIntBig(Code, "42704"), - duplicate_column = readIntBig(Code, "42701"), - duplicate_cursor = readIntBig(Code, "42P03"), - duplicate_database = readIntBig(Code, "42P04"), - duplicate_function = readIntBig(Code, "42723"), - duplicate_prepared_statement = readIntBig(Code, "42P05"), - duplicate_schema = readIntBig(Code, "42P06"), - duplicate_table = readIntBig(Code, "42P07"), - duplicate_alias = readIntBig(Code, "42712"), - duplicate_object = readIntBig(Code, "42710"), - ambiguous_column = readIntBig(Code, "42702"), - ambiguous_function = readIntBig(Code, "42725"), - ambiguous_parameter = readIntBig(Code, "42P08"), - ambiguous_alias = readIntBig(Code, "42P09"), - invalid_column_reference = readIntBig(Code, "42P10"), - invalid_column_definition = readIntBig(Code, "42611"), - invalid_cursor_definition = readIntBig(Code, "42P11"), - invalid_database_definition = readIntBig(Code, "42P12"), - invalid_function_definition = readIntBig(Code, "42P13"), - invalid_prepared_statement_definition = readIntBig(Code, "42P14"), - invalid_schema_definition = readIntBig(Code, "42P15"), - invalid_table_definition = readIntBig(Code, "42P16"), - invalid_object_definition = readIntBig(Code, "42P17"), - // Class 44 — WITH CHECK OPTION Violation - with_check_option_violation = readIntBig(Code, "44000"), - // Class 53 — Insufficient Resources - insufficient_resources = readIntBig(Code, "53000"), - disk_full = readIntBig(Code, "53100"), - out_of_memory = readIntBig(Code, "53200"), - too_many_connections = readIntBig(Code, "53300"), - configuration_limit_exceeded = readIntBig(Code, "53400"), - // Class 54 — Program Limit Exceeded - program_limit_exceeded = readIntBig(Code, "54000"), - statement_too_complex = readIntBig(Code, "54001"), - too_many_columns = readIntBig(Code, "54011"), - too_many_arguments = readIntBig(Code, "54023"), - // Class 55 — Object Not In Prerequisite State - object_not_in_prerequisite_state = readIntBig(Code, "55000"), - object_in_use = readIntBig(Code, "55006"), - cant_change_runtime_param = readIntBig(Code, "55P02"), - lock_not_available = readIntBig(Code, "55P03"), - unsafe_new_enum_value_usage = readIntBig(Code, "55P04"), - // Class 57 — Operator Intervention - operator_intervention = readIntBig(Code, "57000"), - query_canceled = readIntBig(Code, "57014"), - admin_shutdown = readIntBig(Code, "57P01"), - crash_shutdown = readIntBig(Code, "57P02"), - cannot_connect_now = readIntBig(Code, "57P03"), - database_dropped = readIntBig(Code, "57P04"), - idle_session_timeout = readIntBig(Code, "57P05"), - // Class 58 — System Error (errors external to PostgreSQL itself) - system_error = readIntBig(Code, "58000"), - io_error = readIntBig(Code, "58030"), - undefined_file = readIntBig(Code, "58P01"), - duplicate_file = readIntBig(Code, "58P02"), - // Class 72 — Snapshot Failure - snapshot_too_old = readIntBig(Code, "72000"), - // Class F0 — Configuration File Error - config_file_error = readIntBig(Code, "F0000"), - lock_file_exists = readIntBig(Code, "F0001"), - // Class HV — Foreign Data Wrapper Error (SQL/MED) - fdw_error = readIntBig(Code, "HV000"), - fdw_column_name_not_found = readIntBig(Code, "HV005"), - fdw_dynamic_parameter_value_needed = readIntBig(Code, "HV002"), - fdw_function_sequence_error = readIntBig(Code, "HV010"), - fdw_inconsistent_descriptor_information = readIntBig(Code, "HV021"), - fdw_invalid_attribute_value = readIntBig(Code, "HV024"), - fdw_invalid_column_name = readIntBig(Code, "HV007"), - fdw_invalid_column_number = readIntBig(Code, "HV008"), - fdw_invalid_data_type = readIntBig(Code, "HV004"), - fdw_invalid_data_type_descriptors = readIntBig(Code, "HV006"), - fdw_invalid_descriptor_field_identifier = readIntBig(Code, "HV091"), - fdw_invalid_handle = readIntBig(Code, "HV00B"), - fdw_invalid_option_index = readIntBig(Code, "HV00C"), - fdw_invalid_option_name = readIntBig(Code, "HV00D"), - fdw_invalid_string_length_or_buffer_length = readIntBig(Code, "HV090"), - fdw_invalid_string_format = readIntBig(Code, "HV00A"), - fdw_invalid_use_of_null_pointer = readIntBig(Code, "HV009"), - fdw_too_many_handles = readIntBig(Code, "HV014"), - fdw_out_of_memory = readIntBig(Code, "HV001"), - fdw_no_schemas = readIntBig(Code, "HV00P"), - fdw_option_name_not_found = readIntBig(Code, "HV00J"), - fdw_reply_handle = readIntBig(Code, "HV00K"), - fdw_schema_not_found = readIntBig(Code, "HV00Q"), - fdw_table_not_found = readIntBig(Code, "HV00R"), - fdw_unable_to_create_execution = readIntBig(Code, "HV00L"), - fdw_unable_to_create_reply = readIntBig(Code, "HV00M"), - fdw_unable_to_establish_connection = readIntBig(Code, "HV00N"), - // Class P0 — PL/pgSQL Error - plpgsql_error = readIntBig(Code, "P0000"), - raise_exception = readIntBig(Code, "P0001"), - no_data_found = readIntBig(Code, "P0002"), - too_many_rows = readIntBig(Code, "P0003"), - assert_failure = readIntBig(Code, "P0004"), - // Class XX — Internal Error - internal_error = readIntBig(Code, "XX000"), - data_corrupted = readIntBig(Code, "XX001"), - index_corrupted = readIntBig(Code, "XX002"), + pub fn errorCodeStr(code: SqlState) [5]u8 { + return toCodeStr(@enumToInt(code)); + } - _, + // Class 00 — Successful Completion + successful_completion = readVarInt(Code, "00000", .Big), + // Class 01 — Warning + warning = readVarInt(Code, "01000", .Big), + dynamic_result_sets_returned = readVarInt(Code, "0100C", .Big), + implicit_zero_bit_padding = readVarInt(Code, "01008", .Big), + null_value_eliminated_in_set_function = readVarInt(Code, "01003", .Big), + privilege_not_granted = readVarInt(Code, "01007", .Big), + privilege_not_revoked = readVarInt(Code, "01006", .Big), + string_data_right_truncation_warning = readVarInt(Code, "01004", .Big), + deprecated_feature = readVarInt(Code, "01P01", .Big), + // Class 02 — No Data (this is also a warning class per the SQL standard) + no_data = readVarInt(Code, "02000", .Big), + no_additional_dynamic_result_sets_returned = readVarInt(Code, "02001", .Big), + // Class 03 — SQL Statement Not Yet Complete + sql_statement_not_yet_complete = readVarInt(Code, "03000", .Big), + // Class 08 — Connection Exception + connection_exception = readVarInt(Code, "08000", .Big), + connection_does_not_exist = readVarInt(Code, "08003", .Big), + connection_failure = readVarInt(Code, "08006", .Big), + sqlclient_unable_to_establish_sqlconnection = readVarInt(Code, "08001", .Big), + sqlserver_rejected_establishment_of_sqlconnection = readVarInt(Code, "08004", .Big), + transaction_resolution_unknown = readVarInt(Code, "08007", .Big), + protocol_violation = readVarInt(Code, "08P01", .Big), + // Class 09 — Triggered Action Exception + triggered_action_exception = readVarInt(Code, "09000", .Big), + // Class 0A — Feature Not Supported + feature_not_supported = readVarInt(Code, "0A000", .Big), + // Class 0B — Invalid Transaction Initiation + invalid_transaction_initiation = readVarInt(Code, "0B000", .Big), + // Class 0F — Locator Exception + locator_exception = readVarInt(Code, "0F000", .Big), + invalid_locator_specification = readVarInt(Code, "0F001", .Big), + // Class 0L — Invalid Grantor + invalid_grantor = readVarInt(Code, "0L000", .Big), + invalid_grant_operation = readVarInt(Code, "0LP01", .Big), + // Class 0P — Invalid Role Specification + invalid_role_specification = readVarInt(Code, "0P000", .Big), + // Class 0Z — Diagnostics Exception + diagnostics_exception = readVarInt(Code, "0Z000", .Big), + stacked_diagnostics_accessed_without_active_handler = readVarInt(Code, "0Z002", .Big), + // Class 20 — Case Not Found + case_not_found = readVarInt(Code, "20000", .Big), + // Class 21 — Cardinality Violation + cardinality_violation = readVarInt(Code, "21000", .Big), + // Class 22 — Data Exception + data_exception = readVarInt(Code, "22000", .Big), + array_subscript_error = readVarInt(Code, "2202E", .Big), + character_not_in_repertoire = readVarInt(Code, "22021", .Big), + datetime_field_overflow = readVarInt(Code, "22008", .Big), + division_by_zero = readVarInt(Code, "22012", .Big), + error_in_assignment = readVarInt(Code, "22005", .Big), + escape_character_conflict = readVarInt(Code, "2200B", .Big), + indicator_overflow = readVarInt(Code, "22022", .Big), + interval_field_overflow = readVarInt(Code, "22015", .Big), + invalid_argument_for_logarithm = readVarInt(Code, "2201E", .Big), + invalid_argument_for_ntile_function = readVarInt(Code, "22014", .Big), + invalid_argument_for_nth_value_function = readVarInt(Code, "22016", .Big), + invalid_argument_for_power_function = readVarInt(Code, "2201F", .Big), + invalid_argument_for_width_bucket_function = readVarInt(Code, "2201G", .Big), + invalid_character_value_for_cast = readVarInt(Code, "22018", .Big), + invalid_datetime_format = readVarInt(Code, "22007", .Big), + invalid_escape_character = readVarInt(Code, "22019", .Big), + invalid_escape_octet = readVarInt(Code, "2200D", .Big), + invalid_escape_sequence = readVarInt(Code, "22025", .Big), + nonstandard_use_of_escape_character = readVarInt(Code, "22P06", .Big), + invalid_indicator_parameter_value = readVarInt(Code, "22010", .Big), + invalid_parameter_value = readVarInt(Code, "22023", .Big), + invalid_preceding_or_following_size = readVarInt(Code, "22013", .Big), + invalid_regular_expression = readVarInt(Code, "2201B", .Big), + invalid_row_count_in_limit_clause = readVarInt(Code, "2201W", .Big), + invalid_row_count_in_result_offset_clause = readVarInt(Code, "2201X", .Big), + invalid_tablesample_argument = readVarInt(Code, "2202H", .Big), + invalid_tablesample_repeat = readVarInt(Code, "2202G", .Big), + invalid_time_zone_displacement_value = readVarInt(Code, "22009", .Big), + invalid_use_of_escape_character = readVarInt(Code, "2200C", .Big), + most_specific_type_mismatch = readVarInt(Code, "2200G", .Big), + null_value_not_allowed_data_exception = readVarInt(Code, "22004", .Big), + null_value_no_indicator_parameter = readVarInt(Code, "22002", .Big), + numeric_value_out_of_range = readVarInt(Code, "22003", .Big), + sequence_generator_limit_exceeded = readVarInt(Code, "2200H", .Big), + string_data_length_mismatch = readVarInt(Code, "22026", .Big), + string_data_right_truncation_exception = readVarInt(Code, "22001", .Big), + substring_error = readVarInt(Code, "22011", .Big), + trim_error = readVarInt(Code, "22027", .Big), + unterminated_c_string = readVarInt(Code, "22024", .Big), + zero_length_character_string = readVarInt(Code, "2200F", .Big), + floating_point_exception = readVarInt(Code, "22P01", .Big), + invalid_text_representation = readVarInt(Code, "22P02", .Big), + invalid_binary_representation = readVarInt(Code, "22P03", .Big), + bad_copy_file_format = readVarInt(Code, "22P04", .Big), + untranslatable_character = readVarInt(Code, "22P05", .Big), + not_an_xml_document = readVarInt(Code, "2200L", .Big), + invalid_xml_document = readVarInt(Code, "2200M", .Big), + invalid_xml_content = readVarInt(Code, "2200N", .Big), + invalid_xml_comment = readVarInt(Code, "2200S", .Big), + invalid_xml_processing_instruction = readVarInt(Code, "2200T", .Big), + duplicate_json_object_key_value = readVarInt(Code, "22030", .Big), + invalid_argument_for_sql_json_datetime_function = readVarInt(Code, "22031", .Big), + invalid_json_text = readVarInt(Code, "22032", .Big), + invalid_sql_json_subscript = readVarInt(Code, "22033", .Big), + more_than_one_sql_json_item = readVarInt(Code, "22034", .Big), + no_sql_json_item = readVarInt(Code, "22035", .Big), + non_numeric_sql_json_item = readVarInt(Code, "22036", .Big), + non_unique_keys_in_a_json_object = readVarInt(Code, "22037", .Big), + singleton_sql_json_item_required = readVarInt(Code, "22038", .Big), + sql_json_array_not_found = readVarInt(Code, "22039", .Big), + sql_json_member_not_found = readVarInt(Code, "2203A", .Big), + sql_json_number_not_found = readVarInt(Code, "2203B", .Big), + sql_json_object_not_found = readVarInt(Code, "2203C", .Big), + too_many_json_array_elements = readVarInt(Code, "2203D", .Big), + too_many_json_object_members = readVarInt(Code, "2203E", .Big), + sql_json_scalar_required = readVarInt(Code, "2203F", .Big), + // Class 23 — Integrity Constraint Violation + integrity_constraint_violation = readVarInt(Code, "23000", .Big), + restrict_violation = readVarInt(Code, "23001", .Big), + not_null_violation = readVarInt(Code, "23502", .Big), + foreign_key_violation = readVarInt(Code, "23503", .Big), + unique_violation = readVarInt(Code, "23505", .Big), + check_violation = readVarInt(Code, "23514", .Big), + exclusion_violation = readVarInt(Code, "23P01", .Big), + // Class 24 — Invalid Cursor State + invalid_cursor_state = readVarInt(Code, "24000", .Big), + // Class 25 — Invalid Transaction State + invalid_transaction_state = readVarInt(Code, "25000", .Big), + active_sql_transaction = readVarInt(Code, "25001", .Big), + branch_transaction_already_active = readVarInt(Code, "25002", .Big), + held_cursor_requires_same_isolation_level = readVarInt(Code, "25008", .Big), + inappropriate_access_mode_for_branch_transaction = readVarInt(Code, "25003", .Big), + inappropriate_isolation_level_for_branch_transaction = readVarInt(Code, "25004", .Big), + no_active_sql_transaction_for_branch_transaction = readVarInt(Code, "25005", .Big), + read_only_sql_transaction = readVarInt(Code, "25006", .Big), + schema_and_data_statement_mixing_not_supported = readVarInt(Code, "25007", .Big), + no_active_sql_transaction = readVarInt(Code, "25P01", .Big), + in_failed_sql_transaction = readVarInt(Code, "25P02", .Big), + idle_in_transaction_session_timeout = readVarInt(Code, "25P03", .Big), + // Class 26 — Invalid SQL Statement Name + invalid_sql_statement_name = readVarInt(Code, "26000", .Big), + // Class 27 — Triggered Data Change Violation + triggered_data_change_violation = readVarInt(Code, "27000", .Big), + // Class 28 — Invalid Authorization Specification + invalid_authorization_specification = readVarInt(Code, "28000", .Big), + invalid_password = readVarInt(Code, "28P01", .Big), + // Class 2B — Dependent Privilege Descriptors Still Exist + dependent_privilege_descriptors_still_exist = readVarInt(Code, "2B000", .Big), + dependent_objects_still_exist = readVarInt(Code, "2BP01", .Big), + // Class 2D — Invalid Transaction Termination + invalid_transaction_termination = readVarInt(Code, "2D000", .Big), + // Class 2F — SQL Routine Exception + sql_routine_exception = readVarInt(Code, "2F000", .Big), + function_executed_no_return_statement = readVarInt(Code, "2F005", .Big), + modifying_sql_data_not_permitted_sql_exception = readVarInt(Code, "2F002", .Big), + prohibited_sql_statement_attempted_sql_exception = readVarInt(Code, "2F003", .Big), + reading_sql_data_not_permitted_sql_exception = readVarInt(Code, "2F004", .Big), + // Class 34 — Invalid Cursor Name + invalid_cursor_name = readVarInt(Code, "34000", .Big), + // Class 38 — External Routine Exception + external_routine_exception = readVarInt(Code, "38000", .Big), + containing_sql_not_permitted = readVarInt(Code, "38001", .Big), + modifying_sql_data_not_permitted_external_exception = readVarInt(Code, "38002", .Big), + prohibited_sql_statement_attempted_external_exception = readVarInt(Code, "38003", .Big), + reading_sql_data_not_permitted_external_exception = readVarInt(Code, "38004", .Big), + // Class 39 — External Routine Invocation Exception + external_routine_invocation_exception = readVarInt(Code, "39000", .Big), + invalid_sqlstate_returned = readVarInt(Code, "39001", .Big), + null_value_not_allowed_external_exception = readVarInt(Code, "39004", .Big), + trigger_protocol_violated = readVarInt(Code, "39P01", .Big), + srf_protocol_violated = readVarInt(Code, "39P02", .Big), + event_trigger_protocol_violated = readVarInt(Code, "39P03", .Big), + // Class 3B — Savepoint Exception + savepoint_exception = readVarInt(Code, "3B000", .Big), + invalid_savepoint_specification = readVarInt(Code, "3B001", .Big), + // Class 3D — Invalid Catalog Name + invalid_catalog_name = readVarInt(Code, "3D000", .Big), + // Class 3F — Invalid Schema Name + invalid_schema_name = readVarInt(Code, "3F000", .Big), + // Class 40 — Transaction Rollback + transaction_rollback = readVarInt(Code, "40000", .Big), + transaction_integrity_constraint_violation = readVarInt(Code, "40002", .Big), + serialization_failure = readVarInt(Code, "40001", .Big), + statement_completion_unknown = readVarInt(Code, "40003", .Big), + deadlock_detected = readVarInt(Code, "40P01", .Big), + // Class 42 — Syntax Error or Access Rule Violation + syntax_error_or_access_rule_violation = readVarInt(Code, "42000", .Big), + syntax_error = readVarInt(Code, "42601", .Big), + insufficient_privilege = readVarInt(Code, "42501", .Big), + cannot_coerce = readVarInt(Code, "42846", .Big), + grouping_error = readVarInt(Code, "42803", .Big), + windowing_error = readVarInt(Code, "42P20", .Big), + invalid_recursion = readVarInt(Code, "42P19", .Big), + invalid_foreign_key = readVarInt(Code, "42830", .Big), + invalid_name = readVarInt(Code, "42602", .Big), + name_too_long = readVarInt(Code, "42622", .Big), + reserved_name = readVarInt(Code, "42939", .Big), + datatype_mismatch = readVarInt(Code, "42804", .Big), + indeterminate_datatype = readVarInt(Code, "42P18", .Big), + collation_mismatch = readVarInt(Code, "42P21", .Big), + indeterminate_collation = readVarInt(Code, "42P22", .Big), + wrong_object_type = readVarInt(Code, "42809", .Big), + generated_always = readVarInt(Code, "428C9", .Big), + undefined_column = readVarInt(Code, "42703", .Big), + undefined_function = readVarInt(Code, "42883", .Big), + undefined_table = readVarInt(Code, "42P01", .Big), + undefined_parameter = readVarInt(Code, "42P02", .Big), + undefined_object = readVarInt(Code, "42704", .Big), + duplicate_column = readVarInt(Code, "42701", .Big), + duplicate_cursor = readVarInt(Code, "42P03", .Big), + duplicate_database = readVarInt(Code, "42P04", .Big), + duplicate_function = readVarInt(Code, "42723", .Big), + duplicate_prepared_statement = readVarInt(Code, "42P05", .Big), + duplicate_schema = readVarInt(Code, "42P06", .Big), + duplicate_table = readVarInt(Code, "42P07", .Big), + duplicate_alias = readVarInt(Code, "42712", .Big), + duplicate_object = readVarInt(Code, "42710", .Big), + ambiguous_column = readVarInt(Code, "42702", .Big), + ambiguous_function = readVarInt(Code, "42725", .Big), + ambiguous_parameter = readVarInt(Code, "42P08", .Big), + ambiguous_alias = readVarInt(Code, "42P09", .Big), + invalid_column_reference = readVarInt(Code, "42P10", .Big), + invalid_column_definition = readVarInt(Code, "42611", .Big), + invalid_cursor_definition = readVarInt(Code, "42P11", .Big), + invalid_database_definition = readVarInt(Code, "42P12", .Big), + invalid_function_definition = readVarInt(Code, "42P13", .Big), + invalid_prepared_statement_definition = readVarInt(Code, "42P14", .Big), + invalid_schema_definition = readVarInt(Code, "42P15", .Big), + invalid_table_definition = readVarInt(Code, "42P16", .Big), + invalid_object_definition = readVarInt(Code, "42P17", .Big), + // Class 44 — WITH CHECK OPTION Violation + with_check_option_violation = readVarInt(Code, "44000", .Big), + // Class 53 — Insufficient Resources + insufficient_resources = readVarInt(Code, "53000", .Big), + disk_full = readVarInt(Code, "53100", .Big), + out_of_memory = readVarInt(Code, "53200", .Big), + too_many_connections = readVarInt(Code, "53300", .Big), + configuration_limit_exceeded = readVarInt(Code, "53400", .Big), + // Class 54 — Program Limit Exceeded + program_limit_exceeded = readVarInt(Code, "54000", .Big), + statement_too_complex = readVarInt(Code, "54001", .Big), + too_many_columns = readVarInt(Code, "54011", .Big), + too_many_arguments = readVarInt(Code, "54023", .Big), + // Class 55 — Object Not In Prerequisite State + object_not_in_prerequisite_state = readVarInt(Code, "55000", .Big), + object_in_use = readVarInt(Code, "55006", .Big), + cant_change_runtime_param = readVarInt(Code, "55P02", .Big), + lock_not_available = readVarInt(Code, "55P03", .Big), + unsafe_new_enum_value_usage = readVarInt(Code, "55P04", .Big), + // Class 57 — Operator Intervention + operator_intervention = readVarInt(Code, "57000", .Big), + query_canceled = readVarInt(Code, "57014", .Big), + admin_shutdown = readVarInt(Code, "57P01", .Big), + crash_shutdown = readVarInt(Code, "57P02", .Big), + cannot_connect_now = readVarInt(Code, "57P03", .Big), + database_dropped = readVarInt(Code, "57P04", .Big), + idle_session_timeout = readVarInt(Code, "57P05", .Big), + // Class 58 — System Error (errors external to PostgreSQL itself) + system_error = readVarInt(Code, "58000", .Big), + io_error = readVarInt(Code, "58030", .Big), + undefined_file = readVarInt(Code, "58P01", .Big), + duplicate_file = readVarInt(Code, "58P02", .Big), + // Class 72 — Snapshot Failure + snapshot_too_old = readVarInt(Code, "72000", .Big), + // Class F0 — Configuration File Error + config_file_error = readVarInt(Code, "F0000", .Big), + lock_file_exists = readVarInt(Code, "F0001", .Big), + // Class HV — Foreign Data Wrapper Error (SQL/MED) + fdw_error = readVarInt(Code, "HV000", .Big), + fdw_column_name_not_found = readVarInt(Code, "HV005", .Big), + fdw_dynamic_parameter_value_needed = readVarInt(Code, "HV002", .Big), + fdw_function_sequence_error = readVarInt(Code, "HV010", .Big), + fdw_inconsistent_descriptor_information = readVarInt(Code, "HV021", .Big), + fdw_invalid_attribute_value = readVarInt(Code, "HV024", .Big), + fdw_invalid_column_name = readVarInt(Code, "HV007", .Big), + fdw_invalid_column_number = readVarInt(Code, "HV008", .Big), + fdw_invalid_data_type = readVarInt(Code, "HV004", .Big), + fdw_invalid_data_type_descriptors = readVarInt(Code, "HV006", .Big), + fdw_invalid_descriptor_field_identifier = readVarInt(Code, "HV091", .Big), + fdw_invalid_handle = readVarInt(Code, "HV00B", .Big), + fdw_invalid_option_index = readVarInt(Code, "HV00C", .Big), + fdw_invalid_option_name = readVarInt(Code, "HV00D", .Big), + fdw_invalid_string_length_or_buffer_length = readVarInt(Code, "HV090", .Big), + fdw_invalid_string_format = readVarInt(Code, "HV00A", .Big), + fdw_invalid_use_of_null_pointer = readVarInt(Code, "HV009", .Big), + fdw_too_many_handles = readVarInt(Code, "HV014", .Big), + fdw_out_of_memory = readVarInt(Code, "HV001", .Big), + fdw_no_schemas = readVarInt(Code, "HV00P", .Big), + fdw_option_name_not_found = readVarInt(Code, "HV00J", .Big), + fdw_reply_handle = readVarInt(Code, "HV00K", .Big), + fdw_schema_not_found = readVarInt(Code, "HV00Q", .Big), + fdw_table_not_found = readVarInt(Code, "HV00R", .Big), + fdw_unable_to_create_execution = readVarInt(Code, "HV00L", .Big), + fdw_unable_to_create_reply = readVarInt(Code, "HV00M", .Big), + fdw_unable_to_establish_connection = readVarInt(Code, "HV00N", .Big), + // Class P0 — PL/pgSQL Error + plpgsql_error = readVarInt(Code, "P0000", .Big), + raise_exception = readVarInt(Code, "P0001", .Big), + no_data_found = readVarInt(Code, "P0002", .Big), + too_many_rows = readVarInt(Code, "P0003", .Big), + assert_failure = readVarInt(Code, "P0004", .Big), + // Class XX — Internal Error + internal_error = readVarInt(Code, "XX000", .Big), + data_corrupted = readVarInt(Code, "XX001", .Big), + index_corrupted = readVarInt(Code, "XX002", .Big), + + _, + }; }; diff --git a/src/sql/engines/sqlite.zig b/src/sql/engines/sqlite.zig index 7d0a93f..b884301 100644 --- a/src/sql/engines/sqlite.zig +++ b/src/sql/engines/sqlite.zig @@ -28,7 +28,7 @@ fn getCharPos(text: []const u8, offset: c_int) struct { row: usize, col: usize } return .{ .row = row, .col = col }; } -fn handleUnexpectedError(db: *c.sqlite3, code: c_int, sql_text: ?[]const u8) anyerror { +fn handleUnexpectedError(db: *c.sqlite3, code: c_int, sql_text: ?[]const u8) error{Unexpected} { std.log.err("Unexpected error in SQLite engine: {s} ({})", .{ c.sqlite3_errstr(code), code }); std.log.debug("Additional details:", .{}); @@ -51,7 +51,7 @@ pub const Db = struct { pub fn open(path: [:0]const u8) common.OpenError!Db { const flags = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_EXRESCODE; - var db: [*c]c.sqlite3 = null; + var db: ?*c.sqlite3 = null; switch (c.sqlite3_open_v2(@ptrCast([*c]const u8, path), &db, flags, null)) { c.SQLITE_OK => {}, else => |code| { @@ -61,7 +61,7 @@ pub const Db = struct { "Unable to open SQLite DB \"{s}\". Error: {?s} ({})", .{ path, c.sqlite3_errstr(code), code }, ); - return error.InternalException; + return error.BadConnection; } const ext_code = c.sqlite3_extended_errcode(db); @@ -77,7 +77,7 @@ pub const Db = struct { } return Db{ - .db = db, + .db = db.?, }; } @@ -109,55 +109,115 @@ pub const Db = struct { }, }; - inline for (args) |arg, i| { - // SQLite treats $NNN args as having the name NNN, not index NNN. - // As such, if you reference $2 and not $1 in your query (such as - // when dynamically constructing queries), it could assign $2 the - // index 1. So we can't assume the index according to the 1-indexed - // arg array is equivalent to the param to bind it to. - // We can, however, look up the exact index to bind to. - // If the argument is not used in the query, then it will have an "index" - // of 0, and we must not bind the argument. - const name = std.fmt.comptimePrint("${}", .{i + 1}); - const db_idx = c.sqlite3_bind_parameter_index(stmt.?, name); - if (db_idx != 0) { - switch (bindArg(stmt.?, @intCast(u15, db_idx), arg)) { - c.SQLITE_OK => {}, - else => |err| { - return handleUnexpectedError(self.db, err, sql); - }, - } - } else if (!opts.ignore_unknown_parameters) return error.UndefinedParameter; + if (@TypeOf(args) != void) { + inline for (args) |arg, i| { + // SQLite treats $NNN args as having the name NNN, not index NNN. + // As such, if you reference $2 and not $1 in your query (such as + // when dynamically constructing queries), it could assign $2 the + // index 1. So we can't assume the index according to the 1-indexed + // arg array is equivalent to the param to bind it to. + // We can, however, look up the exact index to bind to. + // If the argument is not used in the query, then it will have an "index" + // of 0, and we must not bind the argument. + const name = std.fmt.comptimePrint("${}", .{i + 1}); + const db_idx = c.sqlite3_bind_parameter_index(stmt.?, name); + if (db_idx != 0) + try self.bindArgument(stmt.?, @intCast(u15, db_idx), arg) + else if (!opts.ignore_unused_arguments) + return error.UnusedArgument; + } } return Results{ .stmt = stmt.?, .db = self.db }; } -}; -fn bindArg(stmt: *c.sqlite3_stmt, idx: u15, val: anytype) c_int { - if (comptime std.meta.trait.isZigString(@TypeOf(val))) { - const slice = @as([]const u8, val); - return c.sqlite3_bind_text(stmt, idx, slice.ptr, @intCast(c_int, slice.len), c.SQLITE_TRANSIENT); + fn bindArgument(self: Db, stmt: *c.sqlite3_stmt, idx: u15, val: anytype) !void { + if (comptime std.meta.trait.isZigString(@TypeOf(val))) { + return self.bindString(stmt, idx, val); + } + + const T = @TypeOf(val); + + switch (@typeInfo(T)) { + .Struct, + .Union, + .Opaque, + => { + const arr = if (@hasDecl(T, "toCharArray")) + val.toCharArray() + else if (@hasDecl(T, "toCharArrayZ")) + val.toCharArrayZ() + else + @compileError("SQLite: Could not serialize " ++ @typeName(T) ++ " into staticly sized string"); + + const len = std.mem.len(&arr); + return self.bindString(stmt, idx, arr[0..len]); + }, + .Enum => |info| { + const name = if (info.is_exhaustive) + @tagName(val) + else + @compileError("SQLite: Could not serialize non-exhaustive enum " ++ @typeName(T) ++ " into string"); + return self.bindString(stmt, idx, name); + }, + .Optional => { + return if (val) |v| self.bindArgument(stmt, idx, v) else self.bindNull(stmt, idx); + }, + .Null => return self.bindNull(stmt, idx), + .Int => return self.bindInt(stmt, idx, val), + .Float => return self.bindFloat(stmt, idx, val), + else => @compileError("Unable to serialize type " ++ @typeName(T)), + } } - return switch (@TypeOf(val)) { - Uuid => blk: { - const arr = val.toCharArrayZ(); - break :blk bindArg(stmt, idx, &arr); - }, - DateTime => blk: { - const arr = val.toCharArrayZ(); - break :blk bindArg(stmt, idx, &arr); - }, - @TypeOf(null) => c.sqlite3_bind_null(stmt, idx), - else => |T| switch (@typeInfo(T)) { - .Optional => if (val) |v| bindArg(stmt, idx, v) else bindArg(stmt, idx, null), - .Enum => bindArg(stmt, idx, @tagName(val)), - .Int => c.sqlite3_bind_int64(stmt, idx, @intCast(i64, val)), - else => @compileError("unsupported type " ++ @typeName(T)), - }, - }; -} + fn bindString(self: Db, stmt: *c.sqlite3_stmt, idx: u15, str: []const u8) !void { + const len = std.math.cast(c_int, str.len) orelse { + std.log.err("SQLite: string len {} too large", .{str.len}); + return error.BindException; + }; + + switch (c.sqlite3_bind_text(stmt, idx, str.ptr, len, c.SQLITE_TRANSIENT)) { + c.SQLITE_OK => {}, + else => |result| { + std.log.err("SQLite: Unable to bind string to index {}", .{idx}); + std.log.debug("SQLite: {s}", .{str}); + return handleUnexpectedError(self.db, result, null); + }, + } + } + + fn bindNull(self: Db, stmt: *c.sqlite3_stmt, idx: u15) !void { + switch (c.sqlite3_bind_null(stmt, idx)) { + c.SQLITE_OK => {}, + else => |result| { + std.log.err("SQLite: Unable to bind NULL to index {}", .{idx}); + return handleUnexpectedError(self.db, result, null); + }, + } + } + + fn bindInt(self: Db, stmt: *c.sqlite3_stmt, idx: u15, val: i64) !void { + switch (c.sqlite3_bind_int64(stmt, idx, val)) { + c.SQLITE_OK => {}, + else => |result| { + std.log.err("SQLite: Unable to bind int to index {}", .{idx}); + std.log.debug("SQLite: {}", .{val}); + return handleUnexpectedError(self.db, result, null); + }, + } + } + + fn bindFloat(self: Db, stmt: *c.sqlite3_stmt, idx: u15, val: f64) !void { + switch (c.sqlite3_bind_double(stmt, idx, val)) { + c.SQLITE_OK => {}, + else => |result| { + std.log.err("SQLite: Unable to bind float to index {}", .{idx}); + std.log.debug("SQLite: {}", .{val}); + return handleUnexpectedError(self.db, result, null); + }, + } + } +}; pub const Results = struct { stmt: *c.sqlite3_stmt, @@ -193,12 +253,12 @@ pub const Results = struct { return if (c.sqlite3_column_name(self.stmt, idx)) |ptr| ptr[0..std.mem.len(ptr)] else - return error.Unexpected; + unreachable; } pub fn columnIndex(self: Results, name: []const u8) common.ColumnIndexError!u15 { var i: u15 = 0; - const count = self.columnCount(); + const count = try self.columnCount(); while (i < count) : (i += 1) { const column = try self.columnName(i); if (std.mem.eql(u8, name, column)) return i; @@ -213,33 +273,75 @@ pub const Row = struct { db: *c.sqlite3, pub fn get(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T { - if (c.sqlite3_column_type(self.stmt, idx) == c.SQLITE_NULL) { - return if (@typeInfo(T) == .Optional) null else error.TypeMismatch; - } - - return self.getNotNull(T, idx, alloc); - } - - fn getNotNull(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) error.GetError!T { - return switch (T) { - f32, f64 => @floatCast(T, c.sqlite3_column_double(self.stmt, idx)), - - else => switch (@typeInfo(T)) { - .Int => |int| if (T == i63 or int.bits < 63) - @intCast(T, c.sqlite3_column_int64(self.stmt, idx)) - else - self.getFromString(T, idx, alloc), - .Optional => try self.getNotNull(std.meta.Child(T), idx, alloc), - else => self.getFromString(T, idx, alloc), - }, - }; - } - - fn getFromString(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) error.GetError!T { - const ptr = c.sqlite3_column_text(self.stmt, idx); - const size = @intCast(usize, c.sqlite3_column_bytes(self.stmt, idx)); - const str = ptr[0..size]; - - return common.parseValueNotNull(alloc, T, str); + return getColumn(self.stmt, T, idx, alloc); } }; + +fn getColumn(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T { + return switch (c.sqlite3_column_type(stmt, idx)) { + c.SQLITE_INTEGER => getColumnInt(stmt, T, idx), + c.SQLITE_FLOAT => getColumnFloat(stmt, T, idx), + c.SQLITE_TEXT => getColumnText(stmt, T, idx, alloc), + c.SQLITE_NULL => { + if (@typeInfo(T) != .Optional) { + std.log.err("SQLite column {}: Expected value of type {}, got (null)", .{ idx, T }); + return error.ResultTypeMismatch; + } + + return null; + }, + c.SQLITE_BLOB => { + std.log.err("SQLite column {}: SQLite value had unsupported storage class BLOB", .{idx}); + return error.ResultTypeMismatch; + }, + else => |class| { + std.log.err("SQLite column {}: SQLite value had unknown storage class {}", .{ idx, class }); + return error.ResultTypeMismatch; + }, + }; +} + +fn getColumnInt(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15) common.GetError!T { + const val: i64 = c.sqlite3_column_int64(stmt, idx); + + switch (T) { + DateTime => return DateTime{ .seconds_since_epoch = val }, + else => switch (@typeInfo(T)) { + .Int => if (std.math.cast(T, val)) |v| return v else { + std.log.err("SQLite column {}: Expected value of type {}, got {} (outside of range)", .{ idx, T, val }); + return error.ResultTypeMismatch; + }, + else => { + std.log.err("SQLite column {}: Storage class INT cannot be parsed into type {}", .{ idx, T }); + return error.ResultTypeMismatch; + }, + }, + } +} + +fn getColumnFloat(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15) common.GetError!T { + const val: f64 = c.sqlite3_column_double(stmt, idx); + switch (T) { + // Only support floats that fit in range for now + f16, f32, f64 => return @floatCast(T, val), + DateTime => return DateTime{ + .seconds_since_epoch = std.math.lossyCast(i64, val * @intToFloat(f64, std.time.epoch.secs_per_day)), + }, + else => { + std.log.err("SQLite column {}: Storage class FLOAT cannot be parsed into type {}", .{ idx, T }); + return error.ResultTypeMismatch; + }, + } +} + +fn getColumnText(stmt: *c.sqlite3_stmt, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T { + if (c.sqlite3_column_text(stmt, idx)) |ptr| { + const size = @intCast(usize, c.sqlite3_column_bytes(stmt, idx)); + const str = std.mem.sliceTo(ptr[0..size], 0); + + return common.parseValueNotNull(alloc, T, str); + } else { + std.log.err("SQLite column {}: TEXT value stored but engine returned null pointer (out of memory?)", .{idx}); + return error.ResultTypeMismatch; + } +} diff --git a/src/sql/errors.zig b/src/sql/errors.zig index 11db043..0a7fb9a 100644 --- a/src/sql/errors.zig +++ b/src/sql/errors.zig @@ -17,7 +17,7 @@ const ConnectionError = error{ // - Filesystem full // - Unknown crash // - Filesystem permissions denied - InternalError, + InternalException, }; // Errors related to constraint validation @@ -41,19 +41,21 @@ const ConstraintError = error{ // Errors related to argument binding const ArgumentError = error{ // One of the arguments passed could not be marshalled to pass to the SQL engine - InvalidArgument, + BindException, // The set of arguments passed did not map to query parameters - UndefinedParameter, + UnusedArgument, // The allocator used for staging the query ran out of memory OutOfMemory, + AllocatorRequired, }; // Errors related to retrieving query result columns const ResultColumnError = error{ // The allocator used for retrieving the results ran out of memory OutOfMemory, + AllocatorRequired, // A type error occurred when parsing results (means invalid data is in the DB) ResultTypeMismatch, @@ -68,10 +70,7 @@ const StartQueryError = error{ PermissionDenied, // The SQL query had invalid syntax or used an invalid identifier - InvalidSql, - - // A type error occurred during the query (means query is written wrong) - QueryTypeMismatch, + SqlException, // The set of columns to parse did not match the columns returned by the query ColumnMismatch, @@ -81,16 +80,6 @@ const StartQueryError = error{ BadTransactionState, }; -const RowCountError = error{ - NoRows, - TooManyRows, -}; - -pub const OpenError = error{ - BadConnection, - InternalError, -}; - pub const library_errors = struct { const BaseError = ConnectionError || UnexpectedError; @@ -98,8 +87,8 @@ pub const library_errors = struct { pub const OpenError = BaseError; pub const QueryError = BaseError || ArgumentError || ConstraintError || StartQueryError; pub const RowError = BaseError || ResultColumnError || ConstraintError || StartQueryError; - pub const QueryRowError = QueryError || RowError || RowCountError; - pub const ExecError = QueryError || RowCountError; + pub const QueryRowError = QueryError || RowError || error{ NoRows, TooManyRows }; + pub const ExecError = QueryError; pub const BeginError = BaseError || StartQueryError; pub const CommitError = BaseError || StartQueryError || ConstraintError; }; diff --git a/src/sql/lib.zig b/src/sql/lib.zig index 1177e12..9eedda4 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -2,7 +2,9 @@ const std = @import("std"); const util = @import("util"); const postgres = @import("./engines/postgres.zig"); +//const postgres = @import("./engines/null.zig"); const sqlite = @import("./engines/sqlite.zig"); +//const sqlite = @import("./engines/null.zig"); const common = @import("./engines/common.zig"); const Allocator = std.mem.Allocator; @@ -42,17 +44,20 @@ const RawResults = union(Engine) { } } - fn columnCount(self: RawResults) u15 { - return switch (self) { + fn columnCount(self: RawResults) !u15 { + return try switch (self) { .postgres => |pg| pg.columnCount(), .sqlite => |lite| lite.columnCount(), }; } - fn columnIndex(self: RawResults, name: []const u8) QueryError!u15 { - return try switch (self) { + fn columnIndex(self: RawResults, name: []const u8) error{ NotFound, Unexpected }!u15 { + return switch (self) { .postgres => |pg| pg.columnIndex(name), .sqlite => |lite| lite.columnIndex(name), + } catch |err| switch (err) { + error.OutOfRange => error.Unexpected, + error.NotFound => error.NotFound, }; } @@ -69,7 +74,7 @@ const RawResults = union(Engine) { // Must be deallocated by a call to finish() pub fn Results(comptime T: type) type { // would normally make this a declaration of the struct, but it causes the compiler to crash - const fields = std.meta.fields(T); + const fields = if (T == void) .{} else std.meta.fields(T); return struct { const Self = @This(); @@ -77,15 +82,21 @@ pub fn Results(comptime T: type) type { column_indices: [fields.len]u15, fn from(underlying: RawResults) QueryError!Self { - if (std.debug.runtime_safety and std.meta.trait.isTuple(T) and fields.len != underlying.columnCount()) { - std.log.err("Expected {} columns in result, got {}", .{ fields.len, underlying.columnCount() }); + if (std.debug.runtime_safety and fields.len != underlying.columnCount() catch unreachable) { + std.log.err("Expected {} columns in result, got {}", .{ fields.len, underlying.columnCount() catch unreachable }); return error.ColumnMismatch; } return Self{ .underlying = underlying, .column_indices = blk: { var indices: [fields.len]u15 = undefined; inline for (fields) |f, i| { - indices[i] = if (!std.meta.trait.isTuple(T)) try underlying.columnIndex(f.name) else i; + indices[i] = if (!std.meta.trait.isTuple(T)) + underlying.columnIndex(f.name) catch { + std.log.err("Could not find column index for field {s}", .{f.name}); + return error.ColumnMismatch; + } + else + i; } break :blk indices; } }; @@ -110,7 +121,13 @@ pub fn Results(comptime T: type) type { }; inline for (fields) |f, i| { - @field(result, f.name) = try row_val.get(f.field_type, self.column_indices[i], alloc); + // TODO: Causes compiler segfault. why? + //const F = f.field_type; + const F = @TypeOf(@field(result, f.name)); + @field(result, f.name) = row_val.get(F, self.column_indices[i], alloc) catch |err| { + std.log.err("SQL: Error getting column {s} of type {}", .{ f.name, F }); + return err; + }; fields_allocated += 1; } @@ -129,7 +146,8 @@ const Row = union(Engine) { // Not all types require an allocator to be present. If an allocator is needed but // not required, it will return error.AllocatorRequired. // The caller is responsible for deallocating T, if relevant. - fn get(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) !T { + fn get(self: Row, comptime T: type, idx: u15, alloc: ?Allocator) common.GetError!T { + if (T == void) return; return switch (self) { .postgres => |pg| pg.get(T, idx, alloc), .sqlite => |lite| lite.get(T, idx, alloc), @@ -147,9 +165,9 @@ const QueryHelper = union(Engine) { sql: [:0]const u8, args: anytype, opt: QueryOptions, - ) !RawResults { + ) QueryError!RawResults { return switch (self) { - .postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt.prep_allocator) }, + .postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt) }, .sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args, opt) }, }; } @@ -182,7 +200,12 @@ const QueryHelper = union(Engine) { args: anytype, alloc: ?Allocator, ) QueryError!void { - try self.queryRow(void, sql, args, alloc); + _ = self.queryRow(void, sql, args, alloc) catch |err| return switch (err) { + error.NoRows => {}, + error.TooManyRows => error.SqlException, + error.ResultTypeMismatch => unreachable, + else => |err2| err2, + }; } // Runs a query and returns a single row @@ -224,7 +247,11 @@ const QueryHelper = union(Engine) { comptime var table_spec: []const u8 = table ++ "("; comptime var value_spec: []const u8 = "("; inline for (fields) |field, i| { - types[i] = field.field_type; + // This causes a compile error. Why? + //const F = field.field_type; + const F = @TypeOf(@field(std.mem.zeroes(ValueType), field.name)); + // causes issues if F is @TypeOf(null), use dummy type + types[i] = if (F == @TypeOf(null)) ?i64 else F; table_spec = comptime (table_spec ++ field.name ++ ","); value_spec = comptime value_spec ++ std.fmt.comptimePrint("${},", .{i + 1}); } @@ -252,6 +279,8 @@ pub const Db = struct { tx_open: bool = false, engine: QueryHelper, + pub const is_transaction = false; + pub fn open(cfg: Config) OpenError!Db { return switch (cfg) { .postgres => |postgres_cfg| Db{ @@ -345,6 +374,8 @@ pub const Db = struct { pub const Tx = struct { db: *Db, + pub const is_transaction = true; + pub fn queryWithOptions( self: Tx, comptime RowType: type, @@ -442,12 +473,14 @@ pub const Tx = struct { pub fn commit(self: Tx) CommitError!void { if (!self.db.tx_open) return error.BadTransactionState; self.exec("COMMIT", {}, null) catch |err| switch (err) { - error.InvalidArgument, + error.BindException, error.OutOfMemory, - error.UndefinedParameter, + error.UnusedArgument, + error.AllocatorRequired, => return error.Unexpected, - else => return err, + // use a new capture because it's got a smaller error set + else => |err2| return err2, }; self.db.tx_open = false; } diff --git a/src/util/DateTime.zig b/src/util/DateTime.zig index 87f2630..b0fad4c 100644 --- a/src/util/DateTime.zig +++ b/src/util/DateTime.zig @@ -9,10 +9,16 @@ pub const Duration = struct { seconds_since_epoch: i64, +// Tries the following methods for parsing, in order: +// 1. treats the string as a RFC 3339 DateTime +// 2. treats the string as the number of seconds since epoch pub fn parse(str: []const u8) !DateTime { - // TODO: Try other formats - - return try parseRfc3339(str); + return if (parseRfc3339(str)) |v| + v + else |_| if (std.fmt.parseInt(i64, str, 10)) |v| + DateTime{ .seconds_since_epoch = v } + else |_| + error.UnknownFormat; } pub fn add(self: DateTime, duration: Duration) DateTime { @@ -36,8 +42,8 @@ pub fn parseRfc3339(str: []const u8) !DateTime { const month_num = try std.fmt.parseInt(std.meta.Tag(epoch.Month), str[5..7], 10); const day_num = @as(i64, try std.fmt.parseInt(u9, str[8..10], 10)); const hour_num = @as(i64, try std.fmt.parseInt(u5, str[11..13], 10)); - const minute_num = @as(i64, try std.fmt.parseInt(u6, str[14..15], 10)); - const second_num = @as(i64, try std.fmt.parseInt(u6, str[16..17], 10)); + const minute_num = @as(i64, try std.fmt.parseInt(u6, str[14..16], 10)); + const second_num = @as(i64, try std.fmt.parseInt(u6, str[17..19], 10)); const is_leap_year = epoch.isLeapYear(year_num); const leap_days_preceding_epoch = comptime epoch.epoch_year / 4 - epoch.epoch_year / 100 + epoch.epoch_year / 400; @@ -97,7 +103,7 @@ pub fn second(value: DateTime) u6 { const array_len = 20; -pub fn toCharArray(value: DateTime) [array_len + 1]u8 { +pub fn toCharArray(value: DateTime) [array_len]u8 { var buf: [array_len]u8 = undefined; _ = std.fmt.bufPrintZ(&buf, "{}", value) catch unreachable; return buf; diff --git a/src/util/lib.zig b/src/util/lib.zig index 6d7119a..2c5e1a9 100644 --- a/src/util/lib.zig +++ b/src/util/lib.zig @@ -60,12 +60,15 @@ pub fn comptimeJoin( /// to your enum type. pub fn jsonSerializeEnumAsString( enum_value: anytype, - _: std.json.StringifyOptions, + opt: std.json.StringifyOptions, writer: anytype, ) !void { switch (@typeInfo(@TypeOf(enum_value))) { .Enum => |info| if (!info.is_exhaustive) @compileError("Enum must be exhaustive"), - else => @compileError("Must be enum type"), + .Pointer => |info| if (info.size == .One) { + return jsonSerializeEnumAsString(enum_value.*, opt, writer); + } else @compileError("Must be enum type or pointer to enum, got " ++ @typeName(@TypeOf(enum_value))), + else => @compileError("Must be enum type or pointer to enum, got " ++ @typeName(@TypeOf(enum_value))), } return std.fmt.format(writer, "\"{s}\"", .{@tagName(enum_value)}); diff --git a/tests/api_integration/lib.zig b/tests/api_integration/lib.zig index 5b0f3d4..c20ef70 100644 --- a/tests/api_integration/lib.zig +++ b/tests/api_integration/lib.zig @@ -1,9 +1,8 @@ const std = @import("std"); const main = @import("main"); +const sql = @import("sql"); -const cluster_host = "test_host"; const test_config = .{ - .cluster_host = cluster_host, .db = .{ .sqlite = .{ .db_file = ":memory:", @@ -12,13 +11,22 @@ const test_config = .{ }; const ApiSource = main.api.ApiSource; -const root_password = "password"; +const root_user = "root"; +const root_password = "password1234"; +const admin_host = "example.com"; +const admin_origin = "https://" ++ admin_host; const random_seed = 1234; -fn makeApi(alloc: std.mem.Allocator) !ApiSource { +fn makeDb(alloc: std.mem.Allocator) sql.Db { + var db = try sql.Db.open(test_config.db); + try main.migrations.up(&db); + try main.api.setupAdmin(&db, admin_origin, root_user, root_password, alloc); +} + +fn makeApi(alloc: std.mem.Allocator, db: *sql.Db) !ApiSource { main.api.initThreadPrng(random_seed); - const source = try ApiSource.init(alloc, test_config, root_password); + const source = try ApiSource.init(alloc, test_config, db); return source; } @@ -26,10 +34,11 @@ test "login as root" { const alloc = std.testing.allocator; var arena = std.heap.ArenaAllocator.init(alloc); defer arena.deinit(); - var src = try makeApi(alloc); + var db = try makeDb(alloc); + var src = try makeApi(alloc, &db); std.debug.print("\npassword: {s}\n", .{root_password}); - var api = try src.connectUnauthorized(cluster_host, arena.allocator()); + var api = try src.connectUnauthorized(admin_host, arena.allocator()); defer api.close(); _ = try api.login("root", root_password);