SimilarQueriesSuggester.java

package org.wikimedia.search.glent;

import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.explode;
import static org.apache.spark.sql.functions.length;
import static org.apache.spark.sql.functions.max;
import static org.apache.spark.sql.functions.regexp_replace;
import static org.apache.spark.sql.functions.struct;
import static org.apache.spark.sql.functions.sum;

import java.time.Duration;
import java.util.function.Function;

import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.wikimedia.search.glent.fst.AllPairsLevenshtein;

import com.google.common.annotations.VisibleForTesting;

/**
 * Looks for queries in the request log that have a very high string
 * similarity and considers them as query suggestions.
 *
 * Applies a form of all-pairs matching to generate candidate pairs that are a small
 * edit distance apart. To increase the allowed edit distance beyond the limit
 * of 2 provided by the underlying libraries the candidate generation is applied
 * against strings that have had duplicate characters squashed together (abbbc to abc)
 * during matching. This squashing also supports our custom edit distance applied
 * in SuggestionFilter which considers duplicates less than a full character of distance.
 */
public class SimilarQueriesSuggester implements Function<Dataset<Row>, Dataset<Row>> {
    public static final int MIN_STRING_LENGTH = 4;
    private final SuggestionFilter suggestionFilter;
    private final int numFst;

    public SimilarQueriesSuggester(SuggestionFilter suggestionFilter, int numFst) {
        // TODO: We allow suggesitonFilter to be nullable since candidate
        // generation doesn't need it. It makes the api unclear though.
        this.suggestionFilter = suggestionFilter;
        this.numFst = numFst;
    }

    public Dataset<Row> apply(Dataset<Row> df) {
        return apply(df, null);
    }

    /**
     * Transform m1prep output into suggestions.
     *
     * Candidates may optionally be provided to allow different executor
     * resourcing on large jobs.
     */
    public Dataset<Row> apply(Dataset<Row> df, Dataset<Row> candidates) {
        Dataset<Row> structs = prepareInput(df);
        if (candidates == null) {
            candidates = generateCandidatesFromStructs(structs);
        }
        return aggregate(suggestionFilter.filter(
                resolveCandidates(candidates, structs)));
    }

    /**
     * Prepare input dataframe for transformation
     *
     * Prepare a version of the normalized query with duplicate characters squashed
     * to extend the matchable set of strings. Further bundle up the row into a struct
     * to prepare for having two separate queries, the query and the suggestion, in
     * the same row. Each query will be represented by a column containing one of
     * these structs.
     *
     * @param df m1prep output
     */
    private Dataset<Row> prepareInput(Dataset<Row> df) {
        return df.select(
                dedupChars(col("queryNorm")).alias("queryNormDedup"),
                struct(
                        "query", "queryNorm", "hitsTotal", "suggCount", "ts", "wikiid", "lang"
                ).alias("sugg"));
    }

    /**
     * Apply a form of all-pairs matching to generate candidate pairs.
     *
     * Candidate generation has dramatically different resource requirements
     * from normal spark tasks. Prefer large executors with many cores and
     * 1GB/core, rather than the standard spark configuration of low core count
     * and 2GB/core. Most memory needed is for a data structure shared between
     * all tasks running on the same executor.
     *
     * @param df m1prep output
     */
    public Dataset<Row> generateCandidates(Dataset<Row> df) {
        return generateCandidatesFromStructs(prepareInput(df));
    }

    private Dataset<Row> generateCandidatesFromStructs(Dataset<Row> structs) {
        int numParts = estimateQueryPartitions(structs.count());

        // The fst building process will deduplicate, no need to have spark
        // do it too.
        Dataset<String> dictionary = structs
                .select(col("queryNormDedup"))
                .map((MapFunction<Row, String>) row -> row.getString(0), Encoders.STRING());


        Dataset<Row> queriesToLookup = structs
                // Short queries generate massive numbers of useless suggestions, a 3 character
                // query returns everything within edit distance of 2. Filter them out of the
                // source queries. The short queries are only filtered from the lookup side,
                // we still index them into the dictionary will return them as candidates.
                .where(length(col("sugg.queryNorm")).geq(MIN_STRING_LENGTH))
                // Send only the query string and deduplicate to have the smallest
                // evaluation set possible. We use a single FST for all wikis and
                // let later processing in spark filter out cross-wiki suggestions.
                // This reduces the size of the FST and the number of queries we
                // need to execute, moving the processing to vanilla spark code.
                .select(col("queryNormDedup"))
                .dropDuplicates()
                .repartition(numParts);

        // Find all pairs of queryNorm within levenshtein distance of 2.
        return new AllPairsLevenshtein(dictionary, numFst)
                .apply(queriesToLookup, col("queryNormDedup"), "dymDedup");
    }

    /**
     * Transform raw FST output to valid suggestion pairs
     *
     * Map the candidates generated by all pairs matching back to the source
     * queries. Filter candidates down to queries from the same wiki and language,
     * along with removal of queries suggesting themselves.
     *
     * For reference candidates has two fields: queryNormDedup of the string that
     * was searched for in the FST, and an array of nearby queryNormDedup strings
     * identified by candidate generation.
     *
     * While not explicitly called out, this process also results in queries with
     * the same queryNormDedup being suggested for each other since the FST emits
     * the source query as a candidate.
     */
    private Dataset<Row> resolveCandidates(Dataset<Row> candidates, Dataset<Row> structs) {
        Dataset<Row> queries = structs
                .select(
                        col("queryNormDedup").alias("query"),
                        col("sugg").alias("q1"))
                .alias("queries");
        Dataset<Row> dym = structs
                .select(
                        col("queryNormDedup").alias("dym"),
                        col("sugg").alias("q2"))
                .alias("dym");

        return candidates.alias("c")
                // bring in metadata about the source query
                .join(queries, col("c.queryNormDedup").equalTo(col("queries.query")))
                // prepare for join by exploding array of candidates into
                // row per candidate.
                .select(
                        col("queries.q1"),
                        explode(col("c.dymDedup")).alias("dymDedup"))
                // bring in metadata about the candidate
                .join(dym, col("dymDedup").equalTo(col("dym.dym")))
                // simplify to (q1, q2). This finishes dropping the dedup'd strings we added.
                .select(col("queries.q1"), col("dym.q2"))
                // A query can't suggest itself
                .where(col("q1.queryNorm").notEqual(col("q2.queryNorm")))
                // filter to same context
                .where(col("q1.wikiid").equalTo(col("q2.wikiid")))
                .where(col("q1.lang").equalTo(col("q2.lang")));
    }

    /**
     * Aggregate multiple suggestions for the same query/dym pair into
     * a single result. Reshape into final flattened output without
     * our per-query structs.
     *
     * TODO: Does this make sense, can the processing employed generate
     * multiple suggestions for the same pair? It seems as long as there
     * is one row per queryNorm in the input, there should be no duplicates
     * here.
     */
    private Dataset<Row> aggregate(Dataset<Row> suggestions) {
        // TODO: Copy/paste of SessionReformulationSuggester.apply?
        return suggestions
            .withColumn("dym", col("q2.queryNorm"))
            .groupBy("q1.query", "dym", "q1.wikiid", "q1.lang",
                     "q1q2EditDist")
            .agg(
                max("q2.ts").alias("ts"),
                max("q1.hitsTotal").alias("queryHitsTotal"),
                max("q2.hitsTotal").alias("dymHitsTotal"),
                sum("q2.suggCount").alias("suggCount"))
            .select("query", "dym", "suggCount", "q1q2EditDist",
                    "queryHitsTotal", "dymHitsTotal", "wikiid",
                    "lang", "ts");
    }

    /**
     * Strip duplicate characters in string.
     *
     * Example: `abbbbc   deffff` becomes `abc def`
     */
    @VisibleForTesting
    static Column dedupChars(Column string) {
        return regexp_replace(string, "(.)\\1+", "$1");
    }

    /**
     * Rough estimate to get fst lookup partitions that take
     * ~10 minutes to run. Otherwise on large datasets we might
     * end up with individual partitions that take 3 hours.
     *
     * @param fstSize Number of rows in input dataset
     * @return Number of partitions to perform query with
     */
    private int estimateQueryPartitions(long fstSize) {
        float m = 1F / 5_000_000F;
        // ballpark estimate, this isn't really linear.
        float msPerQuery = 2 + m * fstSize;
        float queryPerMs = 1 / msPerQuery;
        int rowPerPartition = Math.round(Duration.ofMinutes(10).toMillis() * queryPerMs);
        return Math.min(40000, Math.max(2, (int)(fstSize / rowPerPartition)));
    }

}