add codes

This commit is contained in:
Luna 2023-02-23 23:17:31 -03:00
parent f0247795d1
commit fd78300975
4 changed files with 199 additions and 1 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
zig-cache/
zig-out/

View File

@ -1,3 +1,18 @@
# generic_socket_ratelimiter
throttling requests to specific hosts at the socket() syscall level
use https://github.com/mariusae/trickle though
throttling requests to specific hosts at the connect() call level
this implements the token bucket algorithm in zig, storing a hashmap for each
address. default ratelimit is 1 connect every 10 seconds.
hostnames with many addresses won't be treated as a single thing.
if your clients use TCP keepalive, this won't be able to prevent request
spamming at a certain host.
really, you should use trickle.
quickly slapping together an `LD_PRELOAD`able library that puts custom logic
on top of an existing function was fun to make though.

45
build.zig Normal file
View File

@ -0,0 +1,45 @@
const std = @import("std");
// Although this function looks imperative, note that its job is to
// declaratively construct a build graph that will be executed by an external
// runner.
pub fn build(b: *std.Build) void {
// Standard target options allows the person running `zig build` to choose
// what target to build for. Here we do not override the defaults, which
// means any target is allowed, and the default is native. Other options
// for restricting supported target set are available.
const target = b.standardTargetOptions(.{});
// Standard optimization options allow the person running `zig build` to select
// between Debug, ReleaseSafe, ReleaseFast, and ReleaseSmall. Here we do not
// set a preferred release mode, allowing the user to decide how to optimize.
const optimize = b.standardOptimizeOption(.{});
const lib = b.addSharedLibrary(.{
.name = "generic_socket_ratelimiter",
// In this case the main source file is merely a path, however, in more
// complicated build scripts, this could be a generated file.
.root_source_file = .{ .path = "src/main.zig" },
.target = target,
.optimize = optimize,
});
// This declares intent for the library to be installed into the standard
// location when the user invokes the "install" step (the default step when
// running `zig build`).
lib.install();
lib.linkLibC();
// Creates a step for unit testing.
const main_tests = b.addTest(.{
.root_source_file = .{ .path = "src/main.zig" },
.target = target,
.optimize = optimize,
});
// This creates a build step. It will be visible in the `zig build --help` menu,
// and can be selected like this: `zig build test`
// This will evaluate the `test` step rather than the default, which is "install".
const test_step = b.step("test", "Run library tests");
test_step.dependOn(&main_tests.step);
}

136
src/main.zig Normal file
View File

@ -0,0 +1,136 @@
const std = @import("std");
const testing = std.testing;
const logger = std.log.scoped(.generic_ratelimiter);
var global_context: ?*Context = null;
// TODO upstream this constant
const RTLD_NEXT = @intToPtr(*anyopaque, @bitCast(usize, @as(isize, -1)));
fn getOriginalFunction(comptime T: type, name: [:0]const u8) T {
var maybe_func_ptr = @call(.never_tail, std.os.system.dlsym, .{ RTLD_NEXT, name });
if (maybe_func_ptr) |func_ptr| {
return @ptrCast(T, func_ptr);
} else {
@panic("wanted to get original, but didnt get it :( alexa play despacito");
}
}
fn printAddrAsString(posix_addr: *std.os.sockaddr, buf: []u8) ![]const u8 {
const addr = std.net.Address.initPosix(@alignCast(4, posix_addr));
return try std.fmt.bufPrint(buf, "{}", .{addr});
}
const RatelimitConfig = struct {
requests: usize,
second: f64,
};
const DEFAULT_RATELIMIT = RatelimitConfig{ .requests = 1, .second = 10.0 };
fn timestamp() f64 {
return @divFloor(@intToFloat(f64, std.time.microTimestamp()), 1000);
}
const Ratelimiter = struct {
config: RatelimitConfig,
window: f64 = 0.0,
tokens: usize,
retries: usize = 0,
last: f64 = 0.0,
const Self = @This();
pub fn init(config: RatelimitConfig) Self {
return Self{
.config = config,
.tokens = config.requests,
};
}
pub fn getTokens(self: *Self, given_current: f64) usize {
const current = if (given_current == 0) timestamp() else given_current;
var tokens = self.tokens;
if (current > self.window + self.config.second) {
tokens = self.config.requests;
}
return tokens;
}
pub fn updateRateLimit(self: *Self) f64 {
const current = timestamp();
self.last = current;
self.tokens = self.getTokens(current);
if (self.tokens == self.config.requests) self.window = current;
if (self.tokens == 0) {
self.retries += 1;
return self.config.second - (current - self.window);
}
self.retries = 0;
self.tokens -= 1;
if (self.tokens == 0) self.window = current;
return 0;
}
};
const AddressMap = std.StringHashMap(Ratelimiter);
const Context = struct {
allocator: std.mem.Allocator,
original_connect: *const @TypeOf(std.os.system.connect),
address_map: AddressMap,
const Self = @This();
pub fn init() Context {
var allocator = std.heap.c_allocator;
return Context{
.allocator = allocator,
.original_connect = getOriginalFunction(@TypeOf(&std.os.system.connect), "connect"),
.address_map = AddressMap.init(allocator),
};
}
pub fn getSleepDuration(self: *Self, posix_addr: *std.os.sockaddr) !f64 {
var buf: [32]u8 = undefined;
const as_string = try printAddrAsString(posix_addr, &buf);
var maybe_ratelimiter = try self.address_map.getOrPut(as_string);
if (!maybe_ratelimiter.found_existing) {
maybe_ratelimiter.value_ptr.* = Ratelimiter.init(DEFAULT_RATELIMIT);
}
return maybe_ratelimiter.value_ptr.updateRateLimit();
}
};
fn initGlobalContext() *Context {
return if (global_context) |ctx| ctx else blk: {
var ctx = std.heap.c_allocator.create(Context) catch @panic("out of memory");
ctx.* = Context.init();
global_context = ctx;
break :blk ctx;
};
}
export fn connect(fd: c_int, addr: *std.os.sockaddr, len: std.os.socklen_t) c_int {
var ctx = initGlobalContext();
logger.info("connect() was called lmao", .{});
var sleep_duration = ctx.getSleepDuration(addr) catch |err| blk: {
logger.warn("failed to get sleep duration: {s}", .{@errorName(err)});
break :blk 0;
};
var buf: [32]u8 = undefined;
const as_string = printAddrAsString(addr, &buf) catch unreachable;
logger.info("getSleepDuration({s}) => {}", .{ as_string, sleep_duration });
if (sleep_duration > 0) {
std.time.sleep(@floatToInt(u64, sleep_duration * std.time.ns_per_s));
}
return ctx.original_connect(fd, addr, len);
}