generic_socket_ratelimiter/src/main.zig

137 lines
4.1 KiB
Zig

const std = @import("std");
const testing = std.testing;
const logger = std.log.scoped(.generic_ratelimiter);
var global_context: ?*Context = null;
// TODO upstream this constant
const RTLD_NEXT = @intToPtr(*anyopaque, @bitCast(usize, @as(isize, -1)));
fn getOriginalFunction(comptime T: type, name: [:0]const u8) T {
var maybe_func_ptr = @call(.never_tail, std.os.system.dlsym, .{ RTLD_NEXT, name });
if (maybe_func_ptr) |func_ptr| {
return @ptrCast(T, func_ptr);
} else {
@panic("wanted to get original, but didnt get it :( alexa play despacito");
}
}
fn printAddrAsString(posix_addr: *std.os.sockaddr, buf: []u8) ![]const u8 {
const addr = std.net.Address.initPosix(@alignCast(4, posix_addr));
return try std.fmt.bufPrint(buf, "{}", .{addr});
}
const RatelimitConfig = struct {
requests: usize,
second: f64,
};
const DEFAULT_RATELIMIT = RatelimitConfig{ .requests = 1, .second = 10.0 };
fn timestamp() f64 {
return @divFloor(@intToFloat(f64, std.time.microTimestamp()), 1000);
}
const Ratelimiter = struct {
config: RatelimitConfig,
window: f64 = 0.0,
tokens: usize,
retries: usize = 0,
last: f64 = 0.0,
const Self = @This();
pub fn init(config: RatelimitConfig) Self {
return Self{
.config = config,
.tokens = config.requests,
};
}
pub fn getTokens(self: *Self, given_current: f64) usize {
const current = if (given_current == 0) timestamp() else given_current;
var tokens = self.tokens;
if (current > self.window + self.config.second) {
tokens = self.config.requests;
}
return tokens;
}
pub fn updateRateLimit(self: *Self) f64 {
const current = timestamp();
self.last = current;
self.tokens = self.getTokens(current);
if (self.tokens == self.config.requests) self.window = current;
if (self.tokens == 0) {
self.retries += 1;
return self.config.second - (current - self.window);
}
self.retries = 0;
self.tokens -= 1;
if (self.tokens == 0) self.window = current;
return 0;
}
};
const AddressMap = std.StringHashMap(Ratelimiter);
const Context = struct {
allocator: std.mem.Allocator,
original_connect: *const @TypeOf(std.os.system.connect),
address_map: AddressMap,
const Self = @This();
pub fn init() Context {
var allocator = std.heap.c_allocator;
return Context{
.allocator = allocator,
.original_connect = getOriginalFunction(@TypeOf(&std.os.system.connect), "connect"),
.address_map = AddressMap.init(allocator),
};
}
pub fn getSleepDuration(self: *Self, posix_addr: *std.os.sockaddr) !f64 {
var buf: [32]u8 = undefined;
const as_string = try printAddrAsString(posix_addr, &buf);
var maybe_ratelimiter = try self.address_map.getOrPut(as_string);
if (!maybe_ratelimiter.found_existing) {
maybe_ratelimiter.value_ptr.* = Ratelimiter.init(DEFAULT_RATELIMIT);
}
return maybe_ratelimiter.value_ptr.updateRateLimit();
}
};
fn initGlobalContext() *Context {
return if (global_context) |ctx| ctx else blk: {
var ctx = std.heap.c_allocator.create(Context) catch @panic("out of memory");
ctx.* = Context.init();
global_context = ctx;
break :blk ctx;
};
}
export fn connect(fd: c_int, addr: *std.os.sockaddr, len: std.os.socklen_t) c_int {
var ctx = initGlobalContext();
logger.info("connect() was called lmao", .{});
var sleep_duration = ctx.getSleepDuration(addr) catch |err| blk: {
logger.warn("failed to get sleep duration: {s}", .{@errorName(err)});
break :blk 0;
};
var buf: [32]u8 = undefined;
const as_string = printAddrAsString(addr, &buf) catch unreachable;
logger.info("getSleepDuration({s}) => {}", .{ as_string, sleep_duration });
if (sleep_duration > 0) {
std.time.sleep(@floatToInt(u64, sleep_duration * std.time.ns_per_s));
}
return ctx.original_connect(fd, addr, len);
}