diff --git a/src/http/lib.zig b/src/http/lib.zig index 9a767f0..3d12cdc 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -3,13 +3,15 @@ const ciutf8 = @import("util").ciutf8; const request = @import("./request.zig"); -pub const server = @import("./server.zig"); +const server = @import("./server.zig"); pub const Method = std.http.Method; pub const Status = std.http.Status; pub const Request = request.Request; -pub const Server = server.Server; +pub const serveConn = server.serveConn; +pub const Response = server.Response; +pub const Handler = server.Handler; pub const Headers = std.HashMap([]const u8, []const u8, struct { pub fn eql(_: @This(), a: []const u8, b: []const u8) bool { diff --git a/src/http/request.zig b/src/http/request.zig index 17def73..e6fd79c 100644 --- a/src/http/request.zig +++ b/src/http/request.zig @@ -4,13 +4,25 @@ const http = @import("./lib.zig"); const parser = @import("./request/parser.zig"); pub const Request = struct { + pub const Protocol = enum { + http_1_0, + http_1_1, + }; + + protocol: Protocol, + source_address: ?std.net.Address, + method: http.Method, - path: []const u8, + uri: []const u8, headers: http.Headers, body: ?[]const u8 = null, - pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request { - return parser.parse(alloc, reader); + pub fn parse(alloc: std.mem.Allocator, reader: anytype, addr: std.net.Address) !Request { + return parser.parse(alloc, reader, addr); + } + + pub fn parseFree(self: Request, alloc: std.mem.Allocator) void { + parser.parseFree(alloc, self); } }; diff --git a/src/http/request/parser.zig b/src/http/request/parser.zig index 9da174e..8a7c19f 100644 --- a/src/http/request/parser.zig +++ b/src/http/request/parser.zig @@ -22,34 +22,45 @@ const Encoding = enum { chunked, }; -pub fn parse(alloc: std.mem.Allocator, reader: anytype) !Request { - var request: Request = undefined; - - try parseLine(alloc, &request, reader); - request.headers = try parseHeaders(alloc, reader); - - if (request.method.requestHasBody()) { - request.body = try readBody(alloc, request.headers, reader); - } else { - request.body = null; - } - - return request; -} - -fn parseLine(alloc: std.mem.Allocator, request: *Request, reader: anytype) !void { - request.method = try parseMethod(reader); - request.path = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) { +pub fn parse(alloc: std.mem.Allocator, reader: anytype, address: std.net.Address) !Request { + const method = try parseMethod(reader); + const uri = reader.readUntilDelimiterAlloc(alloc, ' ', max_path_len) catch |err| switch (err) { error.StreamTooLong => return error.RequestUriTooLong, else => return err, }; - errdefer alloc.free(request.path); + errdefer alloc.free(uri); - try checkProto(reader); + const proto = try parseProto(reader); // discard \r\n _ = try reader.readByte(); _ = try reader.readByte(); + + var headers = try parseHeaders(alloc, reader); + errdefer freeHeaders(alloc, &headers); + + const body = if (method.requestHasBody()) + try readBody(alloc, headers, reader) + else + null; + errdefer if (body) |b| alloc.free(b); + + const eff_addr = if (headers.get("X-Real-IP")) |ip| + std.net.Address.parseIp(ip, address.getPort()) catch { + return error.BadRequest; + } + else + address; + + return Request{ + .protocol = proto, + .source_address = eff_addr, + + .method = method, + .uri = uri, + .headers = headers, + .body = body, + }; } fn parseMethod(reader: anytype) !Method { @@ -68,7 +79,7 @@ fn parseMethod(reader: anytype) !Method { return error.MethodNotImplemented; } -fn checkProto(reader: anytype) !void { +fn parseProto(reader: anytype) !Request.Protocol { var buf: [8]u8 = undefined; const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { error.StreamTooLong => return error.UnknownProtocol, @@ -84,14 +95,24 @@ fn checkProto(reader: anytype) !void { return error.BadRequest; } - if (buf[0] != '1' or buf[2] != '1') { - return error.HttpVersionNotSupported; - } + if (buf[0] != '1') return error.HttpVersionNotSupported; + return switch (buf[2]) { + '0' => .http_1_0, + '1' => .http_1_1, + else => error.HttpVersionNotSupported, + }; } fn parseHeaders(allocator: std.mem.Allocator, reader: anytype) !Headers { var map = Headers.init(allocator); errdefer map.deinit(); + errdefer { + var iter = map.iterator(); + while (iter.next()) |it| { + allocator.free(it.key_ptr.*); + allocator.free(it.value_ptr.*); + } + } // todo: //errdefer { //var iter = map.iterator(); @@ -167,6 +188,21 @@ fn parseEncoding(encoding: ?[]const u8) !Encoding { return error.UnsupportedMediaType; } +pub fn parseFree(allocator: std.mem.Allocator, request: *Request) void { + allocator.free(request.uri); + freeHeaders(allocator, &request.headers); + if (request.body) |body| allocator.free(body); +} + +fn freeHeaders(allocator: std.mem.Allocator, headers: *http.Headers) void { + var iter = headers.iterator(); + while (iter.next()) |it| { + allocator.free(it.key_ptr.*); + allocator.free(it.value_ptr.*); + } + headers.deinit(); +} + const _test = struct { const expectEqual = std.testing.expectEqual; const expectEqualStrings = std.testing.expectEqualStrings; diff --git a/src/http/server.zig b/src/http/server.zig index c9095ef..e9bf81e 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -2,67 +2,69 @@ const std = @import("std"); const util = @import("util"); const http = @import("./lib.zig"); -const connection = @import("./server/connection.zig"); const response = @import("./server/response.zig"); -pub const Connection = connection.Connection; -pub const Response = response.ResponseStream(Connection.Writer); +pub const Response = struct { + alloc: std.mem.Allocator, + stream: std.net.Stream, + should_close: bool = false, + pub const Stream = response.ResponseStream(std.net.Stream.Writer); + pub fn open(self: *Response, status: http.Status, headers: *const http.Headers) !Stream { + if (headers.get("Connection")) |hdr| { + if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true; + } + + return response.open(self.alloc, self.stream.writer(), headers, status); + } +}; -const ConnectionServer = connection.Server; const Request = http.Request; const request_buf_size = 1 << 16; -pub const Context = struct { - alloc: std.mem.Allocator, - request: Request, - connection: Connection, +pub fn Handler(comptime Ctx: type) type { + return fn (Ctx, Request, *Response) void; +} - pub fn openResponse(self: *Context, headers: *const http.Headers, status: http.Status) !Response { - return try response.open(self.alloc, self.connection.stream.writer(), headers, status); - } +pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: anytype, alloc: std.mem.Allocator) !void { + // TODO: Timeouts + while (true) { + std.log.debug("waiting for request", .{}); + var arena = std.heap.ArenaAllocator.init(alloc); + defer arena.deinit(); - pub fn close(self: *Context) void { - // todo: deallocate request - self.connection.close(); - } -}; - -pub const Server = struct { - conn_server: ConnectionServer, - pub fn listen(addr: std.net.Address) !Server { - return Server{ - .conn_server = try ConnectionServer.listen(addr), + const req = Request.parse(arena.allocator(), conn.stream.reader(), conn.address) catch |err| { + return handleError(conn.stream.writer(), err) catch {}; }; + std.log.debug("done parsing", .{}); + + var res = Response{ + .alloc = arena.allocator(), + .stream = conn.stream, + }; + + handler(ctx, req, &res); + std.log.debug("done handling", .{}); + + if (req.headers.get("Connection")) |hdr| { + if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| return; + } else if (req.headers.get("Keep-Alive")) |hdr| { + std.log.debug("keep-alive: {s}", .{hdr}); + } else if (req.protocol == .http_1_0) return; + + if (res.should_close) return; } +} - pub fn accept(self: *Server, alloc: std.mem.Allocator) !Context { - while (true) { - const conn = try self.conn_server.accept(); - errdefer conn.close(); - - const req = http.Request.parse(alloc, conn.stream.reader()) catch |err| { - handleError(conn.stream.writer(), err) catch unreachable; - continue; - }; - - return Context{ .connection = conn, .request = req, .alloc = alloc }; - } - } - - pub fn shutdown(self: *Server) void { - self.conn_server.shutdown(); - } -}; - -// TODO: We should get more specific about what type of errors can happen +/// Writes an error response message and requests closure of the connection fn handleError(writer: anytype, err: anyerror) !void { const status: http.Status = switch (err) { + error.EndOfStream => return, // Do nothing, the client closed the connection error.BadRequest => .bad_request, error.UnsupportedMediaType => .unsupported_media_type, error.HttpVersionNotSupported => .http_version_not_supported, - else => return err, + else => .internal_server_error, }; - try writer.print("HTTP/1.1 {} {?s}\r\n\r\n", .{ @enumToInt(status), status.phrase() }); + try writer.print("HTTP/1.1 {} {?s}\r\nConnection: close\r\n\r\n", .{ @enumToInt(status), status.phrase() }); } diff --git a/src/http/server/connection.zig b/src/http/server/connection.zig deleted file mode 100644 index 1ae6d69..0000000 --- a/src/http/server/connection.zig +++ /dev/null @@ -1,51 +0,0 @@ -const std = @import("std"); - -pub const Connection = struct { - pub const Id = u64; - pub const Writer = std.net.Stream.Writer; - pub const Reader = std.net.Stream.Reader; - - id: Id, - address: std.net.Address, - stream: std.net.Stream, - - fn new(id: Id, std_conn: std.net.StreamServer.Connection) Connection { - std.log.debug("new connection conn_id={}", .{id}); - return .{ - .id = id, - .address = std_conn.address, - .stream = std_conn.stream, - }; - } - - pub fn close(self: Connection) void { - std.log.debug("terminating connection conn_id={}", .{self.id}); - self.stream.close(); - } -}; - -pub const Server = struct { - next_conn_id: std.atomic.Atomic(Connection.Id) = std.atomic.Atomic(Connection.Id).init(1), - stream_server: std.net.StreamServer, - - pub fn listen(addr: std.net.Address) !Server { - var self = Server{ - .stream_server = std.net.StreamServer.init(.{ .reuse_address = true }), - }; - errdefer self.stream_server.deinit(); - - try self.stream_server.listen(addr); - return self; - } - - pub fn accept(self: *Server) !Connection { - const conn = try self.stream_server.accept(); - const id = self.next_conn_id.fetchAdd(1, .SeqCst); - - return Connection.new(id, conn); - } - - pub fn shutdown(self: *Server) void { - self.stream_server.deinit(); - } -}; diff --git a/src/http/server/response.zig b/src/http/server/response.zig index aa549f8..a99aa8e 100644 --- a/src/http/server/response.zig +++ b/src/http/server/response.zig @@ -111,9 +111,7 @@ pub fn ResponseStream(comptime BaseWriter: type) type { } fn flushBodyUnchunked(self: *Self) Error!void { - if (self.buffer_pos != 0) { - try self.base_writer.print("Content-Length: {}\r\n", .{self.buffer_pos}); - } + try self.base_writer.print("Content-Length: {}\r\n", .{self.buffer_pos}); try self.base_writer.writeAll("\r\n"); @@ -128,6 +126,7 @@ pub fn ResponseStream(comptime BaseWriter: type) type { } pub fn finish(self: *Self) Error!void { + std.log.debug("finishing", .{}); if (!self.chunked) { try self.flushBodyUnchunked(); } else { diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 87fb11d..b0f7800 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -13,14 +13,13 @@ pub const invites = @import("./controllers/invites.zig"); pub const users = @import("./controllers/users.zig"); pub const notes = @import("./controllers/notes.zig"); -pub fn routeRequest(api_source: anytype, ctx: http.server.Context, alloc: std.mem.Allocator) void { +pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void { // TODO: hashmaps? - inline for (routes) |route| { - if (Context(route).matchAndHandle(api_source, ctx, alloc)) return; - } - - var response = Response{ .headers = http.Headers.init(alloc), .ctx = ctx }; + var response = Response{ .headers = http.Headers.init(alloc), .res = res }; defer response.headers.deinit(); + inline for (routes) |route| { + if (Context(route).matchAndHandle(api_source, req, &response, alloc)) return; + } response.status(.not_found) catch {}; } @@ -53,7 +52,7 @@ pub fn Context(comptime Route: type) type { allocator: std.mem.Allocator, method: http.Method, - request_line: []const u8, + uri: []const u8, headers: http.Headers, args: Args, @@ -84,20 +83,16 @@ pub fn Context(comptime Route: type) type { @compileError("Unsupported Type " ++ @typeName(T)); } - pub fn matchAndHandle(api_source: *api.ApiSource, ctx: http.server.Context, alloc: std.mem.Allocator) bool { - const req = ctx.request; + pub fn matchAndHandle(api_source: *api.ApiSource, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool { if (req.method != Route.method) return false; - var path = std.mem.sliceTo(std.mem.sliceTo(req.path, '#'), '?'); + var path = std.mem.sliceTo(std.mem.sliceTo(req.uri, '#'), '?'); var args: Args = parseArgs(path) orelse return false; - var response = Response{ .headers = http.Headers.init(alloc), .ctx = ctx }; - defer response.headers.deinit(); - var self = Self{ .allocator = alloc, .method = req.method, - .request_line = req.path, + .uri = req.uri, .headers = req.headers, .args = args, @@ -105,7 +100,7 @@ pub fn Context(comptime Route: type) type { .query = undefined, }; - self.prepareAndHandle(api_source, req, &response); + self.prepareAndHandle(api_source, req, res); return true; } @@ -149,15 +144,20 @@ pub fn Context(comptime Route: type) type { fn parseQuery(self: *Self) !void { if (Query != void) { - const path = std.mem.sliceTo(self.request_line, '?'); - const q = std.mem.sliceTo(self.request_line[path.len..], '#'); + const path = std.mem.sliceTo(self.uri, '?'); + const q = std.mem.sliceTo(self.uri[path.len..], '#'); self.query = try query_utils.parseQuery(Query, q); } } fn handle(self: Self, response: *Response, api_conn: anytype) void { - Route.handler(self, response, api_conn) catch |err| std.log.err("{}", .{err}); + Route.handler(self, response, api_conn) catch |err| switch (err) { + else => { + std.log.err("{}", .{err}); + response.err(.internal_server_error, "", {}) catch {}; + }, + }; } fn getApiConn(self: *Self, api_source: anytype) !api.ApiSource.Conn { @@ -180,18 +180,25 @@ pub fn Context(comptime Route: type) type { pub const Response = struct { const Self = @This(); headers: http.Headers, - ctx: http.server.Context, + res: *http.Response, + opened: bool = false, pub fn status(self: *Self, status_code: http.Status) !void { - var stream = try self.ctx.openResponse(&self.headers, status_code); + std.debug.assert(!self.opened); + self.opened = true; + + var stream = try self.res.open(status_code, &self.headers); defer stream.close(); try stream.finish(); } pub fn json(self: *Self, status_code: http.Status, response_body: anytype) !void { + std.debug.assert(!self.opened); + self.opened = true; + try self.headers.put("Content-Type", "application/json"); - var stream = try self.ctx.openResponse(&self.headers, status_code); + var stream = try self.res.open(status_code, &self.headers); defer stream.close(); const writer = stream.writer(); diff --git a/src/main/main.zig b/src/main/main.zig index 1d427f0..a82713b 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -8,36 +8,8 @@ const api = @import("api"); pub const migrations = @import("./migrations.zig"); // TODO const c = @import("./controllers.zig"); -pub const RequestServer = struct { - alloc: std.mem.Allocator, - api: *api.ApiSource, - config: Config, - - fn init(alloc: std.mem.Allocator, src: *api.ApiSource, config: Config) !RequestServer { - return RequestServer{ - .alloc = alloc, - .api = src, - .config = config, - }; - } - - fn listenAndRun(self: *RequestServer, addr: std.net.Address) !void { - var srv = http.Server.listen(addr) catch unreachable; - defer srv.shutdown(); - - while (true) { - var arena = std.heap.ArenaAllocator.init(self.alloc); - defer arena.deinit(); - - var ctx = try srv.accept(arena.allocator()); - defer ctx.close(); - - c.routeRequest(self.api, ctx, arena.allocator()); - } - } -}; - pub const Config = struct { + worker_threads: usize = 10, db: sql.Config, }; @@ -91,6 +63,34 @@ fn prepareDb(pool: *sql.ConnPool, alloc: std.mem.Allocator) !void { } } +const ConnectionId = u64; +var next_conn_id = std.atomic.Atomic(ConnectionId).init(0); + +fn thread_main(src: *api.ApiSource, srv: *std.net.StreamServer) void { + util.seedThreadPrng() catch unreachable; + const thread_id = std.Thread.getCurrentId(); + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + + while (true) { + var conn = srv.accept() catch |err| { + std.log.err("Error accepting connection: {}", .{err}); + continue; + }; + defer conn.stream.close(); + const conn_id = next_conn_id.fetchAdd(1, .SeqCst); + std.log.debug("Accepting TCP connection id {} on thread {}", .{ conn_id, thread_id }); + defer std.log.debug("Closing TCP connection id {}", .{conn_id}); + + http.serveConn(conn, .{ .src = src, .conn_id = conn_id, .allocator = gpa.allocator() }, handle, gpa.allocator()) catch |err| { + std.log.err("Error occured on connection {}: {}", .{ conn_id, err }); + }; + } +} + +fn handle(ctx: anytype, req: http.Request, res: *http.Response) void { + c.routeRequest(ctx.src, req, res, ctx.allocator); +} + pub fn main() !void { try util.seedThreadPrng(); @@ -100,6 +100,9 @@ pub fn main() !void { try prepareDb(&pool, gpa.allocator()); var api_src = try api.ApiSource.init(&pool); - var srv = try RequestServer.init(gpa.allocator(), &api_src, cfg); - return srv.listenAndRun(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); + var srv = std.net.StreamServer.init(.{ .reuse_address = true }); + defer srv.deinit(); + try srv.listen(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); + + thread_main(&api_src, &srv); } diff --git a/src/sql/lib.zig b/src/sql/lib.zig index 1208c9c..3124863 100644 --- a/src/sql/lib.zig +++ b/src/sql/lib.zig @@ -453,7 +453,7 @@ fn Tx(comptime tx_level: u8) type { comptime var table_spec: []const u8 = table ++ "("; comptime var value_spec: []const u8 = "("; inline for (fields) |field, i| { - // This causes a compile error. Why? + // This causes a compiler crash. Why? //const F = field.field_type; const F = @TypeOf(@field(std.mem.zeroes(ValueType), field.name)); // causes issues if F is @TypeOf(null), use dummy type