From 159f1c28ccb461042808a7423ce9a1a94a9769e3 Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Wed, 12 Oct 2022 23:19:59 -0700 Subject: [PATCH 1/2] Add sql connection pool --- src/api/lib.zig | 11 ++-- src/main/main.zig | 11 ++-- src/main/migrations.zig | 7 ++- src/sql/engines/null.zig | 4 ++ src/sql/engines/sqlite.zig | 14 +++-- src/sql/lib.zig | 103 ++++++++++++++++++++++++++-------- tests/api_integration/lib.zig | 26 +++++---- 7 files changed, 127 insertions(+), 49 deletions(-) diff --git a/src/api/lib.zig b/src/api/lib.zig index 487a4a9..e6cf82a 100644 --- a/src/api/lib.zig +++ b/src/api/lib.zig @@ -94,15 +94,15 @@ pub fn setupAdmin(db: sql.Db, origin: []const u8, username: []const u8, password } pub const ApiSource = struct { - db_conn: *sql.Conn, + db_conn_pool: *sql.ConnPool, pub const Conn = ApiConn(sql.Db); const root_username = "root"; - pub fn init(db_conn: *sql.Conn) !ApiSource { + pub fn init(pool: *sql.ConnPool) !ApiSource { return ApiSource{ - .db_conn = db_conn, + .db_conn_pool = pool, }; } @@ -110,7 +110,7 @@ pub const ApiSource = struct { var arena = std.heap.ArenaAllocator.init(alloc); errdefer arena.deinit(); - const db = try self.db_conn.acquire(); + const db = try self.db_conn_pool.acquire(); const community = try services.communities.getByHost(db, host, arena.allocator()); return Conn{ @@ -125,7 +125,7 @@ pub const ApiSource = struct { var arena = std.heap.ArenaAllocator.init(alloc); errdefer arena.deinit(); - const db = try self.db_conn.acquire(); + const db = try self.db_conn_pool.acquire(); const community = try services.communities.getByHost(db, host, arena.allocator()); const token_info = try services.auth.verifyToken( @@ -157,6 +157,7 @@ fn ApiConn(comptime DbConn: type) type { pub fn close(self: *Self) void { self.arena.deinit(); + self.db.releaseConnection(); } fn isAdmin(self: *Self) bool { diff --git a/src/main/main.zig b/src/main/main.zig index f722294..1d427f0 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -64,7 +64,10 @@ fn runAdminSetup(db: sql.Db, alloc: std.mem.Allocator) !void { try api.setupAdmin(db, origin, username, password, alloc); } -fn prepareDb(db: sql.Db, alloc: std.mem.Allocator) !void { +fn prepareDb(pool: *sql.ConnPool, alloc: std.mem.Allocator) !void { + const db = try pool.acquire(); + defer db.releaseConnection(); + try migrations.up(db); if (!try api.isAdminSetup(db)) { @@ -93,10 +96,10 @@ pub fn main() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; var cfg = try loadConfig(gpa.allocator()); - var db_conn = try sql.Conn.open(cfg.db); - try prepareDb(try db_conn.acquire(), gpa.allocator()); + var pool = try sql.ConnPool.init(cfg.db); + try prepareDb(&pool, gpa.allocator()); - var api_src = try api.ApiSource.init(&db_conn); + var api_src = try api.ApiSource.init(&pool); var srv = try RequestServer.init(gpa.allocator(), &api_src, cfg); return srv.listenAndRun(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); } diff --git a/src/main/migrations.zig b/src/main/migrations.zig index 595646e..459e75a 100644 --- a/src/main/migrations.zig +++ b/src/main/migrations.zig @@ -16,10 +16,15 @@ fn execStmt(tx: anytype, stmt: []const u8, alloc: std.mem.Allocator) !void { } fn execScript(db: anytype, script: []const u8, alloc: std.mem.Allocator) !void { + const tx = try db.beginOrSavepoint(); + errdefer tx.rollback(); + var iter = util.SqlStmtIter.from(script); while (iter.next()) |stmt| { - try execStmt(db, stmt, alloc); + try execStmt(tx, stmt, alloc); } + + try tx.commitOrRelease(); } fn wasMigrationRan(db: anytype, name: []const u8, alloc: std.mem.Allocator) !bool { diff --git a/src/sql/engines/null.zig b/src/sql/engines/null.zig index f6b2d9d..1c5a434 100644 --- a/src/sql/engines/null.zig +++ b/src/sql/engines/null.zig @@ -28,6 +28,10 @@ pub const Db = struct { unreachable; } + pub fn openUri(_: anytype) common.OpenError!Db { + unreachable; + } + pub fn close(_: Db) void { unreachable; } diff --git a/src/sql/engines/sqlite.zig b/src/sql/engines/sqlite.zig index c191d36..d1a73c2 100644 --- a/src/sql/engines/sqlite.zig +++ b/src/sql/engines/sqlite.zig @@ -49,7 +49,15 @@ pub const Db = struct { db: *c.sqlite3, pub fn open(path: [:0]const u8) common.OpenError!Db { - const flags = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_EXRESCODE; + return openInternal(path, false); + } + + pub fn openUri(path: [:0]const u8) common.OpenError!Db { + return openInternal(path, true); + } + + fn openInternal(path: [:0]const u8, is_uri: bool) common.OpenError!Db { + const flags = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_EXRESCODE | if (is_uri) c.SQLITE_OPEN_URI else 0; var db: ?*c.sqlite3 = null; switch (c.sqlite3_open_v2(@ptrCast([*c]const u8, path), &db, flags, null)) { @@ -121,7 +129,6 @@ pub const Db = struct { // of 0, and we must not bind the argument. const name = std.fmt.comptimePrint("${}", .{i + 1}); const db_idx = c.sqlite3_bind_parameter_index(stmt.?, name); - std.log.debug("param {s} got index {}", .{ name, db_idx }); if (db_idx != 0) try self.bindArgument(stmt.?, @intCast(u15, db_idx), arg) else if (!opts.ignore_unused_arguments) @@ -167,7 +174,6 @@ pub const Db = struct { else @compileError("SQLite: Could not serialize " ++ @typeName(T) ++ " into staticly sized string"); - std.log.debug("binding type {any}: {s}", .{ T, arr }); const len = std.mem.len(&arr); return self.bindString(stmt, idx, arr[0..len]); }, @@ -194,8 +200,6 @@ pub const Db = struct { return error.BindException; }; - std.log.debug("binding string {s} to idx {}", .{ str, idx }); - switch (c.sqlite3_bind_text(stmt, idx, str.ptr, len, c.SQLITE_TRANSIENT)) { c.SQLITE_OK => {}, else => |result| { diff --git a/src/sql/lib.zig b/src/sql/lib.zig index 458cf5e..1208c9c 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -10,6 +10,7 @@ const Allocator = std.mem.Allocator; const errors = @import("./errors.zig").library_errors; +pub const AcquireError = OpenError || error{NoConnectionsLeft}; pub const OpenError = errors.OpenError; pub const QueryError = errors.QueryError; pub const RowError = errors.RowError; @@ -24,12 +25,14 @@ pub const Engine = enum { sqlite, }; +// TODO: make this suck less pub const Config = union(Engine) { postgres: struct { pg_conn_str: [:0]const u8, }, sqlite: struct { sqlite_file_path: [:0]const u8, + sqlite_is_uri: bool = false, }, }; @@ -160,16 +163,58 @@ pub const ConstraintMode = enum { immediate, }; -pub const Conn = struct { - engine: union(Engine) { - postgres: postgres.Db, - sqlite: sqlite.Db, - }, - current_tx_level: u8 = 0, - is_tx_failed: bool = false, +pub const ConnPool = struct { + const max_conns = 4; + const Conn = struct { + engine: union(Engine) { + postgres: postgres.Db, + sqlite: sqlite.Db, + }, + in_use: std.atomic.Atomic(bool) = std.atomic.Atomic(bool).init(false), + current_tx_level: u8 = 0, + }; - pub fn open(cfg: Config) OpenError!Conn { - return switch (cfg) { + config: Config, + connections: [max_conns]Conn, + + pub fn init(cfg: Config) OpenError!ConnPool { + var self = ConnPool{ + .config = cfg, + .connections = undefined, + }; + var count: usize = 0; + errdefer for (self.connections[0..count]) |*c| closeConn(c); + for (self.connections) |*c| { + c.* = try self.createConn(); + count += 1; + } + + return self; + } + + pub fn deinit(self: *ConnPool) void { + for (self.connections) |*c| closeConn(c); + } + + pub fn acquire(self: *ConnPool) AcquireError!Db { + for (self.connections) |*c| { + if (tryAcquire(c)) return Db{ .conn = c }; + } + return error.NoConnectionsLeft; + } + + fn tryAcquire(conn: *Conn) bool { + const acquired = !conn.in_use.swap(true, .AcqRel); + if (acquired) { + if (conn.current_tx_level != 0) @panic("Transaction still open on unused db connection"); + return true; + } + + return false; + } + + fn createConn(self: *ConnPool) OpenError!Conn { + return switch (self.config) { .postgres => |postgres_cfg| Conn{ .engine = .{ .postgres = try postgres.Db.open(postgres_cfg.pg_conn_str), @@ -177,27 +222,22 @@ pub const Conn = struct { }, .sqlite => |lite_cfg| Conn{ .engine = .{ - .sqlite = try sqlite.Db.open(lite_cfg.sqlite_file_path), + .sqlite = if (lite_cfg.sqlite_is_uri) + try sqlite.Db.openUri(lite_cfg.sqlite_file_path) + else + try sqlite.Db.open(lite_cfg.sqlite_file_path), }, }, }; } - pub fn close(self: *Conn) void { - switch (self.engine) { + fn closeConn(conn: *Conn) void { + if (conn.in_use.loadUnchecked()) @panic("DB Conn still open"); + switch (conn.engine) { .postgres => |pg| pg.close(), .sqlite => |lite| lite.close(), } } - - pub fn acquire(self: *Conn) !Db { - if (self.current_tx_level != 0) return error.BadTransactionState; - return Db{ .conn = self }; - } - - pub fn sqlEngine(self: *Conn) Engine { - return self.engine; - } }; pub const Db = Tx(0); @@ -216,11 +256,22 @@ fn Tx(comptime tx_level: u8) type { std.fmt.comptimePrint("save_{}", .{tx_level}); const next_savepoint_name = Tx(tx_level + 1).savepoint_name; - conn: *Conn, + conn: *ConnPool.Conn, /// The type of SQL engine being used. Use of this function should be discouraged pub fn sqlEngine(self: Self) Engine { - return self.conn.sqlEngine(); + return self.conn.engine; + } + + /// Return the connection to the pool + pub fn releaseConnection(self: Self) void { + if (tx_level != 0) @compileError("close must be called on root db"); + if (self.conn.current_tx_level != 0) { + std.log.warn("Database released while transaction in progress!", .{}); + self.rollbackUnchecked() catch {}; + } + + if (!self.conn.in_use.swap(false, .AcqRel)) @panic("Double close on db conection"); } // ********* Transaction management functions ********** @@ -277,7 +328,7 @@ fn Tx(comptime tx_level: u8) type { if (tx_level >= 2) @compileError("Cannot rollback a transaction using a savepoint"); if (self.conn.current_tx_level == 0) return error.BadTransactionState; - try self.exec("ROLLBACK", {}, null); + try self.rollbackUnchecked(); self.conn.current_tx_level = 0; } @@ -453,5 +504,9 @@ fn Tx(comptime tx_level: u8) type { while (try results.row()) |_| {} } + + fn rollbackUnchecked(self: Self) !void { + try self.exec("ROLLBACK", {}, null); + } }; } diff --git a/tests/api_integration/lib.zig b/tests/api_integration/lib.zig index 1027043..ea0fd5d 100644 --- a/tests/api_integration/lib.zig +++ b/tests/api_integration/lib.zig @@ -5,11 +5,10 @@ const sql = @import("sql"); const util = @import("util"); const test_config = .{ - .db = .{ - .sqlite = .{ - .sqlite_file_path = ":memory:", - }, - }, + .db = .{ .sqlite = .{ + .sqlite_file_path = "file::memory:?cache=shared", + .sqlite_is_uri = true, + } }, }; const ApiSource = api.ApiSource; @@ -18,12 +17,16 @@ const root_password = "password1234"; const admin_host = "example.com"; const admin_origin = "https://" ++ admin_host; -fn makeDb(alloc: std.mem.Allocator) !sql.Conn { +fn makeDb(alloc: std.mem.Allocator) !sql.ConnPool { try util.seedThreadPrng(); - var db = try sql.Conn.open(test_config.db); - try migrations.up(try db.acquire()); - try api.setupAdmin(try db.acquire(), admin_origin, root_user, root_password, alloc); - return db; + var pool = try sql.ConnPool.init(test_config.db); + { + var db = try pool.acquire(); + defer db.releaseConnection(); + try migrations.up(db); + try api.setupAdmin(db, admin_origin, root_user, root_password, alloc); + } + return pool; } fn connectAndLogin( @@ -42,6 +45,7 @@ fn connectAndLogin( test "login as root" { const alloc = std.testing.allocator; var db = try makeDb(alloc); + defer db.deinit(); var src = try ApiSource.init(&db); const login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc); @@ -59,6 +63,7 @@ test "login as root" { test "create community" { const alloc = std.testing.allocator; var db = try makeDb(alloc); + defer db.deinit(); var src = try ApiSource.init(&db); const login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc); @@ -80,6 +85,7 @@ test "create community" { test "create community and transfer to new owner" { const alloc = std.testing.allocator; var db = try makeDb(alloc); + defer db.deinit(); var src = try ApiSource.init(&db); const root_login = try connectAndLogin(&src, admin_host, root_user, root_password, alloc); From 0ce315368ab0b787254267dca38d55411e32f586 Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Thu, 13 Oct 2022 02:23:57 -0700 Subject: [PATCH 2/2] Minor HTTP refactor --- src/http/lib.zig | 6 ++- src/http/request.zig | 18 +++++-- src/http/request/parser.zig | 84 ++++++++++++++++++++++--------- src/http/server.zig | 90 +++++++++++++++++----------------- src/http/server/connection.zig | 51 ------------------- src/http/server/response.zig | 5 +- src/main/controllers.zig | 49 ++++++++++-------- src/main/main.zig | 65 ++++++++++++------------ src/sql/lib.zig | 2 +- 9 files changed, 190 insertions(+), 180 deletions(-) delete mode 100644 src/http/server/connection.zig diff --git a/src/http/lib.zig b/src/http/lib.zig index 9a767f0..3d12cdc 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -3,13 +3,15 @@ const ciutf8 = @import("util").ciutf8; const request = @import("./request.zig"); -pub const server = @import("./server.zig"); +const server = @import("./server.zig"); pub const Method = std.http.Method; pub const Status = std.http.Status; pub const Request = request.Request; -pub const Server = server.Server; +pub const serveConn = server.serveConn; +pub const Response = server.Response; +pub const Handler = server.Handler; pub const Headers = std.HashMap([]const u8, []const u8, struct { pub fn eql(_: @This(), a: []const u8, b: []const u8) bool { diff --git a/src/http/request.zig b/src/http/request.zig index 17def73..e6fd79c 100644 --- a/src/http/request.zig +++ b/src/http/request.zig @@ -4,13 +4,25 @@ const http = @import("./lib.zig"); const parser = @import("./request/parser.zig"); pub const Request = struct { + pub const Protocol = enum { + http_1_0, + http_1_1, + }; + + protocol: Protocol, + source_address: ?std.net.Address, + method: http.Method, - path: []const u8, + uri: []const u8, headers: http.Headers, body: ?[]const u8 = null, - pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request { - return parser.parse(alloc, reader); + pub fn parse(alloc: std.mem.Allocator, reader: anytype, addr: std.net.Address) !Request { + return parser.parse(alloc, reader, addr); + } + + pub fn parseFree(self: Request, alloc: std.mem.Allocator) void { + parser.parseFree(alloc, self); } }; diff --git a/src/http/request/parser.zig b/src/http/request/parser.zig index 9da174e..8a7c19f 100644 --- a/src/http/request/parser.zig +++ b/src/http/request/parser.zig @@ -22,34 +22,45 @@ const Encoding = enum { chunked, }; -pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request { - var request: Request = undefined; - - try parseLine(alloc, &request, reader); - request.headers = try parseHeaders(alloc, reader); - - if (request.method.requestHasBody()) { - request.body = try readBody(alloc, request.headers, reader); - } else { - request.body = null; - } - - return request; -} - -fn parseLine(alloc: std.mem.Allocator, request: *Request, reader: anytype) !void { - request.method = try parseMethod(reader); - request.path = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) { +pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address) !Request { + const method = try parseMethod(reader); + const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) { error.StreamTooLong => return error.RequestUriTooLong, else => return err, }; - errdefer alloc.free(request.path); + errdefer alloc.free(uri); - try checkProto(reader); + const proto = try parseProto(reader); // discard \r\n _ = try reader.readByte(); _ = try reader.readByte(); + + var headers = try parseHeaders(alloc, reader); + errdefer freeHeaders(alloc, &headers); + + const body = if (method.requestHasBody()) + try readBody(alloc, headers, reader) + else + null; + errdefer if (body) |b| alloc.free(b); + + const eff_addr = if (headers.get("X-Real-IP")) |ip| + std.net.Address.parseIp(ip, address.getPort()) catch { + return error.BadRequest; + } + else + address; + + return Request{ + .protocol = proto, + .source_address = eff_addr, + + .method = method, + .uri = uri, + .headers = headers, + .body = body, + }; } fn parseMethod(reader: anytype) !Method { @@ -68,7 +79,7 @@ fn parseMethod(reader: anytype) !Method { return error.MethodNotImplemented; } -fn checkProto(reader: anytype) !void { +fn parseProto(reader: anytype) !Request.Protocol { var buf: [8]u8 = undefined; const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { error.StreamTooLong => return error.UnknownProtocol, @@ -84,14 +95,24 @@ fn checkProto(reader: anytype) !void { return error.BadRequest; } - if (buf[0] != '1' or buf[2] != '1') { - return error.HttpVersionNotSupported; - } + if (buf[0] != '1') return error.HttpVersionNotSupported; + return switch (buf[2]) { + '0' => .http_1_0, + '1' => .http_1_1, + else => error.HttpVersionNotSupported, + }; } fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers { var map = Headers.init(allocator); errdefer map.deinit(); + errdefer { + var iter = map.iterator(); + while (iter.next()) |it| { + allocator.free(it.key_ptr.*); + allocator.free(it.value_ptr.*); + } + } // todo: //errdefer { //var iter = map.iterator(); @@ -167,6 +188,21 @@ fn parseEncoding(encoding: ?[]const u8) !Encoding { return error.UnsupportedMediaType; } +pub fn parseFree(allocator: std.mem.Allocator, request: *Request) void { + allocator.free(request.uri); + freeHeaders(allocator, &request.headers); + if (request.body) |body| allocator.free(body); +} + +fn freeHeaders(allocator: std.mem.Allocator, headers: *http.Headers) void { + var iter = headers.iterator(); + while (iter.next()) |it| { + allocator.free(it.key_ptr.*); + allocator.free(it.value_ptr.*); + } + headers.deinit(); +} + const _test = struct { const expectEqual = std.testing.expectEqual; const expectEqualStrings = std.testing.expectEqualStrings; diff --git a/src/http/server.zig b/src/http/server.zig index c9095ef..e9bf81e 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -2,67 +2,69 @@ const std = @import("std"); const util = @import("util"); const http = @import("./lib.zig"); -const connection = @import("./server/connection.zig"); const response = @import("./server/response.zig"); -pub const Connection = connection.Connection; -pub const Response = response.ResponseStream(Connection.Writer); +pub const Response = struct { + alloc: std.mem.Allocator, + stream: std.net.Stream, + should_close: bool = false, + pub const Stream = response.ResponseStream(std.net.Stream.Writer); + pub fn open(self: *Response, status: http.Status, headers: *const http.Headers) !Stream { + if (headers.get("Connection")) |hdr| { + if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true; + } + + return response.open(self.alloc, self.stream.writer(), headers, status); + } +}; -const ConnectionServer = connection.Server; const Request = http.Request; const request_buf_size = 1 << 16; -pub const Context = struct { - alloc: std.mem.Allocator, - request: Request, - connection: Connection, +pub fn Handler(comptime Ctx: type) type { + return fn (Ctx, Request, *Response) void; +} - pub fn openResponse(self: *Context, headers: *const http.Headers, status: http.Status) !Response { - return try response.open(self.alloc, self.connection.stream.writer(), headers, status); - } +pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: anytype, alloc: std.mem.Allocator) !void { + // TODO: Timeouts + while (true) { + std.log.debug("waiting for request", .{}); + var arena = std.heap.ArenaAllocator.init(alloc); + defer arena.deinit(); - pub fn close(self: *Context) void { - // todo: deallocate request - self.connection.close(); - } -}; - -pub const Server = struct { - conn_server: ConnectionServer, - pub fn listen(addr: std.net.Address) !Server { - return Server{ - .conn_server = try ConnectionServer.listen(addr), + const req = Request.parse(arena.allocator(), conn.stream.reader(), conn.address) catch |err| { + return handleError(conn.stream.writer(), err) catch {}; }; + std.log.debug("done parsing", .{}); + + var res = Response{ + .alloc = arena.allocator(), + .stream = conn.stream, + }; + + handler(ctx, req, &res); + std.log.debug("done handling", .{}); + + if (req.headers.get("Connection")) |hdr| { + if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| return; + } else if (req.headers.get("Keep-Alive")) |hdr| { + std.log.debug("keep-alive: {s}", .{hdr}); + } else if (req.protocol == .http_1_0) return; + + if (res.should_close) return; } +} - pub fn accept(self: *Server, alloc: std.mem.Allocator) !Context { - while (true) { - const conn = try self.conn_server.accept(); - errdefer conn.close(); - - const req = http.Request.parse(alloc, conn.stream.reader()) catch |err| { - handleError(conn.stream.writer(), err) catch unreachable; - continue; - }; - - return Context{ .connection = conn, .request = req, .alloc = alloc }; - } - } - - pub fn shutdown(self: *Server) void { - self.conn_server.shutdown(); - } -}; - -// TODO: We should get more specific about what type of errors can happen +/// Writes an error response message and requests closure of the connection fn handleError(writer: anytype, err: anyerror) !void { const status: http.Status = switch (err) { + error.EndOfStream => return, // Do nothing, the client closed the connection error.BadRequest => .bad_request, error.UnsupportedMediaType => .unsupported_media_type, error.HttpVersionNotSupported => .http_version_not_supported, - else => return err, + else => .internal_server_error, }; - try writer.print("HTTP/1.1 {} {?s}\r\n\r\n", .{ @enumToInt(status), status.phrase() }); + try writer.print("HTTP/1.1 {} {?s}\r\nConnection: close\r\n\r\n", .{ @enumToInt(status), status.phrase() }); } diff --git a/src/http/server/connection.zig b/src/http/server/connection.zig deleted file mode 100644 index 1ae6d69..0000000 --- a/src/http/server/connection.zig +++ /dev/null @@ -1,51 +0,0 @@ -const std = @import("std"); - -pub const Connection = struct { - pub const Id = u64; - pub const Writer = std.net.Stream.Writer; - pub const Reader = std.net.Stream.Reader; - - id: Id, - address: std.net.Address, - stream: std.net.Stream, - - fn new(id: Id, std_conn: std.net.StreamServer.Connection) Connection { - std.log.debug("new connection conn_id={}", .{id}); - return .{ - .id = id, - .address = std_conn.address, - .stream = std_conn.stream, - }; - } - - pub fn close(self: Connection) void { - std.log.debug("terminating connection conn_id={}", .{self.id}); - self.stream.close(); - } -}; - -pub const Server = struct { - next_conn_id: std.atomic.Atomic(Connection.Id) = std.atomic.Atomic(Connection.Id).init(1), - stream_server: std.net.StreamServer, - - pub fn listen(addr: std.net.Address) !Server { - var self = Server{ - .stream_server = std.net.StreamServer.init(.{ .reuse_address = true }), - }; - errdefer self.stream_server.deinit(); - - try self.stream_server.listen(addr); - return self; - } - - pub fn accept(self: *Server) !Connection { - const conn = try self.stream_server.accept(); - const id = self.next_conn_id.fetchAdd(1, .SeqCst); - - return Connection.new(id, conn); - } - - pub fn shutdown(self: *Server) void { - self.stream_server.deinit(); - } -}; diff --git a/src/http/server/response.zig b/src/http/server/response.zig index aa549f8..a99aa8e 100644 --- a/src/http/server/response.zig +++ b/src/http/server/response.zig @@ -111,9 +111,7 @@ pub fn ResponseStream(comptime BaseWriter: type) type { } fn flushBodyUnchunked(self: *Self) Error!void { - if (self.buffer_pos != 0) { - try self.base_writer.print("Content-Length: {}\r\n", .{self.buffer_pos}); - } + try self.base_writer.print("Content-Length: {}\r\n", .{self.buffer_pos}); try self.base_writer.writeAll("\r\n"); @@ -128,6 +126,7 @@ pub fn ResponseStream(comptime BaseWriter: type) type { } pub fn finish(self: *Self) Error!void { + std.log.debug("finishing", .{}); if (!self.chunked) { try self.flushBodyUnchunked(); } else { diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 87fb11d..b0f7800 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -13,14 +13,13 @@ pub const invites = @import("./controllers/invites.zig"); pub const users = @import("./controllers/users.zig"); pub const notes = @import("./controllers/notes.zig"); -pub fn routeRequest(api_source: anytype, ctx: http.server.Context, alloc: std.mem.Allocator) void { +pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void { // TODO: hashmaps? - inline for (routes) |route| { - if (Context(route).matchAndHandle(api_source, ctx, alloc)) return; - } - - var response = Response{ .headers = http.Headers.init(alloc), .ctx = ctx }; + var response = Response{ .headers = http.Headers.init(alloc), .res = res }; defer response.headers.deinit(); + inline for (routes) |route| { + if (Context(route).matchAndHandle(api_source, req, &response, alloc)) return; + } response.status(.not_found) catch {}; } @@ -53,7 +52,7 @@ pub fn Context(comptime Route: type) type { allocator: std.mem.Allocator, method: http.Method, - request_line: []const u8, + uri: []const u8, headers: http.Headers, args: Args, @@ -84,20 +83,16 @@ pub fn Context(comptime Route: type) type { @compileError("Unsupported Type " ++ @typeName(T)); } - pub fn matchAndHandle(api_source: *api.ApiSource, ctx: http.server.Context, alloc: std.mem.Allocator) bool { - const req = ctx.request; + pub fn matchAndHandle(api_source: *api.ApiSource, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool { if (req.method != Route.method) return false; - var path = std.mem.sliceTo(std.mem.sliceTo(req.path, '#'), '?'); + var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?'); var args: Args = parseArgs(path) orelse return false; - var response = Response{ .headers = http.Headers.init(alloc), .ctx = ctx }; - defer response.headers.deinit(); - var self = Self{ .allocator = alloc, .method = req.method, - .request_line = req.path, + .uri = req.uri, .headers = req.headers, .args = args, @@ -105,7 +100,7 @@ pub fn Context(comptime Route: type) type { .query = undefined, }; - self.prepareAndHandle(api_source, req, &response); + self.prepareAndHandle(api_source, req, res); return true; } @@ -149,15 +144,20 @@ pub fn Context(comptime Route: type) type { fn parseQuery(self: *Self) !void { if (Query != void) { - const path = std.mem.sliceTo(self.request_line, '?'); - const q = std.mem.sliceTo(self.request_line[path.len..], '#'); + const path = std.mem.sliceTo(self.uri, '?'); + const q = std.mem.sliceTo(self.uri[path.len..], '#'); self.query = try query_utils.parseQuery(Query, q); } } fn handle(self: Self, response: *Response, api_conn: anytype) void { - Route.handler(self, response, api_conn) catch |err| std.log.err("{}", .{err}); + Route.handler(self, response, api_conn) catch |err| switch (err) { + else => { + std.log.err("{}", .{err}); + response.err(.internal_server_error, "", {}) catch {}; + }, + }; } fn getApiConn(self: *Self, api_source: anytype) !api.ApiSource.Conn { @@ -180,18 +180,25 @@ pub fn Context(comptime Route: type) type { pub const Response = struct { const Self = @This(); headers: http.Headers, - ctx: http.server.Context, + res: *http.Response, + opened: bool = false, pub fn status(self: *Self, status_code: http.Status) !void { - var stream = try self.ctx.openResponse(&self.headers, status_code); + std.debug.assert(!self.opened); + self.opened = true; + + var stream = try self.res.open(status_code, &self.headers); defer stream.close(); try stream.finish(); } pub fn json(self: *Self, status_code: http.Status, response_body: anytype) !void { + std.debug.assert(!self.opened); + self.opened = true; + try self.headers.put("Content-Type", "application/json"); - var stream = try self.ctx.openResponse(&self.headers, status_code); + var stream = try self.res.open(status_code, &self.headers); defer stream.close(); const writer = stream.writer(); diff --git a/src/main/main.zig b/src/main/main.zig index 1d427f0..a82713b 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -8,36 +8,8 @@ const api = @import("api"); pub const migrations = @import("./migrations.zig"); // TODO const c = @import("./controllers.zig"); -pub const RequestServer = struct { - alloc: std.mem.Allocator, - api: *api.ApiSource, - config: Config, - - fn init(alloc: std.mem.Allocator, src: *api.ApiSource, config: Config) !RequestServer { - return RequestServer{ - .alloc = alloc, - .api = src, - .config = config, - }; - } - - fn listenAndRun(self: *RequestServer, addr: std.net.Address) !void { - var srv = http.Server.listen(addr) catch unreachable; - defer srv.shutdown(); - - while (true) { - var arena = std.heap.ArenaAllocator.init(self.alloc); - defer arena.deinit(); - - var ctx = try srv.accept(arena.allocator()); - defer ctx.close(); - - c.routeRequest(self.api, ctx, arena.allocator()); - } - } -}; - pub const Config = struct { + worker_threads: usize = 10, db: sql.Config, }; @@ -91,6 +63,34 @@ fn prepareDb(pool: *sql.ConnPool, alloc: std.mem.Allocator) !void { } } +const ConnectionId = u64; +var next_conn_id = std.atomic.Atomic(ConnectionId).init(0); + +fn thread_main(src: *api.ApiSource, srv: *std.net.StreamServer) void { + util.seedThreadPrng() catch unreachable; + const thread_id = std.Thread.getCurrentId(); + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + + while (true) { + var conn = srv.accept() catch |err| { + std.log.err("Error accepting connection: {}", .{err}); + continue; + }; + defer conn.stream.close(); + const conn_id = next_conn_id.fetchAdd(1, .SeqCst); + std.log.debug("Accepting TCP connection id {} on thread {}", .{ conn_id, thread_id }); + defer std.log.debug("Closing TCP connection id {}", .{conn_id}); + + http.serveConn(conn, .{ .src = src, .conn_id = conn_id, .allocator = gpa.allocator() }, handle, gpa.allocator()) catch |err| { + std.log.err("Error occured on connection {}: {}", .{ conn_id, err }); + }; + } +} + +fn handle(ctx: anytype, req: http.Request, res: *http.Response) void { + c.routeRequest(ctx.src, req, res, ctx.allocator); +} + pub fn main() !void { try util.seedThreadPrng(); @@ -100,6 +100,9 @@ pub fn main() !void { try prepareDb(&pool, gpa.allocator()); var api_src = try api.ApiSource.init(&pool); - var srv = try RequestServer.init(gpa.allocator(), &api_src, cfg); - return srv.listenAndRun(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); + var srv = std.net.StreamServer.init(.{ .reuse_address = true }); + defer srv.deinit(); + try srv.listen(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); + + thread_main(&api_src, &srv); } diff --git a/src/sql/lib.zig b/src/sql/lib.zig index 1208c9c..3124863 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -453,7 +453,7 @@ fn Tx(comptime tx_level: u8) type { comptime var table_spec: []const u8 = table ++ "("; comptime var value_spec: []const u8 = "("; inline for (fields) |field, i| { - // This causes a compile error. Why? + // This causes a compiler crash. Why? //const F = field.field_type; const F = @TypeOf(@field(std.mem.zeroes(ValueType), field.name)); // causes issues if F is @TypeOf(null), use dummy type