FSTLookup.java
package org.wikimedia.search.glent.fst;
import static org.apache.spark.sql.functions.udf;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.lucene.search.suggest.analyzing.FSTUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.IntsRefBuilder;
import org.apache.lucene.util.automaton.Automaton;
import org.apache.lucene.util.automaton.LevenshteinAutomata;
import org.apache.lucene.util.automaton.Operations;
import org.apache.lucene.util.automaton.TooComplexToDeterminizeException;
import org.apache.lucene.util.automaton.Transition;
import org.apache.lucene.util.automaton.UTF32ToUTF8;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.Util;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.api.java.UDF3;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.types.DataTypes;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
public class FSTLookup implements UDF3<String, Integer, Integer, String[]> {
private static final long serialVersionUID = 1L;
private final List<Broadcast<SerializableFST>> sharedFsts;
private transient List<FST<Object>> dictionaries;
public FSTLookup(List<Broadcast<SerializableFST>> fsts) {
this.sharedFsts = fsts;
}
static UserDefinedFunction makeUdf(List<Broadcast<SerializableFST>> fsts) {
return udf(new FSTLookup(fsts),
DataTypes.createArrayType(DataTypes.StringType));
}
@Override
public String[] call(String s, Integer editDistance, Integer nonFuzzyPrefix) throws IOException {
if (s.length() == 0) {
return new String[0];
}
if (dictionaries == null) {
dictionaries = sharedFsts.stream()
.map(x -> x.value().getFST())
.collect(Collectors.toList());
}
Automaton utf8automaton;
try {
utf8automaton = makeAutomaton(s, editDistance, nonFuzzyPrefix);
} catch (TooComplexToDeterminizeException e) {
return null;
}
BytesRefBuilder bytesRefScratch = new BytesRefBuilder();
return dictionaries.stream()
.flatMap(dict -> streamIntersectPaths(utf8automaton, dict))
.map(result -> Util.toBytesRef(result.input.get(), bytesRefScratch))
.map(BytesRef::utf8ToString)
.toArray(String[]::new);
}
private Automaton makeAutomaton(String s, int editDistance, int nonFuzzyPrefix) {
String prefix = s.substring(0, nonFuzzyPrefix);
String searchable = s.substring(nonFuzzyPrefix);
LevenshteinAutomata lev = new LevenshteinAutomata(searchable, true);
Automaton utf32automaton = lev.toAutomaton(editDistance, prefix);
Automaton utf8automaton = new UTF32ToUTF8().convert(utf32automaton);
return Operations.determinize(utf8automaton, Operations.DEFAULT_MAX_DETERMINIZED_STATES);
}
/**
* Fit intersectPaths into stream processing.
*
* IOExceptions should be impossible in our configuration. They come from the
* underlying storage and we are doing everything in-memory.
*/
private static <T> Stream<FSTUtil.Path<T>> streamIntersectPaths(Automaton a, FST<T> fst) {
try {
return intersectPaths(a, fst).stream();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* This is almost exactly FSTUtil.intersectPrefixPaths, but the
* condition to add to endNodes has changed such that we continue
* searching and we only add full input queries to endNodes. This
* converts from a prefix search to exact.
*
* TODO: We currently generate suggestions in both directions,
* aa → ab and ab → aa. It seems a short circuiting condition
* could be added here to only traverse arcs where q1 ⇐ q2, and
* later processing can decide if aa → ab or the reversed ab → aa
* is the preferred direction of suggestion. This code is hairy
* and a bit scary though.
*/
@SuppressWarnings("CyclomaticComplexity")
@SuppressFBWarnings(value = {"CNC_COLLECTION_NAMING_CONFUSION", "SPP_USE_ISEMPTY"},
justification = "not our code, make minimal changes")
public static <T> List<FSTUtil.Path<T>> intersectPaths(Automaton a, FST<T> fst)
throws IOException {
assert a.isDeterministic();
final List<FSTUtil.Path<T>> queue = new ArrayList<>();
final List<FSTUtil.Path<T>> endNodes = new ArrayList<>();
if (a.getNumStates() == 0) {
return endNodes;
}
queue.add(new FSTUtil.Path<>(0, fst
.getFirstArc(new FST.Arc<T>()), fst.outputs.getNoOutput(),
new IntsRefBuilder()));
final FST.Arc<T> scratchArc = new FST.Arc<>();
final FST.BytesReader fstReader = fst.getBytesReader();
Transition t = new Transition();
while (!queue.isEmpty()) {
final FSTUtil.Path<T> path = queue.remove(queue.size() - 1);
// The only change from FSTUtil.intersectPrefixPaths is here, to
// add only final nodes and continue searching after finding matches.
if (a.isAccept(path.state) && path.fstNode.isFinal()) {
endNodes.add(path);
}
IntsRefBuilder currentInput = path.input;
int count = a.initTransition(path.state, t);
for (int i = 0; i < count; i++) {
a.getNextTransition(t);
final int min = t.min;
final int max = t.max;
if (min == max) {
final FST.Arc<T> nextArc = fst.findTargetArc(t.min,
path.fstNode, scratchArc, fstReader);
if (nextArc != null) {
final IntsRefBuilder newInput = new IntsRefBuilder();
newInput.copyInts(currentInput.get());
newInput.append(t.min);
queue.add(new FSTUtil.Path<>(t.dest, new FST.Arc<T>()
.copyFrom(nextArc), fst.outputs
.add(path.output, nextArc.output), newInput));
}
} else {
// TODO: if this transition's TO state is accepting, and
// it accepts the entire range possible in the FST (ie. 0 to 255),
// we can simply use the prefix as the accepted state instead of
// looking up all the ranges and terminate early
// here. This just shifts the work from one queue
// (this one) to another (the completion search
// done in AnalyzingSuggester).
FST.Arc<T> nextArc = Util.readCeilArc(min, fst, path.fstNode,
scratchArc, fstReader);
while (nextArc != null && nextArc.label <= max) {
assert nextArc.label <= max;
assert nextArc.label >= min : nextArc.label + " "
+ min;
final IntsRefBuilder newInput = new IntsRefBuilder();
newInput.copyInts(currentInput.get());
newInput.append(nextArc.label);
queue.add(new FSTUtil.Path<>(t.dest, new FST.Arc<T>()
.copyFrom(nextArc), fst.outputs
.add(path.output, nextArc.output), newInput));
final int label = nextArc.label; // used in assert
nextArc = nextArc.isLast() ? null : fst.readNextRealArc(nextArc,
fstReader);
assert nextArc == null || label < nextArc.label : "last: " + label
+ " next: " + nextArc.label;
}
}
}
}
return endNodes;
}
}