FSTBuilder.java

package org.wikimedia.search.glent.fst;

import java.io.IOException;
import java.util.Comparator;
import java.util.Iterator;
import java.util.PriorityQueue;

import org.apache.lucene.search.suggest.InMemorySorter;
import org.apache.lucene.search.suggest.fst.BytesRefSorter;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefIterator;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.fst.Builder;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.IntsRefFSTEnum;
import org.apache.lucene.util.fst.NoOutputs;
import org.apache.lucene.util.fst.Outputs;
import org.apache.lucene.util.fst.Util;

import com.google.common.collect.Iterators;

public final class FSTBuilder {

    private FSTBuilder() {
    }

    /**
     * Spark MapPartitionsFunction for transforming strings into an fst.
     *
     * @param iterator Strings that should be individually findable
     * @return 0 or 1 FST instances
     */
    public static Iterator<SerializableFST> transform(Iterator<String> iterator) throws IOException {
        return toIterator(buildFromStrings(iterator));
    }

    /**
     * Spark MapPartitionsFunction for transforming multiple fst's
     * into a single fst.
     *
     * @param fsts Set of FSTs to merge
     * @return 0 or 1 FST instances
     */
    public static Iterator<SerializableFST> merge(Iterator<SerializableFST> fsts)  throws IOException {
        return toIterator(buildFromFSTs(
            Iterators.transform(fsts, SerializableFST::getFST)));
    }

    /**
     * This is a bit silly, but we convert from nullable to an iterator with 0 or
     * 1 elements to fit with spark's MapPartitionsFunction.
     */
    private static Iterator<SerializableFST> toIterator(FST<Object> automaton) {
        return automaton == null
                ? Iterators.emptyIterator()
                : Iterators.singletonIterator(new SerializableFST(automaton));
    }

    /**
     * Build an FST out of an unsorted iterator over strings. Will pull
     * the full set of strings into memory to sort.
     */
    static FST<Object> buildFromStrings(Iterator<String> iterator) throws IOException {
        final IntsRefBuilder scratchIntsRef = new IntsRefBuilder();
        final BytesRefSorter sorter = new InMemorySorter(Comparator.naturalOrder());

        while (iterator.hasNext()) {
            sorter.add(new BytesRef(iterator.next()));
        }

        final BytesRefIterator sorted = sorter.iterator();
        return buildAutomaton(() -> {
            BytesRef entry = sorted.next();
            return entry == null ? null : Util.toIntsRef(entry, scratchIntsRef);
        });
    }

    /**
     * Build an FST out of two or more FST's.
     */
    static FST<Object> buildFromFSTs(Iterator<FST<Object>> iterator) throws IOException {
        final PriorityQueue<IntsRefFSTEnum<Object>> pq = new PriorityQueue<>(
                Comparator.comparing(x -> x.current().input, Comparator.naturalOrder()));
        // We hold a reference to the last seen FST to gracefully handle
        // a single-fst merge
        FST<Object> lastSeenFST = null;
        // Iniitalize a PQ with iterators for each input FST.
        while (iterator.hasNext()) {
            lastSeenFST = iterator.next();
            IntsRefFSTEnum<Object> fstIterator = new IntsRefFSTEnum<>(lastSeenFST);
            // We need to call next() to initialize and get the first element
            // into current(). When iterating we return the current value and put
            // the next element into the queue. I'm not sure it's possible to have
            // an empty FST, but we null check for completeness.
            if (fstIterator.next() == null) {
                throw new RuntimeException("Empty FST?");
            }
            pq.add(fstIterator);
        }

        // Avoid transforming an fst into itself at great expense. All other
        // conditions handled by buildAutomaton()
        if (pq.size() == 1) {
            assert lastSeenFST != null;
            return lastSeenFST;
        }

        final IntsRefBuilder scratch = new IntsRefBuilder();
        return buildAutomaton(() -> {
            IntsRefFSTEnum<Object> fstIterator = pq.poll();
            if (fstIterator == null) {
                // queue is empty, iteration is complete
                return null;
            }
            // We need to copy the content as the next() call will overwrite.
            scratch.copyInts(fstIterator.current().input);
            if (fstIterator.next() != null) {
                pq.add(fstIterator);
            }
            return scratch.get();
        });
    }

    public interface SupplierWithIO<T> {
        T get() throws IOException;
    }

    static FST<Object> buildAutomaton(SupplierWithIO<IntsRef> iter) throws IOException {
        final Outputs<Object> outputs = NoOutputs.getSingleton();
        final Object empty = outputs.getNoOutput();
        final Builder<Object> builder = new Builder<>(FST.INPUT_TYPE.BYTE1, outputs);

        // Scratch space to recognize and drop duplicates in input stream. Holds the
        // last value inserted to automaton.
        final IntsRefBuilder scratch = new IntsRefBuilder();
        IntsRef entry;
        long count = 0;
        while ((entry = iter.get()) != null) {
            count++;
            // drop any duplicates
            if (scratch.get().compareTo(entry) != 0) {
                builder.add(entry, empty);
                scratch.copyInts(entry);
            }
        }

        return count == 0 ? null : builder.finish();
    }
}