AllPairsLevenshtein.java

package org.wikimedia.search.glent.fst;

import static org.apache.spark.sql.functions.expr;
import static org.apache.spark.sql.functions.lit;
import static org.apache.spark.sql.functions.reverse;
import static org.apache.spark.sql.functions.substring;

import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import org.apache.lucene.util.fst.FST;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.api.java.function.MapPartitionsFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.UserDefinedFunction;

import com.google.common.collect.Iterators;


/**
 * TODO: Currently we query aaa → aab, and separately aab → aaa. It seems the fst
 * evaluation could be short circuited to only generate suggestions where a ⇐ b, and
 * the reversed pair can be generated. Not sure if short circuiting would save 50%,
 * but it would probably be useful.
 */
public class AllPairsLevenshtein {
    private final int numFst;
    private final Dataset<String> dictionary;
    private UserDefinedFunction forwardLookup;
    private UserDefinedFunction reverseLookup;

    /**
     * @param dictionary The set of findable strings
     * @param numFst The number of partitions for FST indices
     */
    public AllPairsLevenshtein(Dataset<String> dictionary, int numFst) {
        this.dictionary = dictionary;
        this.numFst = numFst;
    }

    public Dataset<Row> apply(Dataset<Row> input, Column query, String outputCol) {
        // input should be cached if expensive to calculate
        return apply(input, query, outputCol, 0, 2);
    }

    public Dataset<Row> apply(Dataset<Row> queries, Column query, String outputCol, int nonFuzzyPrefix, int editDistance) {
        boolean reverse = false;
        if (nonFuzzyPrefix == 0) {
            // Searching the FST without a prefix is quite expensive. Instead we will construct
            // a forward and reverse fst and search each with a prefix of 1. This means the first
            // and last char can't both change, but seems acceptable limitation for order of magnitude
            // speedup.
            reverse = true;
            nonFuzzyPrefix = 1;
            preloadLookups();
        }


        Dataset<Row> result = query(getForwardLookup(), queries, query, outputCol,
                nonFuzzyPrefix, editDistance);
        if (!reverse) {
            return result;
        }
        Dataset<Row> reversed = query(
                getReverseLookup(), queries, reverse(query), outputCol,
                nonFuzzyPrefix, editDistance)
            // substring is 1 indexed
            .withColumn("firstChar", substring(query, 1, 1))
            // Exclude suggestions that do not differ in the first (last, when reversed) character,
            // they were already calculated in the forward pass. Couldn't find spark api for higher
            // order functions, so this has to use expr()
            .withColumn(outputCol, expr(String.format(Locale.ROOT,
                    "filter(%s, dym -> firstChar != substring(dym, -1, 1))", outputCol)))
            .drop("firstChar")
            // lookup generated reversed suggestions, flip it all back around.
            .withColumn(outputCol, expr(String.format(Locale.ROOT,
                    "transform(%s, dym -> reverse(dym))", outputCol)));

        return result.union(reversed);
    }

    private void preloadLookups() {
        if (forwardLookup == null && reverseLookup == null) {
            // Stupid hack to load both at same time.
            Thread a = new Thread(this::getForwardLookup, "AllPairsLevenshtein-forward");
            Thread b = new Thread(this::getReverseLookup, "AllPairsLevenshtein-reverse");
            // If we don't do anything failures will print and the thread will stop,
            // but nothing else. Add handling to blow up on failure.
            AtomicReference<Throwable> exception = new AtomicReference<>();
            a.setUncaughtExceptionHandler((t, e) -> exception.set(e));
            b.setUncaughtExceptionHandler((t, e) -> exception.set(e));
            try {
                a.join();
                b.join();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
            if (exception.get() != null) {
                throw new RuntimeException("One or both lookups failed", exception.get());
            }
        }
    }

    /**
     * Transform a dataset of strings into a UDF that will,
     * when invoked with a string, return the set of all nearby strings from
     * the source dataset. The data structure built to answer these
     * queries is significantly compressed from the source, but still takes
     * a large amount of memory and must be broadcast to each executor.
     */
    private UserDefinedFunction toLookup(Dataset<String> df) {
        MapPartitionsFunction<String, SerializableFST> stringsToFST = FSTBuilder::transform;
        MapPartitionsFunction<SerializableFST, SerializableFST> mergeFsts = FSTBuilder::merge;
        Encoder<SerializableFST> encoder = Encoders.javaSerialization(SerializableFST.class);

        // First transform the many input partitions into individual FST's, then
        // reduce to the final number of partitions and merge the FST's into one
        // per partition.
        // In various methods tested spark always distributed all FSTs to all instances,
        // so the most efficient will be to reduce to a single FST. If the dataset
        // is large enough it must be merged on the driver. Essentially spark can't
        // have input partitions larger than 2GB, if the set of inputs is larger than
        // that spark has to make multiple FSTs. During a driver side merge we still
        // merge down to 10 to reduce total memory requirements on the driver.
        Dataset<SerializableFST> dfFst = df.mapPartitions(stringsToFST, encoder);

        List<Broadcast<SerializableFST>> bc;
        // save memory on the driver.
        // TODO: Better gating than simply numFst == 1. For prod use case
        // this always needs to be enabled when numFst == 1.
        if (numFst == 1) {
            dfFst = dfFst
                    .repartition(10)
                    .mapPartitions(mergeFsts, encoder);
            try {
                bc = Collections.singletonList(jsc().broadcast(
                        doDriverSideMerge(dfFst.collectAsList())));
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        } else {
            dfFst = dfFst
                    .repartition(numFst)
                    .mapPartitions(mergeFsts, encoder);

            bc = dfFst.collectAsList()
                    .stream()
                    .map(jsc()::broadcast)
                    .collect(Collectors.toList());
        }

        return FSTLookup.makeUdf(bc);
    }

    private SerializableFST doDriverSideMerge(Iterable<SerializableFST> fsts) throws IOException {
        Iterator<FST<Object>> it = Iterators.transform(fsts.iterator(), SerializableFST::getFST);
        return new SerializableFST(FSTBuilder.buildFromFSTs(it));
    }

    private Dataset<Row> query(UserDefinedFunction lookup, Dataset<Row> queries, Column query, String outputCol, int nonFuzzyPrefix, int editDistance) {
        // Output format is the source dataframe plus a new outputCol
        // containing an array of string suggestions.
        return queries.withColumn(outputCol, lookup.apply(
                query, lit(editDistance), lit(nonFuzzyPrefix)));
    }

    private UserDefinedFunction getForwardLookup() {
        if (forwardLookup == null) {
            forwardLookup = toLookup(dictionary);
        }
        return forwardLookup;
    }

    /**
     * Build a lookup against the reversed dictionary.
     *
     * For simplicity the inputs and outputs to this lookup are also reversed.
     * Providing valid input and transforming outputs is the responsibility of
     * the caller.
     */
    private UserDefinedFunction getReverseLookup() {
        if (reverseLookup == null) {
            // Invoking this will cause forward and reverse to be built sequentially,
            // this::preloadLookups should be invoked to build in parallel.
            reverseLookup = toLookup(dictionary.map(
                    // TODO: Can spark reverse() function be applied? Not sure how against Dataset<String>
                    (MapFunction<String, String>)s -> new StringBuilder(s).reverse().toString(),
                    Encoders.STRING()));
        }
        return reverseLookup;
    }

    private JavaSparkContext jsc() {
        return JavaSparkContext.fromSparkContext(dictionary.sqlContext().sparkContext());
    }
}