diff --git a/src/api/lib.zig b/src/api/lib.zig index a48170a..e6cf82a 100644 --- a/src/api/lib.zig +++ b/src/api/lib.zig @@ -111,7 +111,6 @@ pub const ApiSource = struct { errdefer arena.deinit(); const db = try self.db_conn_pool.acquire(); - errdefer db.releaseConnection(); const community = try services.communities.getByHost(db, host, arena.allocator()); return Conn{ @@ -127,7 +126,6 @@ pub const ApiSource = struct { errdefer arena.deinit(); const db = try self.db_conn_pool.acquire(); - errdefer db.releaseConnection(); const community = try services.communities.getByHost(db, host, arena.allocator()); const token_info = try services.auth.verifyToken( diff --git a/src/http/lib.zig b/src/http/lib.zig index eff1e45..3d12cdc 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -5,8 +5,6 @@ const request = @import("./request.zig"); const server = @import("./server.zig"); -pub const socket = @import("./socket.zig"); - pub const Method = std.http.Method; pub const Status = std.http.Status; diff --git a/src/http/server.zig b/src/http/server.zig index 269fc56..e9bf81e 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -16,11 +16,6 @@ pub const Response = struct { return response.open(self.alloc, self.stream.writer(), headers, status); } - - pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Headers) !std.net.Stream { - try response.writeRequestHeader(self.stream.writer(), headers, status); - return self.stream; - } }; const Request = http.Request; diff --git a/src/http/server/response.zig b/src/http/server/response.zig index 615f0c1..a99aa8e 100644 --- a/src/http/server/response.zig +++ b/src/http/server/response.zig @@ -25,12 +25,6 @@ pub fn open( }; } -pub fn writeRequestHeader(writer: anytype, headers: *const Headers, status: Status) !void { - try writeStatusLine(writer, status); - try writeHeaders(writer, headers); - try writer.writeAll("\r\n"); -} - fn writeStatusLine(writer: anytype, status: Status) !void { const status_text = status.phrase() orelse ""; try writer.print("HTTP/1.1 {} {s}\r\n", .{ @enumToInt(status), status_text }); diff --git a/src/http/socket.zig b/src/http/socket.zig deleted file mode 100644 index 24c10d4..0000000 --- a/src/http/socket.zig +++ /dev/null @@ -1,372 +0,0 @@ -const std = @import("std"); -const builtin = @import("builtin"); -const http = @import("./lib.zig"); -const Stream = std.net.Stream; - -const Opcode = enum(u4) { - continuation = 0x0, - text = 0x1, - binary = 0x2, - // 3-7 are non-control frames - close = 0x8, - ping = 0x9, - pong = 0xa, - // b-f are control frames - _, - - fn isData(self: Opcode) bool { - return !self.isControl(); - } - - fn isControl(self: Opcode) bool { - return @enumToInt(self) & 0x8 == 0x8; - } -}; - -pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Response) !Socket { - const upgrade = req.headers.get("Upgrade") orelse return error.BadHandshake; - const connection = req.headers.get("Connection") orelse return error.BadHandshake; - if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake; - if (std.ascii.indexOfIgnoreCase(connection, "Upgrade") == null) return error.BadHandshake; - - const key_hdr = req.headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake; - if (try std.base64.standard.Decoder.calcSizeForSlice(key_hdr) != 16) return error.BadHandshake; - var key: [16]u8 = undefined; - std.base64.standard.Decoder.decode(&key, key_hdr) catch return error.BadHandshake; - - const version = req.headers.get("Sec-WebSocket-version") orelse return error.BadHandshake; - if (!std.mem.eql(u8, "13", version)) return error.BadHandshake; - - var headers = http.Headers.init(alloc); - defer headers.deinit(); - - try headers.put("Upgrade", "websocket"); - try headers.put("Connection", "Upgrade"); - var response_key = std.mem.zeroes([60]u8); - std.mem.copy(u8, &response_key, key_hdr); - std.mem.copy(u8, response_key[key_hdr.len..], "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); - const Sha1 = std.crypto.hash.Sha1; - var hash: [Sha1.digest_length]u8 = undefined; - Sha1.hash(response_key[0..], &hash, .{}); - var hash_encoded: [std.base64.standard.Encoder.calcSize(Sha1.digest_length)]u8 = undefined; - _ = std.base64.standard.Encoder.encode(&hash_encoded, &hash); - try headers.put("Sec-WebSocket-Accept", &hash_encoded); - const stream = try res.upgrade(.switching_protcols, &headers); - - return Socket{ .stream = stream }; -} - -pub const CloseReason = enum(u16) { - ok = 1000, - going_away = 1001, - protocol_error = 1002, - unsupported_frame_type = 1003, - // reserved = 1004, - - // technically these codes can be explicitly provided but it's against spec - no_reason_provided = 1005, - abnormal_closure = 1006, - - invalid_frame_data = 1007, // e.g. non UTF-8 data in a text type frame - policy_violation = 1008, - message_too_large = 1009, - missing_extension = 1010, - internal_server_error = 1011, - - // unused on server side but included here for completeness - invalid_tls_handshake = 1015, - - _, -}; - -const FrameInfo = struct { - is_final: bool, - rsv: u3, - opcode: Opcode, - masking_key: ?[4]u8, - len: usize, -}; - -fn frameReader(reader: anytype) !FrameReader(@TypeOf(reader)) { - return FrameReader(@TypeOf(reader)).open(reader); -} - -fn FrameReader(comptime R: type) type { - return struct { - const Self = @This(); - info: FrameInfo, - - bytes_read: usize = 0, - mask_idx: u2 = 0, - underlying: R, - - const Reader = std.io.Reader(*Self, anyerror, read); - - fn open(r: R) !Self { - var hdr_buf: [2]u8 = undefined; - try r.readNoEof(&hdr_buf); - const is_final = hdr_buf[0] & 0b1000_0000 != 0; - const rsv = @intCast(u3, (hdr_buf[0] & 0b0111_0000) >> 4); - const opcode = @intToEnum(Opcode, (hdr_buf[0] & 0b1111)); - const is_masked = hdr_buf[1] & 0b1000_0000 != 0; - const initial_len = @intCast(u7, (hdr_buf[1] & 0b0111_1111)); - - const len: usize = if (initial_len < 126) - initial_len - else if (initial_len == 126) - try r.readIntBig(u16) - else - try r.readIntBig(usize); - - const masking_key = if (is_masked) - try r.readBytesNoEof(4) - else - null; - - return Self{ - .info = .{ - .is_final = is_final, - .rsv = rsv, - .opcode = opcode, - .masking_key = masking_key, - .len = len, - }, - .underlying = r, - }; - } - - fn read(self: *Self, buf: []u8) !usize { - var r = std.io.limitedReader(self.underlying, self.info.len - self.bytes_read); - const count = try r.read(buf); - if (self.info.masking_key) |mask| { - for (buf[0..count]) |*ch| { - ch.* ^= mask[self.mask_idx]; - self.mask_idx +%= 1; - } - } - self.bytes_read += count; - return count; - } - - fn reader(self: *Self) Reader { - return .{ .context = self }; - } - - fn close(self: *Self) void { - self.reader().skipBytes(@as(usize, 0) -% 1, .{}) catch {}; - } - }; -} - -fn writeFrame(writer: anytype, header: FrameInfo, buf: []const u8) !void { - std.debug.assert(header.len == buf.len); - - const initial_len: u7 = if (header.len < 126) - @intCast(u7, header.len) - else if (std.math.cast(u16, header.len)) |_| - 126 - else - 127; - - var hdr_buf = [2]u8{ 0, 0 }; - hdr_buf[0] |= if (header.is_final) 0b1000_0000 else 0; - hdr_buf[0] |= @as(u8, header.rsv) << 4; - hdr_buf[0] |= @enumToInt(header.opcode); - hdr_buf[1] |= if (header.masking_key) |_| 0b1000_0000 else 0; - hdr_buf[1] |= initial_len; - try writer.writeAll(&hdr_buf); - if (initial_len == 126) - try writer.writeIntBig(u16, std.math.cast(u16, buf.len).?) - else if (initial_len == 127) - try writer.writeIntBig(usize, buf.len); - - if (header.masking_key) |key| try writer.writeAll(&key); - - const mask = header.masking_key orelse std.mem.zeroes([4]u8); - var mask_idx: u2 = 0; - for (buf) |ch| { - try writer.writeByte(mask[mask_idx] ^ ch); - } -} - -pub const MessageReader = struct { - opcode: Opcode, - socket: *Socket, - frame: FrameReader(Socket.Reader), - - pub const Reader = std.io.Reader(*MessageReader, anyerror, read); - - pub fn read(self: *MessageReader, buf: []u8) !usize { - if (self.socket.closed) return error.SocketClosed; - - var count: usize = 0; - while (count < buf.len) { - const c = try self.frame.read(buf); - count += c; - - if (c == 0) { - self.frame.close(); - if (self.frame.info.is_final) break; - - self.frame = try self.socket.waitForDataFrame(); - if (self.frame.info.opcode != .continuation) { - //self.socket.close(); - return error.InvalidFrame; - } - } - } - - return count; - } - - pub fn reader(self: *MessageReader) Reader { - return .{ .context = self }; - } - - pub fn close(self: *MessageReader) void { - self.reader().skipBytes(@as(usize, 0) -% 1, .{}) catch {}; - } -}; - -pub const MessageWriter = struct { - allocator: std.mem.Allocator, - socket: *Socket, - opcode: Opcode, - buf: []u8, - - cursor: usize = 0, - is_first: bool = true, - - pub const WriteError = Socket.WriteError; - pub const Writer = std.io.Writer(*MessageWriter, WriteError, write); - - pub fn writer(self: *MessageWriter) Writer { - return .{ .context = self }; - } - - pub fn write(self: *MessageWriter, buf: []const u8) !usize { - if (self.socket.closed) return error.SocketClosed; - var count: usize = 0; - while (count < buf.len) { - if (self.cursor == self.buf.len) try self.flushFrame(false); - - const max_write = std.math.min(self.buf.len - self.cursor, buf.len - count); - std.mem.copy(u8, self.buf[self.cursor..], buf[count .. count + max_write]); - - self.cursor += max_write; - count += max_write; - } - - return count; - } - - pub fn close(self: *MessageWriter) void { - self.allocator.free(self.buf); - } - - pub fn finish(self: *MessageWriter) !void { - if (self.socket.closed) return error.SocketClosed; - try self.flushFrame(true); - } - - fn flushFrame(self: *MessageWriter, is_final: bool) !void { - if (self.socket.closed) return error.SocketClosed; - const opcode = if (self.is_first) self.opcode else .continuation; - self.is_first = false; - const header = FrameInfo{ - .is_final = is_final, - .rsv = 0, - .opcode = opcode, - .masking_key = null, // TODO: use a key on client-side - .len = self.cursor, - }; - std.log.debug("sending frame with size {}", .{header.len}); - try writeFrame(self.socket.writer(), header, self.buf[0..self.cursor]); - self.cursor = 0; - } -}; - -pub const Socket = struct { - stream: Stream, - closed: bool = false, - - const ReadError = Stream.ReadError || error{SocketClosed}; - const WriteError = Stream.WriteError || error{SocketClosed}; - const Reader = std.io.Reader(*Socket, ReadError, read); - const Writer = std.io.Writer(*Socket, WriteError, write); - - fn writer(self: *Socket) Writer { - return .{ .context = self }; - } - - fn reader(self: *Socket) Reader { - return .{ .context = self }; - } - - fn read(self: *Socket, buf: []u8) !usize { - if (self.closed) return error.SocketClosed; - - return self.stream.read(buf) catch |err| { - self.closed = true; - return err; - }; - } - - fn write(self: *Socket, buf: []const u8) !usize { - if (self.closed) return error.SocketClosed; - - return self.stream.write(buf) catch |err| { - self.closed = true; - return err; - }; - } - - pub fn accept(self: *Socket) !MessageReader { - if (self.closed) return error.SocketClosed; - const frame = try self.waitForDataFrame(); - return MessageReader{ .opcode = frame.info.opcode, .socket = self, .frame = frame }; - } - - pub const MessageOptions = struct { - buf_size: usize = 512, - }; - pub fn openMessage(self: *Socket, opcode: Opcode, alloc: std.mem.Allocator, options: MessageOptions) !MessageWriter { - if (self.closed) return error.SocketClosed; - var buf = try alloc.alloc(u8, options.buf_size); - return MessageWriter{ .socket = self, .allocator = alloc, .buf = buf, .opcode = opcode }; - } - - fn openFrame(self: *Socket) !FrameReader(Reader) { - std.log.debug("waiting for frame header", .{}); - - const frame = try frameReader(self.reader()); - - std.log.debug("Received frame of type {any}, size {}", .{ frame.info.opcode, frame.info.len }); - - return frame; - } - - fn waitForDataFrame(self: *Socket) !FrameReader(Reader) { - while (true) { - var frame = try frameReader(self.reader()); - if (frame.info.opcode.isData()) return frame; - - // TODO: handle control frames - frame.close(); - } - } -}; - -fn pipe(writer: anytype, reader: anytype) !usize { - var count: usize = 0; - - var buf: [64]u8 = undefined; - while (true) { - const c = try reader.read(&buf); - if (c == 0) break; - - if (try writer.write(buf[0..c]) != c) return error.EndOfStream; - count += c; - } - return count; -} diff --git a/src/main/controllers.zig b/src/main/controllers.zig index 31fbe90..b0f7800 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -12,24 +12,16 @@ pub const communities = @import("./controllers/communities.zig"); pub const invites = @import("./controllers/invites.zig"); pub const users = @import("./controllers/users.zig"); pub const notes = @import("./controllers/notes.zig"); -pub const streaming = @import("./controllers/streaming.zig"); pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void { // TODO: hashmaps? var response = Response{ .headers = http.Headers.init(alloc), .res = res }; defer response.headers.deinit(); - - const found = routeRequestInternal(api_source, req, &response, alloc); - - if (!found) response.status(.not_found) catch {}; -} - -fn routeRequestInternal(api_source: anytype, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool { inline for (routes) |route| { - if (Context(route).matchAndHandle(api_source, req, res, alloc)) return true; + if (Context(route).matchAndHandle(api_source, req, &response, alloc)) return; } - return false; + response.status(.not_found) catch {}; } const routes = .{ @@ -41,7 +33,6 @@ const routes = .{ users.create, notes.create, notes.get, - streaming.streaming, }; pub fn Context(comptime Route: type) type { @@ -58,8 +49,6 @@ pub fn Context(comptime Route: type) type { // leave it as a simple string instead of void pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void; - base_request: http.Request, - allocator: std.mem.Allocator, method: http.Method, @@ -101,7 +90,6 @@ pub fn Context(comptime Route: type) type { var self = Self{ .allocator = alloc, - .base_request = req, .method = req.method, .uri = req.uri, @@ -167,7 +155,7 @@ pub fn Context(comptime Route: type) type { Route.handler(self, response, api_conn) catch |err| switch (err) { else => { std.log.err("{}", .{err}); - if (!response.opened) response.err(.internal_server_error, "", {}) catch {}; + response.err(.internal_server_error, "", {}) catch {}; }, }; } @@ -195,7 +183,6 @@ pub const Response = struct { res: *http.Response, opened: bool = false, - /// Write a response with no body, only a given status pub fn status(self: *Self, status_code: http.Status) !void { std.debug.assert(!self.opened); self.opened = true; @@ -205,7 +192,6 @@ pub const Response = struct { try stream.finish(); } - /// Write a request body as json pub fn json(self: *Self, status_code: http.Status, response_body: anytype) !void { std.debug.assert(!self.opened); self.opened = true; @@ -221,20 +207,12 @@ pub const Response = struct { try stream.finish(); } - /// Prints the given error as json pub fn err(self: *Self, status_code: http.Status, message: []const u8, details: anytype) !void { return self.json(status_code, .{ .message = message, .details = details, }); } - - /// Signals that the HTTP connection should be hijacked without writing a - /// response beforehand. - pub fn hijack(self: *Self) *http.Response { - self.opened = true; - return self.res; - } }; const json_options = if (builtin.mode == .Debug) diff --git a/src/main/controllers/streaming.zig b/src/main/controllers/streaming.zig deleted file mode 100644 index 4b8745b..0000000 --- a/src/main/controllers/streaming.zig +++ /dev/null @@ -1,36 +0,0 @@ -const http = @import("http"); -const std = @import("std"); - -pub const streaming = struct { - pub const method = .GET; - pub const path = "/streaming"; - - pub fn handler(req: anytype, response: anytype, _: anytype) !void { - var iter = req.headers.iterator(); - std.log.debug("--Headers--", .{}); - while (iter.next()) |pair| { - std.log.debug("{s}: {s}", .{ pair.key_ptr.*, pair.value_ptr.* }); - } - - const res = response.hijack(); - - var socket = try http.socket.handshake(req.allocator, req.base_request, res); - - while (true) { - var message = try socket.accept(); - defer message.close(); - std.log.debug("Message received", .{}); - - const reader = message.reader(); - const msg = try reader.readAllAlloc(req.allocator, 1 << 63); - defer req.allocator.free(msg); - - var response_msg = try socket.openMessage(.text, req.allocator, .{}); - defer response_msg.close(); - - try response_msg.writer().writeAll(msg); - try response_msg.finish(); - std.log.debug("{s}", .{msg}); - } - } -}; diff --git a/src/main/main.zig b/src/main/main.zig index 5870bc2..f3c9728 100644 --- a/src/main/main.zig +++ b/src/main/main.zig @@ -102,7 +102,7 @@ pub fn main() !void { var api_src = try api.ApiSource.init(&pool); var srv = std.net.StreamServer.init(.{ .reuse_address = true }); defer srv.deinit(); - try srv.listen(std.net.Address.parseIp("::1", 8080) catch unreachable); + try srv.listen(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); var i: usize = 0; while (i < cfg.worker_threads - 1) : (i += 1) {