TimeAndErrorsThrottler.java
package org.wikidata.query.rdf.blazegraph.throttling;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import javax.servlet.http.HttpServletRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.cache.Cache;
/**
* Throttle users based on the cumulative request processing time and number of errors.
*/
public class TimeAndErrorsThrottler<S extends TimeAndErrorsState> extends Throttler<S> {
private static final Logger LOG = LoggerFactory.getLogger(TimeAndErrorsThrottler.class);
/** Requests longer than this will trigger tracking resource consumption. */
private final Duration requestTimeThreshold;
/**
* Constructor.
*
* Note that a bucket represent our approximation of a single client.
*
* @param requestTimeThreshold requests longer than this will trigger
* tracking resource consumption
* @param createThrottlingState how to create the initial throttling state
* when we start tracking a specific client
* @param stateStore the cache in which we store the per client state of
* throttling
* @param enableThrottlingIfHeader throttling is only enabled if this header is present
* @param alwaysThrottleParam this query parameter will cause throttling no matter what
*/
public TimeAndErrorsThrottler(
Duration requestTimeThreshold,
Callable<S> createThrottlingState,
Cache<Object, S> stateStore,
String enableThrottlingIfHeader,
String alwaysThrottleParam,
Clock clock) {
super(createThrottlingState, stateStore, enableThrottlingIfHeader, alwaysThrottleParam, clock);
this.requestTimeThreshold = requestTimeThreshold;
}
/**
* Notify this throttler that a request has been completed successfully.
*
* @param bucket the bucket to which this request belongs
* @param request the request
* @param elapsed how long that request took
*/
public void success(Object bucket, HttpServletRequest request, Duration elapsed) {
if (shouldBypassThrottling(request)) {
return;
}
try {
S state;
// only start to keep track of time usage if requests are expensive
if (elapsed.compareTo(requestTimeThreshold) > 0) {
state = getState(bucket);
} else {
state = getStateIfPresent(bucket);
}
if (state != null) {
state.consumeTime(elapsed);
}
} catch (ExecutionException e) {
LOG.warn("Could not create throttling state", e);
}
}
/**
* Notify this throttler that a request has completed in error.
*
* @param bucket the bucket to which this request belongs
* @param request the request
* @param elapsed how long that request took
*/
public void failure(Object bucket, HttpServletRequest request, Duration elapsed) {
if (shouldBypassThrottling(request)) return;
try {
S state = getState(bucket);
state.consumeError();
state.consumeTime(elapsed);
} catch (ExecutionException e) {
LOG.warn("Could not create throttling state", e);
}
}
@Override
protected Instant internalThrottledUntil(Object bucket, HttpServletRequest request) {
S state = getStateIfPresent(bucket);
if (state == null) return Instant.MIN;
return state.throttledUntil();
}
}