Move to new DB api
This commit is contained in:
parent
db225b6689
commit
33cf0ff87a
18 changed files with 258 additions and 339 deletions
|
@ -1,9 +1,10 @@
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const util = @import("util");
|
const util = @import("util");
|
||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
|
const sql = @import("sql");
|
||||||
|
|
||||||
const db = @import("./db.zig");
|
|
||||||
const models = @import("./db/models.zig");
|
const models = @import("./db/models.zig");
|
||||||
|
const migrations = @import("./migrations.zig");
|
||||||
pub const DateTime = util.DateTime;
|
pub const DateTime = util.DateTime;
|
||||||
pub const Uuid = util.Uuid;
|
pub const Uuid = util.Uuid;
|
||||||
const Config = @import("./main.zig").Config;
|
const Config = @import("./main.zig").Config;
|
||||||
|
@ -28,7 +29,7 @@ pub const InviteRequest = struct {
|
||||||
|
|
||||||
name: ?[]const u8 = null,
|
name: ?[]const u8 = null,
|
||||||
expires_at: ?DateTime = null, // TODO: Change this to lifespan
|
expires_at: ?DateTime = null, // TODO: Change this to lifespan
|
||||||
max_uses: ?usize = null,
|
max_uses: ?u16 = null,
|
||||||
|
|
||||||
invite_type: Type = .user, // must be user unless the creator is an admin
|
invite_type: Type = .user, // must be user unless the creator is an admin
|
||||||
to_community: ?[]const u8 = null, // only valid on admin community
|
to_community: ?[]const u8 = null, // only valid on admin community
|
||||||
|
@ -94,40 +95,41 @@ pub fn getRandom() std.rand.Random {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const ApiSource = struct {
|
pub const ApiSource = struct {
|
||||||
db: db.Database,
|
db: sql.Db,
|
||||||
internal_alloc: std.mem.Allocator,
|
internal_alloc: std.mem.Allocator,
|
||||||
config: Config,
|
config: Config,
|
||||||
|
|
||||||
pub const Conn = ApiConn(db.Database);
|
pub const Conn = ApiConn(sql.Db);
|
||||||
|
|
||||||
const root_username = "root";
|
const root_username = "root";
|
||||||
|
|
||||||
pub fn init(alloc: std.mem.Allocator, cfg: Config, root_password: ?[]const u8) !ApiSource {
|
pub fn init(alloc: std.mem.Allocator, cfg: Config, root_password: ?[]const u8, db_conn: sql.Db) !ApiSource {
|
||||||
var self = ApiSource{
|
var self = ApiSource{
|
||||||
.db = try db.Database.init(cfg.db.sqlite.db_file),
|
.db = db_conn,
|
||||||
.internal_alloc = alloc,
|
.internal_alloc = alloc,
|
||||||
.config = cfg,
|
.config = cfg,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
try migrations.up(db_conn);
|
||||||
|
|
||||||
if ((try services.users.lookupByUsername(&self.db, root_username, null)) == null) {
|
if ((try services.users.lookupByUsername(&self.db, root_username, null)) == null) {
|
||||||
std.log.info("No cluster root user detected. Creating...", .{});
|
std.log.info("No cluster root user detected. Creating...", .{});
|
||||||
|
|
||||||
// TODO: Fix this
|
// TODO: Fix this
|
||||||
const password = root_password orelse return error.NeedRootPassword;
|
const password = root_password orelse return error.NeedRootPassword;
|
||||||
std.debug.print("\npassword: {s}\n", .{password});
|
|
||||||
var arena = std.heap.ArenaAllocator.init(alloc);
|
var arena = std.heap.ArenaAllocator.init(alloc);
|
||||||
defer arena.deinit();
|
defer arena.deinit();
|
||||||
const user_id = try services.users.create(&self.db, root_username, password, null, .{}, arena.allocator());
|
const user_id = try services.users.create(&self.db, root_username, password, null, .{}, arena.allocator());
|
||||||
std.debug.print("Created {s} ID {}", .{ root_username, user_id });
|
std.log.debug("Created {s} ID {}", .{ root_username, user_id });
|
||||||
}
|
}
|
||||||
|
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn getCommunityFromHost(self: *ApiSource, host: []const u8) !?Uuid {
|
fn getCommunityFromHost(self: *ApiSource, host: []const u8) !?Uuid {
|
||||||
if (try self.db.execRow(
|
if (try self.db.queryRow(
|
||||||
&.{Uuid},
|
&.{Uuid},
|
||||||
"SELECT id FROM community WHERE host = ?",
|
"SELECT id FROM community WHERE host = $1",
|
||||||
.{host},
|
.{host},
|
||||||
null,
|
null,
|
||||||
)) |result| return result[0];
|
)) |result| return result[0];
|
||||||
|
@ -204,9 +206,9 @@ fn ApiConn(comptime DbConn: type) type {
|
||||||
};
|
};
|
||||||
pub fn getTokenInfo(self: *Self) !TokenInfo {
|
pub fn getTokenInfo(self: *Self) !TokenInfo {
|
||||||
if (self.user_id) |user_id| {
|
if (self.user_id) |user_id| {
|
||||||
const result = (try self.db.execRow(
|
const result = (try self.db.queryRow(
|
||||||
&.{[]const u8},
|
&.{[]const u8},
|
||||||
"SELECT username FROM user WHERE id = ?",
|
"SELECT username FROM user WHERE id = $1",
|
||||||
.{user_id},
|
.{user_id},
|
||||||
self.arena.allocator(),
|
self.arena.allocator(),
|
||||||
)) orelse {
|
)) orelse {
|
||||||
|
|
|
@ -23,9 +23,9 @@ pub const passwords = struct {
|
||||||
};
|
};
|
||||||
pub fn verify(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) VerifyError!void {
|
pub fn verify(db: anytype, user_id: Uuid, password: []const u8, alloc: std.mem.Allocator) VerifyError!void {
|
||||||
// TODO: This could be done w/o the dynamically allocated hash buf
|
// TODO: This could be done w/o the dynamically allocated hash buf
|
||||||
const hash = (db.execRow(
|
const hash = (db.queryRow(
|
||||||
&.{[]const u8},
|
&.{[]const u8},
|
||||||
"SELECT hashed_password FROM account_password WHERE user_id = ? LIMIT 1",
|
"SELECT hashed_password FROM account_password WHERE user_id = $1 LIMIT 1",
|
||||||
.{user_id},
|
.{user_id},
|
||||||
alloc,
|
alloc,
|
||||||
) catch return error.DbError) orelse return error.InvalidLogin;
|
) catch return error.DbError) orelse return error.InvalidLogin;
|
||||||
|
@ -95,11 +95,11 @@ pub const tokens = struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lookupUserTokenFromHash(db: anytype, hash: []const u8, community_id: Uuid) !?Token.Info {
|
fn lookupUserTokenFromHash(db: anytype, hash: []const u8, community_id: Uuid) !?Token.Info {
|
||||||
return if (try db.execRow(
|
return if (try db.queryRow(
|
||||||
&.{ Uuid, DateTime },
|
&.{ Uuid, DateTime },
|
||||||
\\SELECT user.id, token.issued_at
|
\\SELECT user.id, token.issued_at
|
||||||
\\FROM token JOIN user ON token.user_id = user.id
|
\\FROM token JOIN user ON token.user_id = user.id
|
||||||
\\WHERE user.community_id = ? AND token.hash = ?
|
\\WHERE user.community_id = $1 AND token.hash = $2
|
||||||
\\LIMIT 1
|
\\LIMIT 1
|
||||||
,
|
,
|
||||||
.{ community_id, hash },
|
.{ community_id, hash },
|
||||||
|
@ -114,11 +114,11 @@ pub const tokens = struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lookupSystemTokenFromHash(db: anytype, hash: []const u8) !?Token.Info {
|
fn lookupSystemTokenFromHash(db: anytype, hash: []const u8) !?Token.Info {
|
||||||
return if (try db.execRow(
|
return if (try db.queryRow(
|
||||||
&.{ Uuid, DateTime },
|
&.{ Uuid, DateTime },
|
||||||
\\SELECT user.id, token.issued_at
|
\\SELECT user.id, token.issued_at
|
||||||
\\FROM token JOIN user ON token.user_id = user.id
|
\\FROM token JOIN user ON token.user_id = user.id
|
||||||
\\WHERE user.community_id IS NULL AND token.hash = ?
|
\\WHERE user.community_id IS NULL AND token.hash = $1
|
||||||
\\LIMIT 1
|
\\LIMIT 1
|
||||||
,
|
,
|
||||||
.{hash},
|
.{hash},
|
||||||
|
|
|
@ -3,8 +3,6 @@ const builtin = @import("builtin");
|
||||||
const util = @import("util");
|
const util = @import("util");
|
||||||
const models = @import("../db/models.zig");
|
const models = @import("../db/models.zig");
|
||||||
|
|
||||||
const DbError = @import("../db.zig").ExecError;
|
|
||||||
|
|
||||||
const getRandom = @import("../api.zig").getRandom;
|
const getRandom = @import("../api.zig").getRandom;
|
||||||
|
|
||||||
const Uuid = util.Uuid;
|
const Uuid = util.Uuid;
|
||||||
|
@ -14,7 +12,7 @@ const CreateError = error{
|
||||||
InvalidOrigin,
|
InvalidOrigin,
|
||||||
UnsupportedScheme,
|
UnsupportedScheme,
|
||||||
CommunityExists,
|
CommunityExists,
|
||||||
} || DbError;
|
} || anyerror; // TODO
|
||||||
|
|
||||||
pub const Scheme = enum {
|
pub const Scheme = enum {
|
||||||
https,
|
https,
|
||||||
|
@ -76,7 +74,7 @@ pub fn create(db: anytype, origin: []const u8, name: ?[]const u8) CreateError!Co
|
||||||
.created_at = DateTime.now(),
|
.created_at = DateTime.now(),
|
||||||
};
|
};
|
||||||
|
|
||||||
if ((try db.execRow(&.{Uuid}, "SELECT id FROM community WHERE host = ?", .{host}, null)) != null) {
|
if ((try db.queryRow(&.{Uuid}, "SELECT id FROM community WHERE host = $1", .{host}, null)) != null) {
|
||||||
return error.CommunityExists;
|
return error.CommunityExists;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,7 +92,7 @@ fn firstIndexOf(str: []const u8, ch: u8) ?usize {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) !Community {
|
pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) !Community {
|
||||||
const result = (try db.execRow(&.{ Uuid, ?Uuid, []const u8, []const u8, Scheme, DateTime }, "SELECT id, owner_id, host, name, scheme, created_at FROM community WHERE host = ?", .{host}, alloc)) orelse return error.NotFound;
|
const result = (try db.queryRow(&.{ Uuid, ?Uuid, []const u8, []const u8, Scheme, DateTime }, "SELECT id, owner_id, host, name, scheme, created_at FROM community WHERE host = $1", .{host}, alloc)) orelse return error.NotFound;
|
||||||
|
|
||||||
return Community{
|
return Community{
|
||||||
.id = result[0],
|
.id = result[0],
|
||||||
|
@ -107,7 +105,7 @@ pub fn getByHost(db: anytype, host: []const u8, alloc: std.mem.Allocator) !Commu
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn transferOwnership(db: anytype, community_id: Uuid, new_owner: Uuid) !void {
|
pub fn transferOwnership(db: anytype, community_id: Uuid, new_owner: Uuid) !void {
|
||||||
_ = try db.execRow(&.{i64}, "UPDATE community SET owner_id = ? WHERE id = ?", .{ new_owner, community_id }, null);
|
try db.exec("UPDATE community SET owner_id = $1 WHERE id = $2", .{ new_owner, community_id }, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const QueryArgs = struct {
|
pub const QueryArgs = struct {
|
||||||
|
@ -238,7 +236,7 @@ pub fn query(db: anytype, args: QueryArgs, alloc: std.mem.Allocator) ![]Communit
|
||||||
max_items,
|
max_items,
|
||||||
};
|
};
|
||||||
|
|
||||||
var results = try db.exec(
|
var results = try db.query(
|
||||||
&.{ Uuid, ?Uuid, []const u8, []const u8, Scheme, DateTime },
|
&.{ Uuid, ?Uuid, []const u8, []const u8, Scheme, DateTime },
|
||||||
builder.array.items,
|
builder.array.items,
|
||||||
query_args,
|
query_args,
|
||||||
|
|
|
@ -2,7 +2,6 @@ const std = @import("std");
|
||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
const util = @import("util");
|
const util = @import("util");
|
||||||
const models = @import("../db/models.zig");
|
const models = @import("../db/models.zig");
|
||||||
const DbError = @import("../db.zig").ExecError;
|
|
||||||
const getRandom = @import("../api.zig").getRandom;
|
const getRandom = @import("../api.zig").getRandom;
|
||||||
|
|
||||||
const Uuid = util.Uuid;
|
const Uuid = util.Uuid;
|
||||||
|
@ -31,6 +30,7 @@ fn defaultJsonStringify(comptime T: type) fn (T, std.json.StringifyOptions, anyt
|
||||||
}.jsonStringify;
|
}.jsonStringify;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const InviteCount = u16;
|
||||||
pub const Invite = struct {
|
pub const Invite = struct {
|
||||||
id: Uuid,
|
id: Uuid,
|
||||||
|
|
||||||
|
@ -40,10 +40,10 @@ pub const Invite = struct {
|
||||||
code: []const u8,
|
code: []const u8,
|
||||||
|
|
||||||
created_at: DateTime,
|
created_at: DateTime,
|
||||||
times_used: usize,
|
times_used: InviteCount,
|
||||||
|
|
||||||
expires_at: ?DateTime,
|
expires_at: ?DateTime,
|
||||||
max_uses: ?usize,
|
max_uses: ?InviteCount,
|
||||||
|
|
||||||
invite_type: InviteType,
|
invite_type: InviteType,
|
||||||
};
|
};
|
||||||
|
@ -59,7 +59,7 @@ const DbModel = struct {
|
||||||
created_at: DateTime,
|
created_at: DateTime,
|
||||||
expires_at: ?DateTime,
|
expires_at: ?DateTime,
|
||||||
|
|
||||||
max_uses: ?usize,
|
max_uses: ?InviteCount,
|
||||||
|
|
||||||
@"type": InviteType,
|
@"type": InviteType,
|
||||||
};
|
};
|
||||||
|
@ -72,7 +72,7 @@ fn cloneStr(str: []const u8, alloc: std.mem.Allocator) ![]const u8 {
|
||||||
|
|
||||||
pub const InviteOptions = struct {
|
pub const InviteOptions = struct {
|
||||||
name: ?[]const u8 = null,
|
name: ?[]const u8 = null,
|
||||||
max_uses: ?usize = null,
|
max_uses: ?InviteCount = null,
|
||||||
expires_at: ?DateTime = null,
|
expires_at: ?DateTime = null,
|
||||||
invite_type: InviteType = .user,
|
invite_type: InviteType = .user,
|
||||||
};
|
};
|
||||||
|
@ -130,14 +130,14 @@ pub fn create(db: anytype, created_by: Uuid, to_community: ?Uuid, options: Invit
|
||||||
|
|
||||||
pub fn getByCode(db: anytype, code: []const u8, alloc: std.mem.Allocator) !Invite {
|
pub fn getByCode(db: anytype, code: []const u8, alloc: std.mem.Allocator) !Invite {
|
||||||
const code_clone = try cloneStr(code, alloc);
|
const code_clone = try cloneStr(code, alloc);
|
||||||
const info = (try db.execRow(&.{ Uuid, Uuid, Uuid, []const u8, DateTime, ?DateTime, usize, ?usize, InviteType },
|
const info = (try db.queryRow(&.{ Uuid, Uuid, Uuid, []const u8, DateTime, ?DateTime, InviteCount, ?InviteCount, InviteType },
|
||||||
\\SELECT
|
\\SELECT
|
||||||
\\ invite.id, invite.created_by, invite.to_community, invite.name,
|
\\ invite.id, invite.created_by, invite.to_community, invite.name,
|
||||||
\\ invite.created_at, invite.expires_at,
|
\\ invite.created_at, invite.expires_at,
|
||||||
\\ COUNT(local_user.user_id) as uses, invite.max_uses,
|
\\ COUNT(local_user.user_id) as uses, invite.max_uses,
|
||||||
\\ invite.type
|
\\ invite.type
|
||||||
\\FROM invite LEFT OUTER JOIN local_user ON invite.id = local_user.invite_id
|
\\FROM invite LEFT OUTER JOIN local_user ON invite.id = local_user.invite_id
|
||||||
\\WHERE invite.code = ?
|
\\WHERE invite.code = $1
|
||||||
\\GROUP BY invite.id
|
\\GROUP BY invite.id
|
||||||
, .{code}, alloc)) orelse return error.NotFound;
|
, .{code}, alloc)) orelse return error.NotFound;
|
||||||
|
|
||||||
|
|
|
@ -40,11 +40,11 @@ pub fn create(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !Note {
|
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !Note {
|
||||||
const result = (try db.execRow(
|
const result = (try db.queryRow(
|
||||||
&.{ Uuid, []const u8, DateTime },
|
&.{ Uuid, []const u8, DateTime },
|
||||||
\\SELECT author_id, content, created_at
|
\\SELECT author_id, content, created_at
|
||||||
\\FROM note
|
\\FROM note
|
||||||
\\WHERE id = ?
|
\\WHERE id = $1
|
||||||
\\LIMIT 1
|
\\LIMIT 1
|
||||||
,
|
,
|
||||||
.{id},
|
.{id},
|
||||||
|
|
|
@ -37,9 +37,9 @@ pub const CreateOptions = struct {
|
||||||
};
|
};
|
||||||
|
|
||||||
fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid {
|
fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid {
|
||||||
return if (try db.execRow(
|
return if (try db.queryRow(
|
||||||
&.{Uuid},
|
&.{Uuid},
|
||||||
"SELECT user.id FROM user WHERE community_id IS NULL AND username = ?",
|
"SELECT user.id FROM user WHERE community_id IS NULL AND username = $1",
|
||||||
.{username},
|
.{username},
|
||||||
null,
|
null,
|
||||||
)) |result|
|
)) |result|
|
||||||
|
@ -49,9 +49,9 @@ fn lookupSystemUserByUsername(db: anytype, username: []const u8) !?Uuid {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lookupUserByUsername(db: anytype, username: []const u8, community_id: Uuid) !?Uuid {
|
fn lookupUserByUsername(db: anytype, username: []const u8, community_id: Uuid) !?Uuid {
|
||||||
return if (try db.execRow(
|
return if (try db.queryRow(
|
||||||
&.{Uuid},
|
&.{Uuid},
|
||||||
"SELECT user.id FROM user WHERE community_id = ? AND username = ?",
|
"SELECT user.id FROM user WHERE community_id = $1 AND username = $2",
|
||||||
.{ community_id, username },
|
.{ community_id, username },
|
||||||
null,
|
null,
|
||||||
)) |result|
|
)) |result|
|
||||||
|
@ -107,11 +107,11 @@ pub const User = struct {
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !User {
|
pub fn get(db: anytype, id: Uuid, alloc: std.mem.Allocator) !User {
|
||||||
const result = (try db.execRow(
|
const result = (try db.queryRow(
|
||||||
&.{ []const u8, []const u8, Uuid, DateTime },
|
&.{ []const u8, []const u8, Uuid, DateTime },
|
||||||
\\SELECT user.username, community.host, community.id, user.created_at
|
\\SELECT user.username, community.host, community.id, user.created_at
|
||||||
\\FROM user JOIN community ON user.community_id = community.id
|
\\FROM user JOIN community ON user.community_id = community.id
|
||||||
\\WHERE user.id = ?
|
\\WHERE user.id = $1
|
||||||
\\LIMIT 1
|
\\LIMIT 1
|
||||||
,
|
,
|
||||||
.{id},
|
.{id},
|
||||||
|
|
|
@ -14,6 +14,7 @@ pub const login = struct {
|
||||||
pub const path = "/auth/login";
|
pub const path = "/auth/login";
|
||||||
pub const method = .POST;
|
pub const method = .POST;
|
||||||
pub fn handler(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void {
|
pub fn handler(srv: *RequestServer, ctx: *http.server.Context, _: RouteArgs) !void {
|
||||||
|
std.debug.print("{s}", .{ctx.request.body.?});
|
||||||
const credentials = try utils.parseRequestBody(struct { username: []const u8, password: []const u8 }, ctx);
|
const credentials = try utils.parseRequestBody(struct { username: []const u8, password: []const u8 }, ctx);
|
||||||
defer utils.freeRequestBody(credentials, ctx.alloc);
|
defer utils.freeRequestBody(credentials, ctx.alloc);
|
||||||
|
|
||||||
|
|
214
src/main/db.zig
214
src/main/db.zig
|
@ -1,214 +0,0 @@
|
||||||
const std = @import("std");
|
|
||||||
const sql = @import("sql");
|
|
||||||
const models = @import("./db/models.zig");
|
|
||||||
const migrations = @import("./db/migrations.zig");
|
|
||||||
const util = @import("util");
|
|
||||||
|
|
||||||
const Uuid = util.Uuid;
|
|
||||||
const DateTime = util.DateTime;
|
|
||||||
const String = []const u8;
|
|
||||||
const comptimePrint = std.fmt.comptimePrint;
|
|
||||||
|
|
||||||
fn readRow(comptime RowTuple: type, row: sql.Row, allocator: ?std.mem.Allocator) !RowTuple {
|
|
||||||
var result: RowTuple = undefined;
|
|
||||||
// TODO: undo allocations on failure
|
|
||||||
inline for (std.meta.fields(RowTuple)) |f, i| {
|
|
||||||
@field(result, f.name) = try getAlloc(row, f.field_type, i, allocator);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn ResultSet(comptime result_types: []const type) type {
|
|
||||||
return struct {
|
|
||||||
pub const Row = std.meta.Tuple(result_types);
|
|
||||||
|
|
||||||
_stmt: sql.PreparedStmt,
|
|
||||||
err: ?ExecError = null,
|
|
||||||
|
|
||||||
pub fn finish(self: *@This()) void {
|
|
||||||
self._stmt.finalize();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn row(self: *@This(), allocator: ?std.mem.Allocator) ?Row {
|
|
||||||
const sql_result = self._stmt.step() catch |err| {
|
|
||||||
self.err = err;
|
|
||||||
return null;
|
|
||||||
};
|
|
||||||
|
|
||||||
if (sql_result) |sql_row| {
|
|
||||||
return readRow(Row, sql_row, allocator) catch |err| {
|
|
||||||
self.err = err;
|
|
||||||
return null;
|
|
||||||
};
|
|
||||||
} else return null;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Binds a value to a parameter in the query. Use this instead of string
|
|
||||||
// concatenation to avoid injection attacks;
|
|
||||||
// If a given type is not supported by this function, you can add support by
|
|
||||||
// declaring a method with the given signature:
|
|
||||||
// pub fn bindToSql(val: T, stmt: sql.PreparedStmt, idx: u15) !void
|
|
||||||
// TODO define what error set this ^ should return
|
|
||||||
fn bind(stmt: sql.PreparedStmt, idx: u15, val: anytype) !void {
|
|
||||||
if (comptime std.meta.trait.isZigString(@TypeOf(val))) return stmt.bindText(idx, val);
|
|
||||||
|
|
||||||
return switch (@TypeOf(val)) {
|
|
||||||
i64 => stmt.bindI64(idx, val),
|
|
||||||
Uuid => stmt.bindUuid(idx, val),
|
|
||||||
DateTime => stmt.bindDateTime(idx, val),
|
|
||||||
@TypeOf(null) => stmt.bindNull(idx),
|
|
||||||
else => |T| switch (@typeInfo(T)) {
|
|
||||||
.Optional => if (val) |v| bind(stmt, idx, v) else stmt.bindNull(idx),
|
|
||||||
.Enum => stmt.bindText(idx, @tagName(val)),
|
|
||||||
.Struct, .Union, .Opaque => if (@hasDecl(T, "bindToSql"))
|
|
||||||
val.bindToSql(stmt, idx)
|
|
||||||
else
|
|
||||||
@compileError("unsupported type " ++ @typeName(T)),
|
|
||||||
.Int => stmt.bindI64(idx, @intCast(i64, val)),
|
|
||||||
else => @compileError("unsupported type " ++ @typeName(T)),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gets a value from the row, allocating memory if necessary.
|
|
||||||
// If a given type is not supported by this function, you can add support by
|
|
||||||
// declaring a method with the given signature:
|
|
||||||
// pub fn getFromSql(row: sql.Row, idx: u15, alloc: std.mem.Allocator) !T
|
|
||||||
// TODO define what error set this ^ should return
|
|
||||||
fn getAlloc(row: sql.Row, comptime T: type, idx: u15, alloc: ?std.mem.Allocator) !T {
|
|
||||||
return switch (T) {
|
|
||||||
[]u8, []const u8 => row.getTextAlloc(idx, alloc orelse return error.AllocatorRequired),
|
|
||||||
i64 => row.getI64(idx),
|
|
||||||
Uuid => row.getUuid(idx),
|
|
||||||
DateTime => row.getDateTime(idx),
|
|
||||||
|
|
||||||
else => switch (@typeInfo(T)) {
|
|
||||||
.Optional => if (try row.isNull(idx))
|
|
||||||
null
|
|
||||||
else
|
|
||||||
try getAlloc(row, std.meta.Child(T), idx, alloc),
|
|
||||||
|
|
||||||
.Struct, .Union, .Opaque => if (@hasDecl(T, "getFromSql"))
|
|
||||||
T.getFromSql(row, idx, alloc)
|
|
||||||
else
|
|
||||||
@compileError("unknown type " ++ @typeName(T)),
|
|
||||||
|
|
||||||
.Enum => try getEnum(row, T, idx),
|
|
||||||
|
|
||||||
.Int => @intCast(T, try row.getI64(idx)),
|
|
||||||
|
|
||||||
//else => unreachable,
|
|
||||||
else => @compileError("unknown type " ++ @typeName(T)),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
fn maxTagLen(comptime T: type) usize {
|
|
||||||
var max: usize = 0;
|
|
||||||
for (std.meta.fields(T)) |f| {
|
|
||||||
if (f.name.len > max) {
|
|
||||||
max = f.name.len;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return max;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn getEnum(row: sql.Row, comptime T: type, idx: u15) !T {
|
|
||||||
var tag_buf: [maxTagLen(T)]u8 = undefined;
|
|
||||||
const tag_name = try row.getText(idx, &tag_buf);
|
|
||||||
inline for (std.meta.fields(T)) |tag| {
|
|
||||||
if (std.mem.eql(u8, tag_name, tag.name)) return @intToEnum(T, tag.value);
|
|
||||||
}
|
|
||||||
|
|
||||||
return error.UnknownTag;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const ExecError = sql.PrepareError || sql.RowGetError || sql.BindError || std.mem.Allocator.Error || error{ AllocatorRequired, UnknownTag };
|
|
||||||
|
|
||||||
pub const Database = struct {
|
|
||||||
db: sql.Sqlite,
|
|
||||||
|
|
||||||
pub fn init(file_path: [:0]const u8) !Database {
|
|
||||||
var db = try sql.Sqlite.open(file_path);
|
|
||||||
errdefer db.close();
|
|
||||||
|
|
||||||
try migrations.up(&db);
|
|
||||||
|
|
||||||
return Database{ .db = db };
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deinit(self: *Database) void {
|
|
||||||
self.db.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn exec(
|
|
||||||
self: *Database,
|
|
||||||
comptime result_types: []const type,
|
|
||||||
comptime q: []const u8,
|
|
||||||
args: anytype,
|
|
||||||
) ExecError!ResultSet(result_types) {
|
|
||||||
std.log.debug("executing sql:\n===\n{s}\n===", .{q});
|
|
||||||
|
|
||||||
const stmt = try self.db.prepare(q);
|
|
||||||
errdefer stmt.finalize();
|
|
||||||
|
|
||||||
inline for (std.meta.fields(@TypeOf(args))) |field, i| {
|
|
||||||
try bind(stmt, @intCast(u15, i + 1), @field(args, field.name));
|
|
||||||
}
|
|
||||||
|
|
||||||
return ResultSet(result_types){
|
|
||||||
._stmt = stmt,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn execRow(
|
|
||||||
self: *Database,
|
|
||||||
comptime result_types: []const type,
|
|
||||||
comptime q: []const u8,
|
|
||||||
args: anytype,
|
|
||||||
allocator: ?std.mem.Allocator,
|
|
||||||
) ExecError!?ResultSet(result_types).Row {
|
|
||||||
var results = try self.exec(result_types, q, args);
|
|
||||||
defer results.finish();
|
|
||||||
|
|
||||||
const row = results.row(allocator);
|
|
||||||
std.log.debug("done exec", .{});
|
|
||||||
if (row) |r| return r;
|
|
||||||
if (results.err) |err| {
|
|
||||||
std.log.debug("{}", .{err});
|
|
||||||
std.log.debug("{?}", .{@errorReturnTrace()});
|
|
||||||
return err;
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_field_list(comptime T: type, comptime placeholder: ?[]const u8) []const u8 {
|
|
||||||
comptime {
|
|
||||||
const joiner = ",";
|
|
||||||
var result: []const u8 = "";
|
|
||||||
inline for (std.meta.fields(T)) |f| {
|
|
||||||
result = result ++ joiner ++ (placeholder orelse f.name);
|
|
||||||
}
|
|
||||||
|
|
||||||
return "(" ++ result[joiner.len..] ++ ")";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn insert(
|
|
||||||
self: *Database,
|
|
||||||
comptime table: []const u8,
|
|
||||||
value: anytype,
|
|
||||||
) ExecError!void {
|
|
||||||
const ValueType = comptime @TypeOf(value);
|
|
||||||
const table_spec = comptime table ++ build_field_list(ValueType, null);
|
|
||||||
const value_spec = comptime build_field_list(ValueType, "?");
|
|
||||||
const q = comptime std.fmt.comptimePrint(
|
|
||||||
"INSERT INTO {s} VALUES {s}",
|
|
||||||
.{ table_spec, value_spec },
|
|
||||||
);
|
|
||||||
_ = try self.execRow(&.{}, q, value, null);
|
|
||||||
}
|
|
||||||
};
|
|
|
@ -1,5 +1,6 @@
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
|
const sql = @import("sql");
|
||||||
const http = @import("http");
|
const http = @import("http");
|
||||||
const util = @import("util");
|
const util = @import("util");
|
||||||
|
|
||||||
|
@ -82,11 +83,7 @@ pub const RequestServer = struct {
|
||||||
|
|
||||||
pub const Config = struct {
|
pub const Config = struct {
|
||||||
cluster_host: []const u8,
|
cluster_host: []const u8,
|
||||||
db: struct {
|
db: sql.Config,
|
||||||
sqlite: struct {
|
|
||||||
db_file: [:0]const u8,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
root_password: ?[]const u8 = null,
|
root_password: ?[]const u8 = null,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -105,7 +102,8 @@ const root_password_envvar = "CLUSTER_ROOT_PASSWORD";
|
||||||
pub fn main() anyerror!void {
|
pub fn main() anyerror!void {
|
||||||
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||||
var cfg = try loadConfig(gpa.allocator());
|
var cfg = try loadConfig(gpa.allocator());
|
||||||
var api_src = api.ApiSource.init(gpa.allocator(), cfg, std.os.getenv(root_password_envvar)) catch |err| switch (err) {
|
var db_conn = try sql.Db.open(cfg.db);
|
||||||
|
var api_src = api.ApiSource.init(gpa.allocator(), cfg, std.os.getenv(root_password_envvar), db_conn) catch |err| switch (err) {
|
||||||
error.NeedRootPassword => {
|
error.NeedRootPassword => {
|
||||||
std.log.err(
|
std.log.err(
|
||||||
"No root user created and no password specified. Please provide the password for the root user by the ${s} environment variable for initial startup. This only needs to be done once",
|
"No root user created and no password specified. Please provide the password for the root user by the ${s} environment variable for initial startup. This only needs to be done once",
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
const std = @import("std");
|
||||||
const sql = @import("sql");
|
const sql = @import("sql");
|
||||||
const DateTime = @import("util").DateTime;
|
const DateTime = @import("util").DateTime;
|
||||||
|
|
||||||
pub const Migration = struct {
|
pub const Migration = struct {
|
||||||
name: []const u8,
|
name: [:0]const u8,
|
||||||
up: []const u8,
|
up: []const u8,
|
||||||
down: []const u8,
|
down: []const u8,
|
||||||
};
|
};
|
||||||
|
@ -15,54 +16,44 @@ fn firstIndexOf(str: []const u8, char: u8) ?usize {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn execStmt(db: *sql.Sqlite, stmt_sql: []const u8) !void {
|
fn execStmt(tx: sql.Tx, stmt: []const u8, alloc: std.mem.Allocator) !void {
|
||||||
const stmt = try db.prepare(stmt_sql);
|
const stmt_null = try std.cstr.addNullByte(alloc, stmt);
|
||||||
defer stmt.finalize();
|
defer alloc.free(stmt_null);
|
||||||
while (try stmt.step()) |_| {}
|
try tx.exec(stmt_null, .{}, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn execScript(db: *sql.Sqlite, script: []const u8) !void {
|
fn execScript(db: sql.Db, script: []const u8, alloc: std.mem.Allocator) !void {
|
||||||
try execStmt(db, "BEGIN;");
|
const tx = try db.begin();
|
||||||
errdefer {
|
errdefer tx.rollback();
|
||||||
_ = execStmt(db, "ROLLBACK;") catch unreachable;
|
|
||||||
}
|
|
||||||
|
|
||||||
var remaining = script;
|
var remaining = script;
|
||||||
while (firstIndexOf(remaining, ';')) |last| {
|
while (firstIndexOf(remaining, ';')) |last| {
|
||||||
try execStmt(db, remaining[0 .. last + 1]);
|
try execStmt(tx, remaining[0 .. last + 1], alloc);
|
||||||
|
|
||||||
remaining = remaining[last + 1 ..];
|
remaining = remaining[last + 1 ..];
|
||||||
}
|
}
|
||||||
|
if (remaining.len > 1) try execStmt(tx, remaining, alloc);
|
||||||
|
|
||||||
try execStmt(db, "COMMIT;");
|
try tx.commit();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn wasMigrationRan(db: *sql.Sqlite, name: []const u8) !bool {
|
fn wasMigrationRan(db: sql.Db, name: []const u8, alloc: std.mem.Allocator) !bool {
|
||||||
const stmt = try db.prepare("SELECT COUNT(*) FROM migration WHERE name = ?;");
|
const row = (try db.queryRow(&.{i32}, "SELECT COUNT(*) FROM migration WHERE name = $1", .{name}, alloc)) orelse return false;
|
||||||
defer stmt.finalize();
|
return row[0] != 0;
|
||||||
|
|
||||||
try stmt.bindText(1, name);
|
|
||||||
const result = (try stmt.step()).?;
|
|
||||||
|
|
||||||
const count = try result.getI64(0);
|
|
||||||
return count != 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn markMigrationAsRan(db: *sql.Sqlite, name: []const u8) !void {
|
pub fn up(db: sql.Db) !void {
|
||||||
const stmt = try db.prepare("INSERT INTO migration(name) VALUES(?);");
|
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||||||
defer stmt.finalize();
|
defer _ = gpa.deinit();
|
||||||
|
std.log.info("Running migrations...", .{});
|
||||||
try stmt.bindText(1, name);
|
try execScript(db, create_migration_table, gpa.allocator());
|
||||||
_ = try stmt.step();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn up(db: *sql.Sqlite) !void {
|
|
||||||
try execScript(db, create_migration_table);
|
|
||||||
|
|
||||||
for (migrations) |migration| {
|
for (migrations) |migration| {
|
||||||
if (!try wasMigrationRan(db, migration.name)) {
|
const was_ran = try wasMigrationRan(db, migration.name, gpa.allocator());
|
||||||
try execScript(db, migration.up);
|
if (!was_ran) {
|
||||||
try markMigrationAsRan(db, migration.name);
|
std.log.info("Running migration {s}", .{migration.name});
|
||||||
|
try execScript(db, migration.up, gpa.allocator());
|
||||||
|
try db.insert("migration", .{ .name = migration.name });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -71,7 +62,7 @@ const create_migration_table =
|
||||||
\\CREATE TABLE IF NOT EXISTS
|
\\CREATE TABLE IF NOT EXISTS
|
||||||
\\migration(
|
\\migration(
|
||||||
\\ name TEXT NOT NULL PRIMARY KEY,
|
\\ name TEXT NOT NULL PRIMARY KEY,
|
||||||
\\ applied_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
\\ applied_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
\\);
|
\\);
|
||||||
;
|
;
|
||||||
|
|
||||||
|
@ -85,7 +76,7 @@ const migrations: []const Migration = &.{
|
||||||
\\ id TEXT NOT NULL PRIMARY KEY,
|
\\ id TEXT NOT NULL PRIMARY KEY,
|
||||||
\\ username TEXT NOT NULL,
|
\\ username TEXT NOT NULL,
|
||||||
\\
|
\\
|
||||||
\\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
\\);
|
\\);
|
||||||
\\
|
\\
|
||||||
\\CREATE TABLE local_user(
|
\\CREATE TABLE local_user(
|
||||||
|
@ -115,7 +106,7 @@ const migrations: []const Migration = &.{
|
||||||
\\ content TEXT NOT NULL,
|
\\ content TEXT NOT NULL,
|
||||||
\\ author_id TEXT NOT NULL REFERENCES user(id),
|
\\ author_id TEXT NOT NULL REFERENCES user(id),
|
||||||
\\
|
\\
|
||||||
\\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
\\);
|
\\);
|
||||||
,
|
,
|
||||||
.down = "DROP TABLE note;",
|
.down = "DROP TABLE note;",
|
||||||
|
@ -129,7 +120,7 @@ const migrations: []const Migration = &.{
|
||||||
\\ user_id TEXT NOT NULL REFERENCES user(id),
|
\\ user_id TEXT NOT NULL REFERENCES user(id),
|
||||||
\\ note_id TEXT NOT NULL REFERENCES note(id),
|
\\ note_id TEXT NOT NULL REFERENCES note(id),
|
||||||
\\
|
\\
|
||||||
\\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
\\);
|
\\);
|
||||||
,
|
,
|
||||||
.down = "DROP TABLE reaction;",
|
.down = "DROP TABLE reaction;",
|
||||||
|
@ -141,7 +132,7 @@ const migrations: []const Migration = &.{
|
||||||
\\ hash TEXT NOT NULL PRIMARY KEY,
|
\\ hash TEXT NOT NULL PRIMARY KEY,
|
||||||
\\ user_id TEXT NOT NULL REFERENCES local_user(id),
|
\\ user_id TEXT NOT NULL REFERENCES local_user(id),
|
||||||
\\
|
\\
|
||||||
\\ issued_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
\\ issued_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
\\);
|
\\);
|
||||||
,
|
,
|
||||||
.down = "DROP TABLE token;",
|
.down = "DROP TABLE token;",
|
||||||
|
@ -158,8 +149,8 @@ const migrations: []const Migration = &.{
|
||||||
\\
|
\\
|
||||||
\\ max_uses INTEGER,
|
\\ max_uses INTEGER,
|
||||||
\\
|
\\
|
||||||
\\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
\\ expires_at DATETIME,
|
\\ expires_at TIMESTAMPTZ,
|
||||||
\\
|
\\
|
||||||
\\ type TEXT NOT NULL CHECK (type in ('system', 'community_owner', 'user'))
|
\\ type TEXT NOT NULL CHECK (type in ('system', 'community_owner', 'user'))
|
||||||
\\);
|
\\);
|
||||||
|
@ -181,7 +172,7 @@ const migrations: []const Migration = &.{
|
||||||
\\ host TEXT NOT NULL UNIQUE,
|
\\ host TEXT NOT NULL UNIQUE,
|
||||||
\\ scheme TEXT NOT NULL CHECK (scheme IN ('http', 'https')),
|
\\ scheme TEXT NOT NULL CHECK (scheme IN ('http', 'https')),
|
||||||
\\
|
\\
|
||||||
\\ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
\\ created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
\\);
|
\\);
|
||||||
\\ALTER TABLE user ADD COLUMN community_id TEXT REFERENCES community(id);
|
\\ALTER TABLE user ADD COLUMN community_id TEXT REFERENCES community(id);
|
||||||
\\ALTER TABLE invite ADD COLUMN to_community TEXT REFERENCES community(id);
|
\\ALTER TABLE invite ADD COLUMN to_community TEXT REFERENCES community(id);
|
|
@ -5,10 +5,30 @@ const Uuid = util.Uuid;
|
||||||
const DateTime = util.DateTime;
|
const DateTime = util.DateTime;
|
||||||
const Allocator = std.mem.Allocator;
|
const Allocator = std.mem.Allocator;
|
||||||
|
|
||||||
|
pub const QueryOptions = struct {
|
||||||
|
// If true, then it will not return an error on the SQLite backend
|
||||||
|
// if an argument passed does not map to a parameter in the query.
|
||||||
|
// Has no effect on the postgres backend.
|
||||||
|
ignore_unknown_parameters: bool = false,
|
||||||
|
|
||||||
|
// The allocator to use for query preparation and submission.
|
||||||
|
// All memory allocated with this allocator will be freed before results
|
||||||
|
// are retrieved.
|
||||||
|
// Some types (enums with constant representation, null terminated strings)
|
||||||
|
// do not require allocators for prep. If an allocator is needed but not
|
||||||
|
// provided, `error.AllocatorRequired` will be returned.
|
||||||
|
// Only used with the postgres backend.
|
||||||
|
prep_allocator: ?Allocator = null,
|
||||||
|
};
|
||||||
|
|
||||||
// Turns a value into its appropriate textual value (or null)
|
// Turns a value into its appropriate textual value (or null)
|
||||||
// as appropriate using the given arena allocator
|
// as appropriate using the given arena allocator
|
||||||
pub fn prepareParamText(arena: std.heap.ArenaAllocator, val: anytype) !?[:0]const u8 {
|
pub fn prepareParamText(arena: *std.heap.ArenaAllocator, val: anytype) !?[:0]const u8 {
|
||||||
if (comptime std.meta.trait.isZigString(@TypeOf(val))) return val;
|
if (comptime std.meta.trait.isZigString(@TypeOf(val))) {
|
||||||
|
if (comptime std.meta.sentinel(@TypeOf(val))) |s| if (comptime s == 0) return val;
|
||||||
|
|
||||||
|
return try std.cstr.addNullByte(arena.allocator(), val);
|
||||||
|
}
|
||||||
|
|
||||||
return switch (@TypeOf(val)) {
|
return switch (@TypeOf(val)) {
|
||||||
[:0]u8, [:0]const u8 => val,
|
[:0]u8, [:0]const u8 => val,
|
||||||
|
@ -36,7 +56,9 @@ pub fn parseValueNotNull(alloc: ?Allocator, comptime T: type, str: []const u8) !
|
||||||
[]u8, []const u8 => if (alloc) |a| util.deepClone(a, str) else return error.AllocatorRequired,
|
[]u8, []const u8 => if (alloc) |a| util.deepClone(a, str) else return error.AllocatorRequired,
|
||||||
|
|
||||||
else => switch (@typeInfo(T)) {
|
else => switch (@typeInfo(T)) {
|
||||||
.Enum => parseEnum(T, str),
|
.Int => std.fmt.parseInt(T, str, 0),
|
||||||
|
.Enum => std.meta.stringToEnum(T, str) orelse return error.InvalidValue,
|
||||||
|
.Optional => try parseValueNotNull(alloc, std.meta.Child(T), str),
|
||||||
|
|
||||||
else => @compileError("Type " ++ @typeName(T) ++ " not supported"),
|
else => @compileError("Type " ++ @typeName(T) ++ " not supported"),
|
||||||
},
|
},
|
||||||
|
|
|
@ -3,8 +3,11 @@ const util = @import("util");
|
||||||
|
|
||||||
const postgres = @import("./postgres.zig");
|
const postgres = @import("./postgres.zig");
|
||||||
const sqlite = @import("./sqlite.zig");
|
const sqlite = @import("./sqlite.zig");
|
||||||
|
const common = @import("./common.zig");
|
||||||
const Allocator = std.mem.Allocator;
|
const Allocator = std.mem.Allocator;
|
||||||
|
|
||||||
|
pub const QueryOptions = common.QueryOptions;
|
||||||
|
|
||||||
pub const Type = enum {
|
pub const Type = enum {
|
||||||
postgres,
|
postgres,
|
||||||
sqlite,
|
sqlite,
|
||||||
|
@ -12,10 +15,10 @@ pub const Type = enum {
|
||||||
|
|
||||||
pub const Config = union(Type) {
|
pub const Config = union(Type) {
|
||||||
postgres: struct {
|
postgres: struct {
|
||||||
conn_str: [:0]const u8,
|
pg_conn_str: [:0]const u8,
|
||||||
},
|
},
|
||||||
sqlite: struct {
|
sqlite: struct {
|
||||||
file_path: [:0]const u8,
|
sqlite_file_path: [:0]const u8,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -106,12 +109,12 @@ pub const Db = struct {
|
||||||
return switch (cfg) {
|
return switch (cfg) {
|
||||||
.postgres => |postgres_cfg| Db{
|
.postgres => |postgres_cfg| Db{
|
||||||
.underlying = .{
|
.underlying = .{
|
||||||
.postgres = try postgres.Db.open(postgres_cfg.conn_str),
|
.postgres = try postgres.Db.open(postgres_cfg.pg_conn_str),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
.sqlite => |lite_cfg| Db{
|
.sqlite => |lite_cfg| Db{
|
||||||
.underlying = .{
|
.underlying = .{
|
||||||
.sqlite = try sqlite.Db.open(lite_cfg.file_path),
|
.sqlite = try sqlite.Db.open(lite_cfg.sqlite_file_path),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
@ -124,6 +127,17 @@ pub const Db = struct {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn queryWithOptions(
|
||||||
|
self: Db,
|
||||||
|
comptime result_types: []const type,
|
||||||
|
sql: [:0]const u8,
|
||||||
|
args: anytype,
|
||||||
|
opt: QueryOptions,
|
||||||
|
) !Results(result_types) {
|
||||||
|
// Create fake transaction to use its functions
|
||||||
|
return (Tx{ .underlying = self.underlying }).queryWithOptions(result_types, sql, args, opt);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn query(
|
pub fn query(
|
||||||
self: Db,
|
self: Db,
|
||||||
comptime result_types: []const type,
|
comptime result_types: []const type,
|
||||||
|
@ -153,7 +167,7 @@ pub const Db = struct {
|
||||||
alloc: ?Allocator,
|
alloc: ?Allocator,
|
||||||
) !?Results(result_types).RowTuple {
|
) !?Results(result_types).RowTuple {
|
||||||
// Create fake transaction to use its functions
|
// Create fake transaction to use its functions
|
||||||
return (Tx{ .underlying = self.underlying }).exec(sql, args, alloc);
|
return (Tx{ .underlying = self.underlying }).queryRow(result_types, sql, args, alloc);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn insert(
|
pub fn insert(
|
||||||
|
@ -182,14 +196,24 @@ pub const Tx = struct {
|
||||||
self: Tx,
|
self: Tx,
|
||||||
sql: [:0]const u8,
|
sql: [:0]const u8,
|
||||||
args: anytype,
|
args: anytype,
|
||||||
alloc: ?Allocator,
|
opt: QueryOptions,
|
||||||
) !RawResults {
|
) !RawResults {
|
||||||
return switch (self.underlying) {
|
return switch (self.underlying) {
|
||||||
.postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, alloc) },
|
.postgres => |pg| RawResults{ .postgres = try pg.exec(sql, args, opt.prep_allocator) },
|
||||||
.sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args) },
|
.sqlite => |lite| RawResults{ .sqlite = try lite.exec(sql, args, opt) },
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn queryWithOptions(
|
||||||
|
self: Tx,
|
||||||
|
comptime result_types: []const type,
|
||||||
|
sql: [:0]const u8,
|
||||||
|
args: anytype,
|
||||||
|
options: QueryOptions,
|
||||||
|
) !Results(result_types) {
|
||||||
|
return Results(result_types){ .underlying = try self.queryInternal(sql, args, options) };
|
||||||
|
}
|
||||||
|
|
||||||
// Executes a query and returns the result set
|
// Executes a query and returns the result set
|
||||||
pub fn query(
|
pub fn query(
|
||||||
self: Tx,
|
self: Tx,
|
||||||
|
@ -198,7 +222,7 @@ pub const Tx = struct {
|
||||||
args: anytype,
|
args: anytype,
|
||||||
alloc: ?Allocator,
|
alloc: ?Allocator,
|
||||||
) !Results(result_types) {
|
) !Results(result_types) {
|
||||||
return Results(result_types){ .unerlying = try self.queryInternal(sql, args, alloc) };
|
return self.queryWithOptions(result_types, sql, args, .{ .prep_allocator = alloc });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Executes a query without returning results
|
// Executes a query without returning results
|
||||||
|
@ -208,7 +232,7 @@ pub const Tx = struct {
|
||||||
args: anytype,
|
args: anytype,
|
||||||
alloc: ?Allocator,
|
alloc: ?Allocator,
|
||||||
) !void {
|
) !void {
|
||||||
(try self.queryInternal(sql, args, alloc)).finish();
|
_ = try self.queryRow(&.{}, sql, args, alloc);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Runs a query and returns a single row
|
// Runs a query and returns a single row
|
||||||
|
@ -242,13 +266,28 @@ pub const Tx = struct {
|
||||||
value: anytype,
|
value: anytype,
|
||||||
) !void {
|
) !void {
|
||||||
const ValueType = comptime @TypeOf(value);
|
const ValueType = comptime @TypeOf(value);
|
||||||
const table_spec = comptime table ++ build_field_list(ValueType, null);
|
|
||||||
const value_spec = comptime build_field_list(ValueType, "?");
|
const fields = std.meta.fields(ValueType);
|
||||||
|
comptime var types: [fields.len]type = undefined;
|
||||||
|
comptime var table_spec: []const u8 = table ++ "(";
|
||||||
|
comptime var value_spec: []const u8 = "(";
|
||||||
|
inline for (fields) |field, i| {
|
||||||
|
types[i] = field.field_type;
|
||||||
|
table_spec = comptime (table_spec ++ field.name ++ ",");
|
||||||
|
value_spec = comptime value_spec ++ std.fmt.comptimePrint("${},", .{i + 1});
|
||||||
|
}
|
||||||
|
table_spec = comptime table_spec[0 .. table_spec.len - 1] ++ ")";
|
||||||
|
value_spec = comptime value_spec[0 .. value_spec.len - 1] ++ ")";
|
||||||
const q = comptime std.fmt.comptimePrint(
|
const q = comptime std.fmt.comptimePrint(
|
||||||
"INSERT INTO {s} VALUES {s}",
|
"INSERT INTO {s} VALUES {s}",
|
||||||
.{ table_spec, value_spec },
|
.{ table_spec, value_spec },
|
||||||
);
|
);
|
||||||
try self.exec(q, value, null);
|
|
||||||
|
var args_tuple: std.meta.Tuple(&types) = undefined;
|
||||||
|
inline for (fields) |field, i| {
|
||||||
|
args_tuple[i] = @field(value, field.name);
|
||||||
|
}
|
||||||
|
try self.exec(q, args_tuple, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rollback(self: Tx) void {
|
pub fn rollback(self: Tx) void {
|
||||||
|
@ -261,15 +300,3 @@ pub const Tx = struct {
|
||||||
try self.exec("COMMIT", .{}, null);
|
try self.exec("COMMIT", .{}, null);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
fn build_field_list(comptime T: type, comptime placeholder: ?[]const u8) []const u8 {
|
|
||||||
comptime {
|
|
||||||
const joiner = ",";
|
|
||||||
var result: []const u8 = "";
|
|
||||||
inline for (std.meta.fields(T)) |f| {
|
|
||||||
result = result ++ joiner ++ (placeholder orelse f.name);
|
|
||||||
}
|
|
||||||
|
|
||||||
return "(" ++ result[joiner.len..] ++ ")";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -90,7 +90,7 @@ pub const Db = struct {
|
||||||
var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired);
|
var arena = std.heap.ArenaAllocator.init(alloc orelse return error.AllocatorRequired);
|
||||||
defer arena.deinit();
|
defer arena.deinit();
|
||||||
const params = try arena.allocator().alloc(?[*]const u8, args.len);
|
const params = try arena.allocator().alloc(?[*]const u8, args.len);
|
||||||
inline for (args) |a, i| params[i] = if (try common.prepareParamText(arena, a)) |slice| slice.ptr else null;
|
inline for (args) |a, i| params[i] = if (try common.prepareParamText(&arena, a)) |slice| slice.ptr else null;
|
||||||
|
|
||||||
break :blk c.PQexecParams(self.conn, sql.ptr, @intCast(c_int, params.len), null, params.ptr, null, null, format_text);
|
break :blk c.PQexecParams(self.conn, sql.ptr, @intCast(c_int, params.len), null, params.ptr, null, null, format_text);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -103,7 +103,7 @@ pub const Db = struct {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn exec(self: Db, sql: []const u8, args: anytype) !Results {
|
pub fn exec(self: Db, sql: []const u8, args: anytype, opts: common.QueryOptions) !Results {
|
||||||
var stmt: ?*c.sqlite3_stmt = undefined;
|
var stmt: ?*c.sqlite3_stmt = undefined;
|
||||||
switch (c.sqlite3_prepare_v2(self.db, sql.ptr, @intCast(c_int, sql.len), &stmt, null)) {
|
switch (c.sqlite3_prepare_v2(self.db, sql.ptr, @intCast(c_int, sql.len), &stmt, null)) {
|
||||||
c.SQLITE_OK => {},
|
c.SQLITE_OK => {},
|
||||||
|
@ -134,7 +134,7 @@ pub const Db = struct {
|
||||||
return handleUnexpectedError(self.db, err, sql);
|
return handleUnexpectedError(self.db, err, sql);
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
} else unreachable;
|
} else if (!opts.ignore_unknown_parameters) return error.UnknownParameter;
|
||||||
}
|
}
|
||||||
|
|
||||||
return Results{ .stmt = stmt.?, .db = self.db };
|
return Results{ .stmt = stmt.?, .db = self.db };
|
||||||
|
@ -216,7 +216,7 @@ pub const Row = struct {
|
||||||
@intCast(T, c.sqlite3_column_int64(self.stmt, idx))
|
@intCast(T, c.sqlite3_column_int64(self.stmt, idx))
|
||||||
else
|
else
|
||||||
self.getFromString(T, idx, alloc),
|
self.getFromString(T, idx, alloc),
|
||||||
.Optional => self.getNotNull(std.meta.Child(T), idx, alloc),
|
.Optional => try self.getNotNull(std.meta.Child(T), idx, alloc),
|
||||||
else => self.getFromString(T, idx, alloc),
|
else => self.getFromString(T, idx, alloc),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
41
src/sql/test.zig
Normal file
41
src/sql/test.zig
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
const sql = @import("./lib.zig");
|
||||||
|
const std = @import("std");
|
||||||
|
const Uuid = @import("util").Uuid;
|
||||||
|
|
||||||
|
const alloc = std.testing.allocator;
|
||||||
|
|
||||||
|
pub fn main() !void {
|
||||||
|
const db = try sql.Db.open(.{
|
||||||
|
//.postgres = .{
|
||||||
|
//.conn_str = "postgresql://localhost",
|
||||||
|
//},
|
||||||
|
.sqlite = .{
|
||||||
|
.file_path = "./test.db",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
defer db.close();
|
||||||
|
|
||||||
|
const tx = try db.begin();
|
||||||
|
try tx.commit();
|
||||||
|
tx.rollback();
|
||||||
|
try tx.commit();
|
||||||
|
}
|
||||||
|
|
||||||
|
test {
|
||||||
|
const db = try sql.Db.open(.{
|
||||||
|
.sqlite = .{
|
||||||
|
.file_path = "./test.db",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
defer db.close();
|
||||||
|
|
||||||
|
var results = try db.query(&.{[]const u8}, "SELECT $1 as id", .{"abcdefg"}, alloc);
|
||||||
|
defer results.finish();
|
||||||
|
|
||||||
|
const row = (try results.row(alloc)) orelse unreachable;
|
||||||
|
defer alloc.free(row[0]);
|
||||||
|
|
||||||
|
try std.testing.expectEqualStrings("abcdefg", row[0]);
|
||||||
|
|
||||||
|
std.log.info("value: {s}", .{row[0]});
|
||||||
|
}
|
|
@ -1,9 +1,48 @@
|
||||||
const DateTime = @This();
|
const DateTime = @This();
|
||||||
|
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
const epoch = std.time.epoch;
|
||||||
|
|
||||||
seconds_since_epoch: i64,
|
seconds_since_epoch: i64,
|
||||||
|
|
||||||
|
pub fn parse(str: []const u8) !DateTime {
|
||||||
|
// TODO: Try other formats
|
||||||
|
|
||||||
|
return try parseRfc3339(str);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Validate non-numeric aspects of datetime
|
||||||
|
// TODO: Don't panic on bad string
|
||||||
|
// TODO: Make seconds optional (see ActivityStreams 2.0 spec §2.3)
|
||||||
|
// TODO: Handle times before 1970
|
||||||
|
pub fn parseRfc3339(str: []const u8) !DateTime {
|
||||||
|
const year_num = try std.fmt.parseInt(u16, str[0..4], 10);
|
||||||
|
const month_num = try std.fmt.parseInt(std.meta.Tag(epoch.Month), str[5..7], 10);
|
||||||
|
const day_num = @as(i64, try std.fmt.parseInt(u9, str[8..10], 10));
|
||||||
|
const hour_num = @as(i64, try std.fmt.parseInt(u5, str[11..13], 10));
|
||||||
|
const minute_num = @as(i64, try std.fmt.parseInt(u6, str[14..15], 10));
|
||||||
|
const second_num = @as(i64, try std.fmt.parseInt(u6, str[16..17], 10));
|
||||||
|
|
||||||
|
const is_leap_year = epoch.isLeapYear(year_num);
|
||||||
|
const leap_days_preceding_epoch = comptime epoch.epoch_year / 4 - epoch.epoch_year / 100 + epoch.epoch_year / 400;
|
||||||
|
const leap_days_preceding_year = year_num / 4 - year_num / 100 + year_num / 400 - leap_days_preceding_epoch - if (is_leap_year) @as(i64, 1) else 0;
|
||||||
|
|
||||||
|
const epoch_day = (year_num - epoch.epoch_year) * 365 + leap_days_preceding_year + year_day: {
|
||||||
|
var days_preceding_month: i64 = 0;
|
||||||
|
var month_i: i64 = 1;
|
||||||
|
while (month_i < month_num) : (month_i += 1) {
|
||||||
|
days_preceding_month += epoch.getDaysInMonth(if (is_leap_year) .leap else .not_leap, @intToEnum(epoch.Month, month_i));
|
||||||
|
}
|
||||||
|
break :year_day days_preceding_month + day_num;
|
||||||
|
};
|
||||||
|
|
||||||
|
const day_second = (hour_num * 60 + minute_num) * 60 + second_num;
|
||||||
|
|
||||||
|
return DateTime{
|
||||||
|
.seconds_since_epoch = epoch_day * epoch.secs_per_day + day_second,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
pub fn now() DateTime {
|
pub fn now() DateTime {
|
||||||
return .{ .seconds_since_epoch = std.time.timestamp() };
|
return .{ .seconds_since_epoch = std.time.timestamp() };
|
||||||
}
|
}
|
||||||
|
@ -40,10 +79,24 @@ pub fn second(value: DateTime) u6 {
|
||||||
return value.epochSeconds().getDaySeconds().getSecondsIntoMinute();
|
return value.epochSeconds().getDaySeconds().getSecondsIntoMinute();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const array_len = 20;
|
||||||
|
|
||||||
|
pub fn toCharArray(value: DateTime) [array_len + 1]u8 {
|
||||||
|
var buf: [array_len]u8 = undefined;
|
||||||
|
_ = std.fmt.bufPrintZ(&buf, "{}", value) catch unreachable;
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn toCharArrayZ(value: DateTime) [array_len + 1:0]u8 {
|
||||||
|
var buf: [array_len + 1:0]u8 = undefined;
|
||||||
|
_ = std.fmt.bufPrintZ(&buf, "{}", value) catch unreachable;
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
|
||||||
pub fn format(value: DateTime, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
|
pub fn format(value: DateTime, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
|
||||||
return std.fmt.format(
|
return std.fmt.format(
|
||||||
writer,
|
writer,
|
||||||
"{}-{}-{} {}:{}:{}",
|
"{:0>4}-{:0>2}-{:0>2}T{:0>2}:{:0>2}:{:0>2}Z",
|
||||||
.{ value.year(), value.month().numeric(), value.day(), value.hour(), value.minute(), value.second() },
|
.{ value.year(), value.month().numeric(), value.day(), value.hour(), value.minute(), value.second() },
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,7 +58,7 @@ pub const ParseError = error{
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn parse(str: []const u8) ParseError!Uuid {
|
pub fn parse(str: []const u8) ParseError!Uuid {
|
||||||
if (str.len != string_len) return error.InvalidLength;
|
if (str.len != string_len and (str.len != string_len + 1 or str[str.len - 1] != 0)) return error.InvalidLength;
|
||||||
|
|
||||||
var data: [16]u8 = undefined;
|
var data: [16]u8 = undefined;
|
||||||
var str_i: usize = 0;
|
var str_i: usize = 0;
|
||||||
|
|
|
@ -24,11 +24,11 @@ pub fn deepFree(alloc: ?std.mem.Allocator, val: anytype) void {
|
||||||
else => @compileError("Many and C-style pointers not supported by deepfree"),
|
else => @compileError("Many and C-style pointers not supported by deepfree"),
|
||||||
},
|
},
|
||||||
.Optional => if (val) |v| deepFree(alloc, v) else {},
|
.Optional => if (val) |v| deepFree(alloc, v) else {},
|
||||||
.Struct => |struct_info| for (struct_info.fields) |field| deepFree(alloc, @field(val, field.name)),
|
.Struct => |struct_info| inline for (struct_info.fields) |field| deepFree(alloc, @field(val, field.name)),
|
||||||
.Union, .ErrorUnion => @compileError("TODO: Unions not yet supported by deepFree"),
|
.Union, .ErrorUnion => @compileError("TODO: Unions not yet supported by deepFree"),
|
||||||
.Array => for (val) |v| deepFree(alloc, v),
|
.Array => for (val) |v| deepFree(alloc, v),
|
||||||
|
|
||||||
.Int, .Float, .Bool, .Void, .Type => {},
|
.Enum, .Int, .Float, .Bool, .Void, .Type => {},
|
||||||
|
|
||||||
else => @compileError("Type " ++ @typeName(T) ++ " not supported by deepFree"),
|
else => @compileError("Type " ++ @typeName(T) ++ " not supported by deepFree"),
|
||||||
}
|
}
|
||||||
|
@ -87,7 +87,7 @@ pub fn deepClone(alloc: std.mem.Allocator, val: anytype) !@TypeOf(val) {
|
||||||
count += 1;
|
count += 1;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
.Int, .Float, .Bool, .Void, .Type => {
|
.Enum, .Int, .Float, .Bool, .Void, .Type => {
|
||||||
result = val;
|
result = val;
|
||||||
},
|
},
|
||||||
else => @compileError("Type " ++ @typeName(T) ++ " not supported"),
|
else => @compileError("Type " ++ @typeName(T) ++ " not supported"),
|
||||||
|
|
Loading…
Reference in a new issue