View Javadoc
1   package org.wikimedia.search.extra.analysis.filters;
2   
3   import static java.util.Collections.emptyMap;
4   import static java.util.Collections.singletonList;
5   
6   import java.io.IOException;
7   import java.io.InputStreamReader;
8   import java.util.ArrayList;
9   import java.util.Collections;
10  import java.util.HashSet;
11  import java.util.List;
12  
13  import javax.annotation.Nullable;
14  
15  import org.apache.lucene.analysis.Analyzer;
16  import org.apache.lucene.analysis.BaseTokenStreamTestCase;
17  import org.apache.lucene.analysis.CharArraySet;
18  import org.apache.lucene.analysis.TokenFilter;
19  import org.apache.lucene.analysis.TokenStream;
20  import org.apache.lucene.analysis.Tokenizer;
21  import org.apache.lucene.analysis.core.LowerCaseFilter;
22  import org.apache.lucene.analysis.core.StopFilter;
23  import org.apache.lucene.analysis.core.WhitespaceTokenizer;
24  import org.apache.lucene.analysis.fr.FrenchAnalyzer;
25  import org.apache.lucene.analysis.fr.FrenchLightStemFilter;
26  import org.apache.lucene.analysis.miscellaneous.ASCIIFoldingFilter;
27  import org.apache.lucene.analysis.miscellaneous.KeywordRepeatFilter;
28  import org.apache.lucene.analysis.miscellaneous.RemoveDuplicatesTokenFilter;
29  import org.apache.lucene.analysis.shingle.ShingleFilter;
30  import org.apache.lucene.analysis.standard.StandardTokenizer;
31  import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
32  import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
33  import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
34  import org.apache.lucene.analysis.util.TokenFilterFactory;
35  import org.junit.Test;
36  
37  import com.google.common.base.Charsets;
38  import com.google.common.primitives.Ints;
39  
40  @SuppressWarnings("checkstyle:classfanoutcomplexity") // do not care too much about complexity of test classes
41  public class PreserveOriginalFilterTest extends BaseTokenStreamTestCase {
42      private final int shingleMaxSize = random().nextInt(3) + 3;
43      private final int shingleMinSize = random().nextInt(shingleMaxSize - 2) + 2;
44      @Test
45      public void simpleTest() throws IOException {
46          String input = "Hello the World";
47          try (Analyzer ws = newPreserveOriginalLowerCase()) {
48              TokenStream ts = ws.tokenStream("", input);
49              assertTokenStreamContents(ts,
50                      new String[]{"hello", "Hello", "world", "World"},
51                      new int[]{0, 0, 10, 10}, // start offsets
52                      new int[]{5, 5, 15, 15}, // end offsets
53                      null, // types, not supported
54                      new int[]{1, 0, 2, 0}, // pos increments
55                      null, // pos size (unsupported)
56                      15, // last offset
57                      null, //keywordAtts, (unsupported)
58                      true);
59          }
60      }
61  
62      private Analyzer newPreserveOriginalLowerCase() {
63          return new Analyzer() {
64              @Override
65              protected TokenStreamComponents createComponents(String fieldName) {
66                  Tokenizer tok = new WhitespaceTokenizer();
67                  TokenStream ts = new StopFilter(tok, new CharArraySet(new HashSet<>(singletonList("the")), true));
68                  ts = new PreserveOriginalFilter(ts, TokenFilterFactory.forName("lowercase", emptyMap()));
69                  return new TokenStreamComponents(tok, ts);
70              }
71          };
72      }
73  
74      @Test
75      public void simpleTestWithStop() throws IOException {
76          // Same test but with a stop filter wrapped
77          // testing that if a term is removed our states are still valid
78          String input = "Hello the World";
79          try (Analyzer ws = newPreserveOriginalWithStop()) {
80              TokenStream ts = ws.tokenStream("", input);
81              assertTokenStreamContents(ts,
82                      new String[]{"hello", "Hello", "world", "World"},
83                      new int[]{0, 0, 10, 10}, // start offsets
84                      new int[]{5, 5, 15, 15}, // end offsets
85                      null, // types, not supported
86                      new int[]{1, 0, 2, 0}, // pos increments
87                      null, // pos size (unsupported)
88                      15, // last offset
89                      null, //keywordAtts, (unsupported)
90                      true);
91          }
92      }
93  
94      private Analyzer newPreserveOriginalWithStop() {
95          return new Analyzer() {
96              @Override
97              protected TokenStreamComponents createComponents(String fieldName) {
98                  Tokenizer tok = new WhitespaceTokenizer();
99                  TokenStream ts = new PreserveOriginalFilter.Recorder(tok);
100                 ts = new StopFilter(ts, new CharArraySet(new HashSet<>(Collections.singletonList("the")), true));
101                 ts = new LowerCaseFilter(ts);
102                 ts = new PreserveOriginalFilter(ts);
103                 return new TokenStreamComponents(tok, ts);
104             }
105         };
106     }
107 
108 
109     @Test(expected = IllegalArgumentException.class)
110     public void testBadSetup() throws IOException {
111         try (Analyzer a = new Analyzer() {
112             @Override
113             protected TokenStreamComponents createComponents(String fieldName) {
114                 Tokenizer tok = new StandardTokenizer();
115                 TokenStream ts = new StopFilter(tok, FrenchAnalyzer.getDefaultStopSet());
116                 ts = new ASCIIFoldingFilter(ts, false);
117                 ts = new PreserveOriginalFilter(ts); // should fail here
118                 return new TokenStreamComponents(tok, ts);
119             }
120         }) {
121             a.tokenStream("", "");
122         }
123     }
124 
125     /**
126      * Test that the preserve original filters work like other preserve strategies.
127      * - ascii folding
128      * - KeywordRepeatFilter + RemoveDuplicatesTokenFilter
129      * @throws IOException
130      */
131     @Test
132     public void longTextTest() throws IOException {
133         String textRes = "/Prise de possession.txt";
134         try (Analyzer expected = stopAndAsciiFoldingPreserve();
135              Analyzer actual = stopGenericPreserveAsciiFolding()) {
136             assertSameOutput(expected, actual, textRes);
137             // test reuse
138             assertSameOutput(expected, actual, textRes);
139         }
140         // Let's retry with a shingle filter which stores/restores states
141         try (Analyzer expected = stopAndAsciiFoldingAndShingle();
142              Analyzer actual = stopGenericPreserveAsciiFoldingShingles()) {
143             assertSameOutput(expected, actual, textRes);
144             // test reuse
145             assertSameOutput(expected, actual, textRes);
146         }
147         // now with a KW repeat and a stemmer
148         try (Analyzer expected = stopKWRepeatStemmerAndShingles();
149              Analyzer actual = stopGenericPreserveStemmerAnsShingles()) {
150             assertSameOutput(expected, actual, textRes);
151             // test reuse
152             assertSameOutput(expected, actual, textRes);
153         }
154     }
155 
156     private Analyzer stopGenericPreserveStemmerAnsShingles() {
157         return new Analyzer() {
158             @Override
159             protected TokenStreamComponents createComponents(String fieldName) {
160                 Tokenizer tok = new StandardTokenizer();
161                 TokenStream ts = new StopFilter(tok, FrenchAnalyzer.getDefaultStopSet());
162                 ts = new PreserveOriginalFilter.Recorder(ts);
163                 ts = new FrenchLightStemFilter(ts);
164                 ts = new PreserveOriginalFilter(ts);
165                 ts = new ShingleFilter(ts, shingleMinSize, shingleMaxSize);
166                 return new TokenStreamComponents(tok, ts);
167             }
168         };
169     }
170 
171     private Analyzer stopKWRepeatStemmerAndShingles() {
172         return new Analyzer() {
173             @Override
174             protected TokenStreamComponents createComponents(String fieldName) {
175                 Tokenizer tok = new StandardTokenizer();
176                 TokenStream ts = new StopFilter(tok, FrenchAnalyzer.getDefaultStopSet());
177                 ts = new KeywordRepeatFilter(ts);
178                 // Keyword repeat emits token in the wrong order (returns the preserved first)
179                 // this code switches token by pair
180                 ts = new TokenFilter(ts) {
181                     private @Nullable State state;
182                     private final PositionIncrementAttribute pattr = getAttribute(PositionIncrementAttribute.class);
183                     @Override
184                     public boolean incrementToken() throws IOException {
185                         if (state != null) {
186                             restoreState(state);
187                             pattr.setPositionIncrement(0);
188                             state = null;
189                             return true;
190                         } else if (input.incrementToken()) {
191                             state = captureState();
192                             int posInc = pattr.getPositionIncrement();
193                             assert input.incrementToken();
194                             assert pattr.getPositionIncrement() == 0;
195                             pattr.setPositionIncrement(posInc);
196                             return true;
197                         }
198                         return false;
199                     }
200                 };
201                 ts = new FrenchLightStemFilter(ts);
202                 ts = new RemoveDuplicatesTokenFilter(ts);
203                 ts = new ShingleFilter(ts, shingleMinSize, shingleMaxSize);
204                 return new TokenStreamComponents(tok, ts);
205             }
206         };
207     }
208 
209     private Analyzer stopGenericPreserveAsciiFoldingShingles() {
210         return new Analyzer() {
211             @Override
212             protected TokenStreamComponents createComponents(String fieldName) {
213                 Tokenizer tok = new StandardTokenizer();
214                 TokenStream ts = new StopFilter(tok, FrenchAnalyzer.getDefaultStopSet());
215                 ts = new PreserveOriginalFilter.Recorder(ts);
216                 ts = new ASCIIFoldingFilter(ts);
217                 ts = new PreserveOriginalFilter(ts);
218                 ts = new ShingleFilter(ts, shingleMinSize, shingleMaxSize);
219                 return new TokenStreamComponents(tok, ts);
220             }
221         };
222     }
223 
224     private Analyzer stopAndAsciiFoldingAndShingle() {
225         return new Analyzer() {
226             @Override
227             protected TokenStreamComponents createComponents(String fieldName) {
228                 Tokenizer tok = new StandardTokenizer();
229                 TokenStream ts = new StopFilter(tok, FrenchAnalyzer.getDefaultStopSet());
230                 ts = new ASCIIFoldingFilter(ts, true);
231                 ts = new ShingleFilter(ts, shingleMinSize, shingleMaxSize);
232                 return new TokenStreamComponents(tok, ts);
233             }
234         };
235     }
236 
237     private Analyzer stopGenericPreserveAsciiFolding() {
238         return new Analyzer() {
239             @Override
240             protected TokenStreamComponents createComponents(String fieldName) {
241                 Tokenizer tok = new StandardTokenizer();
242                 TokenStream ts = new StopFilter(tok, FrenchAnalyzer.getDefaultStopSet());
243                 ts = new PreserveOriginalFilter.Recorder(ts);
244                 ts = new ASCIIFoldingFilter(ts, false);
245                 ts = new PreserveOriginalFilter(ts);
246                 return new TokenStreamComponents(tok, ts);
247             }
248         };
249     }
250 
251     private Analyzer stopAndAsciiFoldingPreserve() {
252         return new Analyzer() {
253             @Override
254             protected TokenStreamComponents createComponents(String fieldName) {
255                 Tokenizer tok = new StandardTokenizer();
256                 TokenStream ts = new StopFilter(tok, FrenchAnalyzer.getDefaultStopSet());
257                 ts = new ASCIIFoldingFilter(ts, true);
258                 return new TokenStreamComponents(tok, ts);
259             }
260         };
261     }
262 
263     private void assertSameOutput(Analyzer expectedAnalyzer, Analyzer actualAnalyzer, String res) throws IOException {
264         List<String> output = new ArrayList<>();
265         List<Integer> posInc = new ArrayList<>();
266         List<Integer> startOffsets = new ArrayList<>();
267         List<Integer> endOffsets = new ArrayList<>();
268         int finalOffset = -1;
269         try (TokenStream expected = expectedAnalyzer.tokenStream("",
270                      new InputStreamReader(this.getClass().getResourceAsStream(res), Charsets.UTF_8));
271              TokenStream actual = actualAnalyzer.tokenStream("",
272                      new InputStreamReader(this.getClass().getResourceAsStream(res), Charsets.UTF_8))) {
273             expected.reset();
274             CharTermAttribute cattr = expected.getAttribute(CharTermAttribute.class);
275             PositionIncrementAttribute pInc = expected.getAttribute(PositionIncrementAttribute.class);
276             OffsetAttribute oattr = expected.getAttribute(OffsetAttribute.class);
277             while (expected.incrementToken()) {
278                 output.add(cattr.toString());
279                 posInc.add(pInc.getPositionIncrement());
280                 startOffsets.add(oattr.startOffset());
281                 endOffsets.add(oattr.endOffset());
282             }
283             expected.end();
284             finalOffset = oattr.endOffset();
285             assertTokenStreamContents(actual,
286                     output.toArray(new String[0]),
287                     Ints.toArray(startOffsets),
288                     Ints.toArray(endOffsets),
289                     null, Ints.toArray(posInc),
290                     null,
291                     finalOffset,
292                     null,
293                     true);
294         }
295     }
296 }