commit 5fbb1b480b036e25bceb0b37563c8934e5b0953e Author: jaina heartles Date: Sat Apr 2 13:23:18 2022 -0700 Basic web server diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..33ed699 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/zig-out +/zig-cache \ No newline at end of file diff --git a/build.zig b/build.zig new file mode 100644 index 0000000..0d7ae3b --- /dev/null +++ b/build.zig @@ -0,0 +1,27 @@ +const std = @import("std"); + +pub fn build(b: *std.build.Builder) void { + // Standard target options allows the person running `zig build` to choose + // what target to build for. Here we do not override the defaults, which + // means any target is allowed, and the default is native. Other options + // for restricting supported target set are available. + const target = b.standardTargetOptions(.{}); + + // Standard release options allow the person running `zig build` to select + // between Debug, ReleaseSafe, ReleaseFast, and ReleaseSmall. + const mode = b.standardReleaseOptions(); + + const exe = b.addExecutable("apub", "src/main.zig"); + exe.setTarget(target); + exe.setBuildMode(mode); + exe.install(); + + const run_cmd = exe.run(); + run_cmd.step.dependOn(b.getInstallStep()); + if (b.args) |args| { + run_cmd.addArgs(args); + } + + const run_step = b.step("run", "Run the app"); + run_step.dependOn(&run_cmd.step); +} diff --git a/src/main.zig b/src/main.zig new file mode 100644 index 0000000..2c9d9c8 --- /dev/null +++ b/src/main.zig @@ -0,0 +1,461 @@ +const std = @import("std"); + +pub const io_mode = .evented; + +const HeaderMap = std.StringHashMap([]const u8); +const Reader = std.net.Stream.Reader; +const Writer = std.net.Stream.Writer; + +fn handleBadRequest(writer: Writer) !void { + std.log.info("400 Bad Request", .{}); + try writer.writeAll("HTTP/1.1 400 Bad Request"); +} + +fn handleNotImplemented(writer: Writer) !void { + std.log.info("501", .{}); + try writer.writeAll("HTTP/1.1 501 Not Implemented"); +} + +fn handleInternalError(writer: Writer) !void { + std.log.info("500", .{}); + try writer.writeAll("HTTP/1.1 500 Internal Server Error"); +} + +const Method = enum { + GET, + //HEAD, + POST, + //PUT, + //DELETE, + //CONNECT, + //OPTIONS, + //TRACE, +}; + +fn areStringsEqual(lhs: []const u8, rhs: []const u8) bool { + if (lhs.len != rhs.len) return false; + for (lhs) |_, i| { + if (lhs[i] != rhs[i]) return false; + } + return true; +} + +fn parseHttpMethod(reader: Reader) !Method { + var buf: [8]u8 = undefined; + const str = reader.readUntilDelimiter(&buf, ' ') catch |err| switch (err) { + error.StreamTooLong => return error.MethodNotImplemented, + else => return err, + }; + + inline for (@typeInfo(Method).Enum.fields) |method| { + if (areStringsEqual(method.name, str)) { + return @intToEnum(Method, method.value); + } + } + + return error.MethodNotImplemented; +} + +fn checkProto(reader: Reader) !void { + var buf: [8]u8 = undefined; + const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { + error.StreamTooLong => return error.UnknownProtocol, + else => return err, + }; + + if (!areStringsEqual(proto, "HTTP")) { + return error.UnknownProtocol; + } + + const count = try reader.read(buf[0..3]); + if (count != 3 or buf[1] != '.') { + return error.BadRequest; + } + + if (buf[0] != '1' or buf[2] != '1') { + return error.HttpVersionNotSupported; + } +} + +fn extractHeaderName(line: []const u8) ?[]const u8 { + var index: usize = 0; + + // TODO: handle whitespace + while (index < line.len) : (index += 1) { + if (line[index] == ':') { + if (index == 0) return null; + return line[0..index]; + } + } + + return null; +} + +fn parseHeaders(allocator: std.mem.Allocator, reader: Reader) !HeaderMap { + var map = HeaderMap.init(allocator); + errdefer map.deinit(); + // TODO: free map keys/values + + var buf: [1024]u8 = undefined; + + while (true) { + const line = try reader.readUntilDelimiter(&buf, '\n'); + if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break; + + // TODO: handle multi-line headers + const name = extractHeaderName(line) orelse continue; + const value = line[name.len + 1 + 1 ..]; + + if (name.len == 0 or value.len == 0) return error.BadRequest; + + const name_alloc = try allocator.alloc(u8, name.len); + errdefer allocator.free(name_alloc); + const value_alloc = try allocator.alloc(u8, value.len); + errdefer allocator.free(value_alloc); + + @memcpy(name_alloc.ptr, name.ptr, name.len); + @memcpy(value_alloc.ptr, value.ptr, value.len); + + for (name_alloc) |*ch| { + //TODO: utf8 + ch.* = std.ascii.toLower(ch.*); + } + + try map.put(name_alloc, value_alloc); + } + + return map; +} + +fn handleConnection(conn: std.net.StreamServer.Connection) void { + defer conn.stream.close(); + const reader = conn.stream.reader(); + const writer = conn.stream.writer(); + + handleRequest(reader, writer) catch |err| std.log.err("unhandled error processing connection: {}", .{err}); +} + +fn handleRequest(reader: Reader, writer: Writer) !void { + handleHttpRequest(reader, writer) catch |err| switch (err) { + error.BadRequest, error.UnknownProtocol => try handleBadRequest(writer), + error.MethodNotImplemented, error.HttpVersionNotSupported => try handleNotImplemented(writer), + else => { + std.log.err("unknown error handling request: {}", .{err}); + try handleInternalError(writer); + }, + }; +} + +fn handleHttpRequest(reader: Reader, writer: Writer) anyerror!void { + const method = try parseHttpMethod(reader); + + var header_buf: [1 << 16]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&header_buf); + const allocator = fba.allocator(); + const path = reader.readUntilDelimiterAlloc(allocator, ' ', header_buf.len) catch |err| switch (err) { + error.StreamTooLong => return error.URITooLong, + else => return err, + }; + + try checkProto(reader); + _ = try reader.readByte(); + _ = try reader.readByte(); + + const headers = try parseHeaders(allocator, reader); + + const has_body = (headers.get("content-length") orelse headers.get("transfer-encoding")) != null; + + const tfer_encoding = headers.get("transfer-encoding"); + if (tfer_encoding != null and !areStringsEqual(tfer_encoding.?, "identity")) { + return error.UnsupportedMediaType; + } + + const encoding = headers.get("content-encoding"); + if (encoding != null and !areStringsEqual(encoding.?, "identity")) { + return error.UnsupportedMediaType; + } + + var context = Context{ + .request = .{ + .method = method, + .path = path, + .headers = headers, + .body = if (has_body) reader else null, + }, + .response = .{ + .headers = HeaderMap.init(allocator), + .writer = writer, + }, + .allocator = allocator, + }; + + try routeRequest(&context); +} + +const Context = struct { + const Request = struct { + method: Method, + path: []const u8, + + route: ?*const Route = null, + + headers: HeaderMap, + + body: ?Reader, + + pub fn arg(self: *Request, name: []const u8) []const u8 { + return self.route.?.arg(name, self.path); + } + }; + + const Response = struct { + headers: HeaderMap, + writer: Writer, + + fn writeHeaders(self: *Response) !void { + var iter = self.headers.iterator(); + var it = iter.next(); + while (it != null) : (it = iter.next()) { + try self.writer.print("{s}: {s}\r\n", .{ it.?.key_ptr.*, it.?.value_ptr.* }); + } + } + + fn statusText(status: u16) []const u8 { + return switch (status) { + 200 => "OK", + 204 => "No Content", + else => "", + }; + } + + fn openInternal(self: *Response, status: u16) !void { + try self.writer.print("HTTP/1.1 {} {s}\r\n", .{ status, statusText(status) }); + try self.writeHeaders(); + try self.writer.writeAll("Connection: close\r\n"); // TODO + } + + pub fn open(self: *Response, status: u16) !Writer { + try self.openInternal(status); + try self.writer.writeAll("\r\n"); + + return self.writer; + } + + pub fn write(self: *Response, status: u16, body: []const u8) !void { + try self.openInternal(status); + if (body.len != 0) { + try self.writer.print("Content-Length: {}\r\n", .{body.len}); + if (self.headers.get("content-type") == null) { + try self.writer.writeAll("Content-Type: application/octet-stream\r\n"); + } + } + + try self.writer.writeAll("\r\n"); + if (body.len != 0) { + try self.writer.writeAll(body); + } + } + }; + + request: Request, + response: Response, + allocator: std.mem.Allocator, +}; + +const Route = struct { + const Segment = union(enum) { + param: []const u8, + literal: []const u8, + }; + + const Handler = fn (*Context) callconv(.Async) anyerror!void; + + fn normalize(comptime path: []const u8) []const u8 { + var arr: [path.len]u8 = undefined; + + var i = 0; + for (path) |ch| { + if (i == 0 and ch == '/') continue; + if (i > 0 and ch == '/' and arr[i - 1] == '/') continue; + + arr[i] = ch; + i += 1; + } + + if (i > 0 and arr[i - 1] == '/') { + i -= 1; + } + + return arr[0..i]; + } + + fn parseSegments(comptime path: []const u8) []const Segment { + var count = 1; + for (path) |ch| { + if (ch == '/') count += 1; + } + + var segment_array: [count]Segment = undefined; + + var segment_start = 0; + for (segment_array) |*seg| { + var index = segment_start; + while (index < path.len) : (index += 1) { + if (path[index] == '/') { + break; + } + } + + const slice = path[segment_start..index]; + if (slice.len > 0 and slice[0] == ':') { + // doing this kinda jankily to get around segfaults in compiler + const param = path[segment_start + 1 .. index]; + seg.* = .{ .param = param }; + } else { + seg.* = .{ .literal = slice }; + } + + segment_start = index + 1; + } + + return &segment_array; + } + + pub fn from(method: Method, comptime path: []const u8, handler: Handler) Route { + const segments = parseSegments(normalize(path)); + return Route{ .method = method, .path = segments, .handler = handler }; + } + + fn nextSegment(path: []const u8) ?[]const u8 { + var start: usize = 0; + var end: usize = start; + while (end < path.len) : (end += 1) { + // skip leading slash + if (end == start and path[start] == '/') { + start += 1; + continue; + } else if (path[end] == '/') { + break; + } + } + + if (start == end) return null; + + return path[start..end]; + } + + pub fn matches(self: Route, path: []const u8) bool { + var segment_start: usize = 0; + for (self.path) |seg| { + var index = segment_start; + while (index < path.len) : (index += 1) { + // skip leading slash + if (index == segment_start and path[index] == '/') { + segment_start += 1; + continue; + } else if (path[index] == '/') { + break; + } + } + + const slice = path[segment_start..index]; + const match = switch (seg) { + .literal => |str| areStringsEqual(slice, str), + .param => true, + }; + + if (!match) return false; + + segment_start = index + 1; + } + + // check for trailing path + while (segment_start < path.len) : (segment_start += 1) { + if (path[segment_start] != '/') return false; + } + + return true; + } + + pub fn arg(self: Route, name: []const u8, path: []const u8) []const u8 { + var index: usize = 0; + for (self.path) |seg| { + const slice = nextSegment(path[index..]); + if (slice == null) return ""; + + index = @ptrToInt(slice.?.ptr) - @ptrToInt(path.ptr) + slice.?.len + 1; + + switch (seg) { + .param => |param| { + if (areStringsEqual(param, name)) { + return slice.?; + } + }, + .literal => continue, + } + } + + std.log.err("unknown parameter {s}", .{name}); + return ""; + } + + method: Method, + path: []const Segment, + handler: Handler, +}; + +fn handleNotFound(ctx: *Context) !void { + try ctx.response.writer.writeAll("HTTP/1.1 404 Not Found\r\n\r\n"); +} + +fn routeRequest(ctx: *Context) !void { + for (routes) |*route| { + if (route.method == ctx.request.method and route.matches(ctx.request.path)) { + std.log.info("{s} {s}", .{ @tagName(ctx.request.method), ctx.request.path }); + ctx.request.route = route; + + var buf = try ctx.allocator.allocWithOptions(u8, @frameSize(route.handler), 8, null); + defer ctx.allocator.free(buf); + return await @asyncCall(buf, {}, route.handler, .{ctx}); + } + } + + std.log.info("404 {s} {s}", .{ @tagName(ctx.request.method), ctx.request.path }); + try handleNotFound(ctx); +} + +const routes = [_]Route{ + Route.from(.GET, "/", staticString("Index Page")), + Route.from(.GET, "/test", staticString("some test value idfk")), + Route.from(.GET, "/objs/:id/get", getObjIdGet), + Route.from(.POST, "/form/submit", staticString("form submit accepted")), +}; + +fn staticString(comptime str: []const u8) Route.Handler { + return (struct { + fn func(ctx: *Context) anyerror!void { + try ctx.response.headers.put("content-type", "text/plain"); + try ctx.response.write(200, str); + } + }).func; +} + +fn getObjIdGet(ctx: *Context) anyerror!void { + try ctx.response.headers.put("content-type", "text/plain"); + var writer = try ctx.response.open(200); + try writer.print("object id {s}", .{ctx.request.arg("id")}); +} + +pub fn main() anyerror!void { + var srv = std.net.StreamServer.init(.{ .reuse_address = true }); + defer srv.deinit(); + + try srv.listen(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable); + + while (true) { + const conn = try srv.accept(); + + // todo: keep track of connections + _ = async handleConnection(conn); + } +}