const std = @import("std"); const builtin = @import("builtin"); const http = @import("./lib.zig"); const Stream = @import("./server.zig").Stream; const Opcode = enum(u4) { continuation = 0x0, text = 0x1, binary = 0x2, // 3-7 are non-control frames close = 0x8, ping = 0x9, pong = 0xa, // b-f are control frames _, fn isData(self: Opcode) bool { return !self.isControl(); } fn isControl(self: Opcode) bool { return @enumToInt(self) & 0x8 == 0x8; } }; pub fn handshake(alloc: std.mem.Allocator, req_headers: *const http.Fields, res: *http.Response) !Socket { const upgrade = req_headers.get("Upgrade") orelse return error.BadHandshake; const connection = req_headers.get("Connection") orelse return error.BadHandshake; if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake; if (std.ascii.indexOfIgnoreCase(connection, "Upgrade") == null) return error.BadHandshake; const key_hdr = req_headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake; if ((try std.base64.standard.Decoder.calcSizeForSlice(key_hdr)) != 16) return error.BadHandshake; var key: [16]u8 = undefined; std.base64.standard.Decoder.decode(&key, key_hdr) catch return error.BadHandshake; const version = req_headers.get("Sec-WebSocket-version") orelse return error.BadHandshake; if (!std.mem.eql(u8, "13", version)) return error.BadHandshake; var headers = http.Fields.init(alloc); defer headers.deinit(); try headers.put("Upgrade", "websocket"); try headers.put("Connection", "Upgrade"); var response_key = std.mem.zeroes([60]u8); std.mem.copy(u8, &response_key, key_hdr); std.mem.copy(u8, response_key[key_hdr.len..], "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); const Sha1 = std.crypto.hash.Sha1; var hash: [Sha1.digest_length]u8 = undefined; Sha1.hash(response_key[0..], &hash, .{}); var hash_encoded: [std.base64.standard.Encoder.calcSize(Sha1.digest_length)]u8 = undefined; _ = std.base64.standard.Encoder.encode(&hash_encoded, &hash); try headers.put("Sec-WebSocket-Accept", &hash_encoded); const stream = try res.upgrade(.switching_protocols, &headers); return Socket{ .stream = stream }; } pub const CloseReason = enum(u16) { ok = 1000, going_away = 1001, protocol_error = 1002, unsupported_frame_type = 1003, // reserved = 1004, // technically these codes can be explicitly provided but it's against spec no_reason_provided = 1005, abnormal_closure = 1006, invalid_frame_data = 1007, // e.g. non UTF-8 data in a text type frame policy_violation = 1008, message_too_large = 1009, missing_extension = 1010, internal_server_error = 1011, // unused on server side but included here for completeness invalid_tls_handshake = 1015, _, }; const FrameInfo = struct { is_final: bool, rsv: u3, opcode: Opcode, masking_key: ?[4]u8, len: usize, }; fn frameReader(reader: anytype) !FrameReader(@TypeOf(reader)) { return FrameReader(@TypeOf(reader)).open(reader); } fn FrameReader(comptime R: type) type { return struct { const Self = @This(); info: FrameInfo, bytes_read: usize = 0, mask_idx: u2 = 0, underlying: R, const Reader = std.io.Reader(*Self, anyerror, read); fn open(r: R) !Self { var hdr_buf: [2]u8 = undefined; try r.readNoEof(&hdr_buf); const is_final = hdr_buf[0] & 0b1000_0000 != 0; const rsv = @intCast(u3, (hdr_buf[0] & 0b0111_0000) >> 4); const opcode = @intToEnum(Opcode, (hdr_buf[0] & 0b1111)); const is_masked = hdr_buf[1] & 0b1000_0000 != 0; const initial_len = @intCast(u7, (hdr_buf[1] & 0b0111_1111)); const len: usize = if (initial_len < 126) initial_len else if (initial_len == 126) try r.readIntBig(u16) else try r.readIntBig(usize); const masking_key = if (is_masked) try r.readBytesNoEof(4) else null; return Self{ .info = .{ .is_final = is_final, .rsv = rsv, .opcode = opcode, .masking_key = masking_key, .len = len, }, .underlying = r, }; } fn read(self: *Self, buf: []u8) !usize { var r = std.io.limitedReader(self.underlying, self.info.len - self.bytes_read); const count = try r.read(buf); if (self.info.masking_key) |mask| { for (buf[0..count]) |*ch| { ch.* ^= mask[self.mask_idx]; self.mask_idx +%= 1; } } self.bytes_read += count; return count; } fn reader(self: *Self) Reader { return .{ .context = self }; } fn close(self: *Self) void { self.reader().skipBytes(@as(usize, 0) -% 1, .{}) catch {}; } }; } fn writeFrame(writer: anytype, header: FrameInfo, buf: []const u8) !void { std.debug.assert(header.len == buf.len); const initial_len: u7 = if (header.len < 126) @intCast(u7, header.len) else if (std.math.cast(u16, header.len)) |_| @as(u7, 126) else @as(u7, 127); var hdr_buf = [2]u8{ 0, 0 }; hdr_buf[0] |= if (header.is_final) @as(u8, 0b1000_0000) else 0; hdr_buf[0] |= @as(u8, header.rsv) << 4; hdr_buf[0] |= @enumToInt(header.opcode); hdr_buf[1] |= if (header.masking_key) |_| @as(u8, 0b1000_0000) else 0; hdr_buf[1] |= initial_len; try writer.writeAll(&hdr_buf); if (initial_len == 126) try writer.writeIntBig(u16, std.math.cast(u16, buf.len).?) else if (initial_len == 127) try writer.writeIntBig(usize, buf.len); if (header.masking_key) |key| try writer.writeAll(&key); const mask = header.masking_key orelse std.mem.zeroes([4]u8); var mask_idx: u2 = 0; for (buf) |ch| { try writer.writeByte(mask[mask_idx] ^ ch); } } pub const MessageReader = struct { opcode: Opcode, socket: *Socket, frame: FrameReader(Socket.Reader), pub const Reader = std.io.Reader(*MessageReader, anyerror, read); pub fn read(self: *MessageReader, buf: []u8) !usize { if (self.socket.closed) return error.SocketClosed; var count: usize = 0; while (count < buf.len) { const c = try self.frame.read(buf); count += c; if (c == 0) { self.frame.close(); if (self.frame.info.is_final) break; self.frame = try self.socket.waitForDataFrame(); if (self.frame.info.opcode != .continuation) { //self.socket.close(); return error.InvalidFrame; } } } return count; } pub fn reader(self: *MessageReader) Reader { return .{ .context = self }; } pub fn close(self: *MessageReader) void { self.reader().skipBytes(@as(usize, 0) -% 1, .{}) catch {}; } }; pub const MessageWriter = struct { allocator: std.mem.Allocator, socket: *Socket, opcode: Opcode, buf: []u8, cursor: usize = 0, is_first: bool = true, pub const WriteError = Socket.WriteError; pub const Writer = std.io.Writer(*MessageWriter, WriteError, write); pub fn writer(self: *MessageWriter) Writer { return .{ .context = self }; } pub fn write(self: *MessageWriter, buf: []const u8) !usize { if (self.socket.closed) return error.SocketClosed; var count: usize = 0; while (count < buf.len) { if (self.cursor == self.buf.len) try self.flushFrame(false); const max_write = std.math.min(self.buf.len - self.cursor, buf.len - count); std.mem.copy(u8, self.buf[self.cursor..], buf[count .. count + max_write]); self.cursor += max_write; count += max_write; } return count; } pub fn close(self: *MessageWriter) void { self.allocator.free(self.buf); } pub fn finish(self: *MessageWriter) !void { if (self.socket.closed) return error.SocketClosed; try self.flushFrame(true); } fn flushFrame(self: *MessageWriter, is_final: bool) !void { if (self.socket.closed) return error.SocketClosed; const opcode = if (self.is_first) self.opcode else .continuation; self.is_first = false; const header = FrameInfo{ .is_final = is_final, .rsv = 0, .opcode = opcode, .masking_key = null, // TODO: use a key on client-side .len = self.cursor, }; std.log.debug("sending frame with size {}", .{header.len}); try writeFrame(self.socket.writer(), header, self.buf[0..self.cursor]); self.cursor = 0; } }; pub const Socket = struct { stream: Stream, closed: bool = false, const ReadError = Stream.ReadError || error{SocketClosed}; const WriteError = Stream.WriteError || error{SocketClosed}; const Reader = std.io.Reader(*Socket, ReadError, read); const Writer = std.io.Writer(*Socket, WriteError, write); fn writer(self: *Socket) Writer { return .{ .context = self }; } fn reader(self: *Socket) Reader { return .{ .context = self }; } fn read(self: *Socket, buf: []u8) !usize { if (self.closed) return error.SocketClosed; return self.stream.read(buf) catch |err| { self.closed = true; return err; }; } fn write(self: *Socket, buf: []const u8) !usize { if (self.closed) return error.SocketClosed; return self.stream.write(buf) catch |err| { self.closed = true; return err; }; } pub fn accept(self: *Socket) !MessageReader { if (self.closed) return error.SocketClosed; const frame = try self.waitForDataFrame(); return MessageReader{ .opcode = frame.info.opcode, .socket = self, .frame = frame }; } pub const MessageOptions = struct { buf_size: usize = 512, }; pub fn openMessage(self: *Socket, opcode: Opcode, alloc: std.mem.Allocator, options: MessageOptions) !MessageWriter { if (self.closed) return error.SocketClosed; var buf = try alloc.alloc(u8, options.buf_size); return MessageWriter{ .socket = self, .allocator = alloc, .buf = buf, .opcode = opcode }; } fn openFrame(self: *Socket) !FrameReader(Reader) { std.log.debug("waiting for frame header", .{}); const frame = try frameReader(self.reader()); std.log.debug("Received frame of type {any}, size {}", .{ frame.info.opcode, frame.info.len }); return frame; } fn waitForDataFrame(self: *Socket) !FrameReader(Reader) { while (true) { var frame = try frameReader(self.reader()); if (frame.info.opcode.isData()) return frame; // TODO: handle control frames frame.close(); } } }; fn pipe(writer: anytype, reader: anytype) !usize { var count: usize = 0; var buf: [64]u8 = undefined; while (true) { const c = try reader.read(&buf); if (c == 0) break; if (try writer.write(buf[0..c]) != c) return error.EndOfStream; count += c; } return count; }