FSTLookup.java

package org.wikimedia.search.glent.fst;

import static org.apache.spark.sql.functions.udf;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.apache.lucene.search.suggest.analyzing.FSTUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.automaton.Automaton;
import org.apache.lucene.util.automaton.LevenshteinAutomata;
import org.apache.lucene.util.automaton.Operations;
import org.apache.lucene.util.automaton.TooComplexToDeterminizeException;
import org.apache.lucene.util.automaton.Transition;
import org.apache.lucene.util.automaton.UTF32ToUTF8;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.Util;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.api.java.UDF3;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.types.DataTypes;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;

public class FSTLookup implements UDF3<String, Integer, Integer, String[]> {
    private static final long serialVersionUID = 1L;

    private final List<Broadcast<SerializableFST>> sharedFsts;
    private transient List<FST<Object>> dictionaries;

    public FSTLookup(List<Broadcast<SerializableFST>> fsts) {
        this.sharedFsts = fsts;
    }

    static UserDefinedFunction makeUdf(List<Broadcast<SerializableFST>> fsts) {
        return udf(new FSTLookup(fsts),
                DataTypes.createArrayType(DataTypes.StringType));
    }

    @Override
    public String[] call(String s, Integer editDistance, Integer nonFuzzyPrefix) throws IOException {
        if (s.length() == 0) {
            return new String[0];
        }

        if (dictionaries == null) {
            dictionaries = sharedFsts.stream()
                    .map(x -> x.value().getFST())
                    .collect(Collectors.toList());
        }

        Automaton utf8automaton;
        try {
            utf8automaton = makeAutomaton(s, editDistance, nonFuzzyPrefix);
        } catch (TooComplexToDeterminizeException e) {
            return null;
        }

        BytesRefBuilder bytesRefScratch = new BytesRefBuilder();
        return dictionaries.stream()
                .flatMap(dict -> streamIntersectPaths(utf8automaton, dict))
                .map(result -> Util.toBytesRef(result.input.get(), bytesRefScratch))
                .map(BytesRef::utf8ToString)
                .toArray(String[]::new);
    }

    private Automaton makeAutomaton(String s, int editDistance, int nonFuzzyPrefix) {
        String prefix = s.substring(0, nonFuzzyPrefix);
        String searchable = s.substring(nonFuzzyPrefix);

        LevenshteinAutomata lev = new LevenshteinAutomata(searchable, true);
        Automaton utf32automaton = lev.toAutomaton(editDistance, prefix);
        Automaton utf8automaton = new UTF32ToUTF8().convert(utf32automaton);
        return Operations.determinize(utf8automaton, Operations.DEFAULT_MAX_DETERMINIZED_STATES);
    }

    /**
     * Fit intersectPaths into stream processing.
     *
     * IOExceptions should be impossible in our configuration. They come from the
     * underlying storage and we are doing everything in-memory.
     */
    private static <T> Stream<FSTUtil.Path<T>> streamIntersectPaths(Automaton a, FST<T> fst) {
        try {
            return intersectPaths(a, fst).stream();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * This is almost exactly FSTUtil.intersectPrefixPaths, but the
     * condition to add to endNodes has changed such that we continue
     * searching and we only add full input queries to endNodes. This
     * converts from a prefix search to exact.
     *
     * TODO: We currently generate suggestions in both directions,
     * aa → ab and ab → aa. It seems a short circuiting condition
     * could be added here to only traverse arcs where q1 ⇐ q2, and
     * later processing can decide if aa → ab or the reversed ab → aa
     * is the preferred direction of suggestion. This code is hairy
     * and a bit scary though.
     */
    @SuppressWarnings("CyclomaticComplexity")
    @SuppressFBWarnings(value = {"CNC_COLLECTION_NAMING_CONFUSION", "SPP_USE_ISEMPTY"},
            justification = "not our code, make minimal changes")
    public static <T> List<FSTUtil.Path<T>> intersectPaths(Automaton a, FST<T> fst)
            throws IOException {
        assert a.isDeterministic();
        final List<FSTUtil.Path<T>> queue = new ArrayList<>();
        final List<FSTUtil.Path<T>> endNodes = new ArrayList<>();
        if (a.getNumStates() == 0) {
            return endNodes;
        }

        queue.add(new FSTUtil.Path<>(0, fst
                .getFirstArc(new FST.Arc<T>()), fst.outputs.getNoOutput(),
                new IntsRefBuilder()));

        final FST.Arc<T> scratchArc = new FST.Arc<>();
        final FST.BytesReader fstReader = fst.getBytesReader();

        Transition t = new Transition();


        while (!queue.isEmpty()) {
            final FSTUtil.Path<T> path = queue.remove(queue.size() - 1);
            // The only change from FSTUtil.intersectPrefixPaths is here, to
            // add only final nodes and continue searching after finding matches.
            if (a.isAccept(path.state) && path.fstNode.isFinal()) {
                endNodes.add(path);
            }

            IntsRefBuilder currentInput = path.input;
            int count = a.initTransition(path.state, t);
            for (int i = 0; i < count; i++) {
                a.getNextTransition(t);
                final int min = t.min;
                final int max = t.max;
                if (min == max) {
                    final FST.Arc<T> nextArc = fst.findTargetArc(t.min,
                            path.fstNode, scratchArc, fstReader);
                    if (nextArc != null) {
                        final IntsRefBuilder newInput = new IntsRefBuilder();
                        newInput.copyInts(currentInput.get());
                        newInput.append(t.min);
                        queue.add(new FSTUtil.Path<>(t.dest, new FST.Arc<T>()
                                .copyFrom(nextArc), fst.outputs
                                .add(path.output, nextArc.output), newInput));
                    }
                } else {
                    // TODO: if this transition's TO state is accepting, and
                    // it accepts the entire range possible in the FST (ie. 0 to 255),
                    // we can simply use the prefix as the accepted state instead of
                    // looking up all the ranges and terminate early
                    // here.  This just shifts the work from one queue
                    // (this one) to another (the completion search
                    // done in AnalyzingSuggester).
                    FST.Arc<T> nextArc = Util.readCeilArc(min, fst, path.fstNode,
                            scratchArc, fstReader);
                    while (nextArc != null && nextArc.label <= max) {
                        assert nextArc.label <=  max;
                        assert nextArc.label >= min : nextArc.label + " "
                                + min;
                        final IntsRefBuilder newInput = new IntsRefBuilder();
                        newInput.copyInts(currentInput.get());
                        newInput.append(nextArc.label);
                        queue.add(new FSTUtil.Path<>(t.dest, new FST.Arc<T>()
                                .copyFrom(nextArc), fst.outputs
                                .add(path.output, nextArc.output), newInput));
                        final int label = nextArc.label; // used in assert
                        nextArc = nextArc.isLast() ? null : fst.readNextRealArc(nextArc,
                                fstReader);
                        assert nextArc == null || label < nextArc.label : "last: " + label
                                + " next: " + nextArc.label;
                    }
                }
            }
        }
        return endNodes;
    }
}