SessionReformulationPrep.java

package org.wikimedia.search.glent;

import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.collect_list;
import static org.apache.spark.sql.functions.explode;
import static org.apache.spark.sql.functions.lit;
import static org.apache.spark.sql.functions.max;
import static org.apache.spark.sql.functions.struct;
import static org.apache.spark.sql.functions.sum;

import java.util.function.BiFunction;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

import com.google.common.annotations.VisibleForTesting;

public class SessionReformulationPrep implements BiFunction<Dataset<Row>, Dataset<Row>, Dataset<Row>> {
    private final SuggestionFilter suggestionFilter;

    public SessionReformulationPrep(SuggestionFilter suggestionFilter) {
        this.suggestionFilter = suggestionFilter;
    }

    public Dataset<Row> apply(Dataset<Row> dfLog, Dataset<Row> dfOld) {
        Dataset<Row> df = pairLogEntries(dfLog);
        df = suggestionFilter.filterWithSession(df);
        // build new M0Prep which is a combination of dfSugg and old M0Prep
        return buildM0Prep(df, dfOld);
    }

    /**
     * Build new version of M0Prep and add it to M0Prep with date = glentDfM0PrepPartNew.
     * Steps include:
     * limit to previous portion of M0Prep dataframe
     * find latest timestamp of the records in previous portion of M0Prep to avoid double counting
     * add previous M0Prep defined as M0Prep with date = glentDfM0PrepPartOld
     * create new M0Prep
     *
     * @param df sugg dataframe
     * @param dfOld M0Prep dataframe
     * @return M0Prep dataframe
     *
     */
    Dataset<Row> buildM0Prep(Dataset<Row> df, Dataset<Row> dfOld) {
        return df
            .withColumn("suggCount", lit(1))
            .select(col("q1.query").alias("q1_query"),
                    col("q1.queryNorm").alias("q1_queryNorm"),
                    col("q1.wikiid").alias("q1_wikiid"),
                    col("q1.lang").alias("q1_lang"),
                    col("q2.query").alias("q2_query"),
                    col("q2.queryNorm").alias("q2_queryNorm"),
                    col("q2.wikiid").alias("q2_wikiid"),
                    col("q2.lang").alias("q2_lang"),
                    col("q1.ts").alias("q1_ts"),
                    col("q1.hitsTotal").alias("q1_hitsTotal"),
                    col("q2.hitsTotal").alias("q2_hitsTotal"),
                    col("q1q2EditDist"),
                    col("suggCount"))
            .union(dfOld)
            .groupBy("q1_query", "q1_queryNorm",
                    "q1_wikiid", "q1_lang",
                     "q2_query", "q2_queryNorm",
                    "q2_wikiid", "q2_lang", "q1q2EditDist")
            .agg(
                max("q1_ts").alias("q1_ts"),
                max("q1_hitsTotal").alias("q1_hitsTotal"),
                max("q2_hitsTotal").alias("q2_hitsTotal"),
                sum("suggCount").alias("suggCount"))
            .select("q1_query", "q1_queryNorm",
                    "q1_wikiid", "q1_lang",
                    "q2_query", "q2_queryNorm",
                    "q2_wikiid", "q2_lang",
                    "q1_ts", "q1_hitsTotal", "q2_hitsTotal", "q1q2EditDist", "suggCount");
    }

    /**
     * create pairs (query, suggestionFilter) based on user's self corrected queries.
     *
     * @param df sugg dataframe
     * @return dataframe with pairs (query, suggestionFilter)
     *
     */
    @VisibleForTesting
    Dataset<Row> pairLogEntries(Dataset<Row> df) {
        // aggregate logEntries by identity
        df = aggLogByIdentity(df);
        // create pairs of log entries
        return pairLogEntriesByTs(df);
    }

    @VisibleForTesting
    Dataset<Row> pairLogEntriesByTs(Dataset<Row> df) {
        return df
                .withColumn("q1", explode(col("logEntry")))
                .withColumn("q2", explode(col("logEntry")))
                // Suggestion must come after query
                .where(col("q2.ts").gt(col("q1.ts")))
                .where(col("q2.queryNorm").notEqual(col("q1.queryNorm")))
                .select("q1", "q2");
    }

    @VisibleForTesting
    Dataset<Row> aggLogByIdentity(Dataset<Row> df) {
        return df
                .groupBy("identity")
                .agg(collect_list(struct(
                        "ts", "wikiid", "query", "hitsTotal", "lang", "queryNorm"
                )).alias("logEntry"));
    }
}