DictionarySuggester.java

package org.wikimedia.search.glent;

import static java.util.stream.Collectors.toList;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.lit;
import static org.apache.spark.sql.functions.max;
import static org.apache.spark.sql.functions.udf;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.annotation.Nullable;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.types.DataTypes;
import org.wikimedia.search.glent.analysis.GlentTokenizer;
import org.wikimedia.search.glent.analysis.Tokenizers;

import lombok.AllArgsConstructor;

public class DictionarySuggester implements BiFunction<Dataset<Row>, Dataset<Row>, Dataset<Row>> {
    private final Instant earliestLegalTs;

    public DictionarySuggester(Instant earliestLegalTs) {
        this.earliestLegalTs = earliestLegalTs;
    }

    /**
     *
     * @param dfLog CirrusSearch query logs since previous run of suggester
     * @param dfOld Previous run of DictionarySuggestor
     * @return Dictionary based query suggestions
     */
    public Dataset<Row> apply(Dataset<Row> dfLog, Dataset<Row> dfOld) {
        // Verify we have the expected named columns and rename to match our outputs.
        // A more generalized way to assert the shape of inputs would be preferred.
        dfLog = dfLog.select(
            "query", "queryNorm", "lang", "wikiid", "ts", "hitsTotal")
                .withColumnRenamed("hitsTotal", "queryHitsTotal");

        // Verify the same on dfOld
        dfOld = dfOld.select(
                "query", "queryNorm", "lang", "wikiid", "ts", "queryHitsTotal");

        // The query suggestion algorithm runs in isolation, there is no cross-query interaction,
        // and the historical dataset is very small compared to the logs we are ingesting. To
        // simplify operations such as updating the algorithm always run the historical suggestions
        // through the current version of the suggester.
        Dataset<Row> df = legalReqs(dfOld.union(dfLog));
        df = findM2QueryMatch(df);
        return reshape(df);
    }

    static List<String> tokenizeString(String query, String lang) {
        GlentTokenizer tokenizer;
        if (null == lang) {
            return Collections.emptyList();
        } else {
            switch (lang) {
                case "ko":
                    tokenizer = Tokenizers.korean();
                    break;
                case "ja":
                    tokenizer = Tokenizers.japanese();
                    break;
                case "zh":
                    tokenizer = Tokenizers.simplifiedChinese();
                    break;
                default:
                    return Collections.emptyList();
            }
        }
        return tokenizer.tokenize(query, " ");
    }

    static UDF2<String, String, String> buildSuggsM2Udf() {
        return (query, lang) -> {
            M2Resources resources = M2Resources.getInstance();
            return buildSuggsM2(query, lang,
                resources.confusions(), resources.wordFreq().get(lang));
        };
    }

    @AllArgsConstructor
    static class TokenConfusion {
        private static final Pattern CJK_CHAR_PAT =
            Pattern.compile("\\p{IsHan}|\\p{IsHangul}|\\p{IsHiragana}|\\p{IsKatakana}|[\\u3099-\\u309F\\u30FC-\\u30FF\\uFF70]");

        final String token;
        @Nullable
        private final List<String> confusions;

        public boolean isSingleCJK() {
            if (token.length() > 1) {
                return false;
            }
            Matcher matcher = CJK_CHAR_PAT.matcher(token);
            return matcher.matches();
        }

        List<String> confusions() {
            return confusions == null ? Collections.emptyList() : confusions;
        }
    }


    /**
     * identify candidates for dym using tokenizer + dictionary + confusion matrix.
     *
     * @param query queryNorm value
     * @param lang lang value
     * @param confusion map from char seq to list of replacement char sequences
     * @param dictionary per-lang map from token to frequency count
     * @return dym possible suggestion
     *
     */
    static String buildSuggsM2(String query, String lang,
                Map<String, List<String>> confusion, Map<String, Integer> dictionary) {

        if (query.isEmpty() || lang.isEmpty() || dictionary == null) {
            return "";
        }
        List<String> tokens = tokenizeString(query, lang);
        if (tokens.isEmpty()) {
            return "";
        }
        List<TokenConfusion> suggCM = new ArrayList<>();
        for (String token : tokens) {
            List<String> confusions = null;
            if (token.length() == 1) {
                confusions = confusion.get(token);
            }
            suggCM.add(new TokenConfusion(token, confusions));
        }

        String dym = runBuildBestConfusionPiece(suggCM, dictionary);
        String queryOrig = String.join("", tokens);
        return dym.equals(queryOrig) ? "" : dym;
    }

    /**
     * function that runs buildBestConfusionPiece.
     *
     * @param suggCM List of List of confusion values
     * @param dictionary Word frequency counts
     * @return possible "did you mean" suggestion
     *
     */
    @SuppressWarnings("ModifiedControlVariable")
    static String runBuildBestConfusionPiece(List<TokenConfusion> suggCM,
                                             Map<String, Integer> dictionary) {
        StringBuilder sb = new StringBuilder();
        int end = suggCM.size() - 1;
        for (int i = 0; i <= end; i++) {
            // Find continuous run of single cjk characters
            int j = i;
            for (; j <= end; j++) {
                if (!suggCM.get(j).isSingleCJK()) {
                    break;
                }
            }
            // We need at least two sequential cjk tokens
            if (j - i < 2) {
                sb.append(suggCM.get(i).token);
            } else {
                sb.append(buildBestConfusionPiece(suggCM.subList(i, j), dictionary));
                // Continue iteration with the non-single cjk that ended our window
                i = j - 1;
            }
        }
        return sb.toString();
    }

    /**
     * build suggestion based on list of tokens and confusions.
     *
     * @param suggCM list of single character tokens and their confusions
     * @param dictionary word frequency statistics for choosing best
     * @return sugg possible suggestion
     *
     */
    static String buildBestConfusionPiece(List<TokenConfusion> suggCM,
                                                Map<String, Integer> dictionary) {
        List<String> tokens = suggCM.stream().map(tc -> tc.token).collect(toList());
        List<String> suggList = new ArrayList<>();
        suggList.add(String.join("", tokens));
        for (int i = 0; i < suggCM.size(); i++) {
            String left = String.join("", tokens.subList(0, i));
            String right = String.join("", tokens.subList(i + 1, suggCM.size()));
            for (String c : suggCM.get(i).confusions()) {
                suggList.add(left + c + right);
            }
        }

        return suggList.stream()
                .filter(dictionary::containsKey)
                .max(Comparator.comparingInt(dictionary::get))
                .orElseGet(() -> suggList.get(0));
    }

    /**
     * find query match based on M2.
     *
     * @param dfUserQuery user query dataframe
     * @return dataframe with possible suggestions that match user query
     *
     */
    static Dataset<Row> findM2QueryMatch(Dataset<Row> dfUserQuery) {
        UserDefinedFunction buildSuggsM2Udf = udf(buildSuggsM2Udf(), DataTypes.StringType);
        dfUserQuery = dfUserQuery.withColumn("dym",
                buildSuggsM2Udf.apply(col("queryNorm"), col("lang")));

        return dfUserQuery
                .where(col("dym").notEqual(""))
                .where(col("dym").notEqual(col("queryNorm")))
                .distinct();
    }

    /**
     * Reshape for output to shared suggestions table
     *
     * Flattens multiple occurances of the same suggestion and renames columns to match
     * our outputs. The fields of the shared format not used here are set to 0.
     *
     * @param df M2 Suggestions dataframe
     * @return dataframe of "did you mean" results
     */
    static Dataset<Row> reshape(Dataset<Row> df) {
        return df
            .groupBy("query", "dym", "wikiid", "lang")
            .agg(
                    max("ts").alias("ts"),
                    max("queryHitsTotal").alias("queryHitsTotal"))
            .withColumn("q1q2EditDist", lit(0F))
            .withColumn("dymHitsTotal", lit(0))
            .withColumn("suggCount", lit(0));
    }

    /**
     * Removes dataframe entries that have timestamp earlier than required by legal.
     *
     * @param df M1Prep dataframe
     * @return dataframe that satisfies legal requirements
     *
     */
    Dataset<Row> legalReqs(Dataset<Row> df) {
        return df.where(col("ts").geq(earliestLegalTs.getEpochSecond()));
    }
}