Basic web server

This commit is contained in:
jaina heartles 2022-04-02 13:23:18 -07:00
commit 5fbb1b480b
3 changed files with 490 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
/zig-out
/zig-cache

27
build.zig Normal file
View File

@ -0,0 +1,27 @@
const std = @import("std");
pub fn build(b: *std.build.Builder) void {
// Standard target options allows the person running `zig build` to choose
// what target to build for. Here we do not override the defaults, which
// means any target is allowed, and the default is native. Other options
// for restricting supported target set are available.
const target = b.standardTargetOptions(.{});
// Standard release options allow the person running `zig build` to select
// between Debug, ReleaseSafe, ReleaseFast, and ReleaseSmall.
const mode = b.standardReleaseOptions();
const exe = b.addExecutable("apub", "src/main.zig");
exe.setTarget(target);
exe.setBuildMode(mode);
exe.install();
const run_cmd = exe.run();
run_cmd.step.dependOn(b.getInstallStep());
if (b.args) |args| {
run_cmd.addArgs(args);
}
const run_step = b.step("run", "Run the app");
run_step.dependOn(&run_cmd.step);
}

461
src/main.zig Normal file
View File

@ -0,0 +1,461 @@
const std = @import("std");
pub const io_mode = .evented;
const HeaderMap = std.StringHashMap([]const u8);
const Reader = std.net.Stream.Reader;
const Writer = std.net.Stream.Writer;
fn handleBadRequest(writer: Writer) !void {
std.log.info("400 Bad Request", .{});
try writer.writeAll("HTTP/1.1 400 Bad Request");
}
fn handleNotImplemented(writer: Writer) !void {
std.log.info("501", .{});
try writer.writeAll("HTTP/1.1 501 Not Implemented");
}
fn handleInternalError(writer: Writer) !void {
std.log.info("500", .{});
try writer.writeAll("HTTP/1.1 500 Internal Server Error");
}
const Method = enum {
GET,
//HEAD,
POST,
//PUT,
//DELETE,
//CONNECT,
//OPTIONS,
//TRACE,
};
fn areStringsEqual(lhs: []const u8, rhs: []const u8) bool {
if (lhs.len != rhs.len) return false;
for (lhs) |_, i| {
if (lhs[i] != rhs[i]) return false;
}
return true;
}
fn parseHttpMethod(reader: Reader) !Method {
var buf: [8]u8 = undefined;
const str = reader.readUntilDelimiter(&buf, ' ') catch |err| switch (err) {
error.StreamTooLong => return error.MethodNotImplemented,
else => return err,
};
inline for (@typeInfo(Method).Enum.fields) |method| {
if (areStringsEqual(method.name, str)) {
return @intToEnum(Method, method.value);
}
}
return error.MethodNotImplemented;
}
fn checkProto(reader: Reader) !void {
var buf: [8]u8 = undefined;
const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) {
error.StreamTooLong => return error.UnknownProtocol,
else => return err,
};
if (!areStringsEqual(proto, "HTTP")) {
return error.UnknownProtocol;
}
const count = try reader.read(buf[0..3]);
if (count != 3 or buf[1] != '.') {
return error.BadRequest;
}
if (buf[0] != '1' or buf[2] != '1') {
return error.HttpVersionNotSupported;
}
}
fn extractHeaderName(line: []const u8) ?[]const u8 {
var index: usize = 0;
// TODO: handle whitespace
while (index < line.len) : (index += 1) {
if (line[index] == ':') {
if (index == 0) return null;
return line[0..index];
}
}
return null;
}
fn parseHeaders(allocator: std.mem.Allocator, reader: Reader) !HeaderMap {
var map = HeaderMap.init(allocator);
errdefer map.deinit();
// TODO: free map keys/values
var buf: [1024]u8 = undefined;
while (true) {
const line = try reader.readUntilDelimiter(&buf, '\n');
if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break;
// TODO: handle multi-line headers
const name = extractHeaderName(line) orelse continue;
const value = line[name.len + 1 + 1 ..];
if (name.len == 0 or value.len == 0) return error.BadRequest;
const name_alloc = try allocator.alloc(u8, name.len);
errdefer allocator.free(name_alloc);
const value_alloc = try allocator.alloc(u8, value.len);
errdefer allocator.free(value_alloc);
@memcpy(name_alloc.ptr, name.ptr, name.len);
@memcpy(value_alloc.ptr, value.ptr, value.len);
for (name_alloc) |*ch| {
//TODO: utf8
ch.* = std.ascii.toLower(ch.*);
}
try map.put(name_alloc, value_alloc);
}
return map;
}
fn handleConnection(conn: std.net.StreamServer.Connection) void {
defer conn.stream.close();
const reader = conn.stream.reader();
const writer = conn.stream.writer();
handleRequest(reader, writer) catch |err| std.log.err("unhandled error processing connection: {}", .{err});
}
fn handleRequest(reader: Reader, writer: Writer) !void {
handleHttpRequest(reader, writer) catch |err| switch (err) {
error.BadRequest, error.UnknownProtocol => try handleBadRequest(writer),
error.MethodNotImplemented, error.HttpVersionNotSupported => try handleNotImplemented(writer),
else => {
std.log.err("unknown error handling request: {}", .{err});
try handleInternalError(writer);
},
};
}
fn handleHttpRequest(reader: Reader, writer: Writer) anyerror!void {
const method = try parseHttpMethod(reader);
var header_buf: [1 << 16]u8 = undefined;
var fba = std.heap.FixedBufferAllocator.init(&header_buf);
const allocator = fba.allocator();
const path = reader.readUntilDelimiterAlloc(allocator, ' ', header_buf.len) catch |err| switch (err) {
error.StreamTooLong => return error.URITooLong,
else => return err,
};
try checkProto(reader);
_ = try reader.readByte();
_ = try reader.readByte();
const headers = try parseHeaders(allocator, reader);
const has_body = (headers.get("content-length") orelse headers.get("transfer-encoding")) != null;
const tfer_encoding = headers.get("transfer-encoding");
if (tfer_encoding != null and !areStringsEqual(tfer_encoding.?, "identity")) {
return error.UnsupportedMediaType;
}
const encoding = headers.get("content-encoding");
if (encoding != null and !areStringsEqual(encoding.?, "identity")) {
return error.UnsupportedMediaType;
}
var context = Context{
.request = .{
.method = method,
.path = path,
.headers = headers,
.body = if (has_body) reader else null,
},
.response = .{
.headers = HeaderMap.init(allocator),
.writer = writer,
},
.allocator = allocator,
};
try routeRequest(&context);
}
const Context = struct {
const Request = struct {
method: Method,
path: []const u8,
route: ?*const Route = null,
headers: HeaderMap,
body: ?Reader,
pub fn arg(self: *Request, name: []const u8) []const u8 {
return self.route.?.arg(name, self.path);
}
};
const Response = struct {
headers: HeaderMap,
writer: Writer,
fn writeHeaders(self: *Response) !void {
var iter = self.headers.iterator();
var it = iter.next();
while (it != null) : (it = iter.next()) {
try self.writer.print("{s}: {s}\r\n", .{ it.?.key_ptr.*, it.?.value_ptr.* });
}
}
fn statusText(status: u16) []const u8 {
return switch (status) {
200 => "OK",
204 => "No Content",
else => "",
};
}
fn openInternal(self: *Response, status: u16) !void {
try self.writer.print("HTTP/1.1 {} {s}\r\n", .{ status, statusText(status) });
try self.writeHeaders();
try self.writer.writeAll("Connection: close\r\n"); // TODO
}
pub fn open(self: *Response, status: u16) !Writer {
try self.openInternal(status);
try self.writer.writeAll("\r\n");
return self.writer;
}
pub fn write(self: *Response, status: u16, body: []const u8) !void {
try self.openInternal(status);
if (body.len != 0) {
try self.writer.print("Content-Length: {}\r\n", .{body.len});
if (self.headers.get("content-type") == null) {
try self.writer.writeAll("Content-Type: application/octet-stream\r\n");
}
}
try self.writer.writeAll("\r\n");
if (body.len != 0) {
try self.writer.writeAll(body);
}
}
};
request: Request,
response: Response,
allocator: std.mem.Allocator,
};
const Route = struct {
const Segment = union(enum) {
param: []const u8,
literal: []const u8,
};
const Handler = fn (*Context) callconv(.Async) anyerror!void;
fn normalize(comptime path: []const u8) []const u8 {
var arr: [path.len]u8 = undefined;
var i = 0;
for (path) |ch| {
if (i == 0 and ch == '/') continue;
if (i > 0 and ch == '/' and arr[i - 1] == '/') continue;
arr[i] = ch;
i += 1;
}
if (i > 0 and arr[i - 1] == '/') {
i -= 1;
}
return arr[0..i];
}
fn parseSegments(comptime path: []const u8) []const Segment {
var count = 1;
for (path) |ch| {
if (ch == '/') count += 1;
}
var segment_array: [count]Segment = undefined;
var segment_start = 0;
for (segment_array) |*seg| {
var index = segment_start;
while (index < path.len) : (index += 1) {
if (path[index] == '/') {
break;
}
}
const slice = path[segment_start..index];
if (slice.len > 0 and slice[0] == ':') {
// doing this kinda jankily to get around segfaults in compiler
const param = path[segment_start + 1 .. index];
seg.* = .{ .param = param };
} else {
seg.* = .{ .literal = slice };
}
segment_start = index + 1;
}
return &segment_array;
}
pub fn from(method: Method, comptime path: []const u8, handler: Handler) Route {
const segments = parseSegments(normalize(path));
return Route{ .method = method, .path = segments, .handler = handler };
}
fn nextSegment(path: []const u8) ?[]const u8 {
var start: usize = 0;
var end: usize = start;
while (end < path.len) : (end += 1) {
// skip leading slash
if (end == start and path[start] == '/') {
start += 1;
continue;
} else if (path[end] == '/') {
break;
}
}
if (start == end) return null;
return path[start..end];
}
pub fn matches(self: Route, path: []const u8) bool {
var segment_start: usize = 0;
for (self.path) |seg| {
var index = segment_start;
while (index < path.len) : (index += 1) {
// skip leading slash
if (index == segment_start and path[index] == '/') {
segment_start += 1;
continue;
} else if (path[index] == '/') {
break;
}
}
const slice = path[segment_start..index];
const match = switch (seg) {
.literal => |str| areStringsEqual(slice, str),
.param => true,
};
if (!match) return false;
segment_start = index + 1;
}
// check for trailing path
while (segment_start < path.len) : (segment_start += 1) {
if (path[segment_start] != '/') return false;
}
return true;
}
pub fn arg(self: Route, name: []const u8, path: []const u8) []const u8 {
var index: usize = 0;
for (self.path) |seg| {
const slice = nextSegment(path[index..]);
if (slice == null) return "";
index = @ptrToInt(slice.?.ptr) - @ptrToInt(path.ptr) + slice.?.len + 1;
switch (seg) {
.param => |param| {
if (areStringsEqual(param, name)) {
return slice.?;
}
},
.literal => continue,
}
}
std.log.err("unknown parameter {s}", .{name});
return "";
}
method: Method,
path: []const Segment,
handler: Handler,
};
fn handleNotFound(ctx: *Context) !void {
try ctx.response.writer.writeAll("HTTP/1.1 404 Not Found\r\n\r\n");
}
fn routeRequest(ctx: *Context) !void {
for (routes) |*route| {
if (route.method == ctx.request.method and route.matches(ctx.request.path)) {
std.log.info("{s} {s}", .{ @tagName(ctx.request.method), ctx.request.path });
ctx.request.route = route;
var buf = try ctx.allocator.allocWithOptions(u8, @frameSize(route.handler), 8, null);
defer ctx.allocator.free(buf);
return await @asyncCall(buf, {}, route.handler, .{ctx});
}
}
std.log.info("404 {s} {s}", .{ @tagName(ctx.request.method), ctx.request.path });
try handleNotFound(ctx);
}
const routes = [_]Route{
Route.from(.GET, "/", staticString("Index Page")),
Route.from(.GET, "/test", staticString("some test value idfk")),
Route.from(.GET, "/objs/:id/get", getObjIdGet),
Route.from(.POST, "/form/submit", staticString("form submit accepted")),
};
fn staticString(comptime str: []const u8) Route.Handler {
return (struct {
fn func(ctx: *Context) anyerror!void {
try ctx.response.headers.put("content-type", "text/plain");
try ctx.response.write(200, str);
}
}).func;
}
fn getObjIdGet(ctx: *Context) anyerror!void {
try ctx.response.headers.put("content-type", "text/plain");
var writer = try ctx.response.open(200);
try writer.print("object id {s}", .{ctx.request.arg("id")});
}
pub fn main() anyerror!void {
var srv = std.net.StreamServer.init(.{ .reuse_address = true });
defer srv.deinit();
try srv.listen(std.net.Address.parseIp("0.0.0.0", 8080) catch unreachable);
while (true) {
const conn = try srv.accept();
// todo: keep track of connections
_ = async handleConnection(conn);
}
}