From fd78300975448f9fd37f9efd90ec4c34efe3facd Mon Sep 17 00:00:00 2001 From: Luna Date: Thu, 23 Feb 2023 23:17:31 -0300 Subject: [PATCH] add codes --- .gitignore | 2 + README.md | 17 ++++++- build.zig | 45 +++++++++++++++++ src/main.zig | 136 +++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 .gitignore create mode 100644 build.zig create mode 100644 src/main.zig diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e73c965 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +zig-cache/ +zig-out/ diff --git a/README.md b/README.md index 5a2dd7e..e820cf4 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,18 @@ # generic_socket_ratelimiter -throttling requests to specific hosts at the socket() syscall level \ No newline at end of file +use https://github.com/mariusae/trickle though + +throttling requests to specific hosts at the connect() call level + +this implements the token bucket algorithm in zig, storing a hashmap for each +address. default ratelimit is 1 connect every 10 seconds. + +hostnames with many addresses won't be treated as a single thing. + +if your clients use TCP keepalive, this won't be able to prevent request +spamming at a certain host. + +really, you should use trickle. + +quickly slapping together an `LD_PRELOAD`able library that puts custom logic +on top of an existing function was fun to make though. diff --git a/build.zig b/build.zig new file mode 100644 index 0000000..e7b6d9b --- /dev/null +++ b/build.zig @@ -0,0 +1,45 @@ +const std = @import("std"); + +// Although this function looks imperative, note that its job is to +// declaratively construct a build graph that will be executed by an external +// runner. +pub fn build(b: *std.Build) void { + // Standard target options allows the person running `zig build` to choose + // what target to build for. Here we do not override the defaults, which + // means any target is allowed, and the default is native. Other options + // for restricting supported target set are available. + const target = b.standardTargetOptions(.{}); + + // Standard optimization options allow the person running `zig build` to select + // between Debug, ReleaseSafe, ReleaseFast, and ReleaseSmall. Here we do not + // set a preferred release mode, allowing the user to decide how to optimize. + const optimize = b.standardOptimizeOption(.{}); + + const lib = b.addSharedLibrary(.{ + .name = "generic_socket_ratelimiter", + // In this case the main source file is merely a path, however, in more + // complicated build scripts, this could be a generated file. + .root_source_file = .{ .path = "src/main.zig" }, + .target = target, + .optimize = optimize, + }); + + // This declares intent for the library to be installed into the standard + // location when the user invokes the "install" step (the default step when + // running `zig build`). + lib.install(); + lib.linkLibC(); + + // Creates a step for unit testing. + const main_tests = b.addTest(.{ + .root_source_file = .{ .path = "src/main.zig" }, + .target = target, + .optimize = optimize, + }); + + // This creates a build step. It will be visible in the `zig build --help` menu, + // and can be selected like this: `zig build test` + // This will evaluate the `test` step rather than the default, which is "install". + const test_step = b.step("test", "Run library tests"); + test_step.dependOn(&main_tests.step); +} diff --git a/src/main.zig b/src/main.zig new file mode 100644 index 0000000..817f033 --- /dev/null +++ b/src/main.zig @@ -0,0 +1,136 @@ +const std = @import("std"); +const testing = std.testing; + +const logger = std.log.scoped(.generic_ratelimiter); + +var global_context: ?*Context = null; + +// TODO upstream this constant +const RTLD_NEXT = @intToPtr(*anyopaque, @bitCast(usize, @as(isize, -1))); + +fn getOriginalFunction(comptime T: type, name: [:0]const u8) T { + var maybe_func_ptr = @call(.never_tail, std.os.system.dlsym, .{ RTLD_NEXT, name }); + if (maybe_func_ptr) |func_ptr| { + return @ptrCast(T, func_ptr); + } else { + @panic("wanted to get original, but didnt get it :( alexa play despacito"); + } +} + +fn printAddrAsString(posix_addr: *std.os.sockaddr, buf: []u8) ![]const u8 { + const addr = std.net.Address.initPosix(@alignCast(4, posix_addr)); + return try std.fmt.bufPrint(buf, "{}", .{addr}); +} + +const RatelimitConfig = struct { + requests: usize, + second: f64, +}; + +const DEFAULT_RATELIMIT = RatelimitConfig{ .requests = 1, .second = 10.0 }; + +fn timestamp() f64 { + return @divFloor(@intToFloat(f64, std.time.microTimestamp()), 1000); +} + +const Ratelimiter = struct { + config: RatelimitConfig, + + window: f64 = 0.0, + tokens: usize, + retries: usize = 0, + last: f64 = 0.0, + + const Self = @This(); + + pub fn init(config: RatelimitConfig) Self { + return Self{ + .config = config, + .tokens = config.requests, + }; + } + + pub fn getTokens(self: *Self, given_current: f64) usize { + const current = if (given_current == 0) timestamp() else given_current; + var tokens = self.tokens; + if (current > self.window + self.config.second) { + tokens = self.config.requests; + } + return tokens; + } + + pub fn updateRateLimit(self: *Self) f64 { + const current = timestamp(); + self.last = current; + self.tokens = self.getTokens(current); + + if (self.tokens == self.config.requests) self.window = current; + + if (self.tokens == 0) { + self.retries += 1; + return self.config.second - (current - self.window); + } + + self.retries = 0; + self.tokens -= 1; + if (self.tokens == 0) self.window = current; + return 0; + } +}; + +const AddressMap = std.StringHashMap(Ratelimiter); + +const Context = struct { + allocator: std.mem.Allocator, + original_connect: *const @TypeOf(std.os.system.connect), + address_map: AddressMap, + + const Self = @This(); + + pub fn init() Context { + var allocator = std.heap.c_allocator; + return Context{ + .allocator = allocator, + .original_connect = getOriginalFunction(@TypeOf(&std.os.system.connect), "connect"), + .address_map = AddressMap.init(allocator), + }; + } + + pub fn getSleepDuration(self: *Self, posix_addr: *std.os.sockaddr) !f64 { + var buf: [32]u8 = undefined; + const as_string = try printAddrAsString(posix_addr, &buf); + + var maybe_ratelimiter = try self.address_map.getOrPut(as_string); + + if (!maybe_ratelimiter.found_existing) { + maybe_ratelimiter.value_ptr.* = Ratelimiter.init(DEFAULT_RATELIMIT); + } + + return maybe_ratelimiter.value_ptr.updateRateLimit(); + } +}; + +fn initGlobalContext() *Context { + return if (global_context) |ctx| ctx else blk: { + var ctx = std.heap.c_allocator.create(Context) catch @panic("out of memory"); + ctx.* = Context.init(); + global_context = ctx; + break :blk ctx; + }; +} + +export fn connect(fd: c_int, addr: *std.os.sockaddr, len: std.os.socklen_t) c_int { + var ctx = initGlobalContext(); + logger.info("connect() was called lmao", .{}); + var sleep_duration = ctx.getSleepDuration(addr) catch |err| blk: { + logger.warn("failed to get sleep duration: {s}", .{@errorName(err)}); + break :blk 0; + }; + var buf: [32]u8 = undefined; + const as_string = printAddrAsString(addr, &buf) catch unreachable; + logger.info("getSleepDuration({s}) => {}", .{ as_string, sleep_duration }); + if (sleep_duration > 0) { + std.time.sleep(@floatToInt(u64, sleep_duration * std.time.ns_per_s)); + } + return ctx.original_connect(fd, addr, len); +}