SuggestionFilter.java
package org.wikimedia.search.glent;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.length;
import java.time.Duration;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions;
import com.google.common.annotations.VisibleForTesting;
public class SuggestionFilter {
private final int minHitsDiff;
private final int minHitsPercDiff;
private final float maxEditDist;
private final double maxNormEditDist;
private final Duration maxTsDiff;
private final EditDistanceCalc edCalc;
interface EditDistanceCalc {
Column editDistance(Column a, Column b);
}
SuggestionFilter() {
this(100, 10, 3F, 1.5, Duration.ofMinutes(2),
functions::levenshtein);
}
public SuggestionFilter(int minHitsDiff, int minHitsPercDiff, float maxEditDist, double maxNormEditDist,
Duration maxTsDiff, EditDistanceCalc edCalc) {
this.minHitsDiff = minHitsDiff;
this.minHitsPercDiff = minHitsPercDiff;
this.maxEditDist = maxEditDist;
this.maxNormEditDist = maxNormEditDist;
this.maxTsDiff = maxTsDiff;
this.edCalc = edCalc;
}
/**
* Filter suggestions based on hit counts and levenshtein distance.
*/
public Dataset<Row> filter(Dataset<Row> df) {
// suggestion does not provide enough lift in hitsTotal
df = filterMinHitsDiff(df);
// suggestion does not provide enough lift in % of hitsTotal
df = filterMinHitsPercDiff(df);
// suggestion is too far from query
df = filterMaxEditDist(df);
// suggestion is too far from query relative to query length
return filterMaxNormEditDist(df);
}
/**
* Filter suggestions based on maximum session length along
* with the default filters.
*/
public Dataset<Row> filterWithSession(Dataset<Row> df) {
return filterMaxTsDiff(filter(df));
}
@VisibleForTesting
Dataset<Row> filterMinHitsDiff(Dataset<Row> df) {
Column diff = col("q2.hitsTotal").minus(col("q1.hitsTotal"));
return df.where(diff.gt(minHitsDiff));
}
private Dataset<Row> filterMinHitsPercDiff(Dataset<Row> df) {
Column diff = col("q2.hitsTotal").divide(col("q1.hitsTotal").plus(1));
Column perc = diff.multiply(100).minus(100);
return df.where(perc.gt(minHitsPercDiff));
}
@VisibleForTesting
Dataset<Row> filterMaxEditDist(Dataset<Row> df) {
// TODO: This is adding a field to the returned df, which doesn't seem like something
// a filter implementation should be doing. The calculation is delayed until here to
// allow throwing out as many suggestions as possible before running a fairly expensive
// edit distance calculation (this should be verified as useful?).
Column dist = edCalc.editDistance(col("q1.queryNorm"), col("q2.queryNorm"));
return df
.withColumn("q1q2EditDist", dist)
.where(dist.lt(maxEditDist));
}
/**
* Filters based on maximum normalized edit distance (edit dist/length of original query).
* This is to ensure that two letter words can't be replaced by unrelated three letter words (even though edit_dist == 3).
* max is params.maxNormEditDist, that has default of 1.5
*
* @param df dataframe
* @return records that match the criteria
*/
@VisibleForTesting
Dataset<Row> filterMaxNormEditDist(Dataset<Row> df) {
return df
.where(length(col("q1.queryNorm")).gt(col("q1q2EditDist").multiply(maxNormEditDist)));
}
@VisibleForTesting
Dataset<Row> filterMaxTsDiff(Dataset<Row> df) {
Column diff = col("q2.ts").minus(col("q1.ts"));
return df.where(diff.lt(maxTsDiff.getSeconds()));
}
}