add codes
This commit is contained in:
parent
f0247795d1
commit
fd78300975
4 changed files with 199 additions and 1 deletions
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
zig-cache/
|
||||||
|
zig-out/
|
17
README.md
17
README.md
|
@ -1,3 +1,18 @@
|
||||||
# generic_socket_ratelimiter
|
# generic_socket_ratelimiter
|
||||||
|
|
||||||
throttling requests to specific hosts at the socket() syscall level
|
use https://github.com/mariusae/trickle though
|
||||||
|
|
||||||
|
throttling requests to specific hosts at the connect() call level
|
||||||
|
|
||||||
|
this implements the token bucket algorithm in zig, storing a hashmap for each
|
||||||
|
address. default ratelimit is 1 connect every 10 seconds.
|
||||||
|
|
||||||
|
hostnames with many addresses won't be treated as a single thing.
|
||||||
|
|
||||||
|
if your clients use TCP keepalive, this won't be able to prevent request
|
||||||
|
spamming at a certain host.
|
||||||
|
|
||||||
|
really, you should use trickle.
|
||||||
|
|
||||||
|
quickly slapping together an `LD_PRELOAD`able library that puts custom logic
|
||||||
|
on top of an existing function was fun to make though.
|
||||||
|
|
45
build.zig
Normal file
45
build.zig
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
const std = @import("std");
|
||||||
|
|
||||||
|
// Although this function looks imperative, note that its job is to
|
||||||
|
// declaratively construct a build graph that will be executed by an external
|
||||||
|
// runner.
|
||||||
|
pub fn build(b: *std.Build) void {
|
||||||
|
// Standard target options allows the person running `zig build` to choose
|
||||||
|
// what target to build for. Here we do not override the defaults, which
|
||||||
|
// means any target is allowed, and the default is native. Other options
|
||||||
|
// for restricting supported target set are available.
|
||||||
|
const target = b.standardTargetOptions(.{});
|
||||||
|
|
||||||
|
// Standard optimization options allow the person running `zig build` to select
|
||||||
|
// between Debug, ReleaseSafe, ReleaseFast, and ReleaseSmall. Here we do not
|
||||||
|
// set a preferred release mode, allowing the user to decide how to optimize.
|
||||||
|
const optimize = b.standardOptimizeOption(.{});
|
||||||
|
|
||||||
|
const lib = b.addSharedLibrary(.{
|
||||||
|
.name = "generic_socket_ratelimiter",
|
||||||
|
// In this case the main source file is merely a path, however, in more
|
||||||
|
// complicated build scripts, this could be a generated file.
|
||||||
|
.root_source_file = .{ .path = "src/main.zig" },
|
||||||
|
.target = target,
|
||||||
|
.optimize = optimize,
|
||||||
|
});
|
||||||
|
|
||||||
|
// This declares intent for the library to be installed into the standard
|
||||||
|
// location when the user invokes the "install" step (the default step when
|
||||||
|
// running `zig build`).
|
||||||
|
lib.install();
|
||||||
|
lib.linkLibC();
|
||||||
|
|
||||||
|
// Creates a step for unit testing.
|
||||||
|
const main_tests = b.addTest(.{
|
||||||
|
.root_source_file = .{ .path = "src/main.zig" },
|
||||||
|
.target = target,
|
||||||
|
.optimize = optimize,
|
||||||
|
});
|
||||||
|
|
||||||
|
// This creates a build step. It will be visible in the `zig build --help` menu,
|
||||||
|
// and can be selected like this: `zig build test`
|
||||||
|
// This will evaluate the `test` step rather than the default, which is "install".
|
||||||
|
const test_step = b.step("test", "Run library tests");
|
||||||
|
test_step.dependOn(&main_tests.step);
|
||||||
|
}
|
136
src/main.zig
Normal file
136
src/main.zig
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
const std = @import("std");
|
||||||
|
const testing = std.testing;
|
||||||
|
|
||||||
|
const logger = std.log.scoped(.generic_ratelimiter);
|
||||||
|
|
||||||
|
var global_context: ?*Context = null;
|
||||||
|
|
||||||
|
// TODO upstream this constant
|
||||||
|
const RTLD_NEXT = @intToPtr(*anyopaque, @bitCast(usize, @as(isize, -1)));
|
||||||
|
|
||||||
|
fn getOriginalFunction(comptime T: type, name: [:0]const u8) T {
|
||||||
|
var maybe_func_ptr = @call(.never_tail, std.os.system.dlsym, .{ RTLD_NEXT, name });
|
||||||
|
if (maybe_func_ptr) |func_ptr| {
|
||||||
|
return @ptrCast(T, func_ptr);
|
||||||
|
} else {
|
||||||
|
@panic("wanted to get original, but didnt get it :( alexa play despacito");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn printAddrAsString(posix_addr: *std.os.sockaddr, buf: []u8) ![]const u8 {
|
||||||
|
const addr = std.net.Address.initPosix(@alignCast(4, posix_addr));
|
||||||
|
return try std.fmt.bufPrint(buf, "{}", .{addr});
|
||||||
|
}
|
||||||
|
|
||||||
|
const RatelimitConfig = struct {
|
||||||
|
requests: usize,
|
||||||
|
second: f64,
|
||||||
|
};
|
||||||
|
|
||||||
|
const DEFAULT_RATELIMIT = RatelimitConfig{ .requests = 1, .second = 10.0 };
|
||||||
|
|
||||||
|
fn timestamp() f64 {
|
||||||
|
return @divFloor(@intToFloat(f64, std.time.microTimestamp()), 1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
const Ratelimiter = struct {
|
||||||
|
config: RatelimitConfig,
|
||||||
|
|
||||||
|
window: f64 = 0.0,
|
||||||
|
tokens: usize,
|
||||||
|
retries: usize = 0,
|
||||||
|
last: f64 = 0.0,
|
||||||
|
|
||||||
|
const Self = @This();
|
||||||
|
|
||||||
|
pub fn init(config: RatelimitConfig) Self {
|
||||||
|
return Self{
|
||||||
|
.config = config,
|
||||||
|
.tokens = config.requests,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getTokens(self: *Self, given_current: f64) usize {
|
||||||
|
const current = if (given_current == 0) timestamp() else given_current;
|
||||||
|
var tokens = self.tokens;
|
||||||
|
if (current > self.window + self.config.second) {
|
||||||
|
tokens = self.config.requests;
|
||||||
|
}
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn updateRateLimit(self: *Self) f64 {
|
||||||
|
const current = timestamp();
|
||||||
|
self.last = current;
|
||||||
|
self.tokens = self.getTokens(current);
|
||||||
|
|
||||||
|
if (self.tokens == self.config.requests) self.window = current;
|
||||||
|
|
||||||
|
if (self.tokens == 0) {
|
||||||
|
self.retries += 1;
|
||||||
|
return self.config.second - (current - self.window);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.retries = 0;
|
||||||
|
self.tokens -= 1;
|
||||||
|
if (self.tokens == 0) self.window = current;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const AddressMap = std.StringHashMap(Ratelimiter);
|
||||||
|
|
||||||
|
const Context = struct {
|
||||||
|
allocator: std.mem.Allocator,
|
||||||
|
original_connect: *const @TypeOf(std.os.system.connect),
|
||||||
|
address_map: AddressMap,
|
||||||
|
|
||||||
|
const Self = @This();
|
||||||
|
|
||||||
|
pub fn init() Context {
|
||||||
|
var allocator = std.heap.c_allocator;
|
||||||
|
return Context{
|
||||||
|
.allocator = allocator,
|
||||||
|
.original_connect = getOriginalFunction(@TypeOf(&std.os.system.connect), "connect"),
|
||||||
|
.address_map = AddressMap.init(allocator),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getSleepDuration(self: *Self, posix_addr: *std.os.sockaddr) !f64 {
|
||||||
|
var buf: [32]u8 = undefined;
|
||||||
|
const as_string = try printAddrAsString(posix_addr, &buf);
|
||||||
|
|
||||||
|
var maybe_ratelimiter = try self.address_map.getOrPut(as_string);
|
||||||
|
|
||||||
|
if (!maybe_ratelimiter.found_existing) {
|
||||||
|
maybe_ratelimiter.value_ptr.* = Ratelimiter.init(DEFAULT_RATELIMIT);
|
||||||
|
}
|
||||||
|
|
||||||
|
return maybe_ratelimiter.value_ptr.updateRateLimit();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fn initGlobalContext() *Context {
|
||||||
|
return if (global_context) |ctx| ctx else blk: {
|
||||||
|
var ctx = std.heap.c_allocator.create(Context) catch @panic("out of memory");
|
||||||
|
ctx.* = Context.init();
|
||||||
|
global_context = ctx;
|
||||||
|
break :blk ctx;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export fn connect(fd: c_int, addr: *std.os.sockaddr, len: std.os.socklen_t) c_int {
|
||||||
|
var ctx = initGlobalContext();
|
||||||
|
logger.info("connect() was called lmao", .{});
|
||||||
|
var sleep_duration = ctx.getSleepDuration(addr) catch |err| blk: {
|
||||||
|
logger.warn("failed to get sleep duration: {s}", .{@errorName(err)});
|
||||||
|
break :blk 0;
|
||||||
|
};
|
||||||
|
var buf: [32]u8 = undefined;
|
||||||
|
const as_string = printAddrAsString(addr, &buf) catch unreachable;
|
||||||
|
logger.info("getSleepDuration({s}) => {}", .{ as_string, sleep_duration });
|
||||||
|
if (sleep_duration > 0) {
|
||||||
|
std.time.sleep(@floatToInt(u64, sleep_duration * std.time.ns_per_s));
|
||||||
|
}
|
||||||
|
return ctx.original_connect(fd, addr, len);
|
||||||
|
}
|
Loading…
Reference in a new issue