SuggestionFilter.java

package org.wikimedia.search.glent;

import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.length;

import java.time.Duration;

import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions;

import com.google.common.annotations.VisibleForTesting;

public class SuggestionFilter {
    private final int minHitsDiff;
    private final int minHitsPercDiff;
    private final float maxEditDist;
    private final double maxNormEditDist;
    private final Duration maxTsDiff;
    private final EditDistanceCalc edCalc;

    interface EditDistanceCalc {
        Column editDistance(Column a, Column b);
    }

    SuggestionFilter() {
        this(100, 10, 3F, 1.5, Duration.ofMinutes(2),
                functions::levenshtein);
    }

    public SuggestionFilter(int minHitsDiff, int minHitsPercDiff, float maxEditDist, double maxNormEditDist,
                            Duration maxTsDiff, EditDistanceCalc edCalc) {
        this.minHitsDiff = minHitsDiff;
        this.minHitsPercDiff = minHitsPercDiff;
        this.maxEditDist = maxEditDist;
        this.maxNormEditDist = maxNormEditDist;
        this.maxTsDiff = maxTsDiff;
        this.edCalc = edCalc;
    }

    /**
     * Filter suggestions based on hit counts and levenshtein distance.
     */
    public Dataset<Row> filter(Dataset<Row> df) {
        // suggestion does not provide enough lift in hitsTotal
        df = filterMinHitsDiff(df);
        // suggestion does not provide enough lift in % of hitsTotal
        df = filterMinHitsPercDiff(df);
        // suggestion is too far from query
        df = filterMaxEditDist(df);
        // suggestion is too far from query relative to query length
        return filterMaxNormEditDist(df);
    }

    /**
     * Filter suggestions based on maximum session length along
     * with the default filters.
     */
    public Dataset<Row> filterWithSession(Dataset<Row> df) {
        return filterMaxTsDiff(filter(df));
    }

    @VisibleForTesting
    Dataset<Row> filterMinHitsDiff(Dataset<Row> df) {
        Column diff = col("q2.hitsTotal").minus(col("q1.hitsTotal"));
        return df.where(diff.gt(minHitsDiff));
    }

    private Dataset<Row> filterMinHitsPercDiff(Dataset<Row> df) {
        Column diff = col("q2.hitsTotal").divide(col("q1.hitsTotal").plus(1));
        Column perc = diff.multiply(100).minus(100);
        return df.where(perc.gt(minHitsPercDiff));
    }

    @VisibleForTesting
    Dataset<Row> filterMaxEditDist(Dataset<Row> df) {
        // TODO: This is adding a field to the returned df, which doesn't seem like something
        // a filter implementation should be doing. The calculation is delayed until here to
        // allow throwing out as many suggestions as possible before running a fairly expensive
        // edit distance calculation (this should be verified as useful?).
        Column dist = edCalc.editDistance(col("q1.queryNorm"), col("q2.queryNorm"));
        return df
                .withColumn("q1q2EditDist", dist)
                .where(dist.lt(maxEditDist));
    }

    /**
     * Filters based on maximum normalized edit distance (edit dist/length of original query).
     * This is to ensure that two letter words can't be replaced by unrelated three letter words (even though edit_dist == 3).
     * max is params.maxNormEditDist, that has default of 1.5
     *
     * @param df dataframe
     * @return records that match the criteria
     */
    @VisibleForTesting
    Dataset<Row> filterMaxNormEditDist(Dataset<Row> df) {
        return df
                .where(length(col("q1.queryNorm")).gt(col("q1q2EditDist").multiply(maxNormEditDist)));
    }

    @VisibleForTesting
    Dataset<Row> filterMaxTsDiff(Dataset<Row> df) {
        Column diff = col("q2.ts").minus(col("q1.ts"));
        return df.where(diff.lt(maxTsDiff.getSeconds()));
    }
}