SuggestionAggregator.java
package org.wikimedia.search.glent;
import static org.apache.spark.sql.functions.array_contains;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.collect_set;
import static org.apache.spark.sql.functions.lit;
import static org.apache.spark.sql.functions.log10;
import static org.apache.spark.sql.functions.max;
import static org.apache.spark.sql.functions.row_number;
import static org.apache.spark.sql.functions.sum;
import static org.apache.spark.sql.functions.when;
import java.util.function.Function;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.Window;
import org.apache.spark.sql.expressions.WindowSpec;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class SuggestionAggregator implements Function<Dataset<Row>, Dataset<Row>> {
static final StructType SCHEMA_OUT = DataTypes.createStructType(new StructField[] {
DataTypes.createStructField("method",
DataTypes.createArrayType(DataTypes.StringType, false), false),
DataTypes.createStructField("wiki", DataTypes.StringType, false),
DataTypes.createStructField("suggestion_score", DataTypes.FloatType, false),
DataTypes.createStructField("query", DataTypes.StringType, false),
DataTypes.createStructField("dym", DataTypes.StringType, false)
});
private static class SchemaInconsistencyException extends RuntimeException {
SchemaInconsistencyException(String message, StructType schema) {
super(message + "\n" + schema.treeString());
}
}
public Dataset<Row> apply(Dataset<Row> df) {
df = mergeAlgorithms(df);
df = withSuggestionScore(df);
df = filterToBestSuggestions(df);
// Rename to match CirrusSearch template
df = df.select(
col("algos").alias("method"),
col("wikiid").alias("wiki"),
col("suggestion_score"),
col("query"), col("dym"));
StructType schema = df.schema();
if (!schema.sameType(SCHEMA_OUT)) {
throw new SchemaInconsistencyException(
"Output does not match expected schema!", schema);
}
return df;
}
/**
* Merge results from multiple algorithms so we have a single row per
* (query, suggestion) pair.
*/
private Dataset<Row> mergeAlgorithms(Dataset<Row> df) {
return df
.groupBy("query", "dym", "q1q2EditDist", "wikiid", "lang")
.agg(
collect_set("algo").alias("algos"),
max("queryHitsTotal").alias("queryHitsTotal"),
max("dymHitsTotal").alias("dymHitsTotal"),
sum("suggCount").alias("suggCount"));
}
private Dataset<Row> withSuggestionScore(Dataset<Row> df) {
return df.withColumn("suggestion_score",
score("algos", "q1q2EditDist", "dymHitsTotal", "suggCount"));
}
/**
* Choose best suggestion for all (wiki, query) sets.
*/
private Dataset<Row> filterToBestSuggestions(Dataset<Row> df) {
WindowSpec w = Window
.partitionBy("wikiid", "query")
.orderBy(col("suggestion_score").desc());
return df
// row_number is 1-indexed. Go figure.
.withColumn("rn", row_number().over(w))
.where(col("rn").equalTo(1))
.drop("rn");
}
private Column frhedScore(Column suggCount, Column suggHits, Column editDist) {
// adjHits = suggHits < 10_000 ? suggHits : 10_000 + suggHits/10_000;
Column tenK = lit(10_000F);
Column adjHits = when(suggHits.lt(tenK), suggHits)
.otherwise(tenK.plus(suggHits.divide(tenK)));
// logGeoMean = log(∛(suggCount^2 * adjHits)) = (2*log(suggCount) + log(adjHits))/3
Column logGeoMean = log10(suggCount).multiply(lit(2)).plus(log10(adjHits)).divide(lit(3));
// frhedScore = log(suggCount^2 * adjHits)/3 - editDist
return logGeoMean.minus(editDist);
}
private Column score(String algos, String dist, String hitsTotal, String suggCount) {
return score(col(algos), col(dist), col(hitsTotal), col(suggCount));
}
private Column score(Column algos, Column dist, Column hitsTotal, Column suggCount) {
// make sure M0 beats M1; M2 should be acting alone
// Convert M2 to 10 - dist; for M0, M1, calculate frhedScore
Column adjustedDist = when(array_contains(algos, Method.M2Run.ALGO),
lit(10F).minus(dist)).otherwise(frhedScore(suggCount, hitsTotal, dist));
// M0 has higher precision than M1, give it a boost
Column m0runBoost = when(array_contains(algos, Method.M0Run.ALGO), lit(20F))
.otherwise(lit(0F));
return adjustedDist.cast(DataTypes.FloatType).plus(m0runBoost);
}
}