137 lines
4.1 KiB
Zig
137 lines
4.1 KiB
Zig
const std = @import("std");
|
|
const testing = std.testing;
|
|
|
|
const logger = std.log.scoped(.generic_ratelimiter);
|
|
|
|
var global_context: ?*Context = null;
|
|
|
|
// TODO upstream this constant
|
|
const RTLD_NEXT = @intToPtr(*anyopaque, @bitCast(usize, @as(isize, -1)));
|
|
|
|
fn getOriginalFunction(comptime T: type, name: [:0]const u8) T {
|
|
var maybe_func_ptr = @call(.never_tail, std.os.system.dlsym, .{ RTLD_NEXT, name });
|
|
if (maybe_func_ptr) |func_ptr| {
|
|
return @ptrCast(T, func_ptr);
|
|
} else {
|
|
@panic("wanted to get original, but didnt get it :( alexa play despacito");
|
|
}
|
|
}
|
|
|
|
fn printAddrAsString(posix_addr: *std.os.sockaddr, buf: []u8) ![]const u8 {
|
|
const addr = std.net.Address.initPosix(@alignCast(4, posix_addr));
|
|
return try std.fmt.bufPrint(buf, "{}", .{addr});
|
|
}
|
|
|
|
const RatelimitConfig = struct {
|
|
requests: usize,
|
|
second: f64,
|
|
};
|
|
|
|
const DEFAULT_RATELIMIT = RatelimitConfig{ .requests = 1, .second = 10.0 };
|
|
|
|
fn timestamp() f64 {
|
|
return @divFloor(@intToFloat(f64, std.time.microTimestamp()), 1000);
|
|
}
|
|
|
|
const Ratelimiter = struct {
|
|
config: RatelimitConfig,
|
|
|
|
window: f64 = 0.0,
|
|
tokens: usize,
|
|
retries: usize = 0,
|
|
last: f64 = 0.0,
|
|
|
|
const Self = @This();
|
|
|
|
pub fn init(config: RatelimitConfig) Self {
|
|
return Self{
|
|
.config = config,
|
|
.tokens = config.requests,
|
|
};
|
|
}
|
|
|
|
pub fn getTokens(self: *Self, given_current: f64) usize {
|
|
const current = if (given_current == 0) timestamp() else given_current;
|
|
var tokens = self.tokens;
|
|
if (current > self.window + self.config.second) {
|
|
tokens = self.config.requests;
|
|
}
|
|
return tokens;
|
|
}
|
|
|
|
pub fn updateRateLimit(self: *Self) f64 {
|
|
const current = timestamp();
|
|
self.last = current;
|
|
self.tokens = self.getTokens(current);
|
|
|
|
if (self.tokens == self.config.requests) self.window = current;
|
|
|
|
if (self.tokens == 0) {
|
|
self.retries += 1;
|
|
return self.config.second - (current - self.window);
|
|
}
|
|
|
|
self.retries = 0;
|
|
self.tokens -= 1;
|
|
if (self.tokens == 0) self.window = current;
|
|
return 0;
|
|
}
|
|
};
|
|
|
|
const AddressMap = std.StringHashMap(Ratelimiter);
|
|
|
|
const Context = struct {
|
|
allocator: std.mem.Allocator,
|
|
original_connect: *const @TypeOf(std.os.system.connect),
|
|
address_map: AddressMap,
|
|
|
|
const Self = @This();
|
|
|
|
pub fn init() Context {
|
|
var allocator = std.heap.c_allocator;
|
|
return Context{
|
|
.allocator = allocator,
|
|
.original_connect = getOriginalFunction(@TypeOf(&std.os.system.connect), "connect"),
|
|
.address_map = AddressMap.init(allocator),
|
|
};
|
|
}
|
|
|
|
pub fn getSleepDuration(self: *Self, posix_addr: *std.os.sockaddr) !f64 {
|
|
var buf: [32]u8 = undefined;
|
|
const as_string = try printAddrAsString(posix_addr, &buf);
|
|
|
|
var maybe_ratelimiter = try self.address_map.getOrPut(as_string);
|
|
|
|
if (!maybe_ratelimiter.found_existing) {
|
|
maybe_ratelimiter.value_ptr.* = Ratelimiter.init(DEFAULT_RATELIMIT);
|
|
}
|
|
|
|
return maybe_ratelimiter.value_ptr.updateRateLimit();
|
|
}
|
|
};
|
|
|
|
fn initGlobalContext() *Context {
|
|
return if (global_context) |ctx| ctx else blk: {
|
|
var ctx = std.heap.c_allocator.create(Context) catch @panic("out of memory");
|
|
ctx.* = Context.init();
|
|
global_context = ctx;
|
|
break :blk ctx;
|
|
};
|
|
}
|
|
|
|
export fn connect(fd: c_int, addr: *std.os.sockaddr, len: std.os.socklen_t) c_int {
|
|
var ctx = initGlobalContext();
|
|
logger.info("connect() was called lmao", .{});
|
|
var sleep_duration = ctx.getSleepDuration(addr) catch |err| blk: {
|
|
logger.warn("failed to get sleep duration: {s}", .{@errorName(err)});
|
|
break :blk 0;
|
|
};
|
|
var buf: [32]u8 = undefined;
|
|
const as_string = printAddrAsString(addr, &buf) catch unreachable;
|
|
logger.info("getSleepDuration({s}) => {}", .{ as_string, sleep_duration });
|
|
if (sleep_duration > 0) {
|
|
std.time.sleep(@floatToInt(u64, sleep_duration * std.time.ns_per_s));
|
|
}
|
|
return ctx.original_connect(fd, addr, len);
|
|
}
|