const std = @import("std"); const db = @import("./db.zig"); pub const io_mode = .evented; const Reader = std.net.Stream.Reader; const Writer = std.net.Stream.Writer; pub const Uuid = struct { data: [16]u8, pub fn eql(lhs: Uuid, rhs: Uuid) bool { return std.mem.eql(u8, &lhs.data, &rhs.data); } pub fn format(value: Uuid, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { try std.fmt.format(writer, "{x:0>2}{x:0>2}{x:0>2}{x:0>2}-{x:0>2}{x:0>2}-{x:0>2}{x:0>2}-{x:0>2}{x:0>2}-{x:0>2}{x:0>2}{x:0>2}{x:0>2}{x:0>2}{x:0>2}", .{ value.data[0], value.data[1], value.data[2], value.data[3], value.data[4], value.data[5], value.data[6], value.data[7], value.data[8], value.data[9], value.data[10], value.data[11], value.data[12], value.data[13], value.data[14], value.data[15], }); } const ParseError = error{ InvalidCharacter, }; pub fn parse(str: []const u8) ParseError!Uuid { if (str.len < 36) return error.InvalidCharacter; var uuid: Uuid = undefined; var str_i: usize = 0; var i: usize = 0; while (i < 16 and str_i < str.len) : ({ i += 1; str_i += 2; }) { uuid.data[i] = std.fmt.parseInt(u8, str[str_i .. str_i + 2], 16) catch |err| switch (err) { error.InvalidCharacter => return error.InvalidCharacter, else => unreachable, }; if (i == 3 or i == 5 or i == 7 or i == 9) { if (str[str_i + 2] != '-') return error.InvalidCharacter; str_i += 1; } } return uuid; } }; const ciutf8 = struct { const Hash = std.hash.Wyhash; const View = std.unicode.Utf8View; const toLower = std.ascii.toLower; const isAscii = std.ascii.isASCII; const seed = 1; pub fn hash(str: []const u8) u64 { // fallback to regular hash on invalid utf8 const view = View.init(str) catch return Hash.hash(seed, str); var iter = view.iterator(); var h = Hash.init(seed); var it = iter.nextCodepointSlice(); while (it != null) : (it = iter.nextCodepointSlice()) { if (it.?.len == 1 and isAscii(it.?[0])) { const ch = [1]u8{toLower(it.?[0])}; h.update(&ch); } else { h.update(it.?); } } return h.final(); } pub fn eql(a: []const u8, b: []const u8) bool { if (a.len != b.len) return false; const va = View.init(a) catch return std.mem.eql(u8, a, b); const vb = View.init(b) catch return false; var iter_a = va.iterator(); var iter_b = vb.iterator(); var it_a = iter_a.nextCodepointSlice(); var it_b = iter_b.nextCodepointSlice(); while (it_a != null and it_b != null) : ({ it_a = iter_a.nextCodepointSlice(); it_b = iter_b.nextCodepointSlice(); }) { if (it_a.?.len != it_b.?.len) return false; if (it_a.?.len == 1) { if (isAscii(it_a.?[0]) and isAscii(it_b.?[0])) { const ch_a = toLower(it_a.?[0]); const ch_b = toLower(it_b.?[0]); if (ch_a != ch_b) return false; } else if (it_a.?[0] != it_b.?[0]) return false; } else if (!std.mem.eql(u8, it_a.?, it_b.?)) return false; } return it_a == null and it_b == null; } pub fn lowerInPlace(str: []u8) void { const view = View.init(str) catch return; var iter = view.iterator(); var it = iter.nextCodepointSlice(); while (it != null) : (it = iter.nextCodepointSlice()) { if (isAscii(it.?[0])) it.?[0] = toLower(it.?[0]); } } }; const HeaderMap = std.HashMap([]const u8, []const u8, struct { pub fn eql(_: @This(), a: []const u8, b: []const u8) bool { return ciutf8.eql(a, b); } pub fn hash(_: @This(), str: []const u8) u64 { return ciutf8.hash(str); } }, std.hash_map.default_max_load_percentage); fn handleBadRequest(writer: Writer) !void { std.log.info("400 Bad Request", .{}); try writer.writeAll("HTTP/1.1 400 Bad Request"); } fn handleNotImplemented(writer: Writer) !void { std.log.info("501", .{}); try writer.writeAll("HTTP/1.1 501 Not Implemented"); } fn handleInternalError(writer: Writer) !void { std.log.info("500", .{}); try writer.writeAll("HTTP/1.1 500 Internal Server Error"); } const Method = enum { GET, //HEAD, POST, //PUT, //DELETE, //CONNECT, //OPTIONS, //TRACE, }; fn parseHttpMethod(reader: Reader) !Method { var buf: [8]u8 = undefined; const str = reader.readUntilDelimiter(&buf, ' ') catch |err| switch (err) { error.StreamTooLong => return error.MethodNotImplemented, else => return err, }; inline for (@typeInfo(Method).Enum.fields) |method| { if (std.mem.eql(u8, method.name, str)) { return @intToEnum(Method, method.value); } } return error.MethodNotImplemented; } fn checkProto(reader: Reader) !void { var buf: [8]u8 = undefined; const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { error.StreamTooLong => return error.UnknownProtocol, else => return err, }; if (!std.mem.eql(u8, proto, "HTTP")) { return error.UnknownProtocol; } const count = try reader.read(buf[0..3]); if (count != 3 or buf[1] != '.') { return error.BadRequest; } if (buf[0] != '1' or buf[2] != '1') { return error.HttpVersionNotSupported; } } fn extractHeaderName(line: []const u8) ?[]const u8 { var index: usize = 0; // TODO: handle whitespace while (index < line.len) : (index += 1) { if (line[index] == ':') { if (index == 0) return null; return line[0..index]; } } return null; } fn parseHeaders(allocator: std.mem.Allocator, reader: Reader) !HeaderMap { var map = HeaderMap.init(allocator); errdefer map.deinit(); // TODO: free map keys/values var buf: [1024]u8 = undefined; while (true) { const line = try reader.readUntilDelimiter(&buf, '\n'); if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break; // TODO: handle multi-line headers const name = extractHeaderName(line) orelse continue; const value_end = if (line[line.len - 1] == '\r') line.len - 1 else line.len; const value = line[name.len + 1 + 1 .. value_end]; if (name.len == 0 or value.len == 0) return error.BadRequest; const name_alloc = try allocator.alloc(u8, name.len); errdefer allocator.free(name_alloc); const value_alloc = try allocator.alloc(u8, value.len); errdefer allocator.free(value_alloc); @memcpy(name_alloc.ptr, name.ptr, name.len); @memcpy(value_alloc.ptr, value.ptr, value.len); try map.put(name_alloc, value_alloc); } return map; } fn handleConnection(conn: std.net.StreamServer.Connection) void { defer conn.stream.close(); const reader = conn.stream.reader(); const writer = conn.stream.writer(); handleRequest(reader, writer) catch |err| std.log.err("unhandled error processing connection: {}", .{err}); } fn handleRequest(reader: Reader, writer: Writer) !void { handleHttpRequest(reader, writer) catch |err| switch (err) { error.BadRequest, error.UnknownProtocol => try handleBadRequest(writer), error.MethodNotImplemented, error.HttpVersionNotSupported => try handleNotImplemented(writer), else => { std.log.err("unknown error handling request: {}", .{err}); try handleInternalError(writer); }, }; } fn handleHttpRequest(reader: Reader, writer: Writer) anyerror!void { const method = try parseHttpMethod(reader); var header_buf: [1 << 16]u8 = undefined; var fba = std.heap.FixedBufferAllocator.init(&header_buf); const allocator = fba.allocator(); const path = reader.readUntilDelimiterAlloc(allocator, ' ', header_buf.len) catch |err| switch (err) { error.StreamTooLong => return error.URITooLong, else => return err, }; try checkProto(reader); _ = try reader.readByte(); _ = try reader.readByte(); const headers = try parseHeaders(allocator, reader); const has_body = (headers.get("Content-Length") orelse headers.get("Transfer-Encoding")) != null; const tfer_encoding = headers.get("Transfer-Encoding"); if (tfer_encoding != null and !std.mem.eql(u8, tfer_encoding.?, "identity")) { return error.UnsupportedMediaType; } const encoding = headers.get("Content-Encoding"); if (encoding != null and !std.mem.eql(u8, encoding.?, "identity")) { return error.UnsupportedMediaType; } var context = Context{ .request = .{ .method = method, .path = path, .headers = headers, .body = if (has_body) reader else null, }, .response = .{ .headers = HeaderMap.init(allocator), .writer = writer, }, .allocator = allocator, }; try routeRequest(&context); } const Context = struct { const Request = struct { method: Method, path: []const u8, route: ?*const Route = null, headers: HeaderMap, body: ?Reader, pub fn arg(self: *Request, name: []const u8) []const u8 { return self.route.?.arg(name, self.path); } }; const Response = struct { headers: HeaderMap, writer: Writer, fn writeHeaders(self: *Response) !void { var iter = self.headers.iterator(); var it = iter.next(); while (it != null) : (it = iter.next()) { try self.writer.print("{s}: {s}\r\n", .{ it.?.key_ptr.*, it.?.value_ptr.* }); } } fn statusText(status: u16) []const u8 { return switch (status) { 200 => "OK", 204 => "No Content", 404 => "Not Found", else => "", }; } fn openInternal(self: *Response, status: u16) !void { try self.writer.print("HTTP/1.1 {} {s}\r\n", .{ status, statusText(status) }); try self.writeHeaders(); try self.writer.writeAll("Connection: close\r\n"); // TODO } pub fn open(self: *Response, status: u16) !Writer { try self.openInternal(status); try self.writer.writeAll("\r\n"); return self.writer; } pub fn write(self: *Response, status: u16, body: []const u8) !void { try self.openInternal(status); if (body.len != 0) { try self.writer.print("Content-Length: {}\r\n", .{body.len}); if (self.headers.get("Content-Type") == null) { try self.writer.writeAll("Content-Type: application/octet-stream\r\n"); } } try self.writer.writeAll("\r\n"); if (body.len != 0) { try self.writer.writeAll(body); } } pub fn statusOnly(self: *Response, status: u16) !void { try self.openInternal(status); } }; request: Request, response: Response, allocator: std.mem.Allocator, }; const Route = struct { const Segment = union(enum) { param: []const u8, literal: []const u8, }; const Handler = fn (*Context) callconv(.Async) anyerror!void; fn normalize(comptime path: []const u8) []const u8 { var arr: [path.len]u8 = undefined; var i = 0; for (path) |ch| { if (i == 0 and ch == '/') continue; if (i > 0 and ch == '/' and arr[i - 1] == '/') continue; arr[i] = ch; i += 1; } if (i > 0 and arr[i - 1] == '/') { i -= 1; } return arr[0..i]; } fn parseSegments(comptime path: []const u8) []const Segment { var count = 1; for (path) |ch| { if (ch == '/') count += 1; } var segment_array: [count]Segment = undefined; var segment_start = 0; for (segment_array) |*seg| { var index = segment_start; while (index < path.len) : (index += 1) { if (path[index] == '/') { break; } } const slice = path[segment_start..index]; if (slice.len > 0 and slice[0] == ':') { // doing this kinda jankily to get around segfaults in compiler const param = path[segment_start + 1 .. index]; seg.* = .{ .param = param }; } else { seg.* = .{ .literal = slice }; } segment_start = index + 1; } return &segment_array; } pub fn from(method: Method, comptime path: []const u8, handler: Handler) Route { const segments = parseSegments(normalize(path)); return Route{ .method = method, .path = segments, .handler = handler }; } fn nextSegment(path: []const u8) ?[]const u8 { var start: usize = 0; var end: usize = start; while (end < path.len) : (end += 1) { // skip leading slash if (end == start and path[start] == '/') { start += 1; continue; } else if (path[end] == '/') { break; } } if (start == end) return null; return path[start..end]; } pub fn matches(self: Route, path: []const u8) bool { var segment_start: usize = 0; for (self.path) |seg| { var index = segment_start; while (index < path.len) : (index += 1) { // skip leading slash if (index == segment_start and path[index] == '/') { segment_start += 1; continue; } else if (path[index] == '/') { break; } } const slice = path[segment_start..index]; const match = switch (seg) { .literal => |str| ciutf8.eql(slice, str), .param => true, }; if (!match) return false; segment_start = index + 1; } // check for trailing path while (segment_start < path.len) : (segment_start += 1) { if (path[segment_start] != '/') return false; } return true; } pub fn arg(self: Route, name: []const u8, path: []const u8) []const u8 { var index: usize = 0; for (self.path) |seg| { const slice = nextSegment(path[index..]); if (slice == null) return ""; index = @ptrToInt(slice.?.ptr) - @ptrToInt(path.ptr) + slice.?.len + 1; switch (seg) { .param => |param| { if (std.mem.eql(u8, param, name)) { return slice.?; } }, .literal => continue, } } std.log.err("unknown parameter {s}", .{name}); return ""; } method: Method, path: []const Segment, handler: Handler, }; fn handleNotFound(ctx: *Context) !void { try ctx.response.writer.writeAll("HTTP/1.1 404 Not Found\r\n\r\n"); } fn routeRequest(ctx: *Context) !void { for (routes) |*route| { if (route.method == ctx.request.method and route.matches(ctx.request.path)) { std.log.info("{s} {s}", .{ @tagName(ctx.request.method), ctx.request.path }); ctx.request.route = route; var buf = try ctx.allocator.allocWithOptions(u8, @frameSize(route.handler), 8, null); defer ctx.allocator.free(buf); return await @asyncCall(buf, {}, route.handler, .{ctx}); } } std.log.info("404 {s} {s}", .{ @tagName(ctx.request.method), ctx.request.path }); try handleNotFound(ctx); } const routes = [_]Route{ Route.from(.GET, "/", staticString("Index Page")), Route.from(.GET, "/abc", staticString("abc")), Route.from(.GET, "/user/:id", getUser), }; const this_scheme = "http"; const this_host = "localhost:8080"; fn getUser(ctx: *Context) anyerror!void { const id_str = ctx.request.arg("id"); const host = ctx.request.headers.get("host") orelse { try ctx.response.statusOnly(400); return; }; const id = Uuid.parse(id_str) catch { try ctx.response.statusOnly(400); return; }; const actor = try db.getActorById(id); if (actor == null or !std.mem.eql(u8, actor.?.host, host)) { try ctx.response.statusOnly(404); return; } try ctx.response.headers.put("Content-Type", "application/ld+json; profile=\"https://www.w3.org/ns/activitystreams\""); var writer = try ctx.response.open(200); try writer.writeAll("{\"type\":\"Person\","); try writer.print("\"id\":\"{s}://{s}/user/{}\",", .{ this_scheme, this_host, id }); try writer.print("\"preferredUsername\":\"{s}\"", .{actor.?.handle}); try writer.writeAll("}"); } fn staticString(comptime str: []const u8) Route.Handler { return (struct { fn func(ctx: *Context) anyerror!void { try ctx.response.headers.put("Content-Type", "text/plain"); try ctx.response.write(200, str); } }).func; } pub fn main() anyerror!void { var srv = std.net.StreamServer.init(.{ .reuse_address = true }); defer srv.deinit(); const uuid = try Uuid.parse("f75f5160-12d3-42c2-a81d-ad2245b7a74b"); std.log.debug("{}", .{uuid}); try srv.listen(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); while (true) { const conn = try srv.accept(); // todo: keep track of connections _ = async handleConnection(conn); } }