const std = @import("std"); const testing = std.testing; const logger = std.log.scoped(.generic_ratelimiter); var global_context: ?*Context = null; // TODO upstream this constant const RTLD_NEXT = @intToPtr(*anyopaque, @bitCast(usize, @as(isize, -1))); fn getOriginalFunction(comptime T: type, name: [:0]const u8) T { var maybe_func_ptr = @call(.never_tail, std.os.system.dlsym, .{ RTLD_NEXT, name }); if (maybe_func_ptr) |func_ptr| { return @ptrCast(T, func_ptr); } else { @panic("wanted to get original, but didnt get it :( alexa play despacito"); } } fn printAddrAsString(posix_addr: *std.os.sockaddr, buf: []u8) ![]const u8 { const addr = std.net.Address.initPosix(@alignCast(4, posix_addr)); return try std.fmt.bufPrint(buf, "{}", .{addr}); } const RatelimitConfig = struct { requests: usize, second: f64, }; const DEFAULT_RATELIMIT = RatelimitConfig{ .requests = 1, .second = 10.0 }; fn timestamp() f64 { return @divFloor(@intToFloat(f64, std.time.microTimestamp()), 1000); } const Ratelimiter = struct { config: RatelimitConfig, window: f64 = 0.0, tokens: usize, retries: usize = 0, last: f64 = 0.0, const Self = @This(); pub fn init(config: RatelimitConfig) Self { return Self{ .config = config, .tokens = config.requests, }; } pub fn getTokens(self: *Self, given_current: f64) usize { const current = if (given_current == 0) timestamp() else given_current; var tokens = self.tokens; if (current > self.window + self.config.second) { tokens = self.config.requests; } return tokens; } pub fn updateRateLimit(self: *Self) f64 { const current = timestamp(); self.last = current; self.tokens = self.getTokens(current); if (self.tokens == self.config.requests) self.window = current; if (self.tokens == 0) { self.retries += 1; return self.config.second - (current - self.window); } self.retries = 0; self.tokens -= 1; if (self.tokens == 0) self.window = current; return 0; } }; const AddressMap = std.StringHashMap(Ratelimiter); const Context = struct { allocator: std.mem.Allocator, original_connect: *const @TypeOf(std.os.system.connect), address_map: AddressMap, const Self = @This(); pub fn init() Context { var allocator = std.heap.c_allocator; return Context{ .allocator = allocator, .original_connect = getOriginalFunction(@TypeOf(&std.os.system.connect), "connect"), .address_map = AddressMap.init(allocator), }; } pub fn getSleepDuration(self: *Self, posix_addr: *std.os.sockaddr) !f64 { var buf: [32]u8 = undefined; const as_string = try printAddrAsString(posix_addr, &buf); var maybe_ratelimiter = try self.address_map.getOrPut(as_string); if (!maybe_ratelimiter.found_existing) { maybe_ratelimiter.value_ptr.* = Ratelimiter.init(DEFAULT_RATELIMIT); } return maybe_ratelimiter.value_ptr.updateRateLimit(); } }; fn initGlobalContext() *Context { return if (global_context) |ctx| ctx else blk: { var ctx = std.heap.c_allocator.create(Context) catch @panic("out of memory"); ctx.* = Context.init(); global_context = ctx; break :blk ctx; }; } export fn connect(fd: c_int, addr: *std.os.sockaddr, len: std.os.socklen_t) c_int { var ctx = initGlobalContext(); logger.info("connect() was called lmao", .{}); var sleep_duration = ctx.getSleepDuration(addr) catch |err| blk: { logger.warn("failed to get sleep duration: {s}", .{@errorName(err)}); break :blk 0; }; var buf: [32]u8 = undefined; const as_string = printAddrAsString(addr, &buf) catch unreachable; logger.info("getSleepDuration({s}) => {}", .{ as_string, sleep_duration }); if (sleep_duration > 0) { std.time.sleep(@floatToInt(u64, sleep_duration * std.time.ns_per_s)); } return ctx.original_connect(fd, addr, len); }