Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Rate limiting based on user plan in Spring Cloud Gateway

Say my users subscribe to a plan. Is it possible then using Spring Cloud Gateway to rate limit user requests based up on the subscription plan? Given there're Silver and Gold plans, would it let Silver subscriptions to have replenishRate/burstCapacity of 5/10 and Gold 50/100?

I naively thought of passing a new instance of RedisRateLimiter (see below I construct a new one with 5/10 settings) to the filter but I needed to get the information about the user from the request somehow in order to be able to find out whether it is Silver and Gold plan.

@Bean
public RouteLocator myRoutes(RouteLocatorBuilder builder) {
    return builder.routes()
        .route(p -> p
            .path("/get")
            .filters(f ->
              f.requestRateLimiter(r -> {
                  r.setRateLimiter(new RedisRateLimiter(5, 10))
              })
            .uri("http://httpbin.org:80"))
            .build();
}

Am I trying to achieve something that is even possible with Spring Cloud Gateway? What other products would you recommend to check for the purpose if any?

Thanks!

like image 495
Barbadoss Avatar asked Mar 05 '23 23:03

Barbadoss


1 Answers

Okay, it is possible by creating a custom rate limiter on top of RedisRateLimiter class. Unfortunately the class has not been architected for extendability so the solution is somewhat "hacky", I could only decorate the normal RedisRateLimiter and duplicate some of its code in there:

@Primary
@Component
public class ApiKeyRateLimiter implements RateLimiter {

    private Log log = LogFactory.getLog(getClass());

    // How many requests per second do you want a user to be allowed to do?
    private static final int REPLENISH_RATE = 1;
    // How much bursting do you want to allow?
    private static final int BURST_CAPACITY = 1;

    private final RedisRateLimiter rateLimiter;
    private final RedisScript<List<Long>> script;
    private final ReactiveRedisTemplate<String, String> redisTemplate;

    @Autowired
    public ApiKeyRateLimiter(
        RedisRateLimiter rateLimiter,
        @Qualifier(RedisRateLimiter.REDIS_SCRIPT_NAME) RedisScript<List<Long>> script,
        ReactiveRedisTemplate<String, String> redisTemplate) {

        this.rateLimiter = rateLimiter;
        this.script = script;
        this.redisTemplate = redisTemplate;
    }

    // These two methods are the core of the rate limiter
    // Their purpose is to come up with a rate limits for given API KEY (or user ID)
    // It is up to implementor to return limits based up on the api key passed
    private int getBurstCapacity(String routeId, String apiKey) {
        return BURST_CAPACITY;
    }
    private int getReplenishRate(String routeId, String apiKey) {
        return REPLENISH_RATE;
    }

    public Mono<Response> isAllowed(String routeId, String apiKey) {

        int replenishRate = getReplenishRate(routeId, apiKey);
        int burstCapacity = getBurstCapacity(routeId, apiKey);

        try {
            List<String> keys = getKeys(apiKey);

            // The arguments to the LUA script. time() returns unixtime in seconds.
            List<String> scriptArgs = Arrays.asList(replenishRate + "", burstCapacity + "",
                Instant.now().getEpochSecond() + "", "1");
            Flux<List<Long>> flux = this.redisTemplate.execute(this.script, keys, scriptArgs);

            return flux.onErrorResume(throwable -> Flux.just(Arrays.asList(1L, -1L)))
                .reduce(new ArrayList<Long>(), (longs, l) -> {
                    longs.addAll(l);
                    return longs;
                }) .map(results -> {
                    boolean allowed = results.get(0) == 1L;
                    Long tokensLeft = results.get(1);

                    Response response = new Response(allowed, getHeaders(tokensLeft, replenishRate, burstCapacity));

                    if (log.isDebugEnabled()) {
                        log.debug("response: " + response);
                    }
                    return response;
                });
        }
        catch (Exception e) {
            /*
             * We don't want a hard dependency on Redis to allow traffic. Make sure to set
             * an alert so you know if this is happening too much. Stripe's observed
             * failure rate is 0.01%.
             */
            log.error("Error determining if user allowed from redis", e);
        }
        return Mono.just(new Response(true, getHeaders(-1L, replenishRate, burstCapacity)));
    }

    private static List<String> getKeys(String id) {
        String prefix = "request_rate_limiter.{" + id;
        String tokenKey = prefix + "}.tokens";
        String timestampKey = prefix + "}.timestamp";
        return Arrays.asList(tokenKey, timestampKey);
    }

    private HashMap<String, String> getHeaders(Long tokensLeft, Long replenish, Long burst) {
        HashMap<String, String> headers = new HashMap<>();
        headers.put(RedisRateLimiter.REMAINING_HEADER, tokensLeft.toString());
        headers.put(RedisRateLimiter.REPLENISH_RATE_HEADER, replenish.toString());
        headers.put(RedisRateLimiter.BURST_CAPACITY_HEADER, burst.toString());
        return headers;
    }

    @Override
    public Map getConfig() {
        return rateLimiter.getConfig();
    }

    @Override
    public Class getConfigClass() {
        return rateLimiter.getConfigClass();
    }

    @Override
    public Object newConfig() {
        return rateLimiter.newConfig();
    }
}

So, the route would look like this:

@Component
public class Routes {

    @Autowired
    ApiKeyRateLimiter rateLimiter;

    @Autowired
    ApiKeyResolver apiKeyResolver;

    @Bean
    public RouteLocator theRoutes(RouteLocatorBuilder b) {
        return b.routes()
            .route(p -> p
                    .path("/unlimited")
                    .uri("http://httpbin.org:80/anything?route=unlimited")
            )
            .route(p -> p
                    .path("/limited")
                    .filters(f ->
                            f.requestRateLimiter(r -> {
                                r.setKeyResolver(apiKeyResolver);
                                r.setRateLimiter(rateLimiter);
                            } )
                    )
                    .uri("http://httpbin.org:80/anything?route=limited")
            )
            .build();
    }

}

Hope it saves a work day for somebody...

like image 173
Barbadoss Avatar answered Mar 19 '23 02:03

Barbadoss