const std = @import("std"); pub const io_mode = .evented; const HeaderMap = std.StringHashMap([]const u8); const Reader = std.net.Stream.Reader; const Writer = std.net.Stream.Writer; fn handleBadRequest(writer: Writer) !void { std.log.info("400 Bad Request", .{}); try writer.writeAll("HTTP/1.1 400 Bad Request"); } fn handleNotImplemented(writer: Writer) !void { std.log.info("501", .{}); try writer.writeAll("HTTP/1.1 501 Not Implemented"); } fn handleInternalError(writer: Writer) !void { std.log.info("500", .{}); try writer.writeAll("HTTP/1.1 500 Internal Server Error"); } const Method = enum { GET, //HEAD, POST, //PUT, //DELETE, //CONNECT, //OPTIONS, //TRACE, }; fn areStringsEqual(lhs: []const u8, rhs: []const u8) bool { if (lhs.len != rhs.len) return false; for (lhs) |_, i| { if (lhs[i] != rhs[i]) return false; } return true; } fn parseHttpMethod(reader: Reader) !Method { var buf: [8]u8 = undefined; const str = reader.readUntilDelimiter(&buf, ' ') catch |err| switch (err) { error.StreamTooLong => return error.MethodNotImplemented, else => return err, }; inline for (@typeInfo(Method).Enum.fields) |method| { if (areStringsEqual(method.name, str)) { return @intToEnum(Method, method.value); } } return error.MethodNotImplemented; } fn checkProto(reader: Reader) !void { var buf: [8]u8 = undefined; const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { error.StreamTooLong => return error.UnknownProtocol, else => return err, }; if (!areStringsEqual(proto, "HTTP")) { return error.UnknownProtocol; } const count = try reader.read(buf[0..3]); if (count != 3 or buf[1] != '.') { return error.BadRequest; } if (buf[0] != '1' or buf[2] != '1') { return error.HttpVersionNotSupported; } } fn extractHeaderName(line: []const u8) ?[]const u8 { var index: usize = 0; // TODO: handle whitespace while (index < line.len) : (index += 1) { if (line[index] == ':') { if (index == 0) return null; return line[0..index]; } } return null; } fn parseHeaders(allocator: std.mem.Allocator, reader: Reader) !HeaderMap { var map = HeaderMap.init(allocator); errdefer map.deinit(); // TODO: free map keys/values var buf: [1024]u8 = undefined; while (true) { const line = try reader.readUntilDelimiter(&buf, '\n'); if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break; // TODO: handle multi-line headers const name = extractHeaderName(line) orelse continue; const value = line[name.len + 1 + 1 ..]; if (name.len == 0 or value.len == 0) return error.BadRequest; const name_alloc = try allocator.alloc(u8, name.len); errdefer allocator.free(name_alloc); const value_alloc = try allocator.alloc(u8, value.len); errdefer allocator.free(value_alloc); @memcpy(name_alloc.ptr, name.ptr, name.len); @memcpy(value_alloc.ptr, value.ptr, value.len); for (name_alloc) |*ch| { //TODO: utf8 ch.* = std.ascii.toLower(ch.*); } try map.put(name_alloc, value_alloc); } return map; } fn handleConnection(conn: std.net.StreamServer.Connection) void { defer conn.stream.close(); const reader = conn.stream.reader(); const writer = conn.stream.writer(); handleRequest(reader, writer) catch |err| std.log.err("unhandled error processing connection: {}", .{err}); } fn handleRequest(reader: Reader, writer: Writer) !void { handleHttpRequest(reader, writer) catch |err| switch (err) { error.BadRequest, error.UnknownProtocol => try handleBadRequest(writer), error.MethodNotImplemented, error.HttpVersionNotSupported => try handleNotImplemented(writer), else => { std.log.err("unknown error handling request: {}", .{err}); try handleInternalError(writer); }, }; } fn handleHttpRequest(reader: Reader, writer: Writer) anyerror!void { const method = try parseHttpMethod(reader); var header_buf: [1 << 16]u8 = undefined; var fba = std.heap.FixedBufferAllocator.init(&header_buf); const allocator = fba.allocator(); const path = reader.readUntilDelimiterAlloc(allocator, ' ', header_buf.len) catch |err| switch (err) { error.StreamTooLong => return error.URITooLong, else => return err, }; try checkProto(reader); _ = try reader.readByte(); _ = try reader.readByte(); const headers = try parseHeaders(allocator, reader); const has_body = (headers.get("content-length") orelse headers.get("transfer-encoding")) != null; const tfer_encoding = headers.get("transfer-encoding"); if (tfer_encoding != null and !areStringsEqual(tfer_encoding.?, "identity")) { return error.UnsupportedMediaType; } const encoding = headers.get("content-encoding"); if (encoding != null and !areStringsEqual(encoding.?, "identity")) { return error.UnsupportedMediaType; } var context = Context{ .request = .{ .method = method, .path = path, .headers = headers, .body = if (has_body) reader else null, }, .response = .{ .headers = HeaderMap.init(allocator), .writer = writer, }, .allocator = allocator, }; try routeRequest(&context); } const Context = struct { const Request = struct { method: Method, path: []const u8, route: ?*const Route = null, headers: HeaderMap, body: ?Reader, pub fn arg(self: *Request, name: []const u8) []const u8 { return self.route.?.arg(name, self.path); } }; const Response = struct { headers: HeaderMap, writer: Writer, fn writeHeaders(self: *Response) !void { var iter = self.headers.iterator(); var it = iter.next(); while (it != null) : (it = iter.next()) { try self.writer.print("{s}: {s}\r\n", .{ it.?.key_ptr.*, it.?.value_ptr.* }); } } fn statusText(status: u16) []const u8 { return switch (status) { 200 => "OK", 204 => "No Content", else => "", }; } fn openInternal(self: *Response, status: u16) !void { try self.writer.print("HTTP/1.1 {} {s}\r\n", .{ status, statusText(status) }); try self.writeHeaders(); try self.writer.writeAll("Connection: close\r\n"); // TODO } pub fn open(self: *Response, status: u16) !Writer { try self.openInternal(status); try self.writer.writeAll("\r\n"); return self.writer; } pub fn write(self: *Response, status: u16, body: []const u8) !void { try self.openInternal(status); if (body.len != 0) { try self.writer.print("Content-Length: {}\r\n", .{body.len}); if (self.headers.get("content-type") == null) { try self.writer.writeAll("Content-Type: application/octet-stream\r\n"); } } try self.writer.writeAll("\r\n"); if (body.len != 0) { try self.writer.writeAll(body); } } }; request: Request, response: Response, allocator: std.mem.Allocator, }; const Route = struct { const Segment = union(enum) { param: []const u8, literal: []const u8, }; const Handler = fn (*Context) callconv(.Async) anyerror!void; fn normalize(comptime path: []const u8) []const u8 { var arr: [path.len]u8 = undefined; var i = 0; for (path) |ch| { if (i == 0 and ch == '/') continue; if (i > 0 and ch == '/' and arr[i - 1] == '/') continue; arr[i] = ch; i += 1; } if (i > 0 and arr[i - 1] == '/') { i -= 1; } return arr[0..i]; } fn parseSegments(comptime path: []const u8) []const Segment { var count = 1; for (path) |ch| { if (ch == '/') count += 1; } var segment_array: [count]Segment = undefined; var segment_start = 0; for (segment_array) |*seg| { var index = segment_start; while (index < path.len) : (index += 1) { if (path[index] == '/') { break; } } const slice = path[segment_start..index]; if (slice.len > 0 and slice[0] == ':') { // doing this kinda jankily to get around segfaults in compiler const param = path[segment_start + 1 .. index]; seg.* = .{ .param = param }; } else { seg.* = .{ .literal = slice }; } segment_start = index + 1; } return &segment_array; } pub fn from(method: Method, comptime path: []const u8, handler: Handler) Route { const segments = parseSegments(normalize(path)); return Route{ .method = method, .path = segments, .handler = handler }; } fn nextSegment(path: []const u8) ?[]const u8 { var start: usize = 0; var end: usize = start; while (end < path.len) : (end += 1) { // skip leading slash if (end == start and path[start] == '/') { start += 1; continue; } else if (path[end] == '/') { break; } } if (start == end) return null; return path[start..end]; } pub fn matches(self: Route, path: []const u8) bool { var segment_start: usize = 0; for (self.path) |seg| { var index = segment_start; while (index < path.len) : (index += 1) { // skip leading slash if (index == segment_start and path[index] == '/') { segment_start += 1; continue; } else if (path[index] == '/') { break; } } const slice = path[segment_start..index]; const match = switch (seg) { .literal => |str| areStringsEqual(slice, str), .param => true, }; if (!match) return false; segment_start = index + 1; } // check for trailing path while (segment_start < path.len) : (segment_start += 1) { if (path[segment_start] != '/') return false; } return true; } pub fn arg(self: Route, name: []const u8, path: []const u8) []const u8 { var index: usize = 0; for (self.path) |seg| { const slice = nextSegment(path[index..]); if (slice == null) return ""; index = @ptrToInt(slice.?.ptr) - @ptrToInt(path.ptr) + slice.?.len + 1; switch (seg) { .param => |param| { if (areStringsEqual(param, name)) { return slice.?; } }, .literal => continue, } } std.log.err("unknown parameter {s}", .{name}); return ""; } method: Method, path: []const Segment, handler: Handler, }; fn handleNotFound(ctx: *Context) !void { try ctx.response.writer.writeAll("HTTP/1.1 404 Not Found\r\n\r\n"); } fn routeRequest(ctx: *Context) !void { for (routes) |*route| { if (route.method == ctx.request.method and route.matches(ctx.request.path)) { std.log.info("{s} {s}", .{ @tagName(ctx.request.method), ctx.request.path }); ctx.request.route = route; var buf = try ctx.allocator.allocWithOptions(u8, @frameSize(route.handler), 8, null); defer ctx.allocator.free(buf); return await @asyncCall(buf, {}, route.handler, .{ctx}); } } std.log.info("404 {s} {s}", .{ @tagName(ctx.request.method), ctx.request.path }); try handleNotFound(ctx); } const routes = [_]Route{ Route.from(.GET, "/", staticString("Index Page")), Route.from(.GET, "/test", staticString("some test value idfk")), Route.from(.GET, "/objs/:id/get", getObjIdGet), Route.from(.POST, "/form/submit", staticString("form submit accepted")), }; fn staticString(comptime str: []const u8) Route.Handler { return (struct { fn func(ctx: *Context) anyerror!void { try ctx.response.headers.put("content-type", "text/plain"); try ctx.response.write(200, str); } }).func; } fn getObjIdGet(ctx: *Context) anyerror!void { try ctx.response.headers.put("content-type", "text/plain"); var writer = try ctx.response.open(200); try writer.print("object id {s}", .{ctx.request.arg("id")}); } pub fn main() anyerror!void { var srv = std.net.StreamServer.init(.{ .reuse_address = true }); defer srv.deinit(); try srv.listen(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); while (true) { const conn = try srv.accept(); // todo: keep track of connections _ = async handleConnection(conn); } }