SSX3LobbyServer/lib/base/rate_limit.hpp

99 lines
3.0 KiB
C++
Raw Permalink Normal View History

#pragma once
#include <chrono>
#include <base/types.hpp>
namespace base {
template <class Clock = std::chrono::steady_clock, class Dur = std::chrono::microseconds>
struct BasicRateLimiter final {
using TimePointType = std::chrono::time_point<Clock, Dur>;
using EventGrainType = std::uint64_t;
/// (mostly) Opaque per-object state.
struct State {
[[nodiscard]] bool CoolingDown() const noexcept { return coolingDown.load(); }
private:
friend struct BasicRateLimiter;
std::atomic_bool coolingDown {};
/// Time point of when last event was taken.
std::atomic<TimePointType> lastEvent {};
/// Current event count
std::atomic<EventGrainType> eventCount {};
};
template <class Dur2, class Dur3>
constexpr BasicRateLimiter(EventGrainType maxEvents, Dur2 maxRate, Dur3 cooldownTime) noexcept
: cooldownTime(cooldownTime), maxRate(maxRate), maxEventCount(maxEvents) {}
// Disallow copying, but allow movement, if so desired.
constexpr BasicRateLimiter(const BasicRateLimiter&) = delete;
constexpr BasicRateLimiter(BasicRateLimiter&&) noexcept = default;
/// Try and take a single event, possibly activating the rate limit.
[[nodiscard]] constexpr bool TryTakeEvent(State& state) const noexcept { return TryTakeEvents(state, 1); }
/// Try and take events, possibly activating the rate limit.
[[nodiscard]] bool TryTakeEvents(State& state, EventGrainType nrEvents) const noexcept {
// Pre-calculate the current time & the delta time that has
// elapsed since we last entered this function.
//
// This doesn't speed things up per se, but it probably aides
// the compiler a bit to optimize things a bit better.
const auto now = std::chrono::time_point_cast<Dur>(Clock::now());
const auto elapsedSinceLastEvent = (now - state.lastEvent.load());
if(state.coolingDown.load()) {
// Check if we have passed the cool-down time (if we're cooling down).
// If we have, we can let the cooldown go, and let the state take the
// events.
if(elapsedSinceLastEvent >= cooldownTime && state.coolingDown.load())
state.coolingDown.store(false);
else
return false;
}
if(elapsedSinceLastEvent < maxRate) {
// Check if the event count has went past [maxEventCount]/[maxRate].
// If it has, we start the cool-down process.
if((state.eventCount += nrEvents) >= maxEventCount) {
state.coolingDown.store(true);
state.lastEvent.store(now);
return false;
}
} else {
// The event happened far after max rate, so it's probably fine.
state.eventCount = 0;
state.lastEvent.store(now);
}
return true;
}
constexpr void SetMaxEventCount(EventGrainType count) noexcept { maxEventCount = count; }
template <class Dur2>
constexpr void SetMaxRate(Dur2 dur) noexcept {
maxRate = dur;
}
template <class Dur2>
constexpr void SetCooldownTime(Dur2 dur) noexcept {
cooldownTime = dur;
}
private:
Dur cooldownTime;
Dur maxRate;
EventGrainType maxEventCount;
};
using RateLimiter = BasicRateLimiter<>;
} // namespace base