Compare commits

...

2 commits

Author SHA1 Message Date
0ce315368a Minor HTTP refactor 2022-10-13 02:23:57 -07:00
159f1c28cc Add sql connection pool 2022-10-12 23:19:59 -07:00
14 changed files with 317 additions and 229 deletions

View file

@ -94,15 +94,15 @@ pub fn setupAdmin(db: sql.Db, origin: []const u8, username: []const u8, password
} }
pub const ApiSource = struct { pub const ApiSource = struct {
db_conn: *sql.Conn, db_conn_pool: *sql.ConnPool,
pub const Conn = ApiConn(sql.Db); pub const Conn = ApiConn(sql.Db);
const root_username = "root"; const root_username = "root";
pub fn init(db_conn: *sql.Conn) !ApiSource { pub fn init(pool: *sql.ConnPool) !ApiSource {
return ApiSource{ return ApiSource{
.db_conn = db_conn, .db_conn_pool = pool,
}; };
} }
@ -110,7 +110,7 @@ pub const ApiSource = struct {
var arena = std.heap.ArenaAllocator.init(alloc); var arena = std.heap.ArenaAllocator.init(alloc);
errdefer arena.deinit(); errdefer arena.deinit();
const db = try self.db_conn.acquire(); const db = try self.db_conn_pool.acquire();
const community = try services.communities.getByHost(db, host, arena.allocator()); const community = try services.communities.getByHost(db, host, arena.allocator());
return Conn{ return Conn{
@ -125,7 +125,7 @@ pub const ApiSource = struct {
var arena = std.heap.ArenaAllocator.init(alloc); var arena = std.heap.ArenaAllocator.init(alloc);
errdefer arena.deinit(); errdefer arena.deinit();
const db = try self.db_conn.acquire(); const db = try self.db_conn_pool.acquire();
const community = try services.communities.getByHost(db, host, arena.allocator()); const community = try services.communities.getByHost(db, host, arena.allocator());
const token_info = try services.auth.verifyToken( const token_info = try services.auth.verifyToken(
@ -157,6 +157,7 @@ fn ApiConn(comptime DbConn: type) type {
pub fn close(self: *Self) void { pub fn close(self: *Self) void {
self.arena.deinit(); self.arena.deinit();
self.db.releaseConnection();
} }
fn isAdmin(self: *Self) bool { fn isAdmin(self: *Self) bool {

View file

@ -3,13 +3,15 @@ const ciutf8 = @import("util").ciutf8;
const request = @import("./request.zig"); const request = @import("./request.zig");
pub const server = @import("./server.zig"); const server = @import("./server.zig");
pub const Method = std.http.Method; pub const Method = std.http.Method;
pub const Status = std.http.Status; pub const Status = std.http.Status;
pub const Request = request.Request; pub const Request = request.Request;
pub const Server = server.Server; pub const serveConn = server.serveConn;
pub const Response = server.Response;
pub const Handler = server.Handler;
pub const Headers = std.HashMap([]const u8, []const u8, struct { pub const Headers = std.HashMap([]const u8, []const u8, struct {
pub fn eql(_: @This(), a: []const u8, b: []const u8) bool { pub fn eql(_: @This(), a: []const u8, b: []const u8) bool {

View file

@ -4,13 +4,25 @@ const http = @import("./lib.zig");
const parser = @import("./request/parser.zig"); const parser = @import("./request/parser.zig");
pub const Request = struct { pub const Request = struct {
pub const Protocol = enum {
http_1_0,
http_1_1,
};
protocol: Protocol,
source_address: ?std.net.Address,
method: http.Method, method: http.Method,
path: []const u8, uri: []const u8,
headers: http.Headers, headers: http.Headers,
body: ?[]const u8 = null, body: ?[]const u8 = null,
pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request { pub fn parse(alloc: std.mem.Allocator, reader: anytype, addr: std.net.Address) !Request {
return parser.parse(alloc, reader); return parser.parse(alloc, reader, addr);
}
pub fn parseFree(self: Request, alloc: std.mem.Allocator) void {
parser.parseFree(alloc, self);
} }
}; };

View file

@ -22,34 +22,45 @@ const Encoding = enum {
chunked, chunked,
}; };
pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request { pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address) !Request {
var request: Request = undefined; const method = try parseMethod(reader);
const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) {
try parseLine(alloc, &request, reader);
request.headers = try parseHeaders(alloc, reader);
if (request.method.requestHasBody()) {
request.body = try readBody(alloc, request.headers, reader);
} else {
request.body = null;
}
return request;
}
fn parseLine(alloc: std.mem.Allocator, request: *Request, reader: anytype) !void {
request.method = try parseMethod(reader);
request.path = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) {
error.StreamTooLong => return error.RequestUriTooLong, error.StreamTooLong => return error.RequestUriTooLong,
else => return err, else => return err,
}; };
errdefer alloc.free(request.path); errdefer alloc.free(uri);
try checkProto(reader); const proto = try parseProto(reader);
// discard \r\n // discard \r\n
_ = try reader.readByte(); _ = try reader.readByte();
_ = try reader.readByte(); _ = try reader.readByte();
var headers = try parseHeaders(alloc, reader);
errdefer freeHeaders(alloc, &headers);
const body = if (method.requestHasBody())
try readBody(alloc, headers, reader)
else
null;
errdefer if (body) |b| alloc.free(b);
const eff_addr = if (headers.get("X-Real-IP")) |ip|
std.net.Address.parseIp(ip, address.getPort()) catch {
return error.BadRequest;
}
else
address;
return Request{
.protocol = proto,
.source_address = eff_addr,
.method = method,
.uri = uri,
.headers = headers,
.body = body,
};
} }
fn parseMethod(reader: anytype) !Method { fn parseMethod(reader: anytype) !Method {
@ -68,7 +79,7 @@ fn parseMethod(reader: anytype) !Method {
return error.MethodNotImplemented; return error.MethodNotImplemented;
} }
fn checkProto(reader: anytype) !void { fn parseProto(reader: anytype) !Request.Protocol {
var buf: [8]u8 = undefined; var buf: [8]u8 = undefined;
const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) {
error.StreamTooLong => return error.UnknownProtocol, error.StreamTooLong => return error.UnknownProtocol,
@ -84,14 +95,24 @@ fn checkProto(reader: anytype) !void {
return error.BadRequest; return error.BadRequest;
} }
if (buf[0] != '1' or buf[2] != '1') { if (buf[0] != '1') return error.HttpVersionNotSupported;
return error.HttpVersionNotSupported; return switch (buf[2]) {
} '0' => .http_1_0,
'1' => .http_1_1,
else => error.HttpVersionNotSupported,
};
} }
fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers { fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers {
var map = Headers.init(allocator); var map = Headers.init(allocator);
errdefer map.deinit(); errdefer map.deinit();
errdefer {
var iter = map.iterator();
while (iter.next()) |it| {
allocator.free(it.key_ptr.*);
allocator.free(it.value_ptr.*);
}
}
// todo: // todo:
//errdefer { //errdefer {
//var iter = map.iterator(); //var iter = map.iterator();
@ -167,6 +188,21 @@ fn parseEncoding(encoding: ?[]const u8) !Encoding {
return error.UnsupportedMediaType; return error.UnsupportedMediaType;
} }
pub fn parseFree(allocator: std.mem.Allocator, request: *Request) void {
allocator.free(request.uri);
freeHeaders(allocator, &request.headers);
if (request.body) |body| allocator.free(body);
}
fn freeHeaders(allocator: std.mem.Allocator, headers: *http.Headers) void {
var iter = headers.iterator();
while (iter.next()) |it| {
allocator.free(it.key_ptr.*);
allocator.free(it.value_ptr.*);
}
headers.deinit();
}
const _test = struct { const _test = struct {
const expectEqual = std.testing.expectEqual; const expectEqual = std.testing.expectEqual;
const expectEqualStrings = std.testing.expectEqualStrings; const expectEqualStrings = std.testing.expectEqualStrings;

View file

@ -2,67 +2,69 @@ const std = @import("std");
const util = @import("util"); const util = @import("util");
const http = @import("./lib.zig"); const http = @import("./lib.zig");
const connection = @import("./server/connection.zig");
const response = @import("./server/response.zig"); const response = @import("./server/response.zig");
pub const Connection = connection.Connection; pub const Response = struct {
pub const Response = response.ResponseStream(Connection.Writer); alloc: std.mem.Allocator,
stream: std.net.Stream,
should_close: bool = false,
pub const Stream = response.ResponseStream(std.net.Stream.Writer);
pub fn open(self: *Response, status: http.Status, headers: *const http.Headers) !Stream {
if (headers.get("Connection")) |hdr| {
if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true;
}
return response.open(self.alloc, self.stream.writer(), headers, status);
}
};
const ConnectionServer = connection.Server;
const Request = http.Request; const Request = http.Request;
const request_buf_size = 1 << 16; const request_buf_size = 1 << 16;
pub const Context = struct { pub fn Handler(comptime Ctx: type) type {
alloc: std.mem.Allocator, return fn (Ctx, Request, *Response) void;
request: Request, }
connection: Connection,
pub fn openResponse(self: *Context, headers: *const http.Headers, status: http.Status) !Response { pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: anytype, alloc: std.mem.Allocator) !void {
return try response.open(self.alloc, self.connection.stream.writer(), headers, status); // TODO: Timeouts
} while (true) {
std.log.debug("waiting for request", .{});
var arena = std.heap.ArenaAllocator.init(alloc);
defer arena.deinit();
pub fn close(self: *Context) void { const req = Request.parse(arena.allocator(), conn.stream.reader(), conn.address) catch |err| {
// todo: deallocate request return handleError(conn.stream.writer(), err) catch {};
self.connection.close();
}
};
pub const Server = struct {
conn_server: ConnectionServer,
pub fn listen(addr: std.net.Address) !Server {
return Server{
.conn_server = try ConnectionServer.listen(addr),
}; };
std.log.debug("done parsing", .{});
var res = Response{
.alloc = arena.allocator(),
.stream = conn.stream,
};
handler(ctx, req, &res);
std.log.debug("done handling", .{});
if (req.headers.get("Connection")) |hdr| {
if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| return;
} else if (req.headers.get("Keep-Alive")) |hdr| {
std.log.debug("keep-alive: {s}", .{hdr});
} else if (req.protocol == .http_1_0) return;
if (res.should_close) return;
} }
}
pub fn accept(self: *Server, alloc: std.mem.Allocator) !Context { /// Writes an error response message and requests closure of the connection
while (true) {
const conn = try self.conn_server.accept();
errdefer conn.close();
const req = http.Request.parse(alloc, conn.stream.reader()) catch |err| {
handleError(conn.stream.writer(), err) catch unreachable;
continue;
};
return Context{ .connection = conn, .request = req, .alloc = alloc };
}
}
pub fn shutdown(self: *Server) void {
self.conn_server.shutdown();
}
};
// TODO: We should get more specific about what type of errors can happen
fn handleError(writer: anytype, err: anyerror) !void { fn handleError(writer: anytype, err: anyerror) !void {
const status: http.Status = switch (err) { const status: http.Status = switch (err) {
error.EndOfStream => return, // Do nothing, the client closed the connection
error.BadRequest => .bad_request, error.BadRequest => .bad_request,
error.UnsupportedMediaType => .unsupported_media_type, error.UnsupportedMediaType => .unsupported_media_type,
error.HttpVersionNotSupported => .http_version_not_supported, error.HttpVersionNotSupported => .http_version_not_supported,
else => return err, else => .internal_server_error,
}; };
try writer.print("HTTP/1.1 {} {?s}\r\n\r\n", .{ @enumToInt(status), status.phrase() }); try writer.print("HTTP/1.1 {} {?s}\r\nConnection: close\r\n\r\n", .{ @enumToInt(status), status.phrase() });
} }

View file

@ -1,51 +0,0 @@
const std = @import("std");
pub const Connection = struct {
pub const Id = u64;
pub const Writer = std.net.Stream.Writer;
pub const Reader = std.net.Stream.Reader;
id: Id,
address: std.net.Address,
stream: std.net.Stream,
fn new(id: Id, std_conn: std.net.StreamServer.Connection) Connection {
std.log.debug("new connection conn_id={}", .{id});
return .{
.id = id,
.address = std_conn.address,
.stream = std_conn.stream,
};
}
pub fn close(self: Connection) void {
std.log.debug("terminating connection conn_id={}", .{self.id});
self.stream.close();
}
};
pub const Server = struct {
next_conn_id: std.atomic.Atomic(Connection.Id) = std.atomic.Atomic(Connection.Id).init(1),
stream_server: std.net.StreamServer,
pub fn listen(addr: std.net.Address) !Server {
var self = Server{
.stream_server = std.net.StreamServer.init(.{ .reuse_address = true }),
};
errdefer self.stream_server.deinit();
try self.stream_server.listen(addr);
return self;
}
pub fn accept(self: *Server) !Connection {
const conn = try self.stream_server.accept();
const id = self.next_conn_id.fetchAdd(1, .SeqCst);
return Connection.new(id, conn);
}
pub fn shutdown(self: *Server) void {
self.stream_server.deinit();
}
};

View file

@ -111,9 +111,7 @@ pub fn ResponseStream(comptime BaseWriter: type) type {
} }
fn flushBodyUnchunked(self: *Self) Error!void { fn flushBodyUnchunked(self: *Self) Error!void {
if (self.buffer_pos != 0) { try self.base_writer.print("Content-Length: {}\r\n", .{self.buffer_pos});
try self.base_writer.print("Content-Length: {}\r\n", .{self.buffer_pos});
}
try self.base_writer.writeAll("\r\n"); try self.base_writer.writeAll("\r\n");
@ -128,6 +126,7 @@ pub fn ResponseStream(comptime BaseWriter: type) type {
} }
pub fn finish(self: *Self) Error!void { pub fn finish(self: *Self) Error!void {
std.log.debug("finishing", .{});
if (!self.chunked) { if (!self.chunked) {
try self.flushBodyUnchunked(); try self.flushBodyUnchunked();
} else { } else {

View file

@ -13,14 +13,13 @@ pub const invites = @import("./controllers/invites.zig");
pub const users = @import("./controllers/users.zig"); pub const users = @import("./controllers/users.zig");
pub const notes = @import("./controllers/notes.zig"); pub const notes = @import("./controllers/notes.zig");
pub fn routeRequest(api_source: anytype, ctx: http.server.Context, alloc: std.mem.Allocator) void { pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void {
// TODO: hashmaps? // TODO: hashmaps?
inline for (routes) |route| { var response = Response{ .headers = http.Headers.init(alloc), .res = res };
if (Context(route).matchAndHandle(api_source, ctx, alloc)) return;
}
var response = Response{ .headers = http.Headers.init(alloc), .ctx = ctx };
defer response.headers.deinit(); defer response.headers.deinit();
inline for (routes) |route| {
if (Context(route).matchAndHandle(api_source, req, &response, alloc)) return;
}
response.status(.not_found) catch {}; response.status(.not_found) catch {};
} }
@ -53,7 +52,7 @@ pub fn Context(comptime Route: type) type {
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
method: http.Method, method: http.Method,
request_line: []const u8, uri: []const u8,
headers: http.Headers, headers: http.Headers,
args: Args, args: Args,
@ -84,20 +83,16 @@ pub fn Context(comptime Route: type) type {
@compileError("Unsupported Type " ++ @typeName(T)); @compileError("Unsupported Type " ++ @typeName(T));
} }
pub fn matchAndHandle(api_source: *api.ApiSource, ctx: http.server.Context, alloc: std.mem.Allocator) bool { pub fn matchAndHandle(api_source: *api.ApiSource, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool {
const req = ctx.request;
if (req.method != Route.method) return false; if (req.method != Route.method) return false;
var path = std.mem.sliceTo(std.mem.sliceTo(req.path, '#'), '?'); var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?');
var args: Args = parseArgs(path) orelse return false; var args: Args = parseArgs(path) orelse return false;
var response = Response{ .headers = http.Headers.init(alloc), .ctx = ctx };
defer response.headers.deinit();
var self = Self{ var self = Self{
.allocator = alloc, .allocator = alloc,
.method = req.method, .method = req.method,
.request_line = req.path, .uri = req.uri,
.headers = req.headers, .headers = req.headers,
.args = args, .args = args,
@ -105,7 +100,7 @@ pub fn Context(comptime Route: type) type {
.query = undefined, .query = undefined,
}; };
self.prepareAndHandle(api_source, req, &response); self.prepareAndHandle(api_source, req, res);
return true; return true;
} }
@ -149,15 +144,20 @@ pub fn Context(comptime Route: type) type {
fn parseQuery(self: *Self) !void { fn parseQuery(self: *Self) !void {
if (Query != void) { if (Query != void) {
const path = std.mem.sliceTo(self.request_line, '?'); const path = std.mem.sliceTo(self.uri, '?');
const q = std.mem.sliceTo(self.request_line[path.len..], '#'); const q = std.mem.sliceTo(self.uri[path.len..], '#');
self.query = try query_utils.parseQuery(Query, q); self.query = try query_utils.parseQuery(Query, q);
} }
} }
fn handle(self: Self, response: *Response, api_conn: anytype) void { fn handle(self: Self, response: *Response, api_conn: anytype) void {
Route.handler(self, response, api_conn) catch |err| std.log.err("{}", .{err}); Route.handler(self, response, api_conn) catch |err| switch (err) {
else => {
std.log.err("{}", .{err});
response.err(.internal_server_error, "", {}) catch {};
},
};
} }
fn getApiConn(self: *Self, api_source: anytype) !api.ApiSource.Conn { fn getApiConn(self: *Self, api_source: anytype) !api.ApiSource.Conn {
@ -180,18 +180,25 @@ pub fn Context(comptime Route: type) type {
pub const Response = struct { pub const Response = struct {
const Self = @This(); const Self = @This();
headers: http.Headers, headers: http.Headers,
ctx: http.server.Context, res: *http.Response,
opened: bool = false,
pub fn status(self: *Self, status_code: http.Status) !void { pub fn status(self: *Self, status_code: http.Status) !void {
var stream = try self.ctx.openResponse(&self.headers, status_code); std.debug.assert(!self.opened);
self.opened = true;
var stream = try self.res.open(status_code, &self.headers);
defer stream.close(); defer stream.close();
try stream.finish(); try stream.finish();
} }
pub fn json(self: *Self, status_code: http.Status, response_body: anytype) !void { pub fn json(self: *Self, status_code: http.Status, response_body: anytype) !void {
std.debug.assert(!self.opened);
self.opened = true;
try self.headers.put("Content-Type", "application/json"); try self.headers.put("Content-Type", "application/json");
var stream = try self.ctx.openResponse(&self.headers, status_code); var stream = try self.res.open(status_code, &self.headers);
defer stream.close(); defer stream.close();
const writer = stream.writer(); const writer = stream.writer();

View file

@ -8,36 +8,8 @@ const api = @import("api");
pub const migrations = @import("./migrations.zig"); // TODO pub const migrations = @import("./migrations.zig"); // TODO
const c = @import("./controllers.zig"); const c = @import("./controllers.zig");
pub const RequestServer = struct {
alloc: std.mem.Allocator,
api: *api.ApiSource,
config: Config,
fn init(alloc: std.mem.Allocator, src: *api.ApiSource, config: Config) !RequestServer {
return RequestServer{
.alloc = alloc,
.api = src,
.config = config,
};
}
fn listenAndRun(self: *RequestServer, addr: std.net.Address) !void {
var srv = http.Server.listen(addr) catch unreachable;
defer srv.shutdown();
while (true) {
var arena = std.heap.ArenaAllocator.init(self.alloc);
defer arena.deinit();
var ctx = try srv.accept(arena.allocator());
defer ctx.close();
c.routeRequest(self.api, ctx, arena.allocator());
}
}
};
pub const Config = struct { pub const Config = struct {
worker_threads: usize = 10,
db: sql.Config, db: sql.Config,
}; };
@ -64,7 +36,10 @@ fn runAdminSetup(db: sql.Db, alloc: std.mem.Allocator) !void {
try api.setupAdmin(db, origin, username, password, alloc); try api.setupAdmin(db, origin, username, password, alloc);
} }
fn prepareDb(db: sql.Db, alloc: std.mem.Allocator) !void { fn prepareDb(pool: *sql.ConnPool, alloc: std.mem.Allocator) !void {
const db = try pool.acquire();
defer db.releaseConnection();
try migrations.up(db); try migrations.up(db);
if (!try api.isAdminSetup(db)) { if (!try api.isAdminSetup(db)) {
@ -88,15 +63,46 @@ fn prepareDb(db: sql.Db, alloc: std.mem.Allocator) !void {
} }
} }
const ConnectionId = u64;
var next_conn_id = std.atomic.Atomic(ConnectionId).init(0);
fn thread_main(src: *api.ApiSource, srv: *std.net.StreamServer) void {
util.seedThreadPrng() catch unreachable;
const thread_id = std.Thread.getCurrentId();
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
while (true) {
var conn = srv.accept() catch |err| {
std.log.err("Error accepting connection: {}", .{err});
continue;
};
defer conn.stream.close();
const conn_id = next_conn_id.fetchAdd(1, .SeqCst);
std.log.debug("Accepting TCP connection id {} on thread {}", .{ conn_id, thread_id });
defer std.log.debug("Closing TCP connection id {}", .{conn_id});
http.serveConn(conn, .{ .src = src, .conn_id = conn_id, .allocator = gpa.allocator() }, handle, gpa.allocator()) catch |err| {
std.log.err("Error occured on connection {}: {}", .{ conn_id, err });
};
}
}
fn handle(ctx: anytype, req: http.Request, res: *http.Response) void {
c.routeRequest(ctx.src, req, res, ctx.allocator);
}
pub fn main() !void { pub fn main() !void {
try util.seedThreadPrng(); try util.seedThreadPrng();
var gpa = std.heap.GeneralPurposeAllocator(.{}){}; var gpa = std.heap.GeneralPurposeAllocator(.{}){};
var cfg = try loadConfig(gpa.allocator()); var cfg = try loadConfig(gpa.allocator());
var db_conn = try sql.Conn.open(cfg.db); var pool = try sql.ConnPool.init(cfg.db);
try prepareDb(try db_conn.acquire(), gpa.allocator()); try prepareDb(&pool, gpa.allocator());
var api_src = try api.ApiSource.init(&db_conn); var api_src = try api.ApiSource.init(&pool);
var srv = try RequestServer.init(gpa.allocator(), &api_src, cfg); var srv = std.net.StreamServer.init(.{ .reuse_address = true });
return srv.listenAndRun(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); defer srv.deinit();
try srv.listen(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable);
thread_main(&api_src, &srv);
} }

View file

@ -16,10 +16,15 @@ fn execStmt(tx: anytype, stmt: []const u8, alloc: std.mem.Allocator) !void {
} }
fn execScript(db: anytype, script: []const u8, alloc: std.mem.Allocator) !void { fn execScript(db: anytype, script: []const u8, alloc: std.mem.Allocator) !void {
const tx = try db.beginOrSavepoint();
errdefer tx.rollback();
var iter = util.SqlStmtIter.from(script); var iter = util.SqlStmtIter.from(script);
while (iter.next()) |stmt| { while (iter.next()) |stmt| {
try execStmt(db, stmt, alloc); try execStmt(tx, stmt, alloc);
} }
try tx.commitOrRelease();
} }
fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !bool { fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !bool {

View file

@ -28,6 +28,10 @@ pub const Db = struct {
unreachable; unreachable;
} }
pub fn openUri(_: anytype) common.OpenError!Db {
unreachable;
}
pub fn close(_: Db) void { pub fn close(_: Db) void {
unreachable; unreachable;
} }

View file

@ -49,7 +49,15 @@ pub const Db = struct {
db: *c.sqlite3, db: *c.sqlite3,
pub fn open(path: [:0]const u8) common.OpenError!Db { pub fn open(path: [:0]const u8) common.OpenError!Db {
const flags = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_EXRESCODE; return openInternal(path, false);
}
pub fn openUri(path: [:0]const u8) common.OpenError!Db {
return openInternal(path, true);
}
fn openInternal(path: [:0]const u8, is_uri: bool) common.OpenError!Db {
const flags = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_EXRESCODE | if (is_uri) c.SQLITE_OPEN_URI else 0;
var db: ?*c.sqlite3 = null; var db: ?*c.sqlite3 = null;
switch (c.sqlite3_open_v2(@ptrCast([*c]const u8, path), &db, flags, null)) { switch (c.sqlite3_open_v2(@ptrCast([*c]const u8, path), &db, flags, null)) {
@ -121,7 +129,6 @@ pub const Db = struct {
// of 0, and we must not bind the argument. // of 0, and we must not bind the argument.
const name = std.fmt.comptimePrint("${}", .{i + 1}); const name = std.fmt.comptimePrint("${}", .{i + 1});
const db_idx = c.sqlite3_bind_parameter_index(stmt.?, name); const db_idx = c.sqlite3_bind_parameter_index(stmt.?, name);
std.log.debug("param {s} got index {}", .{ name, db_idx });
if (db_idx != 0) if (db_idx != 0)
try self.bindArgument(stmt.?, @intCast(u15, db_idx), arg) try self.bindArgument(stmt.?, @intCast(u15, db_idx), arg)
else if (!opts.ignore_unused_arguments) else if (!opts.ignore_unused_arguments)
@ -167,7 +174,6 @@ pub const Db = struct {
else else
@compileError("SQLite: Could not serialize " ++ @typeName(T) ++ " into staticly sized string"); @compileError("SQLite: Could not serialize " ++ @typeName(T) ++ " into staticly sized string");
std.log.debug("binding type {any}: {s}", .{ T, arr });
const len = std.mem.len(&arr); const len = std.mem.len(&arr);
return self.bindString(stmt, idx, arr[0..len]); return self.bindString(stmt, idx, arr[0..len]);
}, },
@ -194,8 +200,6 @@ pub const Db = struct {
return error.BindException; return error.BindException;
}; };
std.log.debug("binding string {s} to idx {}", .{ str, idx });
switch (c.sqlite3_bind_text(stmt, idx, str.ptr, len, c.SQLITE_TRANSIENT)) { switch (c.sqlite3_bind_text(stmt, idx, str.ptr, len, c.SQLITE_TRANSIENT)) {
c.SQLITE_OK => {}, c.SQLITE_OK => {},
else => |result| { else => |result| {

View file

@ -10,6 +10,7 @@ const Allocator = std.mem.Allocator;
const errors = @import("./errors.zig").library_errors; const errors = @import("./errors.zig").library_errors;
pub const AcquireError = OpenError || error{NoConnectionsLeft};
pub const OpenError = errors.OpenError; pub const OpenError = errors.OpenError;
pub const QueryError = errors.QueryError; pub const QueryError = errors.QueryError;
pub const RowError = errors.RowError; pub const RowError = errors.RowError;
@ -24,12 +25,14 @@ pub const Engine = enum {
sqlite, sqlite,
}; };
// TODO: make this suck less
pub const Config = union(Engine) { pub const Config = union(Engine) {
postgres: struct { postgres: struct {
pg_conn_str: [:0]const u8, pg_conn_str: [:0]const u8,
}, },
sqlite: struct { sqlite: struct {
sqlite_file_path: [:0]const u8, sqlite_file_path: [:0]const u8,
sqlite_is_uri: bool = false,
}, },
}; };
@ -160,16 +163,58 @@ pub const ConstraintMode = enum {
immediate, immediate,
}; };
pub const Conn = struct { pub const ConnPool = struct {
engine: union(Engine) { const max_conns = 4;
postgres: postgres.Db, const Conn = struct {
sqlite: sqlite.Db, engine: union(Engine) {
}, postgres: postgres.Db,
current_tx_level: u8 = 0, sqlite: sqlite.Db,
is_tx_failed: bool = false, },
in_use: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false),
current_tx_level: u8 = 0,
};
pub fn open(cfg: Config) OpenError!Conn { config: Config,
return switch (cfg) { connections: [max_conns]Conn,
pub fn init(cfg: Config) OpenError!ConnPool {
var self = ConnPool{
.config = cfg,
.connections = undefined,
};
var count: usize = 0;
errdefer for (self.connections[0..count]) |*c| closeConn(c);
for (self.connections) |*c| {
c.* = try self.createConn();
count += 1;
}
return self;
}
pub fn deinit(self: *ConnPool) void {
for (self.connections) |*c| closeConn(c);
}
pub fn acquire(self: *ConnPool) AcquireError!Db {
for (self.connections) |*c| {
if (tryAcquire(c)) return Db{ .conn = c };
}
return error.NoConnectionsLeft;
}
fn tryAcquire(conn: *Conn) bool {
const acquired = !conn.in_use.swap(true, .AcqRel);
if (acquired) {
if (conn.current_tx_level != 0) @panic("Transaction still open on unused db connection");
return true;
}
return false;
}
fn createConn(self: *ConnPool) OpenError!Conn {
return switch (self.config) {
.postgres => |postgres_cfg| Conn{ .postgres => |postgres_cfg| Conn{
.engine = .{ .engine = .{
.postgres = try postgres.Db.open(postgres_cfg.pg_conn_str), .postgres = try postgres.Db.open(postgres_cfg.pg_conn_str),
@ -177,27 +222,22 @@ pub const Conn = struct {
}, },
.sqlite => |lite_cfg| Conn{ .sqlite => |lite_cfg| Conn{
.engine = .{ .engine = .{
.sqlite = try sqlite.Db.open(lite_cfg.sqlite_file_path), .sqlite = if (lite_cfg.sqlite_is_uri)
try sqlite.Db.openUri(lite_cfg.sqlite_file_path)
else
try sqlite.Db.open(lite_cfg.sqlite_file_path),
}, },
}, },
}; };
} }
pub fn close(self: *Conn) void { fn closeConn(conn: *Conn) void {
switch (self.engine) { if (conn.in_use.loadUnchecked()) @panic("DB Conn still open");
switch (conn.engine) {
.postgres => |pg| pg.close(), .postgres => |pg| pg.close(),
.sqlite => |lite| lite.close(), .sqlite => |lite| lite.close(),
} }
} }
pub fn acquire(self: *Conn) !Db {
if (self.current_tx_level != 0) return error.BadTransactionState;
return Db{ .conn = self };
}
pub fn sqlEngine(self: *Conn) Engine {
return self.engine;
}
}; };
pub const Db = Tx(0); pub const Db = Tx(0);
@ -216,11 +256,22 @@ fn Tx(comptime tx_level: u8) type {
std.fmt.comptimePrint("save_{}", .{tx_level}); std.fmt.comptimePrint("save_{}", .{tx_level});
const next_savepoint_name = Tx(tx_level + 1).savepoint_name; const next_savepoint_name = Tx(tx_level + 1).savepoint_name;
conn: *Conn, conn: *ConnPool.Conn,
/// The type of SQL engine being used. Use of this function should be discouraged /// The type of SQL engine being used. Use of this function should be discouraged
pub fn sqlEngine(self: Self) Engine { pub fn sqlEngine(self: Self) Engine {
return self.conn.sqlEngine(); return self.conn.engine;
}
/// Return the connection to the pool
pub fn releaseConnection(self: Self) void {
if (tx_level != 0) @compileError("close must be called on root db");
if (self.conn.current_tx_level != 0) {
std.log.warn("Database released while transaction in progress!", .{});
self.rollbackUnchecked() catch {};
}
if (!self.conn.in_use.swap(false, .AcqRel)) @panic("Double close on db conection");
} }
// ********* Transaction management functions ********** // ********* Transaction management functions **********
@ -277,7 +328,7 @@ fn Tx(comptime tx_level: u8) type {
if (tx_level >= 2) @compileError("Cannot rollback a transaction using a savepoint"); if (tx_level >= 2) @compileError("Cannot rollback a transaction using a savepoint");
if (self.conn.current_tx_level == 0) return error.BadTransactionState; if (self.conn.current_tx_level == 0) return error.BadTransactionState;
try self.exec("ROLLBACK", {}, null); try self.rollbackUnchecked();
self.conn.current_tx_level = 0; self.conn.current_tx_level = 0;
} }
@ -402,7 +453,7 @@ fn Tx(comptime tx_level: u8) type {
comptime var table_spec: []const u8 = table ++ "("; comptime var table_spec: []const u8 = table ++ "(";
comptime var value_spec: []const u8 = "("; comptime var value_spec: []const u8 = "(";
inline for (fields) |field, i| { inline for (fields) |field, i| {
// This causes a compile error. Why? // This causes a compiler crash. Why?
//const F = field.field_type; //const F = field.field_type;
const F = @TypeOf(@field(std.mem.zeroes(ValueType), field.name)); const F = @TypeOf(@field(std.mem.zeroes(ValueType), field.name));
// causes issues if F is @TypeOf(null), use dummy type // causes issues if F is @TypeOf(null), use dummy type
@ -453,5 +504,9 @@ fn Tx(comptime tx_level: u8) type {
while (try results.row()) |_| {} while (try results.row()) |_| {}
} }
fn rollbackUnchecked(self: Self) !void {
try self.exec("ROLLBACK", {}, null);
}
}; };
} }

View file

@ -5,11 +5,10 @@ const sql = @import("sql");
const util = @import("util"); const util = @import("util");
const test_config = .{ const test_config = .{
.db = .{ .db = .{ .sqlite = .{
.sqlite = .{ .sqlite_file_path = "file::memory:?cache=shared",
.sqlite_file_path = ":memory:", .sqlite_is_uri = true,
}, } },
},
}; };
const ApiSource = api.ApiSource; const ApiSource = api.ApiSource;
@ -18,12 +17,16 @@ const root_password = "password1234";
const admin_host = "example.com"; const admin_host = "example.com";
const admin_origin = "https://" ++ admin_host; const admin_origin = "https://" ++ admin_host;
fn makeDb(alloc: std.mem.Allocator) !sql.Conn { fn makeDb(alloc: std.mem.Allocator) !sql.ConnPool {
try util.seedThreadPrng(); try util.seedThreadPrng();
var db = try sql.Conn.open(test_config.db); var pool = try sql.ConnPool.init(test_config.db);
try migrations.up(try db.acquire()); {
try api.setupAdmin(try db.acquire(), admin_origin, root_user, root_password, alloc); var db = try pool.acquire();
return db; defer db.releaseConnection();
try migrations.up(db);
try api.setupAdmin(db, admin_origin, root_user, root_password, alloc);
}
return pool;
} }
fn connectAndLogin( fn connectAndLogin(
@ -42,6 +45,7 @@ fn connectAndLogin(
test "login as root" { test "login as root" {
const alloc = std.testing.allocator; const alloc = std.testing.allocator;
var db = try makeDb(alloc); var db = try makeDb(alloc);
defer db.deinit();
var src = try ApiSource.init(&db); var src = try ApiSource.init(&db);
const login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc); const login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc);
@ -59,6 +63,7 @@ test "login as root" {
test "create community" { test "create community" {
const alloc = std.testing.allocator; const alloc = std.testing.allocator;
var db = try makeDb(alloc); var db = try makeDb(alloc);
defer db.deinit();
var src = try ApiSource.init(&db); var src = try ApiSource.init(&db);
const login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc); const login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc);
@ -80,6 +85,7 @@ test "create community" {
test "create community and transfer to new owner" { test "create community and transfer to new owner" {
const alloc = std.testing.allocator; const alloc = std.testing.allocator;
var db = try makeDb(alloc); var db = try makeDb(alloc);
defer db.deinit();
var src = try ApiSource.init(&db); var src = try ApiSource.init(&db);
const root_login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc); const root_login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc);