diff --git a/src/api/lib.zig b/src/api/lib.zig index e6cf82a..487a4a9 100644 --- a/src/api/lib.zig +++ b/src/api/lib.zig @@ -94,15 +94,15 @@ pub fn setupAdmin(db: sql.Db, origin: []const u8, username: []const u8, password } pub const ApiSource = struct { - db_conn_pool: *sql.ConnPool, + db_conn: *sql.Conn, pub const Conn = ApiConn(sql.Db); const root_username = "root"; - pub fn init(pool: *sql.ConnPool) !ApiSource { + pub fn init(db_conn: *sql.Conn) !ApiSource { return ApiSource{ - .db_conn_pool = pool, + .db_conn = db_conn, }; } @@ -110,7 +110,7 @@ pub const ApiSource = struct { var arena = std.heap.ArenaAllocator.init(alloc); errdefer arena.deinit(); - const db = try self.db_conn_pool.acquire(); + const db = try self.db_conn.acquire(); const community = try services.communities.getByHost(db, host, arena.allocator()); return Conn{ @@ -125,7 +125,7 @@ pub const ApiSource = struct { var arena = std.heap.ArenaAllocator.init(alloc); errdefer arena.deinit(); - const db = try self.db_conn_pool.acquire(); + const db = try self.db_conn.acquire(); const community = try services.communities.getByHost(db, host, arena.allocator()); const token_info = try services.auth.verifyToken( @@ -157,7 +157,6 @@ fn ApiConn(comptime DbConn: type) type { pub fn close(self: *Self) void { self.arena.deinit(); - self.db.releaseConnection(); } fn isAdmin(self: *Self) bool { diff --git a/src/http/lib.zig b/src/http/lib.zig index 3d12cdc..9a767f0 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -3,15 +3,13 @@ const ciutf8 = @import("util").ciutf8; const request = @import("./request.zig"); -const server = @import("./server.zig"); +pub const server = @import("./server.zig"); pub const Method = std.http.Method; pub const Status = std.http.Status; pub const Request = request.Request; -pub const serveConn = server.serveConn; -pub const Response = server.Response; -pub const Handler = server.Handler; +pub const Server = server.Server; pub const Headers = std.HashMap([]const u8, []const u8, struct { pub fn eql(_: @This(), a: []const u8, b: []const u8) bool { diff --git a/src/http/request.zig b/src/http/request.zig index e6fd79c..17def73 100644 --- a/src/http/request.zig +++ b/src/http/request.zig @@ -4,25 +4,13 @@ const http = @import("./lib.zig"); const parser = @import("./request/parser.zig"); pub const Request = struct { - pub const Protocol = enum { - http_1_0, - http_1_1, - }; - - protocol: Protocol, - source_address: ?std.net.Address, - method: http.Method, - uri: []const u8, + path: []const u8, headers: http.Headers, body: ?[]const u8 = null, - pub fn parse(alloc: std.mem.Allocator, reader: anytype, addr: std.net.Address) !Request { - return parser.parse(alloc, reader, addr); - } - - pub fn parseFree(self: Request, alloc: std.mem.Allocator) void { - parser.parseFree(alloc, self); + pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request { + return parser.parse(alloc, reader); } }; diff --git a/src/http/request/parser.zig b/src/http/request/parser.zig index 8a7c19f..9da174e 100644 --- a/src/http/request/parser.zig +++ b/src/http/request/parser.zig @@ -22,45 +22,34 @@ const Encoding = enum { chunked, }; -pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address) !Request { - const method = try parseMethod(reader); - const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) { +pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request { + var request: Request = undefined; + + try parseLine(alloc, &request, reader); + request.headers = try parseHeaders(alloc, reader); + + if (request.method.requestHasBody()) { + request.body = try readBody(alloc, request.headers, reader); + } else { + request.body = null; + } + + return request; +} + +fn parseLine(alloc: std.mem.Allocator, request: *Request, reader: anytype) !void { + request.method = try parseMethod(reader); + request.path = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) { error.StreamTooLong => return error.RequestUriTooLong, else => return err, }; - errdefer alloc.free(uri); + errdefer alloc.free(request.path); - const proto = try parseProto(reader); + try checkProto(reader); // discard \r\n _ = try reader.readByte(); _ = try reader.readByte(); - - var headers = try parseHeaders(alloc, reader); - errdefer freeHeaders(alloc, &headers); - - const body = if (method.requestHasBody()) - try readBody(alloc, headers, reader) - else - null; - errdefer if (body) |b| alloc.free(b); - - const eff_addr = if (headers.get("X-Real-IP")) |ip| - std.net.Address.parseIp(ip, address.getPort()) catch { - return error.BadRequest; - } - else - address; - - return Request{ - .protocol = proto, - .source_address = eff_addr, - - .method = method, - .uri = uri, - .headers = headers, - .body = body, - }; } fn parseMethod(reader: anytype) !Method { @@ -79,7 +68,7 @@ fn parseMethod(reader: anytype) !Method { return error.MethodNotImplemented; } -fn parseProto(reader: anytype) !Request.Protocol { +fn checkProto(reader: anytype) !void { var buf: [8]u8 = undefined; const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { error.StreamTooLong => return error.UnknownProtocol, @@ -95,24 +84,14 @@ fn parseProto(reader: anytype) !Request.Protocol { return error.BadRequest; } - if (buf[0] != '1') return error.HttpVersionNotSupported; - return switch (buf[2]) { - '0' => .http_1_0, - '1' => .http_1_1, - else => error.HttpVersionNotSupported, - }; + if (buf[0] != '1' or buf[2] != '1') { + return error.HttpVersionNotSupported; + } } fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers { var map = Headers.init(allocator); errdefer map.deinit(); - errdefer { - var iter = map.iterator(); - while (iter.next()) |it| { - allocator.free(it.key_ptr.*); - allocator.free(it.value_ptr.*); - } - } // todo: //errdefer { //var iter = map.iterator(); @@ -188,21 +167,6 @@ fn parseEncoding(encoding: ?[]const u8) !Encoding { return error.UnsupportedMediaType; } -pub fn parseFree(allocator: std.mem.Allocator, request: *Request) void { - allocator.free(request.uri); - freeHeaders(allocator, &request.headers); - if (request.body) |body| allocator.free(body); -} - -fn freeHeaders(allocator: std.mem.Allocator, headers: *http.Headers) void { - var iter = headers.iterator(); - while (iter.next()) |it| { - allocator.free(it.key_ptr.*); - allocator.free(it.value_ptr.*); - } - headers.deinit(); -} - const _test = struct { const expectEqual = std.testing.expectEqual; const expectEqualStrings = std.testing.expectEqualStrings; diff --git a/src/http/server.zig b/src/http/server.zig index e9bf81e..c9095ef 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -2,69 +2,67 @@ const std = @import("std"); const util = @import("util"); const http = @import("./lib.zig"); +const connection = @import("./server/connection.zig"); const response = @import("./server/response.zig"); -pub const Response = struct { - alloc: std.mem.Allocator, - stream: std.net.Stream, - should_close: bool = false, - pub const Stream = response.ResponseStream(std.net.Stream.Writer); - pub fn open(self: *Response, status: http.Status, headers: *const http.Headers) !Stream { - if (headers.get("Connection")) |hdr| { - if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true; - } - - return response.open(self.alloc, self.stream.writer(), headers, status); - } -}; +pub const Connection = connection.Connection; +pub const Response = response.ResponseStream(Connection.Writer); +const ConnectionServer = connection.Server; const Request = http.Request; const request_buf_size = 1 << 16; -pub fn Handler(comptime Ctx: type) type { - return fn (Ctx, Request, *Response) void; -} +pub const Context = struct { + alloc: std.mem.Allocator, + request: Request, + connection: Connection, -pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: anytype, alloc: std.mem.Allocator) !void { - // TODO: Timeouts - while (true) { - std.log.debug("waiting for request", .{}); - var arena = std.heap.ArenaAllocator.init(alloc); - defer arena.deinit(); - - const req = Request.parse(arena.allocator(), conn.stream.reader(), conn.address) catch |err| { - return handleError(conn.stream.writer(), err) catch {}; - }; - std.log.debug("done parsing", .{}); - - var res = Response{ - .alloc = arena.allocator(), - .stream = conn.stream, - }; - - handler(ctx, req, &res); - std.log.debug("done handling", .{}); - - if (req.headers.get("Connection")) |hdr| { - if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| return; - } else if (req.headers.get("Keep-Alive")) |hdr| { - std.log.debug("keep-alive: {s}", .{hdr}); - } else if (req.protocol == .http_1_0) return; - - if (res.should_close) return; + pub fn openResponse(self: *Context, headers: *const http.Headers, status: http.Status) !Response { + return try response.open(self.alloc, self.connection.stream.writer(), headers, status); } -} -/// Writes an error response message and requests closure of the connection + pub fn close(self: *Context) void { + // todo: deallocate request + self.connection.close(); + } +}; + +pub const Server = struct { + conn_server: ConnectionServer, + pub fn listen(addr: std.net.Address) !Server { + return Server{ + .conn_server = try ConnectionServer.listen(addr), + }; + } + + pub fn accept(self: *Server, alloc: std.mem.Allocator) !Context { + while (true) { + const conn = try self.conn_server.accept(); + errdefer conn.close(); + + const req = http.Request.parse(alloc, conn.stream.reader()) catch |err| { + handleError(conn.stream.writer(), err) catch unreachable; + continue; + }; + + return Context{ .connection = conn, .request = req, .alloc = alloc }; + } + } + + pub fn shutdown(self: *Server) void { + self.conn_server.shutdown(); + } +}; + +// TODO: We should get more specific about what type of errors can happen fn handleError(writer: anytype, err: anyerror) !void { const status: http.Status = switch (err) { - error.EndOfStream => return, // Do nothing, the client closed the connection error.BadRequest => .bad_request, error.UnsupportedMediaType => .unsupported_media_type, error.HttpVersionNotSupported => .http_version_not_supported, - else => .internal_server_error, + else => return err, }; - try writer.print("HTTP/1.1 {} {?s}\r\nConnection: close\r\n\r\n", .{ @enumToInt(status), status.phrase() }); + try writer.print("HTTP/1.1 {} {?s}\r\n\r\n", .{ @enumToInt(status), status.phrase() }); } diff --git a/src/http/server/connection.zig b/src/http/server/connection.zig new file mode 100644 index 0000000..1ae6d69 --- /dev/null +++ b/src/http/server/connection.zig @@ -0,0 +1,51 @@ +const std = @import("std"); + +pub const Connection = struct { + pub const Id = u64; + pub const Writer = std.net.Stream.Writer; + pub const Reader = std.net.Stream.Reader; + + id: Id, + address: std.net.Address, + stream: std.net.Stream, + + fn new(id: Id, std_conn: std.net.StreamServer.Connection) Connection { + std.log.debug("new connection conn_id={}", .{id}); + return .{ + .id = id, + .address = std_conn.address, + .stream = std_conn.stream, + }; + } + + pub fn close(self: Connection) void { + std.log.debug("terminating connection conn_id={}", .{self.id}); + self.stream.close(); + } +}; + +pub const Server = struct { + next_conn_id: std.atomic.Atomic(Connection.Id) = std.atomic.Atomic(Connection.Id).init(1), + stream_server: std.net.StreamServer, + + pub fn listen(addr: std.net.Address) !Server { + var self = Server{ + .stream_server = std.net.StreamServer.init(.{ .reuse_address = true }), + }; + errdefer self.stream_server.deinit(); + + try self.stream_server.listen(addr); + return self; + } + + pub fn accept(self: *Server) !Connection { + const conn = try self.stream_server.accept(); + const id = self.next_conn_id.fetchAdd(1, .SeqCst); + + return Connection.new(id, conn); + } + + pub fn shutdown(self: *Server) void { + self.stream_server.deinit(); + } +}; diff --git a/src/http/server/response.zig b/src/http/server/response.zig index a99aa8e..aa549f8 100644 --- a/src/http/server/response.zig +++ b/src/http/server/response.zig @@ -111,7 +111,9 @@ pub fn ResponseStream(comptime BaseWriter: type) type { } fn flushBodyUnchunked(self: *Self) Error!void { - try self.base_writer.print("Content-Length: {}\r\n", .{self.buffer_pos}); + if (self.buffer_pos != 0) { + try self.base_writer.print("Content-Length: {}\r\n", .{self.buffer_pos}); + } try self.base_writer.writeAll("\r\n"); @@ -126,7 +128,6 @@ pub fn ResponseStream(comptime BaseWriter: type) type { } pub fn finish(self: *Self) Error!void { - std.log.debug("finishing", .{}); if (!self.chunked) { try self.flushBodyUnchunked(); } else { diff --git a/src/main/controllers.zig b/src/main/controllers.zig index b0f7800..87fb11d 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -13,14 +13,15 @@ pub const invites = @import("./controllers/invites.zig"); pub const users = @import("./controllers/users.zig"); pub const notes = @import("./controllers/notes.zig"); -pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void { +pub fn routeRequest(api_source: anytype, ctx: http.server.Context, alloc: std.mem.Allocator) void { // TODO: hashmaps? - var response = Response{ .headers = http.Headers.init(alloc), .res = res }; - defer response.headers.deinit(); inline for (routes) |route| { - if (Context(route).matchAndHandle(api_source, req, &response, alloc)) return; + if (Context(route).matchAndHandle(api_source, ctx, alloc)) return; } + var response = Response{ .headers = http.Headers.init(alloc), .ctx = ctx }; + defer response.headers.deinit(); + response.status(.not_found) catch {}; } @@ -52,7 +53,7 @@ pub fn Context(comptime Route: type) type { allocator: std.mem.Allocator, method: http.Method, - uri: []const u8, + request_line: []const u8, headers: http.Headers, args: Args, @@ -83,16 +84,20 @@ pub fn Context(comptime Route: type) type { @compileError("Unsupported Type " ++ @typeName(T)); } - pub fn matchAndHandle(api_source: *api.ApiSource, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool { + pub fn matchAndHandle(api_source: *api.ApiSource, ctx: http.server.Context, alloc: std.mem.Allocator) bool { + const req = ctx.request; if (req.method != Route.method) return false; - var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?'); + var path = std.mem.sliceTo(std.mem.sliceTo(req.path, '#'), '?'); var args: Args = parseArgs(path) orelse return false; + var response = Response{ .headers = http.Headers.init(alloc), .ctx = ctx }; + defer response.headers.deinit(); + var self = Self{ .allocator = alloc, .method = req.method, - .uri = req.uri, + .request_line = req.path, .headers = req.headers, .args = args, @@ -100,7 +105,7 @@ pub fn Context(comptime Route: type) type { .query = undefined, }; - self.prepareAndHandle(api_source, req, res); + self.prepareAndHandle(api_source, req, &response); return true; } @@ -144,20 +149,15 @@ pub fn Context(comptime Route: type) type { fn parseQuery(self: *Self) !void { if (Query != void) { - const path = std.mem.sliceTo(self.uri, '?'); - const q = std.mem.sliceTo(self.uri[path.len..], '#'); + const path = std.mem.sliceTo(self.request_line, '?'); + const q = std.mem.sliceTo(self.request_line[path.len..], '#'); self.query = try query_utils.parseQuery(Query, q); } } fn handle(self: Self, response: *Response, api_conn: anytype) void { - Route.handler(self, response, api_conn) catch |err| switch (err) { - else => { - std.log.err("{}", .{err}); - response.err(.internal_server_error, "", {}) catch {}; - }, - }; + Route.handler(self, response, api_conn) catch |err| std.log.err("{}", .{err}); } fn getApiConn(self: *Self, api_source: anytype) !api.ApiSource.Conn { @@ -180,25 +180,18 @@ pub fn Context(comptime Route: type) type { pub const Response = struct { const Self = @This(); headers: http.Headers, - res: *http.Response, - opened: bool = false, + ctx: http.server.Context, pub fn status(self: *Self, status_code: http.Status) !void { - std.debug.assert(!self.opened); - self.opened = true; - - var stream = try self.res.open(status_code, &self.headers); + var stream = try self.ctx.openResponse(&self.headers, status_code); defer stream.close(); try stream.finish(); } pub fn json(self: *Self, status_code: http.Status, response_body: anytype) !void { - std.debug.assert(!self.opened); - self.opened = true; - try self.headers.put("Content-Type", "application/json"); - var stream = try self.res.open(status_code, &self.headers); + var stream = try self.ctx.openResponse(&self.headers, status_code); defer stream.close(); const writer = stream.writer(); diff --git a/src/main/main.zig b/src/main/main.zig index a82713b..f722294 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -8,8 +8,36 @@ const api = @import("api"); pub const migrations = @import("./migrations.zig"); // TODO const c = @import("./controllers.zig"); +pub const RequestServer = struct { + alloc: std.mem.Allocator, + api: *api.ApiSource, + config: Config, + + fn init(alloc: std.mem.Allocator, src: *api.ApiSource, config: Config) !RequestServer { + return RequestServer{ + .alloc = alloc, + .api = src, + .config = config, + }; + } + + fn listenAndRun(self: *RequestServer, addr: std.net.Address) !void { + var srv = http.Server.listen(addr) catch unreachable; + defer srv.shutdown(); + + while (true) { + var arena = std.heap.ArenaAllocator.init(self.alloc); + defer arena.deinit(); + + var ctx = try srv.accept(arena.allocator()); + defer ctx.close(); + + c.routeRequest(self.api, ctx, arena.allocator()); + } + } +}; + pub const Config = struct { - worker_threads: usize = 10, db: sql.Config, }; @@ -36,10 +64,7 @@ fn runAdminSetup(db: sql.Db, alloc: std.mem.Allocator) !void { try api.setupAdmin(db, origin, username, password, alloc); } -fn prepareDb(pool: *sql.ConnPool, alloc: std.mem.Allocator) !void { - const db = try pool.acquire(); - defer db.releaseConnection(); - +fn prepareDb(db: sql.Db, alloc: std.mem.Allocator) !void { try migrations.up(db); if (!try api.isAdminSetup(db)) { @@ -63,46 +88,15 @@ fn prepareDb(pool: *sql.ConnPool, alloc: std.mem.Allocator) !void { } } -const ConnectionId = u64; -var next_conn_id = std.atomic.Atomic(ConnectionId).init(0); - -fn thread_main(src: *api.ApiSource, srv: *std.net.StreamServer) void { - util.seedThreadPrng() catch unreachable; - const thread_id = std.Thread.getCurrentId(); - var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - - while (true) { - var conn = srv.accept() catch |err| { - std.log.err("Error accepting connection: {}", .{err}); - continue; - }; - defer conn.stream.close(); - const conn_id = next_conn_id.fetchAdd(1, .SeqCst); - std.log.debug("Accepting TCP connection id {} on thread {}", .{ conn_id, thread_id }); - defer std.log.debug("Closing TCP connection id {}", .{conn_id}); - - http.serveConn(conn, .{ .src = src, .conn_id = conn_id, .allocator = gpa.allocator() }, handle, gpa.allocator()) catch |err| { - std.log.err("Error occured on connection {}: {}", .{ conn_id, err }); - }; - } -} - -fn handle(ctx: anytype, req: http.Request, res: *http.Response) void { - c.routeRequest(ctx.src, req, res, ctx.allocator); -} - pub fn main() !void { try util.seedThreadPrng(); var gpa = std.heap.GeneralPurposeAllocator(.{}){}; var cfg = try loadConfig(gpa.allocator()); - var pool = try sql.ConnPool.init(cfg.db); - try prepareDb(&pool, gpa.allocator()); + var db_conn = try sql.Conn.open(cfg.db); + try prepareDb(try db_conn.acquire(), gpa.allocator()); - var api_src = try api.ApiSource.init(&pool); - var srv = std.net.StreamServer.init(.{ .reuse_address = true }); - defer srv.deinit(); - try srv.listen(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); - - thread_main(&api_src, &srv); + var api_src = try api.ApiSource.init(&db_conn); + var srv = try RequestServer.init(gpa.allocator(), &api_src, cfg); + return srv.listenAndRun(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); } diff --git a/src/main/migrations.zig b/src/main/migrations.zig index 459e75a..595646e 100644 --- a/src/main/migrations.zig +++ b/src/main/migrations.zig @@ -16,15 +16,10 @@ fn execStmt(tx: anytype, stmt: []const u8, alloc: std.mem.Allocator) !void { } fn execScript(db: anytype, script: []const u8, alloc: std.mem.Allocator) !void { - const tx = try db.beginOrSavepoint(); - errdefer tx.rollback(); - var iter = util.SqlStmtIter.from(script); while (iter.next()) |stmt| { - try execStmt(tx, stmt, alloc); + try execStmt(db, stmt, alloc); } - - try tx.commitOrRelease(); } fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !bool { diff --git a/src/sql/engines/null.zig b/src/sql/engines/null.zig index 1c5a434..f6b2d9d 100644 --- a/src/sql/engines/null.zig +++ b/src/sql/engines/null.zig @@ -28,10 +28,6 @@ pub const Db = struct { unreachable; } - pub fn openUri(_: anytype) common.OpenError!Db { - unreachable; - } - pub fn close(_: Db) void { unreachable; } diff --git a/src/sql/engines/sqlite.zig b/src/sql/engines/sqlite.zig index d1a73c2..c191d36 100644 --- a/src/sql/engines/sqlite.zig +++ b/src/sql/engines/sqlite.zig @@ -49,15 +49,7 @@ pub const Db = struct { db: *c.sqlite3, pub fn open(path: [:0]const u8) common.OpenError!Db { - return openInternal(path, false); - } - - pub fn openUri(path: [:0]const u8) common.OpenError!Db { - return openInternal(path, true); - } - - fn openInternal(path: [:0]const u8, is_uri: bool) common.OpenError!Db { - const flags = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_EXRESCODE | if (is_uri) c.SQLITE_OPEN_URI else 0; + const flags = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_EXRESCODE; var db: ?*c.sqlite3 = null; switch (c.sqlite3_open_v2(@ptrCast([*c]const u8, path), &db, flags, null)) { @@ -129,6 +121,7 @@ pub const Db = struct { // of 0, and we must not bind the argument. const name = std.fmt.comptimePrint("${}", .{i + 1}); const db_idx = c.sqlite3_bind_parameter_index(stmt.?, name); + std.log.debug("param {s} got index {}", .{ name, db_idx }); if (db_idx != 0) try self.bindArgument(stmt.?, @intCast(u15, db_idx), arg) else if (!opts.ignore_unused_arguments) @@ -174,6 +167,7 @@ pub const Db = struct { else @compileError("SQLite: Could not serialize " ++ @typeName(T) ++ " into staticly sized string"); + std.log.debug("binding type {any}: {s}", .{ T, arr }); const len = std.mem.len(&arr); return self.bindString(stmt, idx, arr[0..len]); }, @@ -200,6 +194,8 @@ pub const Db = struct { return error.BindException; }; + std.log.debug("binding string {s} to idx {}", .{ str, idx }); + switch (c.sqlite3_bind_text(stmt, idx, str.ptr, len, c.SQLITE_TRANSIENT)) { c.SQLITE_OK => {}, else => |result| { diff --git a/src/sql/lib.zig b/src/sql/lib.zig index 3124863..458cf5e 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -10,7 +10,6 @@ const Allocator = std.mem.Allocator; const errors = @import("./errors.zig").library_errors; -pub const AcquireError = OpenError || error{NoConnectionsLeft}; pub const OpenError = errors.OpenError; pub const QueryError = errors.QueryError; pub const RowError = errors.RowError; @@ -25,14 +24,12 @@ pub const Engine = enum { sqlite, }; -// TODO: make this suck less pub const Config = union(Engine) { postgres: struct { pg_conn_str: [:0]const u8, }, sqlite: struct { sqlite_file_path: [:0]const u8, - sqlite_is_uri: bool = false, }, }; @@ -163,58 +160,16 @@ pub const ConstraintMode = enum { immediate, }; -pub const ConnPool = struct { - const max_conns = 4; - const Conn = struct { - engine: union(Engine) { - postgres: postgres.Db, - sqlite: sqlite.Db, - }, - in_use: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false), - current_tx_level: u8 = 0, - }; +pub const Conn = struct { + engine: union(Engine) { + postgres: postgres.Db, + sqlite: sqlite.Db, + }, + current_tx_level: u8 = 0, + is_tx_failed: bool = false, - config: Config, - connections: [max_conns]Conn, - - pub fn init(cfg: Config) OpenError!ConnPool { - var self = ConnPool{ - .config = cfg, - .connections = undefined, - }; - var count: usize = 0; - errdefer for (self.connections[0..count]) |*c| closeConn(c); - for (self.connections) |*c| { - c.* = try self.createConn(); - count += 1; - } - - return self; - } - - pub fn deinit(self: *ConnPool) void { - for (self.connections) |*c| closeConn(c); - } - - pub fn acquire(self: *ConnPool) AcquireError!Db { - for (self.connections) |*c| { - if (tryAcquire(c)) return Db{ .conn = c }; - } - return error.NoConnectionsLeft; - } - - fn tryAcquire(conn: *Conn) bool { - const acquired = !conn.in_use.swap(true, .AcqRel); - if (acquired) { - if (conn.current_tx_level != 0) @panic("Transaction still open on unused db connection"); - return true; - } - - return false; - } - - fn createConn(self: *ConnPool) OpenError!Conn { - return switch (self.config) { + pub fn open(cfg: Config) OpenError!Conn { + return switch (cfg) { .postgres => |postgres_cfg| Conn{ .engine = .{ .postgres = try postgres.Db.open(postgres_cfg.pg_conn_str), @@ -222,22 +177,27 @@ pub const ConnPool = struct { }, .sqlite => |lite_cfg| Conn{ .engine = .{ - .sqlite = if (lite_cfg.sqlite_is_uri) - try sqlite.Db.openUri(lite_cfg.sqlite_file_path) - else - try sqlite.Db.open(lite_cfg.sqlite_file_path), + .sqlite = try sqlite.Db.open(lite_cfg.sqlite_file_path), }, }, }; } - fn closeConn(conn: *Conn) void { - if (conn.in_use.loadUnchecked()) @panic("DB Conn still open"); - switch (conn.engine) { + pub fn close(self: *Conn) void { + switch (self.engine) { .postgres => |pg| pg.close(), .sqlite => |lite| lite.close(), } } + + pub fn acquire(self: *Conn) !Db { + if (self.current_tx_level != 0) return error.BadTransactionState; + return Db{ .conn = self }; + } + + pub fn sqlEngine(self: *Conn) Engine { + return self.engine; + } }; pub const Db = Tx(0); @@ -256,22 +216,11 @@ fn Tx(comptime tx_level: u8) type { std.fmt.comptimePrint("save_{}", .{tx_level}); const next_savepoint_name = Tx(tx_level + 1).savepoint_name; - conn: *ConnPool.Conn, + conn: *Conn, /// The type of SQL engine being used. Use of this function should be discouraged pub fn sqlEngine(self: Self) Engine { - return self.conn.engine; - } - - /// Return the connection to the pool - pub fn releaseConnection(self: Self) void { - if (tx_level != 0) @compileError("close must be called on root db"); - if (self.conn.current_tx_level != 0) { - std.log.warn("Database released while transaction in progress!", .{}); - self.rollbackUnchecked() catch {}; - } - - if (!self.conn.in_use.swap(false, .AcqRel)) @panic("Double close on db conection"); + return self.conn.sqlEngine(); } // ********* Transaction management functions ********** @@ -328,7 +277,7 @@ fn Tx(comptime tx_level: u8) type { if (tx_level >= 2) @compileError("Cannot rollback a transaction using a savepoint"); if (self.conn.current_tx_level == 0) return error.BadTransactionState; - try self.rollbackUnchecked(); + try self.exec("ROLLBACK", {}, null); self.conn.current_tx_level = 0; } @@ -453,7 +402,7 @@ fn Tx(comptime tx_level: u8) type { comptime var table_spec: []const u8 = table ++ "("; comptime var value_spec: []const u8 = "("; inline for (fields) |field, i| { - // This causes a compiler crash. Why? + // This causes a compile error. Why? //const F = field.field_type; const F = @TypeOf(@field(std.mem.zeroes(ValueType), field.name)); // causes issues if F is @TypeOf(null), use dummy type @@ -504,9 +453,5 @@ fn Tx(comptime tx_level: u8) type { while (try results.row()) |_| {} } - - fn rollbackUnchecked(self: Self) !void { - try self.exec("ROLLBACK", {}, null); - } }; } diff --git a/tests/api_integration/lib.zig b/tests/api_integration/lib.zig index ea0fd5d..1027043 100644 --- a/tests/api_integration/lib.zig +++ b/tests/api_integration/lib.zig @@ -5,10 +5,11 @@ const sql = @import("sql"); const util = @import("util"); const test_config = .{ - .db = .{ .sqlite = .{ - .sqlite_file_path = "file::memory:?cache=shared", - .sqlite_is_uri = true, - } }, + .db = .{ + .sqlite = .{ + .sqlite_file_path = ":memory:", + }, + }, }; const ApiSource = api.ApiSource; @@ -17,16 +18,12 @@ const root_password = "password1234"; const admin_host = "example.com"; const admin_origin = "https://" ++ admin_host; -fn makeDb(alloc: std.mem.Allocator) !sql.ConnPool { +fn makeDb(alloc: std.mem.Allocator) !sql.Conn { try util.seedThreadPrng(); - var pool = try sql.ConnPool.init(test_config.db); - { - var db = try pool.acquire(); - defer db.releaseConnection(); - try migrations.up(db); - try api.setupAdmin(db, admin_origin, root_user, root_password, alloc); - } - return pool; + var db = try sql.Conn.open(test_config.db); + try migrations.up(try db.acquire()); + try api.setupAdmin(try db.acquire(), admin_origin, root_user, root_password, alloc); + return db; } fn connectAndLogin( @@ -45,7 +42,6 @@ fn connectAndLogin( test "login as root" { const alloc = std.testing.allocator; var db = try makeDb(alloc); - defer db.deinit(); var src = try ApiSource.init(&db); const login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc); @@ -63,7 +59,6 @@ test "login as root" { test "create community" { const alloc = std.testing.allocator; var db = try makeDb(alloc); - defer db.deinit(); var src = try ApiSource.init(&db); const login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc); @@ -85,7 +80,6 @@ test "create community" { test "create community and transfer to new owner" { const alloc = std.testing.allocator; var db = try makeDb(alloc); - defer db.deinit(); var src = try ApiSource.init(&db); const root_login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc);