1 package org.wikimedia.search.extra.regex.ngram;
2
3 import java.io.IOException;
4 import java.util.ArrayList;
5 import java.util.HashMap;
6 import java.util.LinkedList;
7 import java.util.List;
8 import java.util.Map;
9
10 import javax.annotation.Nullable;
11
12 import org.apache.lucene.analysis.Analyzer;
13 import org.apache.lucene.analysis.TokenStream;
14 import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
15 import org.apache.lucene.util.automaton.Automaton;
16 import org.apache.lucene.util.automaton.Transition;
17 import org.wikimedia.search.extra.regex.expression.And;
18 import org.wikimedia.search.extra.regex.expression.Expression;
19 import org.wikimedia.search.extra.regex.expression.ExpressionSource;
20 import org.wikimedia.search.extra.regex.expression.False;
21 import org.wikimedia.search.extra.regex.expression.Leaf;
22 import org.wikimedia.search.extra.regex.expression.Or;
23 import org.wikimedia.search.extra.regex.expression.True;
24
25 import com.google.common.collect.ImmutableSet;
26
27 import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
28 import lombok.EqualsAndHashCode;
29
30
31
32
33
34 @SuppressFBWarnings(value = "DLC_DUBIOUS_LIST_COLLECTION", justification = "Need more time to investigate")
35
36
37 public class NGramAutomaton {
38 private final Automaton source;
39 private final int gramSize;
40 private final int maxExpand;
41 private final int maxStatesTraced;
42 private final int maxTransitions;
43 private final List<NGramState> initialStates = new ArrayList<>();
44 private final List<NGramState> acceptStates = new ArrayList<>();
45 private final Map<NGramState, NGramState> states = new HashMap<>();
46 private final Analyzer ngramAnalyzer;
47
48
49
50
51
52
53
54
55
56
57
58
59
60 public NGramAutomaton(Automaton source, int gramSize, int maxExpand, int maxStatesTraced, int maxTransitions, Analyzer ngramAnalyzer) {
61 this.source = source;
62 this.gramSize = gramSize;
63 this.maxExpand = maxExpand;
64 this.maxStatesTraced = maxStatesTraced;
65 this.maxTransitions = maxTransitions;
66 this.ngramAnalyzer = ngramAnalyzer;
67 if (source.getNumStates() == 0) {
68 return;
69 }
70
71 int[] codePoints = new int[gramSize - 1];
72 buildInitial(codePoints, 0, 0);
73 traceRemainingStates();
74 }
75
76
77
78
79
80 public String toDot() {
81 StringBuilder b = new StringBuilder("digraph Automaton {\n");
82 b.append(" rankdir = LR;\n");
83 b.append(" initial [shape=plaintext,label=\"\"];\n");
84 for (NGramState state : states.keySet()) {
85 b.append(" ").append(state.dotName());
86 if (acceptStates.contains(state)) {
87 b.append(" [shape=doublecircle,label=\"").append(state).append("\"];\n");
88 } else {
89 b.append(" [shape=circle,label=\"").append(state).append("\"];\n");
90 }
91 if (state.initial) {
92 b.append(" initial -> ").append(state.dotName()).append('\n');
93 }
94 for (NGramTransition transition : state.outgoingTransitions) {
95 b.append(" ").append(transition).append('\n');
96 }
97 }
98 return b.append("}\n").toString();
99 }
100
101
102
103
104
105
106 public Expression<String> expression() {
107 return Or.fromExpressionSources(acceptStates);
108 }
109
110
111
112
113
114
115
116
117
118 private boolean buildInitial(int[] codePoints, int offset, int currentState) {
119 if (source.isAccept(currentState)) {
120
121
122
123 initialStates.clear();
124 states.clear();
125 return false;
126 }
127 if (offset == gramSize - 1) {
128
129 NGramState state = new NGramState(currentState, new String(codePoints, 0, gramSize - 1), true);
130
131
132 if (states.containsKey(state)) {
133 return true;
134 }
135 initialStates.add(state);
136 states.put(state, state);
137 return true;
138 }
139
140 Transition transition = new Transition();
141 int totalLeavingState = source.initTransition(currentState, transition);
142 for (int currentLeavingState = 0; currentLeavingState < totalLeavingState; currentLeavingState++) {
143 source.getNextTransition(transition);
144 int min;
145 int max;
146 if (transition.max - transition.min >= maxExpand) {
147
148 min = 0;
149 max = 0;
150 } else {
151 min = transition.min;
152 max = transition.max;
153 }
154 for (int c = min; c <= max; c++) {
155 codePoints[offset] = c;
156 if (!buildInitial(codePoints, offset + 1, transition.dest)) {
157 return false;
158 }
159 }
160 }
161 return true;
162 }
163
164 private void traceRemainingStates() {
165 LinkedList<NGramState> leftToProcess = new LinkedList<>(initialStates);
166 int[] codePoint = new int[1];
167 int statesTraced = 0;
168 Transition transition = new Transition();
169 int currentTransitions = 0;
170 while (!leftToProcess.isEmpty()) {
171 if (statesTraced >= maxStatesTraced) {
172 throw new AutomatonTooComplexException();
173 }
174 statesTraced++;
175 NGramState from = leftToProcess.pop();
176 if (acceptStates.contains(from)) {
177
178
179 continue;
180 }
181 int totalLeavingState = source.initTransition(from.sourceState, transition);
182 if (currentTransitions >= maxTransitions) {
183 acceptStates.add(from);
184 continue;
185 }
186 for (int currentLeavingState = 0; currentLeavingState < totalLeavingState; currentLeavingState++) {
187 source.getNextTransition(transition);
188 int min;
189 int max;
190 if (transition.max - transition.min >= maxExpand) {
191
192 min = 0;
193 max = 0;
194 } else {
195 min = transition.min;
196 max = transition.max;
197 }
198 for (int c = min; c <= max; c++) {
199 codePoint[0] = c;
200 String ngram = from.prefix + new String(codePoint, 0, 1);
201 NGramState next = buildOrFind(leftToProcess, transition.dest, ngram.substring(1));
202
203
204 if (ngram.indexOf(0) >= 0) {
205 ngram = null;
206 }
207 if (currentTransitions >= maxTransitions) {
208 acceptStates.add(from);
209 continue;
210 }
211 currentTransitions++;
212 NGramTransition ngramTransition = new NGramTransition(from, next, analyze(ngram));
213 from.outgoingTransitions.add(ngramTransition);
214 ngramTransition.to.incomingTransitions.add(ngramTransition);
215 }
216 }
217 }
218 }
219
220 @Nullable
221 private String analyze(@Nullable String ngram) {
222 if (ngram == null) {
223 return ngram;
224 }
225 try (TokenStream ts = ngramAnalyzer.tokenStream("", ngram)) {
226 CharTermAttribute cattr = ts.addAttribute(CharTermAttribute.class);
227 ts.reset();
228 if (ts.incrementToken()) {
229 ngram = cattr.toString();
230 if (ts.incrementToken()) {
231 throw new IllegalArgumentException("Analyzer provided generate more than one tokens, " +
232 "if using 3grams make sure to use a 3grams analyzer, " +
233 "for input [" + ngram + "] first is [" + ngram + "] " +
234 "but [" + cattr + "] was generated.");
235 }
236 }
237 } catch (IOException ioe) {
238 throw new RuntimeException(ioe);
239 }
240 return ngram;
241 }
242
243 private NGramState buildOrFind(LinkedList<NGramState> leftToProcess, int sourceState, String prefix) {
244 NGramState built = new NGramState(sourceState, prefix, false);
245 NGramState found = states.get(built);
246 if (found != null) {
247 return found;
248 }
249 if (source.isAccept(sourceState)) {
250 acceptStates.add(built);
251 }
252 states.put(built, built);
253 leftToProcess.add(built);
254 return built;
255 }
256
257
258
259
260
261 @EqualsAndHashCode(of = { "prefix", "sourceState" })
262 private static final class NGramState implements ExpressionSource<String> {
263
264
265
266 private static final String INVALID_CHAR = new String(new int[] {0}, 0, 1);
267
268
269
270 private static final String INVALID_PRINT_CHAR = "__";
271
272
273
274
275 private final int sourceState;
276
277
278
279 private final String prefix;
280
281
282
283
284 private final boolean initial;
285
286
287
288 private final List<NGramTransition> outgoingTransitions = new ArrayList<>();
289
290
291
292 private final List<NGramTransition> incomingTransitions = new ArrayList<>();
293
294
295
296
297 @Nullable private Expression<String> expression;
298
299
300
301 private boolean inPath;
302
303 private NGramState(int sourceState, String prefix, boolean initial) {
304 this.sourceState = sourceState;
305 this.prefix = prefix;
306 this.initial = initial;
307 }
308
309 @Override
310 public String toString() {
311 return "(" + prettyPrefix() + ", " + sourceState + ")";
312 }
313
314 public String dotName() {
315
316 return prettyPrefix().replace(" ", "___").replace("`", "_bt_")
317 .replace("^", "_caret_").replace("|", "_pipe_")
318 .replace("{", "_lcb_").replace("}", "_rcb_")
319 .replace("=", "_eq_") + sourceState;
320 }
321
322 public String prettyPrefix() {
323 return prefix.replace(INVALID_CHAR, INVALID_PRINT_CHAR);
324 }
325
326 @Override
327 public Expression<String> expression() {
328 if (expression == null) {
329 if (initial) {
330 expression = True.instance();
331 } else {
332 inPath = true;
333 expression = Or.fromExpressionSources(incomingTransitions);
334 inPath = false;
335 }
336 }
337 return expression;
338 }
339 }
340
341 private static final class NGramTransition implements ExpressionSource<String> {
342 private final NGramState from;
343 private final NGramState to;
344 @Nullable private final String ngram;
345
346 private NGramTransition(NGramState from, NGramState to, @Nullable String ngram) {
347 this.from = from;
348 this.to = to;
349 this.ngram = ngram;
350 }
351
352 @Override
353 public Expression<String> expression() {
354 if (from.inPath) {
355 return False.instance();
356 }
357 if (ngram == null) {
358 return from.expression();
359 }
360 return new And<>(ImmutableSet.of(from.expression(), new Leaf<>(ngram)));
361 }
362
363 @Override
364 public String toString() {
365 StringBuilder b = new StringBuilder();
366 b.append(from.dotName()).append(" -> ").append(to.dotName());
367 if (ngram != null) {
368 b.append(" [label=\"").append(ngram.replace(' ', '_')).append("\"]");
369 }
370 return b.toString();
371 }
372 }
373 }