SessionReformulationPrep.java
package org.wikimedia.search.glent;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.collect_list;
import static org.apache.spark.sql.functions.explode;
import static org.apache.spark.sql.functions.lit;
import static org.apache.spark.sql.functions.max;
import static org.apache.spark.sql.functions.struct;
import static org.apache.spark.sql.functions.sum;
import java.util.function.BiFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import com.google.common.annotations.VisibleForTesting;
public class SessionReformulationPrep implements BiFunction<Dataset<Row>, Dataset<Row>, Dataset<Row>> {
private final SuggestionFilter suggestionFilter;
public SessionReformulationPrep(SuggestionFilter suggestionFilter) {
this.suggestionFilter = suggestionFilter;
}
public Dataset<Row> apply(Dataset<Row> dfLog, Dataset<Row> dfOld) {
Dataset<Row> df = pairLogEntries(dfLog);
df = suggestionFilter.filterWithSession(df);
// build new M0Prep which is a combination of dfSugg and old M0Prep
return buildM0Prep(df, dfOld);
}
/**
* Build new version of M0Prep and add it to M0Prep with date = glentDfM0PrepPartNew.
* Steps include:
* limit to previous portion of M0Prep dataframe
* find latest timestamp of the records in previous portion of M0Prep to avoid double counting
* add previous M0Prep defined as M0Prep with date = glentDfM0PrepPartOld
* create new M0Prep
*
* @param df sugg dataframe
* @param dfOld M0Prep dataframe
* @return M0Prep dataframe
*
*/
Dataset<Row> buildM0Prep(Dataset<Row> df, Dataset<Row> dfOld) {
return df
.withColumn("suggCount", lit(1))
.select(col("q1.query").alias("q1_query"),
col("q1.queryNorm").alias("q1_queryNorm"),
col("q1.wikiid").alias("q1_wikiid"),
col("q1.lang").alias("q1_lang"),
col("q2.query").alias("q2_query"),
col("q2.queryNorm").alias("q2_queryNorm"),
col("q2.wikiid").alias("q2_wikiid"),
col("q2.lang").alias("q2_lang"),
col("q1.ts").alias("q1_ts"),
col("q1.hitsTotal").alias("q1_hitsTotal"),
col("q2.hitsTotal").alias("q2_hitsTotal"),
col("q1q2EditDist"),
col("suggCount"))
.union(dfOld)
.groupBy("q1_query", "q1_queryNorm",
"q1_wikiid", "q1_lang",
"q2_query", "q2_queryNorm",
"q2_wikiid", "q2_lang", "q1q2EditDist")
.agg(
max("q1_ts").alias("q1_ts"),
max("q1_hitsTotal").alias("q1_hitsTotal"),
max("q2_hitsTotal").alias("q2_hitsTotal"),
sum("suggCount").alias("suggCount"))
.select("q1_query", "q1_queryNorm",
"q1_wikiid", "q1_lang",
"q2_query", "q2_queryNorm",
"q2_wikiid", "q2_lang",
"q1_ts", "q1_hitsTotal", "q2_hitsTotal", "q1q2EditDist", "suggCount");
}
/**
* create pairs (query, suggestionFilter) based on user's self corrected queries.
*
* @param df sugg dataframe
* @return dataframe with pairs (query, suggestionFilter)
*
*/
@VisibleForTesting
Dataset<Row> pairLogEntries(Dataset<Row> df) {
// aggregate logEntries by identity
df = aggLogByIdentity(df);
// create pairs of log entries
return pairLogEntriesByTs(df);
}
@VisibleForTesting
Dataset<Row> pairLogEntriesByTs(Dataset<Row> df) {
return df
.withColumn("q1", explode(col("logEntry")))
.withColumn("q2", explode(col("logEntry")))
// Suggestion must come after query
.where(col("q2.ts").gt(col("q1.ts")))
.where(col("q2.queryNorm").notEqual(col("q1.queryNorm")))
.select("q1", "q2");
}
@VisibleForTesting
Dataset<Row> aggLogByIdentity(Dataset<Row> df) {
return df
.groupBy("identity")
.agg(collect_list(struct(
"ts", "wikiid", "query", "hitsTotal", "lang", "queryNorm"
)).alias("logEntry"));
}
}