SimilarQueriesSuggester.java
package org.wikimedia.search.glent;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.explode;
import static org.apache.spark.sql.functions.length;
import static org.apache.spark.sql.functions.max;
import static org.apache.spark.sql.functions.regexp_replace;
import static org.apache.spark.sql.functions.struct;
import static org.apache.spark.sql.functions.sum;
import java.time.Duration;
import java.util.function.Function;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.wikimedia.search.glent.fst.AllPairsLevenshtein;
import com.google.common.annotations.VisibleForTesting;
/**
* Looks for queries in the request log that have a very high string
* similarity and considers them as query suggestions.
*
* Applies a form of all-pairs matching to generate candidate pairs that are a small
* edit distance apart. To increase the allowed edit distance beyond the limit
* of 2 provided by the underlying libraries the candidate generation is applied
* against strings that have had duplicate characters squashed together (abbbc to abc)
* during matching. This squashing also supports our custom edit distance applied
* in SuggestionFilter which considers duplicates less than a full character of distance.
*/
public class SimilarQueriesSuggester implements Function<Dataset<Row>, Dataset<Row>> {
public static final int MIN_STRING_LENGTH = 4;
private final SuggestionFilter suggestionFilter;
private final int numFst;
public SimilarQueriesSuggester(SuggestionFilter suggestionFilter, int numFst) {
// TODO: We allow suggesitonFilter to be nullable since candidate
// generation doesn't need it. It makes the api unclear though.
this.suggestionFilter = suggestionFilter;
this.numFst = numFst;
}
public Dataset<Row> apply(Dataset<Row> df) {
return apply(df, null);
}
/**
* Transform m1prep output into suggestions.
*
* Candidates may optionally be provided to allow different executor
* resourcing on large jobs.
*/
public Dataset<Row> apply(Dataset<Row> df, Dataset<Row> candidates) {
Dataset<Row> structs = prepareInput(df);
if (candidates == null) {
candidates = generateCandidatesFromStructs(structs);
}
return aggregate(suggestionFilter.filter(
resolveCandidates(candidates, structs)));
}
/**
* Prepare input dataframe for transformation
*
* Prepare a version of the normalized query with duplicate characters squashed
* to extend the matchable set of strings. Further bundle up the row into a struct
* to prepare for having two separate queries, the query and the suggestion, in
* the same row. Each query will be represented by a column containing one of
* these structs.
*
* @param df m1prep output
*/
private Dataset<Row> prepareInput(Dataset<Row> df) {
return df.select(
dedupChars(col("queryNorm")).alias("queryNormDedup"),
struct(
"query", "queryNorm", "hitsTotal", "suggCount", "ts", "wikiid", "lang"
).alias("sugg"));
}
/**
* Apply a form of all-pairs matching to generate candidate pairs.
*
* Candidate generation has dramatically different resource requirements
* from normal spark tasks. Prefer large executors with many cores and
* 1GB/core, rather than the standard spark configuration of low core count
* and 2GB/core. Most memory needed is for a data structure shared between
* all tasks running on the same executor.
*
* @param df m1prep output
*/
public Dataset<Row> generateCandidates(Dataset<Row> df) {
return generateCandidatesFromStructs(prepareInput(df));
}
private Dataset<Row> generateCandidatesFromStructs(Dataset<Row> structs) {
int numParts = estimateQueryPartitions(structs.count());
// The fst building process will deduplicate, no need to have spark
// do it too.
Dataset<String> dictionary = structs
.select(col("queryNormDedup"))
.map((MapFunction<Row, String>) row -> row.getString(0), Encoders.STRING());
Dataset<Row> queriesToLookup = structs
// Short queries generate massive numbers of useless suggestions, a 3 character
// query returns everything within edit distance of 2. Filter them out of the
// source queries. The short queries are only filtered from the lookup side,
// we still index them into the dictionary will return them as candidates.
.where(length(col("sugg.queryNorm")).geq(MIN_STRING_LENGTH))
// Send only the query string and deduplicate to have the smallest
// evaluation set possible. We use a single FST for all wikis and
// let later processing in spark filter out cross-wiki suggestions.
// This reduces the size of the FST and the number of queries we
// need to execute, moving the processing to vanilla spark code.
.select(col("queryNormDedup"))
.dropDuplicates()
.repartition(numParts);
// Find all pairs of queryNorm within levenshtein distance of 2.
return new AllPairsLevenshtein(dictionary, numFst)
.apply(queriesToLookup, col("queryNormDedup"), "dymDedup");
}
/**
* Transform raw FST output to valid suggestion pairs
*
* Map the candidates generated by all pairs matching back to the source
* queries. Filter candidates down to queries from the same wiki and language,
* along with removal of queries suggesting themselves.
*
* For reference candidates has two fields: queryNormDedup of the string that
* was searched for in the FST, and an array of nearby queryNormDedup strings
* identified by candidate generation.
*
* While not explicitly called out, this process also results in queries with
* the same queryNormDedup being suggested for each other since the FST emits
* the source query as a candidate.
*/
private Dataset<Row> resolveCandidates(Dataset<Row> candidates, Dataset<Row> structs) {
Dataset<Row> queries = structs
.select(
col("queryNormDedup").alias("query"),
col("sugg").alias("q1"))
.alias("queries");
Dataset<Row> dym = structs
.select(
col("queryNormDedup").alias("dym"),
col("sugg").alias("q2"))
.alias("dym");
return candidates.alias("c")
// bring in metadata about the source query
.join(queries, col("c.queryNormDedup").equalTo(col("queries.query")))
// prepare for join by exploding array of candidates into
// row per candidate.
.select(
col("queries.q1"),
explode(col("c.dymDedup")).alias("dymDedup"))
// bring in metadata about the candidate
.join(dym, col("dymDedup").equalTo(col("dym.dym")))
// simplify to (q1, q2). This finishes dropping the dedup'd strings we added.
.select(col("queries.q1"), col("dym.q2"))
// A query can't suggest itself
.where(col("q1.queryNorm").notEqual(col("q2.queryNorm")))
// filter to same context
.where(col("q1.wikiid").equalTo(col("q2.wikiid")))
.where(col("q1.lang").equalTo(col("q2.lang")));
}
/**
* Aggregate multiple suggestions for the same query/dym pair into
* a single result. Reshape into final flattened output without
* our per-query structs.
*
* TODO: Does this make sense, can the processing employed generate
* multiple suggestions for the same pair? It seems as long as there
* is one row per queryNorm in the input, there should be no duplicates
* here.
*/
private Dataset<Row> aggregate(Dataset<Row> suggestions) {
// TODO: Copy/paste of SessionReformulationSuggester.apply?
return suggestions
.withColumn("dym", col("q2.queryNorm"))
.groupBy("q1.query", "dym", "q1.wikiid", "q1.lang",
"q1q2EditDist")
.agg(
max("q2.ts").alias("ts"),
max("q1.hitsTotal").alias("queryHitsTotal"),
max("q2.hitsTotal").alias("dymHitsTotal"),
sum("q2.suggCount").alias("suggCount"))
.select("query", "dym", "suggCount", "q1q2EditDist",
"queryHitsTotal", "dymHitsTotal", "wikiid",
"lang", "ts");
}
/**
* Strip duplicate characters in string.
*
* Example: `abbbbc deffff` becomes `abc def`
*/
@VisibleForTesting
static Column dedupChars(Column string) {
return regexp_replace(string, "(.)\\1+", "$1");
}
/**
* Rough estimate to get fst lookup partitions that take
* ~10 minutes to run. Otherwise on large datasets we might
* end up with individual partitions that take 3 hours.
*
* @param fstSize Number of rows in input dataset
* @return Number of partitions to perform query with
*/
private int estimateQueryPartitions(long fstSize) {
float m = 1F / 5_000_000F;
// ballpark estimate, this isn't really linear.
float msPerQuery = 2 + m * fstSize;
float queryPerMs = 1 / msPerQuery;
int rowPerPartition = Math.round(Duration.ofMinutes(10).toMillis() * queryPerMs);
return Math.min(40000, Math.max(2, (int)(fstSize / rowPerPartition)));
}
}