1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.wikimedia.search.extra.simswitcher;
18
19 import static org.hamcrest.Matchers.greaterThan;
20
21 import java.io.IOException;
22 import java.util.ArrayList;
23 import java.util.Map;
24 import java.util.stream.Collectors;
25 import java.util.stream.IntStream;
26 import java.util.stream.Stream;
27
28 import org.apache.lucene.analysis.standard.StandardAnalyzer;
29 import org.apache.lucene.document.Document;
30 import org.apache.lucene.document.Field;
31 import org.apache.lucene.index.IndexReader;
32 import org.apache.lucene.index.RandomIndexWriter;
33 import org.apache.lucene.misc.SweetSpotSimilarity;
34 import org.apache.lucene.search.Explanation;
35 import org.apache.lucene.search.IndexSearcher;
36 import org.apache.lucene.search.Query;
37 import org.apache.lucene.search.ScoreDoc;
38 import org.apache.lucene.search.TopDocs;
39 import org.apache.lucene.search.similarities.AfterEffectB;
40 import org.apache.lucene.search.similarities.AxiomaticF3LOG;
41 import org.apache.lucene.search.similarities.BM25Similarity;
42 import org.apache.lucene.search.similarities.BasicModelIF;
43 import org.apache.lucene.search.similarities.BooleanSimilarity;
44 import org.apache.lucene.search.similarities.ClassicSimilarity;
45 import org.apache.lucene.search.similarities.DFISimilarity;
46 import org.apache.lucene.search.similarities.DFRSimilarity;
47 import org.apache.lucene.search.similarities.DistributionLL;
48 import org.apache.lucene.search.similarities.IBSimilarity;
49 import org.apache.lucene.search.similarities.IndependenceChiSquared;
50 import org.apache.lucene.search.similarities.LMDirichletSimilarity;
51 import org.apache.lucene.search.similarities.LMJelinekMercerSimilarity;
52 import org.apache.lucene.search.similarities.LambdaDF;
53 import org.apache.lucene.search.similarities.NormalizationH1;
54 import org.apache.lucene.search.similarities.NormalizationH3;
55 import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper;
56 import org.apache.lucene.search.similarities.Similarity;
57 import org.apache.lucene.store.Directory;
58 import org.apache.lucene.util.LuceneTestCase;
59 import org.apache.lucene.util.QueryBuilder;
60 import org.junit.After;
61 import org.junit.Before;
62 import org.junit.Test;
63
64 @SuppressWarnings("checkstyle:classfanoutcomplexity")
65 public class SimSwitcherQueryTest extends LuceneTestCase {
66
67 private IndexSearcher searcherUnderTest;
68 private RandomIndexWriter indexWriterUnderTest;
69 private IndexReader indexReaderUnderTest;
70 private Directory dirUnderTest;
71 private Similarity similarity;
72 private StandardAnalyzer analyzer;
73 private Map<String, Similarity> similarityMap;
74
75
76 private final String[] docs = new String[] {
77 "how now brown cow",
78 "brown is the color of cows",
79 "brown cow",
80 "banana cows are yummy"
81 };
82
83 @Before
84 public void setupIndex() throws IOException {
85 dirUnderTest = newDirectory();
86 similarityMap = Stream.of(
87 new ClassicSimilarity(),
88 new SweetSpotSimilarity(),
89 new BM25Similarity(),
90 new LMDirichletSimilarity(),
91 new BooleanSimilarity(),
92 new LMJelinekMercerSimilarity(0.2F),
93 new AxiomaticF3LOG(0.5F, 10),
94 new DFISimilarity(new IndependenceChiSquared()),
95 new DFRSimilarity(new BasicModelIF(), new AfterEffectB(), new NormalizationH1()),
96 new IBSimilarity(new DistributionLL(), new LambdaDF(), new NormalizationH3())
97 ).collect(Collectors.toMap((s) -> s.getClass().getSimpleName(), (s) -> s));
98
99 similarity = new ArrayList<>(similarityMap.values()).get(random().nextInt(similarityMap.size()));
100 PerFieldSimilarityWrapper simWrapper = new PerFieldSimilarityWrapper() {
101 @Override
102 public Similarity get(String name) {
103 Similarity sim = similarityMap.get(name);
104 if (sim == null) {
105 return similarity;
106 }
107 return sim;
108 }
109 };
110 analyzer = new StandardAnalyzer();
111 indexWriterUnderTest = new RandomIndexWriter(random(), dirUnderTest, newIndexWriterConfig(analyzer).setSimilarity(similarity));
112
113 for (int i = 0; i < docs.length; i++) {
114 Document doc = new Document();
115 String data = docs[i];
116 doc.add(newTextField("id", "" + i, Field.Store.YES));
117 doc.add(newTextField("main_field", data, Field.Store.YES));
118 similarityMap.keySet().forEach((f) -> doc.add(newTextField(f, data, Field.Store.NO)));
119 indexWriterUnderTest.addDocument(doc);
120 }
121 indexWriterUnderTest.commit();
122 indexWriterUnderTest.forceMerge(1);
123 indexWriterUnderTest.flush();
124
125 indexReaderUnderTest = indexWriterUnderTest.getReader();
126 searcherUnderTest = newSearcher(indexReaderUnderTest);
127 searcherUnderTest.setSimilarity(simWrapper);
128 }
129
130 @Test
131 public void testSearch() throws IOException {
132 String q = "brown";
133
134 for (Map.Entry<String, Similarity> entry : similarityMap.entrySet()) {
135 String msg = "switch from " + similarity.getClass().getSimpleName() + " to " + entry.getKey();
136 Query query = new QueryBuilder(analyzer).createBooleanQuery(entry.getKey(), q);
137 Query hacked = new QueryBuilder(analyzer).createBooleanQuery("main_field", q);
138 TopDocs docs = searcherUnderTest.search(query, 10);
139 assertThat(docs.totalHits.value, greaterThan(0L));
140 TopDocs hackedDocs = searcherUnderTest.search(new SimSwitcherQuery(entry.getValue(), hacked), 10);
141 assertEquals(msg, docs.totalHits, hackedDocs.totalHits);
142 IntStream.range(0, docs.scoreDocs.length).forEach((i) -> {
143 ScoreDoc doc = docs.scoreDocs[i];
144 ScoreDoc hackedDoc = hackedDocs.scoreDocs[i];
145 assertEquals(msg, doc.doc, hackedDoc.doc);
146 assertEquals(msg, doc.score, hackedDoc.score, Math.ulp(doc.score));
147 });
148 }
149 }
150
151 @Test
152 public void testExplain() throws IOException {
153 String q = "brown cow";
154 for (Map.Entry<String, Similarity> entry : similarityMap.entrySet()) {
155 String msg = "switch from " + similarity.getClass().getSimpleName() + " to " + entry.getKey();
156 Query query = new QueryBuilder(analyzer).createBooleanQuery(entry.getKey(), q);
157 Query hacked = new SimSwitcherQuery(entry.getValue(), new QueryBuilder(analyzer).createBooleanQuery("main_field", q));
158 TopDocs docs = searcherUnderTest.search(query, 10);
159 Explanation exp = searcherUnderTest.explain(query, docs.scoreDocs[0].doc);
160 Explanation hackExp = searcherUnderTest.explain(hacked, docs.scoreDocs[0].doc);
161
162 double expectedDoubleValue = exp.getValue().doubleValue();
163 assertEquals(msg, expectedDoubleValue, hackExp.getValue().doubleValue(), Math.ulp(expectedDoubleValue));
164 }
165 }
166
167 @After
168 public void closeStuff() throws IOException {
169 indexReaderUnderTest.close();
170 indexWriterUnderTest.close();
171 dirUnderTest.close();
172 }
173 }