From e4cbb0acc3486e1c7721d81a83275ca2cacbd02c Mon Sep 17 00:00:00 2001 From: jaina heartles Date: Sun, 16 Oct 2022 05:48:12 -0700 Subject: [PATCH] Add basic websocket support --- src/api/lib.zig | 2 + src/http/lib.zig | 2 + src/http/server.zig | 5 + src/http/server/response.zig | 6 + src/http/socket.zig | 372 +++++++++++++++++++++++++++++ src/main/controllers.zig | 28 ++- src/main/controllers/streaming.zig | 36 +++ 7 files changed, 448 insertions(+), 3 deletions(-) create mode 100644 src/http/socket.zig create mode 100644 src/main/controllers/streaming.zig diff --git a/src/api/lib.zig b/src/api/lib.zig index e6cf82a..a48170a 100644 --- a/src/api/lib.zig +++ b/src/api/lib.zig @@ -111,6 +111,7 @@ pub const ApiSource = struct { errdefer arena.deinit(); const db = try self.db_conn_pool.acquire(); + errdefer db.releaseConnection(); const community = try services.communities.getByHost(db, host, arena.allocator()); return Conn{ @@ -126,6 +127,7 @@ pub const ApiSource = struct { errdefer arena.deinit(); const db = try self.db_conn_pool.acquire(); + errdefer db.releaseConnection(); const community = try services.communities.getByHost(db, host, arena.allocator()); const token_info = try services.auth.verifyToken( diff --git a/src/http/lib.zig b/src/http/lib.zig index 3d12cdc..eff1e45 100644 --- a/src/http/lib.zig +++ b/src/http/lib.zig @@ -5,6 +5,8 @@ const request = @import("./request.zig"); const server = @import("./server.zig"); +pub const socket = @import("./socket.zig"); + pub const Method = std.http.Method; pub const Status = std.http.Status; diff --git a/src/http/server.zig b/src/http/server.zig index e9bf81e..269fc56 100644 --- a/src/http/server.zig +++ b/src/http/server.zig @@ -16,6 +16,11 @@ pub const Response = struct { return response.open(self.alloc, self.stream.writer(), headers, status); } + + pub fn upgrade(self: *Response, status: http.Status, headers: *const http.Headers) !std.net.Stream { + try response.writeRequestHeader(self.stream.writer(), headers, status); + return self.stream; + } }; const Request = http.Request; diff --git a/src/http/server/response.zig b/src/http/server/response.zig index a99aa8e..615f0c1 100644 --- a/src/http/server/response.zig +++ b/src/http/server/response.zig @@ -25,6 +25,12 @@ pub fn open( }; } +pub fn writeRequestHeader(writer: anytype, headers: *const Headers, status: Status) !void { + try writeStatusLine(writer, status); + try writeHeaders(writer, headers); + try writer.writeAll("\r\n"); +} + fn writeStatusLine(writer: anytype, status: Status) !void { const status_text = status.phrase() orelse ""; try writer.print("HTTP/1.1 {} {s}\r\n", .{ @enumToInt(status), status_text }); diff --git a/src/http/socket.zig b/src/http/socket.zig new file mode 100644 index 0000000..24c10d4 --- /dev/null +++ b/src/http/socket.zig @@ -0,0 +1,372 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const http = @import("./lib.zig"); +const Stream = std.net.Stream; + +const Opcode = enum(u4) { + continuation = 0x0, + text = 0x1, + binary = 0x2, + // 3-7 are non-control frames + close = 0x8, + ping = 0x9, + pong = 0xa, + // b-f are control frames + _, + + fn isData(self: Opcode) bool { + return !self.isControl(); + } + + fn isControl(self: Opcode) bool { + return @enumToInt(self) & 0x8 == 0x8; + } +}; + +pub fn handshake(alloc: std.mem.Allocator, req: http.Request, res: *http.Response) !Socket { + const upgrade = req.headers.get("Upgrade") orelse return error.BadHandshake; + const connection = req.headers.get("Connection") orelse return error.BadHandshake; + if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake; + if (std.ascii.indexOfIgnoreCase(connection, "Upgrade") == null) return error.BadHandshake; + + const key_hdr = req.headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake; + if (try std.base64.standard.Decoder.calcSizeForSlice(key_hdr) != 16) return error.BadHandshake; + var key: [16]u8 = undefined; + std.base64.standard.Decoder.decode(&key, key_hdr) catch return error.BadHandshake; + + const version = req.headers.get("Sec-WebSocket-version") orelse return error.BadHandshake; + if (!std.mem.eql(u8, "13", version)) return error.BadHandshake; + + var headers = http.Headers.init(alloc); + defer headers.deinit(); + + try headers.put("Upgrade", "websocket"); + try headers.put("Connection", "Upgrade"); + var response_key = std.mem.zeroes([60]u8); + std.mem.copy(u8, &response_key, key_hdr); + std.mem.copy(u8, response_key[key_hdr.len..], "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + const Sha1 = std.crypto.hash.Sha1; + var hash: [Sha1.digest_length]u8 = undefined; + Sha1.hash(response_key[0..], &hash, .{}); + var hash_encoded: [std.base64.standard.Encoder.calcSize(Sha1.digest_length)]u8 = undefined; + _ = std.base64.standard.Encoder.encode(&hash_encoded, &hash); + try headers.put("Sec-WebSocket-Accept", &hash_encoded); + const stream = try res.upgrade(.switching_protcols, &headers); + + return Socket{ .stream = stream }; +} + +pub const CloseReason = enum(u16) { + ok = 1000, + going_away = 1001, + protocol_error = 1002, + unsupported_frame_type = 1003, + // reserved = 1004, + + // technically these codes can be explicitly provided but it's against spec + no_reason_provided = 1005, + abnormal_closure = 1006, + + invalid_frame_data = 1007, // e.g. non UTF-8 data in a text type frame + policy_violation = 1008, + message_too_large = 1009, + missing_extension = 1010, + internal_server_error = 1011, + + // unused on server side but included here for completeness + invalid_tls_handshake = 1015, + + _, +}; + +const FrameInfo = struct { + is_final: bool, + rsv: u3, + opcode: Opcode, + masking_key: ?[4]u8, + len: usize, +}; + +fn frameReader(reader: anytype) !FrameReader(@TypeOf(reader)) { + return FrameReader(@TypeOf(reader)).open(reader); +} + +fn FrameReader(comptime R: type) type { + return struct { + const Self = @This(); + info: FrameInfo, + + bytes_read: usize = 0, + mask_idx: u2 = 0, + underlying: R, + + const Reader = std.io.Reader(*Self, anyerror, read); + + fn open(r: R) !Self { + var hdr_buf: [2]u8 = undefined; + try r.readNoEof(&hdr_buf); + const is_final = hdr_buf[0] & 0b1000_0000 != 0; + const rsv = @intCast(u3, (hdr_buf[0] & 0b0111_0000) >> 4); + const opcode = @intToEnum(Opcode, (hdr_buf[0] & 0b1111)); + const is_masked = hdr_buf[1] & 0b1000_0000 != 0; + const initial_len = @intCast(u7, (hdr_buf[1] & 0b0111_1111)); + + const len: usize = if (initial_len < 126) + initial_len + else if (initial_len == 126) + try r.readIntBig(u16) + else + try r.readIntBig(usize); + + const masking_key = if (is_masked) + try r.readBytesNoEof(4) + else + null; + + return Self{ + .info = .{ + .is_final = is_final, + .rsv = rsv, + .opcode = opcode, + .masking_key = masking_key, + .len = len, + }, + .underlying = r, + }; + } + + fn read(self: *Self, buf: []u8) !usize { + var r = std.io.limitedReader(self.underlying, self.info.len - self.bytes_read); + const count = try r.read(buf); + if (self.info.masking_key) |mask| { + for (buf[0..count]) |*ch| { + ch.* ^= mask[self.mask_idx]; + self.mask_idx +%= 1; + } + } + self.bytes_read += count; + return count; + } + + fn reader(self: *Self) Reader { + return .{ .context = self }; + } + + fn close(self: *Self) void { + self.reader().skipBytes(@as(usize, 0) -% 1, .{}) catch {}; + } + }; +} + +fn writeFrame(writer: anytype, header: FrameInfo, buf: []const u8) !void { + std.debug.assert(header.len == buf.len); + + const initial_len: u7 = if (header.len < 126) + @intCast(u7, header.len) + else if (std.math.cast(u16, header.len)) |_| + 126 + else + 127; + + var hdr_buf = [2]u8{ 0, 0 }; + hdr_buf[0] |= if (header.is_final) 0b1000_0000 else 0; + hdr_buf[0] |= @as(u8, header.rsv) << 4; + hdr_buf[0] |= @enumToInt(header.opcode); + hdr_buf[1] |= if (header.masking_key) |_| 0b1000_0000 else 0; + hdr_buf[1] |= initial_len; + try writer.writeAll(&hdr_buf); + if (initial_len == 126) + try writer.writeIntBig(u16, std.math.cast(u16, buf.len).?) + else if (initial_len == 127) + try writer.writeIntBig(usize, buf.len); + + if (header.masking_key) |key| try writer.writeAll(&key); + + const mask = header.masking_key orelse std.mem.zeroes([4]u8); + var mask_idx: u2 = 0; + for (buf) |ch| { + try writer.writeByte(mask[mask_idx] ^ ch); + } +} + +pub const MessageReader = struct { + opcode: Opcode, + socket: *Socket, + frame: FrameReader(Socket.Reader), + + pub const Reader = std.io.Reader(*MessageReader, anyerror, read); + + pub fn read(self: *MessageReader, buf: []u8) !usize { + if (self.socket.closed) return error.SocketClosed; + + var count: usize = 0; + while (count < buf.len) { + const c = try self.frame.read(buf); + count += c; + + if (c == 0) { + self.frame.close(); + if (self.frame.info.is_final) break; + + self.frame = try self.socket.waitForDataFrame(); + if (self.frame.info.opcode != .continuation) { + //self.socket.close(); + return error.InvalidFrame; + } + } + } + + return count; + } + + pub fn reader(self: *MessageReader) Reader { + return .{ .context = self }; + } + + pub fn close(self: *MessageReader) void { + self.reader().skipBytes(@as(usize, 0) -% 1, .{}) catch {}; + } +}; + +pub const MessageWriter = struct { + allocator: std.mem.Allocator, + socket: *Socket, + opcode: Opcode, + buf: []u8, + + cursor: usize = 0, + is_first: bool = true, + + pub const WriteError = Socket.WriteError; + pub const Writer = std.io.Writer(*MessageWriter, WriteError, write); + + pub fn writer(self: *MessageWriter) Writer { + return .{ .context = self }; + } + + pub fn write(self: *MessageWriter, buf: []const u8) !usize { + if (self.socket.closed) return error.SocketClosed; + var count: usize = 0; + while (count < buf.len) { + if (self.cursor == self.buf.len) try self.flushFrame(false); + + const max_write = std.math.min(self.buf.len - self.cursor, buf.len - count); + std.mem.copy(u8, self.buf[self.cursor..], buf[count .. count + max_write]); + + self.cursor += max_write; + count += max_write; + } + + return count; + } + + pub fn close(self: *MessageWriter) void { + self.allocator.free(self.buf); + } + + pub fn finish(self: *MessageWriter) !void { + if (self.socket.closed) return error.SocketClosed; + try self.flushFrame(true); + } + + fn flushFrame(self: *MessageWriter, is_final: bool) !void { + if (self.socket.closed) return error.SocketClosed; + const opcode = if (self.is_first) self.opcode else .continuation; + self.is_first = false; + const header = FrameInfo{ + .is_final = is_final, + .rsv = 0, + .opcode = opcode, + .masking_key = null, // TODO: use a key on client-side + .len = self.cursor, + }; + std.log.debug("sending frame with size {}", .{header.len}); + try writeFrame(self.socket.writer(), header, self.buf[0..self.cursor]); + self.cursor = 0; + } +}; + +pub const Socket = struct { + stream: Stream, + closed: bool = false, + + const ReadError = Stream.ReadError || error{SocketClosed}; + const WriteError = Stream.WriteError || error{SocketClosed}; + const Reader = std.io.Reader(*Socket, ReadError, read); + const Writer = std.io.Writer(*Socket, WriteError, write); + + fn writer(self: *Socket) Writer { + return .{ .context = self }; + } + + fn reader(self: *Socket) Reader { + return .{ .context = self }; + } + + fn read(self: *Socket, buf: []u8) !usize { + if (self.closed) return error.SocketClosed; + + return self.stream.read(buf) catch |err| { + self.closed = true; + return err; + }; + } + + fn write(self: *Socket, buf: []const u8) !usize { + if (self.closed) return error.SocketClosed; + + return self.stream.write(buf) catch |err| { + self.closed = true; + return err; + }; + } + + pub fn accept(self: *Socket) !MessageReader { + if (self.closed) return error.SocketClosed; + const frame = try self.waitForDataFrame(); + return MessageReader{ .opcode = frame.info.opcode, .socket = self, .frame = frame }; + } + + pub const MessageOptions = struct { + buf_size: usize = 512, + }; + pub fn openMessage(self: *Socket, opcode: Opcode, alloc: std.mem.Allocator, options: MessageOptions) !MessageWriter { + if (self.closed) return error.SocketClosed; + var buf = try alloc.alloc(u8, options.buf_size); + return MessageWriter{ .socket = self, .allocator = alloc, .buf = buf, .opcode = opcode }; + } + + fn openFrame(self: *Socket) !FrameReader(Reader) { + std.log.debug("waiting for frame header", .{}); + + const frame = try frameReader(self.reader()); + + std.log.debug("Received frame of type {any}, size {}", .{ frame.info.opcode, frame.info.len }); + + return frame; + } + + fn waitForDataFrame(self: *Socket) !FrameReader(Reader) { + while (true) { + var frame = try frameReader(self.reader()); + if (frame.info.opcode.isData()) return frame; + + // TODO: handle control frames + frame.close(); + } + } +}; + +fn pipe(writer: anytype, reader: anytype) !usize { + var count: usize = 0; + + var buf: [64]u8 = undefined; + while (true) { + const c = try reader.read(&buf); + if (c == 0) break; + + if (try writer.write(buf[0..c]) != c) return error.EndOfStream; + count += c; + } + return count; +} diff --git a/src/main/controllers.zig b/src/main/controllers.zig index b0f7800..31fbe90 100644 --- a/src/main/controllers.zig +++ b/src/main/controllers.zig @@ -12,16 +12,24 @@ pub const communities = @import("./controllers/communities.zig"); pub const invites = @import("./controllers/invites.zig"); pub const users = @import("./controllers/users.zig"); pub const notes = @import("./controllers/notes.zig"); +pub const streaming = @import("./controllers/streaming.zig"); pub fn routeRequest(api_source: anytype, req: http.Request, res: *http.Response, alloc: std.mem.Allocator) void { // TODO: hashmaps? var response = Response{ .headers = http.Headers.init(alloc), .res = res }; defer response.headers.deinit(); + + const found = routeRequestInternal(api_source, req, &response, alloc); + + if (!found) response.status(.not_found) catch {}; +} + +fn routeRequestInternal(api_source: anytype, req: http.Request, res: *Response, alloc: std.mem.Allocator) bool { inline for (routes) |route| { - if (Context(route).matchAndHandle(api_source, req, &response, alloc)) return; + if (Context(route).matchAndHandle(api_source, req, res, alloc)) return true; } - response.status(.not_found) catch {}; + return false; } const routes = .{ @@ -33,6 +41,7 @@ const routes = .{ users.create, notes.create, notes.get, + streaming.streaming, }; pub fn Context(comptime Route: type) type { @@ -49,6 +58,8 @@ pub fn Context(comptime Route: type) type { // leave it as a simple string instead of void pub const Query = if (@hasDecl(Route, "Query")) Route.Query else void; + base_request: http.Request, + allocator: std.mem.Allocator, method: http.Method, @@ -90,6 +101,7 @@ pub fn Context(comptime Route: type) type { var self = Self{ .allocator = alloc, + .base_request = req, .method = req.method, .uri = req.uri, @@ -155,7 +167,7 @@ pub fn Context(comptime Route: type) type { Route.handler(self, response, api_conn) catch |err| switch (err) { else => { std.log.err("{}", .{err}); - response.err(.internal_server_error, "", {}) catch {}; + if (!response.opened) response.err(.internal_server_error, "", {}) catch {}; }, }; } @@ -183,6 +195,7 @@ pub const Response = struct { res: *http.Response, opened: bool = false, + /// Write a response with no body, only a given status pub fn status(self: *Self, status_code: http.Status) !void { std.debug.assert(!self.opened); self.opened = true; @@ -192,6 +205,7 @@ pub const Response = struct { try stream.finish(); } + /// Write a request body as json pub fn json(self: *Self, status_code: http.Status, response_body: anytype) !void { std.debug.assert(!self.opened); self.opened = true; @@ -207,12 +221,20 @@ pub const Response = struct { try stream.finish(); } + /// Prints the given error as json pub fn err(self: *Self, status_code: http.Status, message: []const u8, details: anytype) !void { return self.json(status_code, .{ .message = message, .details = details, }); } + + /// Signals that the HTTP connection should be hijacked without writing a + /// response beforehand. + pub fn hijack(self: *Self) *http.Response { + self.opened = true; + return self.res; + } }; const json_options = if (builtin.mode == .Debug) diff --git a/src/main/controllers/streaming.zig b/src/main/controllers/streaming.zig new file mode 100644 index 0000000..4b8745b --- /dev/null +++ b/src/main/controllers/streaming.zig @@ -0,0 +1,36 @@ +const http = @import("http"); +const std = @import("std"); + +pub const streaming = struct { + pub const method = .GET; + pub const path = "/streaming"; + + pub fn handler(req: anytype, response: anytype, _: anytype) !void { + var iter = req.headers.iterator(); + std.log.debug("--Headers--", .{}); + while (iter.next()) |pair| { + std.log.debug("{s}: {s}", .{ pair.key_ptr.*, pair.value_ptr.* }); + } + + const res = response.hijack(); + + var socket = try http.socket.handshake(req.allocator, req.base_request, res); + + while (true) { + var message = try socket.accept(); + defer message.close(); + std.log.debug("Message received", .{}); + + const reader = message.reader(); + const msg = try reader.readAllAlloc(req.allocator, 1 << 63); + defer req.allocator.free(msg); + + var response_msg = try socket.openMessage(.text, req.allocator, .{}); + defer response_msg.close(); + + try response_msg.writer().writeAll(msg); + try response_msg.finish(); + std.log.debug("{s}", .{msg}); + } + } +};