Rate limiting fundamentals

Leaky bucket

public abstract class RateLimiter {

  protected final int maxRequestPerSec;
  protected RateLimiter(int maxRequestPerSec) {
    this.maxRequestPerSec = maxRequestPerSec;
  }

  abstract boolean allow();
}

public class LeakyBucket extends RateLimiter {

  private long nextAllowedTime;
  private final long REQUEST_INTERVAL_MILLIS;

  protected LeakyBucket(int maxRequestPerSec) {
    super(maxRequestPerSec);
    REQUEST_INTERVAL_MILLIS = 1000 / maxRequestPerSec;
    nextAllowedTime = System.currentTimeMillis();
  }

  @Override
  boolean allow() {
    long curTime = System.currentTimeMillis();
    synchronized (this) {
      if (curTime >= nextAllowedTime) {
        nextAllowedTime = curTime + REQUEST_INTERVAL_MILLIS;
        return true;
      }
      return false;
    }
  }
}

Token Bucket

Eager mode

public class TokenBucket extends RateLimiter {

  private int tokens;

  public TokenBucket(int maxRequestsPerSec) {
    super(maxRequestsPerSec);
    this.tokens = maxRequestsPerSec;
    new Thread(() -> {
      while (true) {
        try {
          TimeUnit.SECONDS.sleep(1);
        } catch (InterruptedException e) {
          e.printStackTrace();
        }
        refillTokens(maxRequestsPerSec);
      }
    }).start();
  }

  @Override
  public boolean allow() {
    synchronized (this) {
      if (tokens == 0) {
        return false;
      }
      tokens--;
      return true;
    }
  }

  private void refillTokens(int cnt) {
    synchronized (this) {
      tokens = Math.min(tokens + cnt, maxRequestPerSec);
      notifyAll();
    }
  }
}

Lazy mode

public class TokenBucketLazyRefill extends RateLimiter {

  private int tokens;
  private long lastRefillTime;

  public TokenBucketLazyRefill(int maxRequestPerSec) {
    super(maxRequestPerSec);
    this.tokens = maxRequestPerSec;
    this.lastRefillTime = System.currentTimeMillis();
  }

  @Override
  public boolean allow() {
    synchronized (this) {
      refillTokens();
      if (tokens == 0) {
        return false;
      }
      tokens--;
      return true;
    }
  }

  private void refillTokens() {
    long curTime = System.currentTimeMillis();
    double secSinceLastRefill = (curTime - lastRefillTime) / 1000.0;
    int cnt = (int) (secSinceLastRefill * maxRequestPerSec);
    if (cnt > 0) {
      tokens = Math.min(tokens + cnt, maxRequestPerSec);
      lastRefillTime = curTime;
    }
  }
}

Fixed Window Counter

public class FixedWindowCounter extends RateLimiter {

  // TODO: Clean up stale entries
  private final ConcurrentMap<Long, AtomicInteger> windows = new ConcurrentHashMap<>();

  protected FixedWindowCounter(int maxRequestPerSec) {
    super(maxRequestPerSec);
  }

  @Override
  boolean allow() {
    long windowKey = System.currentTimeMillis() / 1000 * 1000;
    windows.putIfAbsent(windowKey, new AtomicInteger(0));
    return windows.get(windowKey).incrementAndGet() <= maxRequestPerSec;
  }
}

Sliding Window Log

public class SlidingWindowLog extends RateLimiter {

  private final Queue<Long> log = new LinkedList<>();

  protected SlidingWindowLog(int maxRequestPerSec) {
    super(maxRequestPerSec);
  }

  @Override
  boolean allow() {
    long curTime = System.currentTimeMillis();
    long boundary = curTime - 1000;
    synchronized (log) {
      while (!log.isEmpty() && log.element() <= boundary) {
        log.poll();
      }
      log.add(curTime);
      return log.size() <= maxRequestPerSec;
    }
  }
}

Sliding Window

This is still not accurate becasue it assumes that the distribution of requests in previous window is even, which may not be true. But compares to fixed window counter, which only guarantees rate within each window, and sliding window log, which has huge memory footprint, sliding window is more practical.

public class SlidingWindow extends RateLimiter {

  // TODO: Clean up stale entries
  private final ConcurrentMap<Long, AtomicInteger> windows = new ConcurrentHashMap<>();

  protected SlidingWindow(int maxRequestPerSec) {
    super(maxRequestPerSec);
  }

  @Override
  boolean allow() {
    long curTime = System.currentTimeMillis();
    long curWindowKey = curTime / 1000 * 1000;
    windows.putIfAbsent(curWindowKey, new AtomicInteger(0));
    long preWindowKey = curWindowKey - 1000;
    AtomicInteger preCount = windows.get(preWindowKey);
    if (preCount == null) {
      return windows.get(curWindowKey).incrementAndGet() <= maxRequestPerSec;
    }

    double preWeight = 1 - (curTime - curWindowKey) / 1000.0;
    long count = (long) (preCount.get() * preWeight
        + windows.get(curWindowKey).incrementAndGet());
    return count <= maxRequestPerSec;
  }
}