diff --git a/src/http/lib.zig b/src/http/lib.zig index ce97ad1..26f7756 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -10,10 +10,10 @@ pub const socket = @import("./socket.zig"); pub const Method = std.http.Method; pub const Status = std.http.Status; -pub const Request = request.Request(std.net.Stream.Reader); -pub const serveConn = server.serveConn; +pub const Request = request.Request(server.Stream.Reader); pub const Response = server.Response; pub const Handler = server.Handler; +pub const Server = server.Server; pub const Fields = @import("./headers.zig").Fields; diff --git a/src/http/server.zig b/src/http/server.zig index 693c937..21a83be 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -4,80 +4,145 @@ const http = @import("./lib.zig"); const response = @import("./server/response.zig"); const request = @import("./request.zig"); +const os = std.os; pub const Response = struct { alloc: std.mem.Allocator, - stream: std.net.Stream, + stream: Stream, should_close: bool = false, - pub const Stream = response.ResponseStream(SendStream.Writer); - pub fn open(self: *Response, status: http.Status, headers: *const http.Fields) !Stream { + + pub const ResponseStream = response.ResponseStream(Stream.Writer); + pub fn open(self: *Response, status: http.Status, headers: *const http.Fields) !ResponseStream { if (headers.get("Connection")) |hdr| { if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| self.should_close = true; } - const stream = SendStream{ .sockfd = self.stream.handle }; - - return response.open(self.alloc, stream.writer(), headers, status); + return response.open(self.alloc, self.stream.writer(), headers, status); } - pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Fields) !std.net.Stream { + pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Fields) !Stream { try response.writeRequestHeader(self.stream.writer(), headers, status); return self.stream; } }; +pub const StreamKind = enum { + tcp, +}; -// Replacement for std.net.Stream that uses os.send instead of os.write -// see https://github.com/ziglang/zig/issues/5614 -const SendStream = struct { - sockfd: std.os.socket_t, +pub const Stream = struct { + kind: StreamKind, - const WriteError = std.os.SendError; - const Writer = std.io.Writer(SendStream, WriteError, write); + socket: os.socket_t, - fn write(self: SendStream, bytes: []const u8) WriteError!usize { - if (std.io.is_async) @compileError("TODO: Async not supported yet"); - return std.os.send(self.sockfd, bytes, std.os.MSG.NOSIGNAL); + pub fn close(self: Stream) void { + os.closeSocket(self.socket); } - fn writer(self: SendStream) Writer { + pub const ReadError = os.RecvFromError; + pub const WriteError = os.SendError; + + pub const Reader = std.io.Reader(Stream, ReadError, read); + pub const Writer = std.io.Writer(Stream, WriteError, write); + + pub fn read(self: Stream, buffer: []u8) ReadError!usize { + if (std.io.is_async) @compileError("TODO: async not supported"); + if (self.kind != .tcp) @panic("TODO"); + + return os.recv(self.socket, buffer, 0); + } + + pub fn write(self: Stream, buffer: []const u8) WriteError!usize { + if (std.io.is_async) @compileError("TODO: Async not supported yet"); + if (self.kind != .tcp) @panic("TODO"); + + return os.send(self.socket, buffer, os.MSG.NOSIGNAL); + } + + pub fn reader(self: Stream) Reader { + return .{ .context = self }; + } + + pub fn writer(self: Stream) Writer { return .{ .context = self }; } }; -const Request = http.Request; -const request_buf_size = 1 << 16; +pub const Server = struct { + tcp_server: std.net.StreamServer, -pub fn serveConn(conn: std.net.StreamServer.Connection, ctx: anytype, handler: anytype, alloc: std.mem.Allocator) !void { - // TODO: Timeouts - while (true) { - std.log.debug("waiting for request", .{}); - var arena = std.heap.ArenaAllocator.init(alloc); - defer arena.deinit(); - - var req = request.parse(arena.allocator(), conn.stream.reader()) catch |err| { - return handleError(conn.stream.writer(), err) catch {}; + pub fn init() Server { + return Server{ + .tcp_server = std.net.StreamServer.init(.{ .reuse_address = true }), }; - std.log.debug("done parsing", .{}); - - var res = Response{ - .alloc = arena.allocator(), - .stream = conn.stream, - }; - - std.log.debug("{any}", .{req}); - std.log.debug("Opening handler", .{}); - handler(ctx, &req, &res); - std.log.debug("done handling", .{}); - - if (req.headers.get("Connection")) |hdr| { - if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| return; - } else if (req.headers.get("Keep-Alive")) |hdr| { - std.log.debug("keep-alive: {s}", .{hdr}); - } else if (req.protocol == .http_1_0) return; - - if (res.should_close) return; } -} + + pub fn deinit(self: *Server) void { + self.tcp_server.deinit(); + } + + pub fn listen(self: *Server, address: std.net.Address) !void { + try self.tcp_server.listen(address); + } + + pub const Connection = struct { + stream: Stream, + address: std.net.Address, + }; + + pub fn handleLoop( + self: *Server, + allocator: std.mem.Allocator, + ctx: anytype, + handler: anytype, + ) void { + while (true) { + const conn = self.tcp_server.accept() catch |err| { + if (err == error.SocketNotListening) return; + + std.log.err("Error occurred accepting connection: {}", .{err}); + continue; + }; + + serveConn( + allocator, + Connection{ + .stream = Stream{ .kind = .tcp, .socket = conn.stream.handle }, + .address = conn.address, + }, + ctx, + handler, + ); + } + } + + fn serveConn( + allocator: std.mem.Allocator, + conn: Connection, + ctx: anytype, + handler: anytype, + ) void { + while (true) { + var req = request.parse(allocator, conn.stream.reader()) catch |err| { + return handleError(conn.stream.writer(), err) catch {}; + }; + + var res = Response{ + .alloc = allocator, + .stream = conn.stream, + }; + + handler(ctx, &req, &res); + + if (req.headers.get("Connection")) |hdr| { + if (std.ascii.indexOfIgnoreCase(hdr, "close")) |_| return; + } else if (req.headers.get("Keep-Alive")) |_| { + // TODO: Support this + return; + } else if (req.protocol == .http_1_0) return; + if (res.should_close) return; + } + } +}; /// Writes an error response message and requests closure of the connection fn handleError(writer: anytype, err: anyerror) !void { diff --git a/src/http/socket.zig b/src/http/socket.zig index eab1a1d..ab885bb 100644 --- a/src/http/socket.zig +++ b/src/http/socket.zig @@ -1,7 +1,7 @@ const std = @import("std"); const builtin = @import("builtin"); const http = @import("./lib.zig"); -const Stream = std.net.Stream; +const Stream = @import("./server.zig").Stream; const Opcode = enum(u4) { continuation = 0x0, diff --git a/src/main/controllers.zig b/src/main/controllers.zig index f162fc0..3fcfd6a 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -262,7 +262,7 @@ pub const Response = struct { try stream.finish(); } - pub fn open(self: *Self, status_code: http.Status) !http.Response.Stream { + pub fn open(self: *Self, status_code: http.Status) !http.Response.ResponseStream { std.debug.assert(!self.opened); self.opened = true; diff --git a/src/main/main.zig b/src/main/main.zig index 9ff01fe..c512d3c 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -66,25 +66,11 @@ fn prepareDb(pool: *sql.ConnPool, alloc: std.mem.Allocator) !void { const ConnectionId = u64; var next_conn_id = std.atomic.Atomic(ConnectionId).init(0); -fn thread_main(src: *api.ApiSource, srv: *std.net.StreamServer) void { +fn thread_main(src: *api.ApiSource, srv: *http.Server) void { util.seedThreadPrng() catch unreachable; - const thread_id = std.Thread.getCurrentId(); var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - - while (true) { - var conn = srv.accept() catch |err| { - std.log.err("Error accepting connection: {}", .{err}); - continue; - }; - defer conn.stream.close(); - const conn_id = next_conn_id.fetchAdd(1, .SeqCst); - std.log.debug("Accepting TCP connection id {} on thread {}", .{ conn_id, thread_id }); - defer std.log.debug("Closing TCP connection id {}", .{conn_id}); - - http.serveConn(conn, .{ .src = src, .conn_id = conn_id, .allocator = gpa.allocator() }, handle, gpa.allocator()) catch |err| { - std.log.err("Error occured on connection {}: {}", .{ conn_id, err }); - }; - } + defer _ = gpa.deinit(); + srv.handleLoop(gpa.allocator(), .{ .src = src, .allocator = gpa.allocator() }, handle); } fn handle(ctx: anytype, req: *http.Request, res: *http.Response) void { @@ -100,7 +86,7 @@ pub fn main() !void { try prepareDb(&pool, gpa.allocator()); var api_src = try api.ApiSource.init(&pool); - var srv = std.net.StreamServer.init(.{ .reuse_address = true }); + var srv = http.Server.init(); defer srv.deinit(); try srv.listen(std.net.Address.parseIp("::1", 8080) catch unreachable);