fediglam/src/http/socket.zig

373 lines
11 KiB
Zig

const std = @import("std");
const builtin = @import("builtin");
const http = @import("./lib.zig");
const Stream = @import("./server.zig").Stream;
const Opcode = enum(u4) {
continuation = 0x0,
text = 0x1,
binary = 0x2,
// 3-7 are non-control frames
close = 0x8,
ping = 0x9,
pong = 0xa,
// b-f are control frames
_,
fn isData(self: Opcode) bool {
return !self.isControl();
}
fn isControl(self: Opcode) bool {
return @enumToInt(self) & 0x8 == 0x8;
}
};
pub fn handshake(alloc: std.mem.Allocator, req_headers: *const http.Fields, res: *http.Response) !Socket {
const upgrade = req_headers.get("Upgrade") orelse return error.BadHandshake;
const connection = req_headers.get("Connection") orelse return error.BadHandshake;
if (std.ascii.indexOfIgnoreCase(upgrade, "websocket") == null) return error.BadHandshake;
if (std.ascii.indexOfIgnoreCase(connection, "Upgrade") == null) return error.BadHandshake;
const key_hdr = req_headers.get("Sec-WebSocket-Key") orelse return error.BadHandshake;
if ((try std.base64.standard.Decoder.calcSizeForSlice(key_hdr)) != 16) return error.BadHandshake;
var key: [16]u8 = undefined;
std.base64.standard.Decoder.decode(&key, key_hdr) catch return error.BadHandshake;
const version = req_headers.get("Sec-WebSocket-version") orelse return error.BadHandshake;
if (!std.mem.eql(u8, "13", version)) return error.BadHandshake;
var headers = http.Fields.init(alloc);
defer headers.deinit();
try headers.put("Upgrade", "websocket");
try headers.put("Connection", "Upgrade");
var response_key = std.mem.zeroes([60]u8);
std.mem.copy(u8, &response_key, key_hdr);
std.mem.copy(u8, response_key[key_hdr.len..], "258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
const Sha1 = std.crypto.hash.Sha1;
var hash: [Sha1.digest_length]u8 = undefined;
Sha1.hash(response_key[0..], &hash, .{});
var hash_encoded: [std.base64.standard.Encoder.calcSize(Sha1.digest_length)]u8 = undefined;
_ = std.base64.standard.Encoder.encode(&hash_encoded, &hash);
try headers.put("Sec-WebSocket-Accept", &hash_encoded);
const stream = try res.upgrade(.switching_protocols, &headers);
return Socket{ .stream = stream };
}
pub const CloseReason = enum(u16) {
ok = 1000,
going_away = 1001,
protocol_error = 1002,
unsupported_frame_type = 1003,
// reserved = 1004,
// technically these codes can be explicitly provided but it's against spec
no_reason_provided = 1005,
abnormal_closure = 1006,
invalid_frame_data = 1007, // e.g. non UTF-8 data in a text type frame
policy_violation = 1008,
message_too_large = 1009,
missing_extension = 1010,
internal_server_error = 1011,
// unused on server side but included here for completeness
invalid_tls_handshake = 1015,
_,
};
const FrameInfo = struct {
is_final: bool,
rsv: u3,
opcode: Opcode,
masking_key: ?[4]u8,
len: usize,
};
fn frameReader(reader: anytype) !FrameReader(@TypeOf(reader)) {
return FrameReader(@TypeOf(reader)).open(reader);
}
fn FrameReader(comptime R: type) type {
return struct {
const Self = @This();
info: FrameInfo,
bytes_read: usize = 0,
mask_idx: u2 = 0,
underlying: R,
const Reader = std.io.Reader(*Self, anyerror, read);
fn open(r: R) !Self {
var hdr_buf: [2]u8 = undefined;
try r.readNoEof(&hdr_buf);
const is_final = hdr_buf[0] & 0b1000_0000 != 0;
const rsv = @intCast(u3, (hdr_buf[0] & 0b0111_0000) >> 4);
const opcode = @intToEnum(Opcode, (hdr_buf[0] & 0b1111));
const is_masked = hdr_buf[1] & 0b1000_0000 != 0;
const initial_len = @intCast(u7, (hdr_buf[1] & 0b0111_1111));
const len: usize = if (initial_len < 126)
initial_len
else if (initial_len == 126)
try r.readIntBig(u16)
else
try r.readIntBig(usize);
const masking_key = if (is_masked)
try r.readBytesNoEof(4)
else
null;
return Self{
.info = .{
.is_final = is_final,
.rsv = rsv,
.opcode = opcode,
.masking_key = masking_key,
.len = len,
},
.underlying = r,
};
}
fn read(self: *Self, buf: []u8) !usize {
var r = std.io.limitedReader(self.underlying, self.info.len - self.bytes_read);
const count = try r.read(buf);
if (self.info.masking_key) |mask| {
for (buf[0..count]) |*ch| {
ch.* ^= mask[self.mask_idx];
self.mask_idx +%= 1;
}
}
self.bytes_read += count;
return count;
}
fn reader(self: *Self) Reader {
return .{ .context = self };
}
fn close(self: *Self) void {
self.reader().skipBytes(@as(usize, 0) -% 1, .{}) catch {};
}
};
}
fn writeFrame(writer: anytype, header: FrameInfo, buf: []const u8) !void {
std.debug.assert(header.len == buf.len);
const initial_len: u7 = if (header.len < 126)
@intCast(u7, header.len)
else if (std.math.cast(u16, header.len)) |_|
@as(u7, 126)
else
@as(u7, 127);
var hdr_buf = [2]u8{ 0, 0 };
hdr_buf[0] |= if (header.is_final) @as(u8, 0b1000_0000) else 0;
hdr_buf[0] |= @as(u8, header.rsv) << 4;
hdr_buf[0] |= @enumToInt(header.opcode);
hdr_buf[1] |= if (header.masking_key) |_| @as(u8, 0b1000_0000) else 0;
hdr_buf[1] |= initial_len;
try writer.writeAll(&hdr_buf);
if (initial_len == 126)
try writer.writeIntBig(u16, std.math.cast(u16, buf.len).?)
else if (initial_len == 127)
try writer.writeIntBig(usize, buf.len);
if (header.masking_key) |key| try writer.writeAll(&key);
const mask = header.masking_key orelse std.mem.zeroes([4]u8);
var mask_idx: u2 = 0;
for (buf) |ch| {
try writer.writeByte(mask[mask_idx] ^ ch);
}
}
pub const MessageReader = struct {
opcode: Opcode,
socket: *Socket,
frame: FrameReader(Socket.Reader),
pub const Reader = std.io.Reader(*MessageReader, anyerror, read);
pub fn read(self: *MessageReader, buf: []u8) !usize {
if (self.socket.closed) return error.SocketClosed;
var count: usize = 0;
while (count < buf.len) {
const c = try self.frame.read(buf);
count += c;
if (c == 0) {
self.frame.close();
if (self.frame.info.is_final) break;
self.frame = try self.socket.waitForDataFrame();
if (self.frame.info.opcode != .continuation) {
//self.socket.close();
return error.InvalidFrame;
}
}
}
return count;
}
pub fn reader(self: *MessageReader) Reader {
return .{ .context = self };
}
pub fn close(self: *MessageReader) void {
self.reader().skipBytes(@as(usize, 0) -% 1, .{}) catch {};
}
};
pub const MessageWriter = struct {
allocator: std.mem.Allocator,
socket: *Socket,
opcode: Opcode,
buf: []u8,
cursor: usize = 0,
is_first: bool = true,
pub const WriteError = Socket.WriteError;
pub const Writer = std.io.Writer(*MessageWriter, WriteError, write);
pub fn writer(self: *MessageWriter) Writer {
return .{ .context = self };
}
pub fn write(self: *MessageWriter, buf: []const u8) !usize {
if (self.socket.closed) return error.SocketClosed;
var count: usize = 0;
while (count < buf.len) {
if (self.cursor == self.buf.len) try self.flushFrame(false);
const max_write = std.math.min(self.buf.len - self.cursor, buf.len - count);
std.mem.copy(u8, self.buf[self.cursor..], buf[count .. count + max_write]);
self.cursor += max_write;
count += max_write;
}
return count;
}
pub fn close(self: *MessageWriter) void {
self.allocator.free(self.buf);
}
pub fn finish(self: *MessageWriter) !void {
if (self.socket.closed) return error.SocketClosed;
try self.flushFrame(true);
}
fn flushFrame(self: *MessageWriter, is_final: bool) !void {
if (self.socket.closed) return error.SocketClosed;
const opcode = if (self.is_first) self.opcode else .continuation;
self.is_first = false;
const header = FrameInfo{
.is_final = is_final,
.rsv = 0,
.opcode = opcode,
.masking_key = null, // TODO: use a key on client-side
.len = self.cursor,
};
std.log.debug("sending frame with size {}", .{header.len});
try writeFrame(self.socket.writer(), header, self.buf[0..self.cursor]);
self.cursor = 0;
}
};
pub const Socket = struct {
stream: Stream,
closed: bool = false,
const ReadError = Stream.ReadError || error{SocketClosed};
const WriteError = Stream.WriteError || error{SocketClosed};
const Reader = std.io.Reader(*Socket, ReadError, read);
const Writer = std.io.Writer(*Socket, WriteError, write);
fn writer(self: *Socket) Writer {
return .{ .context = self };
}
fn reader(self: *Socket) Reader {
return .{ .context = self };
}
fn read(self: *Socket, buf: []u8) !usize {
if (self.closed) return error.SocketClosed;
return self.stream.read(buf) catch |err| {
self.closed = true;
return err;
};
}
fn write(self: *Socket, buf: []const u8) !usize {
if (self.closed) return error.SocketClosed;
return self.stream.write(buf) catch |err| {
self.closed = true;
return err;
};
}
pub fn accept(self: *Socket) !MessageReader {
if (self.closed) return error.SocketClosed;
const frame = try self.waitForDataFrame();
return MessageReader{ .opcode = frame.info.opcode, .socket = self, .frame = frame };
}
pub const MessageOptions = struct {
buf_size: usize = 512,
};
pub fn openMessage(self: *Socket, opcode: Opcode, alloc: std.mem.Allocator, options: MessageOptions) !MessageWriter {
if (self.closed) return error.SocketClosed;
var buf = try alloc.alloc(u8, options.buf_size);
return MessageWriter{ .socket = self, .allocator = alloc, .buf = buf, .opcode = opcode };
}
fn openFrame(self: *Socket) !FrameReader(Reader) {
std.log.debug("waiting for frame header", .{});
const frame = try frameReader(self.reader());
std.log.debug("Received frame of type {any}, size {}", .{ frame.info.opcode, frame.info.len });
return frame;
}
fn waitForDataFrame(self: *Socket) !FrameReader(Reader) {
while (true) {
var frame = try frameReader(self.reader());
if (frame.info.opcode.isData()) return frame;
// TODO: handle control frames
frame.close();
}
}
};
fn pipe(writer: anytype, reader: anytype) !usize {
var count: usize = 0;
var buf: [64]u8 = undefined;
while (true) {
const c = try reader.read(&buf);
if (c == 0) break;
if (try writer.write(buf[0..c]) != c) return error.EndOfStream;
count += c;
}
return count;
}