JWTIdentityFilter.java

package org.wikidata.query.rdf.blazegraph.filters;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Stream;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.auth0.jwt.JWT;
import com.auth0.jwt.JWTVerifier;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.JWTVerificationException;
import com.auth0.jwt.interfaces.DecodedJWT;

/**
 * Overrides getRemoteUser() for requests containing a signed JWT token containing a claim of the identity.
 *
 * The token must be created and provided as a cookie to the user in some external service,
 * such as mw-oauth-proxy. The provided identity claim must be represented as a string.
 */
public class JWTIdentityFilter implements Filter {

    private static final Logger LOG = LoggerFactory.getLogger(JWTIdentityFilter.class);

    private Function<HttpServletRequest, Optional<String>> usernames;

    private Function<HttpServletRequest, Optional<String>> usernameProvider(FilterConfiguration config) {
        String cookieName = config.loadStringParam("jwt-identity-cookie-name");
        String identityClaim = config.loadStringParam("jwt-identity-claim");
        String secret = config.loadStringParam("jwt-identity-secret");
        if (allNotNull(cookieName, identityClaim, secret)) {
            LOG.info("Configured filter against {} claim of jwt token in the {} cookie", identityClaim, cookieName);
            return new UsernameFromJWTCookie(cookieName, identityClaim, secret);
        } else if (secret == null) {
            LOG.info("Filter disabled, no configuration available.");
            // Better way to disable filter when unconfigured? Seems better than returning a null provider.
            return r -> Optional.empty();
        } else {
            throw new IllegalArgumentException(
                "All three of jwt-identity-cookie-name, jwt-identity-claim, and jwt-identity-secret " +
                    "must be provided");
        }
    }

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        usernames = usernameProvider(new FilterConfiguration(filterConfig, FilterConfiguration.WDQS_CONFIG_PREFIX));
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        filterChain.doFilter(ProvidedRemoteUserHttpRequestWrapper.wrap(servletRequest, usernames), servletResponse);
    }

    @Override
    public void destroy() {
    }

    private boolean allNotNull(Object... values) {
        return Arrays.stream(values).allMatch(Objects::nonNull);
    }

    public static class UsernameFromJWTCookie implements Function<HttpServletRequest, Optional<String>> {
        private final String cookieName;
        private final String identityClaim;
        private final JWTVerifier verifier;

        public UsernameFromJWTCookie(String cookieName, String identityClaim, String secret) {
            this.cookieName = cookieName;
            this.identityClaim = identityClaim;
            this.verifier = JWT.require(Algorithm.HMAC256(secret)).withClaimPresence(identityClaim).build();
        }

        public Optional<String> apply(HttpServletRequest request) {
            return Optional.ofNullable(request.getCookies())
                .map(Arrays::stream)
                .orElseGet(Stream::empty)
                .filter(c -> cookieName.equals(c.getName()))
                .findFirst()
                .map(Cookie::getValue)
                .flatMap(this::decode)
                .flatMap(t -> Optional.ofNullable(t.getClaim(identityClaim).asString()));
        }

        private Optional<DecodedJWT> decode(String token) {
            try {
                return Optional.ofNullable(verifier.verify(token));
            } catch (JWTVerificationException e) {
                LOG.info("Received invalid JWT token, incorrect secret?");
                return Optional.empty();
            }
        }
    }

    private static final class ProvidedRemoteUserHttpRequestWrapper extends HttpServletRequestWrapper {
        private final String remoteUser;

        public static ServletRequest wrap(ServletRequest servletRequest, Function<HttpServletRequest, Optional<String>> provider) {
            try {
                return wrap((HttpServletRequest)servletRequest, provider);
            } catch (ClassCastException e) {
                // In practice should be unreachable? Unclear.
                return servletRequest;
            }
        }

        public static HttpServletRequest wrap(HttpServletRequest servletRequest, Function<HttpServletRequest, Optional<String>> provider) {
            return provider.apply(servletRequest)
                .<HttpServletRequest>map(username -> new ProvidedRemoteUserHttpRequestWrapper(servletRequest, username))
                .orElse(servletRequest);
        }

        ProvidedRemoteUserHttpRequestWrapper(HttpServletRequest request, String remoteUser) {
            super(request);
            this.remoteUser = remoteUser;
        }

        @Override
        public String getRemoteUser() {
            return remoteUser;
        }
    }
}