diff --git a/src/http.zig b/src/http.zig new file mode 100644 index 0000000..546537b --- /dev/null +++ b/src/http.zig @@ -0,0 +1,427 @@ +const std = @import("std"); +const root = @import("root"); + +const ciutf8 = root.ciutf8; +const Reader = std.net.Stream.Reader; +const Writer = std.net.Stream.Writer; + +const HeaderMap = std.HashMap([]const u8, []const u8, struct { + pub fn eql(_: @This(), a: []const u8, b: []const u8) bool { + return ciutf8.eql(a, b); + } + + pub fn hash(_: @This(), str: []const u8) u64 { + return ciutf8.hash(str); + } +}, std.hash_map.default_max_load_percentage); + +fn handleBadRequest(writer: Writer) !void { + std.log.info("400 Bad Request", .{}); + try writer.writeAll("HTTP/1.1 400 Bad Request"); +} + +fn handleNotImplemented(writer: Writer) !void { + std.log.info("501", .{}); + try writer.writeAll("HTTP/1.1 501 Not Implemented"); +} + +fn handleInternalError(writer: Writer) !void { + std.log.info("500", .{}); + try writer.writeAll("HTTP/1.1 500 Internal Server Error"); +} + +pub const Method = enum { + GET, + //HEAD, + POST, + //PUT, + //DELETE, + //CONNECT, + //OPTIONS, + //TRACE, +}; + +fn parseHttpMethod(reader: Reader) !Method { + var buf: [8]u8 = undefined; + const str = reader.readUntilDelimiter(&buf, ' ') catch |err| switch (err) { + error.StreamTooLong => return error.MethodNotImplemented, + else => return err, + }; + + inline for (@typeInfo(Method).Enum.fields) |method| { + if (std.mem.eql(u8, method.name, str)) { + return @intToEnum(Method, method.value); + } + } + + return error.MethodNotImplemented; +} + +fn checkProto(reader: Reader) !void { + var buf: [8]u8 = undefined; + const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { + error.StreamTooLong => return error.UnknownProtocol, + else => return err, + }; + + if (!std.mem.eql(u8, proto, "HTTP")) { + return error.UnknownProtocol; + } + + const count = try reader.read(buf[0..3]); + if (count != 3 or buf[1] != '.') { + return error.BadRequest; + } + + if (buf[0] != '1' or buf[2] != '1') { + return error.HttpVersionNotSupported; + } +} + +fn extractHeaderName(line: []const u8) ?[]const u8 { + var index: usize = 0; + + // TODO: handle whitespace + while (index < line.len) : (index += 1) { + if (line[index] == ':') { + if (index == 0) return null; + return line[0..index]; + } + } + + return null; +} + +fn parseHeaders(allocator: std.mem.Allocator, reader: Reader) !HeaderMap { + var map = HeaderMap.init(allocator); + errdefer map.deinit(); + // TODO: free map keys/values + + var buf: [1024]u8 = undefined; + + while (true) { + const line = try reader.readUntilDelimiter(&buf, '\n'); + if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break; + + // TODO: handle multi-line headers + const name = extractHeaderName(line) orelse continue; + const value_end = if (line[line.len - 1] == '\r') line.len - 1 else line.len; + const value = line[name.len + 1 + 1 .. value_end]; + + if (name.len == 0 or value.len == 0) return error.BadRequest; + + const name_alloc = try allocator.alloc(u8, name.len); + errdefer allocator.free(name_alloc); + const value_alloc = try allocator.alloc(u8, value.len); + errdefer allocator.free(value_alloc); + + @memcpy(name_alloc.ptr, name.ptr, name.len); + @memcpy(value_alloc.ptr, value.ptr, value.len); + + try map.put(name_alloc, value_alloc); + } + + return map; +} + +pub fn handleConnection(conn: std.net.StreamServer.Connection) void { + defer conn.stream.close(); + const reader = conn.stream.reader(); + const writer = conn.stream.writer(); + + handleRequest(reader, writer) catch |err| std.log.err("unhandled error processing connection: {}", .{err}); +} + +fn handleRequest(reader: Reader, writer: Writer) !void { + handleHttpRequest(reader, writer) catch |err| switch (err) { + error.BadRequest, error.UnknownProtocol => try handleBadRequest(writer), + error.MethodNotImplemented, error.HttpVersionNotSupported => try handleNotImplemented(writer), + else => { + std.log.err("unknown error handling request: {}", .{err}); + try handleInternalError(writer); + }, + }; +} + +fn handleHttpRequest(reader: Reader, writer: Writer) anyerror!void { + const method = try parseHttpMethod(reader); + + var header_buf: [1 << 16]u8 = undefined; + var fba = std.heap.FixedBufferAllocator.init(&header_buf); + const allocator = fba.allocator(); + const path = reader.readUntilDelimiterAlloc(allocator, ' ', header_buf.len) catch |err| switch (err) { + error.StreamTooLong => return error.URITooLong, + else => return err, + }; + + try checkProto(reader); + _ = try reader.readByte(); + _ = try reader.readByte(); + + const headers = try parseHeaders(allocator, reader); + + const has_body = (headers.get("Content-Length") orelse headers.get("Transfer-Encoding")) != null; + + const tfer_encoding = headers.get("Transfer-Encoding"); + if (tfer_encoding != null and !std.mem.eql(u8, tfer_encoding.?, "identity")) { + return error.UnsupportedMediaType; + } + + const encoding = headers.get("Content-Encoding"); + if (encoding != null and !std.mem.eql(u8, encoding.?, "identity")) { + return error.UnsupportedMediaType; + } + + var context = Context{ + .request = .{ + .method = method, + .path = path, + .headers = headers, + .body = if (has_body) reader else null, + }, + .response = .{ + .headers = HeaderMap.init(allocator), + .writer = writer, + }, + .allocator = allocator, + }; + + try routeRequest(&context); +} + +pub const Context = struct { + const Request = struct { + method: Method, + path: []const u8, + + route: ?*const Route = null, + + headers: HeaderMap, + + body: ?Reader, + + pub fn arg(self: *Request, name: []const u8) []const u8 { + return self.route.?.arg(name, self.path); + } + }; + + const Response = struct { + headers: HeaderMap, + writer: Writer, + + fn writeHeaders(self: *Response) !void { + var iter = self.headers.iterator(); + var it = iter.next(); + while (it != null) : (it = iter.next()) { + try self.writer.print("{s}: {s}\r\n", .{ it.?.key_ptr.*, it.?.value_ptr.* }); + } + } + + fn statusText(status: u16) []const u8 { + return switch (status) { + 200 => "OK", + 204 => "No Content", + 404 => "Not Found", + else => "", + }; + } + + fn openInternal(self: *Response, status: u16) !void { + try self.writer.print("HTTP/1.1 {} {s}\r\n", .{ status, statusText(status) }); + try self.writeHeaders(); + try self.writer.writeAll("Connection: close\r\n"); // TODO + } + + pub fn open(self: *Response, status: u16) !Writer { + try self.openInternal(status); + try self.writer.writeAll("\r\n"); + + return self.writer; + } + + pub fn write(self: *Response, status: u16, body: []const u8) !void { + try self.openInternal(status); + if (body.len != 0) { + try self.writer.print("Content-Length: {}\r\n", .{body.len}); + if (self.headers.get("Content-Type") == null) { + try self.writer.writeAll("Content-Type: application/octet-stream\r\n"); + } + } + + try self.writer.writeAll("\r\n"); + if (body.len != 0) { + try self.writer.writeAll(body); + } + } + + pub fn statusOnly(self: *Response, status: u16) !void { + try self.openInternal(status); + } + }; + + request: Request, + response: Response, + allocator: std.mem.Allocator, +}; + +pub const Route = struct { + const Segment = union(enum) { + param: []const u8, + literal: []const u8, + }; + + pub const Handler = fn (*Context) callconv(.Async) anyerror!void; + + fn normalize(comptime path: []const u8) []const u8 { + var arr: [path.len]u8 = undefined; + + var i = 0; + for (path) |ch| { + if (i == 0 and ch == '/') continue; + if (i > 0 and ch == '/' and arr[i - 1] == '/') continue; + + arr[i] = ch; + i += 1; + } + + if (i > 0 and arr[i - 1] == '/') { + i -= 1; + } + + return arr[0..i]; + } + + fn parseSegments(comptime path: []const u8) []const Segment { + var count = 1; + for (path) |ch| { + if (ch == '/') count += 1; + } + + var segment_array: [count]Segment = undefined; + + var segment_start = 0; + for (segment_array) |*seg| { + var index = segment_start; + while (index < path.len) : (index += 1) { + if (path[index] == '/') { + break; + } + } + + const slice = path[segment_start..index]; + if (slice.len > 0 and slice[0] == ':') { + // doing this kinda jankily to get around segfaults in compiler + const param = path[segment_start + 1 .. index]; + seg.* = .{ .param = param }; + } else { + seg.* = .{ .literal = slice }; + } + + segment_start = index + 1; + } + + return &segment_array; + } + + pub fn from(method: Method, comptime path: []const u8, handler: Handler) Route { + const segments = parseSegments(normalize(path)); + return Route{ .method = method, .path = segments, .handler = handler }; + } + + fn nextSegment(path: []const u8) ?[]const u8 { + var start: usize = 0; + var end: usize = start; + while (end < path.len) : (end += 1) { + // skip leading slash + if (end == start and path[start] == '/') { + start += 1; + continue; + } else if (path[end] == '/') { + break; + } + } + + if (start == end) return null; + + return path[start..end]; + } + + pub fn matches(self: Route, path: []const u8) bool { + var segment_start: usize = 0; + for (self.path) |seg| { + var index = segment_start; + while (index < path.len) : (index += 1) { + // skip leading slash + if (index == segment_start and path[index] == '/') { + segment_start += 1; + continue; + } else if (path[index] == '/') { + break; + } + } + + const slice = path[segment_start..index]; + const match = switch (seg) { + .literal => |str| ciutf8.eql(slice, str), + .param => true, + }; + + if (!match) return false; + + segment_start = index + 1; + } + + // check for trailing path + while (segment_start < path.len) : (segment_start += 1) { + if (path[segment_start] != '/') return false; + } + + return true; + } + + pub fn arg(self: Route, name: []const u8, path: []const u8) []const u8 { + var index: usize = 0; + for (self.path) |seg| { + const slice = nextSegment(path[index..]); + if (slice == null) return ""; + + index = @ptrToInt(slice.?.ptr) - @ptrToInt(path.ptr) + slice.?.len + 1; + + switch (seg) { + .param => |param| { + if (std.mem.eql(u8, param, name)) { + return slice.?; + } + }, + .literal => continue, + } + } + + std.log.err("unknown parameter {s}", .{name}); + return ""; + } + + method: Method, + path: []const Segment, + handler: Handler, +}; + +fn handleNotFound(ctx: *Context) !void { + try ctx.response.writer.writeAll("HTTP/1.1 404 Not Found\r\n\r\n"); +} + +fn routeRequest(ctx: *Context) !void { + for (root.routes) |*route| { + if (route.method == ctx.request.method and route.matches(ctx.request.path)) { + std.log.info("{s} {s}", .{ @tagName(ctx.request.method), ctx.request.path }); + ctx.request.route = route; + + var buf = try ctx.allocator.allocWithOptions(u8, @frameSize(route.handler), 8, null); + defer ctx.allocator.free(buf); + return await @asyncCall(buf, {}, route.handler, .{ctx}); + } + } + + std.log.info("404 {s} {s}", .{ @tagName(ctx.request.method), ctx.request.path }); + try handleNotFound(ctx); +} diff --git a/src/main.zig b/src/main.zig index b9610c9..8e4a0ff 100644 --- a/src/main.zig +++ b/src/main.zig @@ -2,436 +2,16 @@ const std = @import("std"); pub const db = @import("./db.zig"); pub const util = @import("./util.zig"); +pub const http = @import("./http.zig"); + +pub const Uuid = util.Uuid; +pub const ciutf8 = util.ciutf8; pub const io_mode = .evented; -const Uuid = util.Uuid; -const ciutf8 = util.ciutf8; -const Reader = std.net.Stream.Reader; -const Writer = std.net.Stream.Writer; +const Route = http.Route; -const HeaderMap = std.HashMap([]const u8, []const u8, struct { - pub fn eql(_: @This(), a: []const u8, b: []const u8) bool { - return ciutf8.eql(a, b); - } - - pub fn hash(_: @This(), str: []const u8) u64 { - return ciutf8.hash(str); - } -}, std.hash_map.default_max_load_percentage); - -fn handleBadRequest(writer: Writer) !void { - std.log.info("400 Bad Request", .{}); - try writer.writeAll("HTTP/1.1 400 Bad Request"); -} - -fn handleNotImplemented(writer: Writer) !void { - std.log.info("501", .{}); - try writer.writeAll("HTTP/1.1 501 Not Implemented"); -} - -fn handleInternalError(writer: Writer) !void { - std.log.info("500", .{}); - try writer.writeAll("HTTP/1.1 500 Internal Server Error"); -} - -const Method = enum { - GET, - //HEAD, - POST, - //PUT, - //DELETE, - //CONNECT, - //OPTIONS, - //TRACE, -}; - -fn parseHttpMethod(reader: Reader) !Method { - var buf: [8]u8 = undefined; - const str = reader.readUntilDelimiter(&buf, ' ') catch |err| switch (err) { - error.StreamTooLong => return error.MethodNotImplemented, - else => return err, - }; - - inline for (@typeInfo(Method).Enum.fields) |method| { - if (std.mem.eql(u8, method.name, str)) { - return @intToEnum(Method, method.value); - } - } - - return error.MethodNotImplemented; -} - -fn checkProto(reader: Reader) !void { - var buf: [8]u8 = undefined; - const proto = reader.readUntilDelimiter(&buf, '/') catch |err| switch (err) { - error.StreamTooLong => return error.UnknownProtocol, - else => return err, - }; - - if (!std.mem.eql(u8, proto, "HTTP")) { - return error.UnknownProtocol; - } - - const count = try reader.read(buf[0..3]); - if (count != 3 or buf[1] != '.') { - return error.BadRequest; - } - - if (buf[0] != '1' or buf[2] != '1') { - return error.HttpVersionNotSupported; - } -} - -fn extractHeaderName(line: []const u8) ?[]const u8 { - var index: usize = 0; - - // TODO: handle whitespace - while (index < line.len) : (index += 1) { - if (line[index] == ':') { - if (index == 0) return null; - return line[0..index]; - } - } - - return null; -} - -fn parseHeaders(allocator: std.mem.Allocator, reader: Reader) !HeaderMap { - var map = HeaderMap.init(allocator); - errdefer map.deinit(); - // TODO: free map keys/values - - var buf: [1024]u8 = undefined; - - while (true) { - const line = try reader.readUntilDelimiter(&buf, '\n'); - if (line.len == 0 or (line.len == 1 and line[0] == '\r')) break; - - // TODO: handle multi-line headers - const name = extractHeaderName(line) orelse continue; - const value_end = if (line[line.len - 1] == '\r') line.len - 1 else line.len; - const value = line[name.len + 1 + 1 .. value_end]; - - if (name.len == 0 or value.len == 0) return error.BadRequest; - - const name_alloc = try allocator.alloc(u8, name.len); - errdefer allocator.free(name_alloc); - const value_alloc = try allocator.alloc(u8, value.len); - errdefer allocator.free(value_alloc); - - @memcpy(name_alloc.ptr, name.ptr, name.len); - @memcpy(value_alloc.ptr, value.ptr, value.len); - - try map.put(name_alloc, value_alloc); - } - - return map; -} - -fn handleConnection(conn: std.net.StreamServer.Connection) void { - defer conn.stream.close(); - const reader = conn.stream.reader(); - const writer = conn.stream.writer(); - - handleRequest(reader, writer) catch |err| std.log.err("unhandled error processing connection: {}", .{err}); -} - -fn handleRequest(reader: Reader, writer: Writer) !void { - handleHttpRequest(reader, writer) catch |err| switch (err) { - error.BadRequest, error.UnknownProtocol => try handleBadRequest(writer), - error.MethodNotImplemented, error.HttpVersionNotSupported => try handleNotImplemented(writer), - else => { - std.log.err("unknown error handling request: {}", .{err}); - try handleInternalError(writer); - }, - }; -} - -fn handleHttpRequest(reader: Reader, writer: Writer) anyerror!void { - const method = try parseHttpMethod(reader); - - var header_buf: [1 << 16]u8 = undefined; - var fba = std.heap.FixedBufferAllocator.init(&header_buf); - const allocator = fba.allocator(); - const path = reader.readUntilDelimiterAlloc(allocator, ' ', header_buf.len) catch |err| switch (err) { - error.StreamTooLong => return error.URITooLong, - else => return err, - }; - - try checkProto(reader); - _ = try reader.readByte(); - _ = try reader.readByte(); - - const headers = try parseHeaders(allocator, reader); - - const has_body = (headers.get("Content-Length") orelse headers.get("Transfer-Encoding")) != null; - - const tfer_encoding = headers.get("Transfer-Encoding"); - if (tfer_encoding != null and !std.mem.eql(u8, tfer_encoding.?, "identity")) { - return error.UnsupportedMediaType; - } - - const encoding = headers.get("Content-Encoding"); - if (encoding != null and !std.mem.eql(u8, encoding.?, "identity")) { - return error.UnsupportedMediaType; - } - - var context = Context{ - .request = .{ - .method = method, - .path = path, - .headers = headers, - .body = if (has_body) reader else null, - }, - .response = .{ - .headers = HeaderMap.init(allocator), - .writer = writer, - }, - .allocator = allocator, - }; - - try routeRequest(&context); -} - -const Context = struct { - const Request = struct { - method: Method, - path: []const u8, - - route: ?*const Route = null, - - headers: HeaderMap, - - body: ?Reader, - - pub fn arg(self: *Request, name: []const u8) []const u8 { - return self.route.?.arg(name, self.path); - } - }; - - const Response = struct { - headers: HeaderMap, - writer: Writer, - - fn writeHeaders(self: *Response) !void { - var iter = self.headers.iterator(); - var it = iter.next(); - while (it != null) : (it = iter.next()) { - try self.writer.print("{s}: {s}\r\n", .{ it.?.key_ptr.*, it.?.value_ptr.* }); - } - } - - fn statusText(status: u16) []const u8 { - return switch (status) { - 200 => "OK", - 204 => "No Content", - 404 => "Not Found", - else => "", - }; - } - - fn openInternal(self: *Response, status: u16) !void { - try self.writer.print("HTTP/1.1 {} {s}\r\n", .{ status, statusText(status) }); - try self.writeHeaders(); - try self.writer.writeAll("Connection: close\r\n"); // TODO - } - - pub fn open(self: *Response, status: u16) !Writer { - try self.openInternal(status); - try self.writer.writeAll("\r\n"); - - return self.writer; - } - - pub fn write(self: *Response, status: u16, body: []const u8) !void { - try self.openInternal(status); - if (body.len != 0) { - try self.writer.print("Content-Length: {}\r\n", .{body.len}); - if (self.headers.get("Content-Type") == null) { - try self.writer.writeAll("Content-Type: application/octet-stream\r\n"); - } - } - - try self.writer.writeAll("\r\n"); - if (body.len != 0) { - try self.writer.writeAll(body); - } - } - - pub fn statusOnly(self: *Response, status: u16) !void { - try self.openInternal(status); - } - }; - - request: Request, - response: Response, - allocator: std.mem.Allocator, -}; - -const Route = struct { - const Segment = union(enum) { - param: []const u8, - literal: []const u8, - }; - - const Handler = fn (*Context) callconv(.Async) anyerror!void; - - fn normalize(comptime path: []const u8) []const u8 { - var arr: [path.len]u8 = undefined; - - var i = 0; - for (path) |ch| { - if (i == 0 and ch == '/') continue; - if (i > 0 and ch == '/' and arr[i - 1] == '/') continue; - - arr[i] = ch; - i += 1; - } - - if (i > 0 and arr[i - 1] == '/') { - i -= 1; - } - - return arr[0..i]; - } - - fn parseSegments(comptime path: []const u8) []const Segment { - var count = 1; - for (path) |ch| { - if (ch == '/') count += 1; - } - - var segment_array: [count]Segment = undefined; - - var segment_start = 0; - for (segment_array) |*seg| { - var index = segment_start; - while (index < path.len) : (index += 1) { - if (path[index] == '/') { - break; - } - } - - const slice = path[segment_start..index]; - if (slice.len > 0 and slice[0] == ':') { - // doing this kinda jankily to get around segfaults in compiler - const param = path[segment_start + 1 .. index]; - seg.* = .{ .param = param }; - } else { - seg.* = .{ .literal = slice }; - } - - segment_start = index + 1; - } - - return &segment_array; - } - - pub fn from(method: Method, comptime path: []const u8, handler: Handler) Route { - const segments = parseSegments(normalize(path)); - return Route{ .method = method, .path = segments, .handler = handler }; - } - - fn nextSegment(path: []const u8) ?[]const u8 { - var start: usize = 0; - var end: usize = start; - while (end < path.len) : (end += 1) { - // skip leading slash - if (end == start and path[start] == '/') { - start += 1; - continue; - } else if (path[end] == '/') { - break; - } - } - - if (start == end) return null; - - return path[start..end]; - } - - pub fn matches(self: Route, path: []const u8) bool { - var segment_start: usize = 0; - for (self.path) |seg| { - var index = segment_start; - while (index < path.len) : (index += 1) { - // skip leading slash - if (index == segment_start and path[index] == '/') { - segment_start += 1; - continue; - } else if (path[index] == '/') { - break; - } - } - - const slice = path[segment_start..index]; - const match = switch (seg) { - .literal => |str| ciutf8.eql(slice, str), - .param => true, - }; - - if (!match) return false; - - segment_start = index + 1; - } - - // check for trailing path - while (segment_start < path.len) : (segment_start += 1) { - if (path[segment_start] != '/') return false; - } - - return true; - } - - pub fn arg(self: Route, name: []const u8, path: []const u8) []const u8 { - var index: usize = 0; - for (self.path) |seg| { - const slice = nextSegment(path[index..]); - if (slice == null) return ""; - - index = @ptrToInt(slice.?.ptr) - @ptrToInt(path.ptr) + slice.?.len + 1; - - switch (seg) { - .param => |param| { - if (std.mem.eql(u8, param, name)) { - return slice.?; - } - }, - .literal => continue, - } - } - - std.log.err("unknown parameter {s}", .{name}); - return ""; - } - - method: Method, - path: []const Segment, - handler: Handler, -}; - -fn handleNotFound(ctx: *Context) !void { - try ctx.response.writer.writeAll("HTTP/1.1 404 Not Found\r\n\r\n"); -} - -fn routeRequest(ctx: *Context) !void { - for (routes) |*route| { - if (route.method == ctx.request.method and route.matches(ctx.request.path)) { - std.log.info("{s} {s}", .{ @tagName(ctx.request.method), ctx.request.path }); - ctx.request.route = route; - - var buf = try ctx.allocator.allocWithOptions(u8, @frameSize(route.handler), 8, null); - defer ctx.allocator.free(buf); - return await @asyncCall(buf, {}, route.handler, .{ctx}); - } - } - - std.log.info("404 {s} {s}", .{ @tagName(ctx.request.method), ctx.request.path }); - try handleNotFound(ctx); -} - -const routes = [_]Route{ +pub const routes = [_]Route{ Route.from(.GET, "/", staticString("Index Page")), Route.from(.GET, "/abc", staticString("abc")), Route.from(.GET, "/user/:id", getUser), @@ -440,7 +20,7 @@ const routes = [_]Route{ const this_scheme = "http"; const this_host = "localhost:8080"; -fn getUser(ctx: *Context) anyerror!void { +fn getUser(ctx: *http.Context) anyerror!void { const id_str = ctx.request.arg("id"); const host = ctx.request.headers.get("host") orelse { @@ -471,7 +51,7 @@ fn getUser(ctx: *Context) anyerror!void { fn staticString(comptime str: []const u8) Route.Handler { return (struct { - fn func(ctx: *Context) anyerror!void { + fn func(ctx: *http.Context) anyerror!void { try ctx.response.headers.put("Content-Type", "text/plain"); try ctx.response.write(200, str); } @@ -491,6 +71,6 @@ pub fn main() anyerror!void { const conn = try srv.accept(); // todo: keep track of connections - _ = async handleConnection(conn); + _ = async http.handleConnection(conn); } }