AllPairsLevenshtein.java
package org.wikimedia.search.glent.fst;
import static org.apache.spark.sql.functions.expr;
import static org.apache.spark.sql.functions.lit;
import static org.apache.spark.sql.functions.reverse;
import static org.apache.spark.sql.functions.substring;
import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.apache.lucene.util.fst.FST;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.api.java.function.MapPartitionsFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import com.google.common.collect.Iterators;
/**
* TODO: Currently we query aaa → aab, and separately aab → aaa. It seems the fst
* evaluation could be short circuited to only generate suggestions where a ⇐ b, and
* the reversed pair can be generated. Not sure if short circuiting would save 50%,
* but it would probably be useful.
*/
public class AllPairsLevenshtein {
private final int numFst;
private final Dataset<String> dictionary;
private UserDefinedFunction forwardLookup;
private UserDefinedFunction reverseLookup;
/**
* @param dictionary The set of findable strings
* @param numFst The number of partitions for FST indices
*/
public AllPairsLevenshtein(Dataset<String> dictionary, int numFst) {
this.dictionary = dictionary;
this.numFst = numFst;
}
public Dataset<Row> apply(Dataset<Row> input, Column query, String outputCol) {
// input should be cached if expensive to calculate
return apply(input, query, outputCol, 0, 2);
}
public Dataset<Row> apply(Dataset<Row> queries, Column query, String outputCol, int nonFuzzyPrefix, int editDistance) {
boolean reverse = false;
if (nonFuzzyPrefix == 0) {
// Searching the FST without a prefix is quite expensive. Instead we will construct
// a forward and reverse fst and search each with a prefix of 1. This means the first
// and last char can't both change, but seems acceptable limitation for order of magnitude
// speedup.
reverse = true;
nonFuzzyPrefix = 1;
preloadLookups();
}
Dataset<Row> result = query(getForwardLookup(), queries, query, outputCol,
nonFuzzyPrefix, editDistance);
if (!reverse) {
return result;
}
Dataset<Row> reversed = query(
getReverseLookup(), queries, reverse(query), outputCol,
nonFuzzyPrefix, editDistance)
// substring is 1 indexed
.withColumn("firstChar", substring(query, 1, 1))
// Exclude suggestions that do not differ in the first (last, when reversed) character,
// they were already calculated in the forward pass. Couldn't find spark api for higher
// order functions, so this has to use expr()
.withColumn(outputCol, expr(String.format(Locale.ROOT,
"filter(%s, dym -> firstChar != substring(dym, -1, 1))", outputCol)))
.drop("firstChar")
// lookup generated reversed suggestions, flip it all back around.
.withColumn(outputCol, expr(String.format(Locale.ROOT,
"transform(%s, dym -> reverse(dym))", outputCol)));
return result.union(reversed);
}
private void preloadLookups() {
if (forwardLookup == null && reverseLookup == null) {
// Stupid hack to load both at same time.
Thread a = new Thread(this::getForwardLookup, "AllPairsLevenshtein-forward");
Thread b = new Thread(this::getReverseLookup, "AllPairsLevenshtein-reverse");
// If we don't do anything failures will print and the thread will stop,
// but nothing else. Add handling to blow up on failure.
AtomicReference<Throwable> exception = new AtomicReference<>();
a.setUncaughtExceptionHandler((t, e) -> exception.set(e));
b.setUncaughtExceptionHandler((t, e) -> exception.set(e));
try {
a.join();
b.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
if (exception.get() != null) {
throw new RuntimeException("One or both lookups failed", exception.get());
}
}
}
/**
* Transform a dataset of strings into a UDF that will,
* when invoked with a string, return the set of all nearby strings from
* the source dataset. The data structure built to answer these
* queries is significantly compressed from the source, but still takes
* a large amount of memory and must be broadcast to each executor.
*/
private UserDefinedFunction toLookup(Dataset<String> df) {
MapPartitionsFunction<String, SerializableFST> stringsToFST = FSTBuilder::transform;
MapPartitionsFunction<SerializableFST, SerializableFST> mergeFsts = FSTBuilder::merge;
Encoder<SerializableFST> encoder = Encoders.javaSerialization(SerializableFST.class);
// First transform the many input partitions into individual FST's, then
// reduce to the final number of partitions and merge the FST's into one
// per partition.
// In various methods tested spark always distributed all FSTs to all instances,
// so the most efficient will be to reduce to a single FST. If the dataset
// is large enough it must be merged on the driver. Essentially spark can't
// have input partitions larger than 2GB, if the set of inputs is larger than
// that spark has to make multiple FSTs. During a driver side merge we still
// merge down to 10 to reduce total memory requirements on the driver.
Dataset<SerializableFST> dfFst = df.mapPartitions(stringsToFST, encoder);
List<Broadcast<SerializableFST>> bc;
// save memory on the driver.
// TODO: Better gating than simply numFst == 1. For prod use case
// this always needs to be enabled when numFst == 1.
if (numFst == 1) {
dfFst = dfFst
.repartition(10)
.mapPartitions(mergeFsts, encoder);
try {
bc = Collections.singletonList(jsc().broadcast(
doDriverSideMerge(dfFst.collectAsList())));
} catch (IOException e) {
throw new RuntimeException(e);
}
} else {
dfFst = dfFst
.repartition(numFst)
.mapPartitions(mergeFsts, encoder);
bc = dfFst.collectAsList()
.stream()
.map(jsc()::broadcast)
.collect(Collectors.toList());
}
return FSTLookup.makeUdf(bc);
}
private SerializableFST doDriverSideMerge(Iterable<SerializableFST> fsts) throws IOException {
Iterator<FST<Object>> it = Iterators.transform(fsts.iterator(), SerializableFST::getFST);
return new SerializableFST(FSTBuilder.buildFromFSTs(it));
}
private Dataset<Row> query(UserDefinedFunction lookup, Dataset<Row> queries, Column query, String outputCol, int nonFuzzyPrefix, int editDistance) {
// Output format is the source dataframe plus a new outputCol
// containing an array of string suggestions.
return queries.withColumn(outputCol, lookup.apply(
query, lit(editDistance), lit(nonFuzzyPrefix)));
}
private UserDefinedFunction getForwardLookup() {
if (forwardLookup == null) {
forwardLookup = toLookup(dictionary);
}
return forwardLookup;
}
/**
* Build a lookup against the reversed dictionary.
*
* For simplicity the inputs and outputs to this lookup are also reversed.
* Providing valid input and transforming outputs is the responsibility of
* the caller.
*/
private UserDefinedFunction getReverseLookup() {
if (reverseLookup == null) {
// Invoking this will cause forward and reverse to be built sequentially,
// this::preloadLookups should be invoked to build in parallel.
reverseLookup = toLookup(dictionary.map(
// TODO: Can spark reverse() function be applied? Not sure how against Dataset<String>
(MapFunction<String, String>)s -> new StringBuilder(s).reverse().toString(),
Encoders.STRING()));
}
return reverseLookup;
}
private JavaSparkContext jsc() {
return JavaSparkContext.fromSparkContext(dictionary.sqlContext().sparkContext());
}
}