SessionReformulationSuggester.java

package org.wikimedia.search.glent;

import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.max;
import static org.apache.spark.sql.functions.sum;

import java.util.function.Function;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

public class SessionReformulationSuggester implements Function<Dataset<Row>, Dataset<Row>> {

    /**
     * for now order dym by levenstein distance between normalized queries from q1 and q2 and hitsTotal of q2.
     *
     * @param df M0Prep dataframe
     * @return dataframe of "did you mean" results
     *
     */
    public Dataset<Row> apply(Dataset<Row> df) {
        return df
            .groupBy(
                    col("q1_query").alias("query"),
                    col("q2_queryNorm").alias("dym"),
                    col("q1q2EditDist"),
                    col("q1_wikiid").alias("wikiid"),
                    col("q1_lang").alias("lang"))
            .agg(
                max("q1_ts").alias("ts"),
                max("q1_hitsTotal").alias("queryHitsTotal"),
                max("q2_hitsTotal").alias("dymHitsTotal"),
                sum("suggCount").alias("suggCount"))
            .orderBy(col("wikiid"), col("query"),
                    col("q1q2EditDist").asc(), col("dymHitsTotal").desc())
            .select("query", "dym", "suggCount", "q1q2EditDist", "queryHitsTotal",
                    "dymHitsTotal", "wikiid", "lang", "ts");
    }
}