1 package org.wikimedia.search.extra.router;
2
3 import static org.hamcrest.CoreMatchers.containsString;
4 import static org.hamcrest.CoreMatchers.endsWith;
5 import static org.hamcrest.CoreMatchers.instanceOf;
6 import static org.wikimedia.search.extra.router.AbstractRouterQueryBuilder.ConditionDefinition.gt;
7 import static org.wikimedia.search.extra.router.AbstractRouterQueryBuilder.ConditionDefinition.gte;
8
9 import java.io.IOException;
10 import java.util.Arrays;
11 import java.util.Collection;
12 import java.util.Optional;
13
14 import org.apache.lucene.analysis.Analyzer;
15 import org.apache.lucene.index.Term;
16 import org.apache.lucene.index.memory.MemoryIndex;
17 import org.apache.lucene.search.MatchAllDocsQuery;
18 import org.apache.lucene.search.MatchNoDocsQuery;
19 import org.apache.lucene.search.Query;
20 import org.apache.lucene.search.TermQuery;
21 import org.elasticsearch.common.ParsingException;
22 import org.elasticsearch.common.compress.CompressedXContent;
23 import org.elasticsearch.index.mapper.MapperService;
24 import org.elasticsearch.index.query.MatchAllQueryBuilder;
25 import org.elasticsearch.index.query.MatchNoneQueryBuilder;
26 import org.elasticsearch.index.query.MatchPhraseQueryBuilder;
27 import org.elasticsearch.index.query.QueryBuilder;
28 import org.elasticsearch.index.query.QueryBuilders;
29 import org.elasticsearch.index.query.QueryShardContext;
30 import org.elasticsearch.index.query.Rewriteable;
31 import org.elasticsearch.index.query.TermQueryBuilder;
32 import org.elasticsearch.index.query.WrapperQueryBuilder;
33 import org.elasticsearch.plugins.Plugin;
34 import org.elasticsearch.test.AbstractQueryTestCase;
35 import org.elasticsearch.test.TestGeoShapeFieldMapperPlugin;
36 import org.wikimedia.search.extra.ExtraCorePlugin;
37 import org.wikimedia.search.extra.router.AbstractRouterQueryBuilder.Condition;
38
39 public class TokenCountRouterBuilderESTest extends AbstractQueryTestCase<TokenCountRouterQueryBuilder> {
40 protected Collection<Class<? extends Plugin>> getPlugins() {
41 return Arrays.asList(ExtraCorePlugin.class, TestGeoShapeFieldMapperPlugin.class);
42 }
43 private static final String MY_FIELD = "tok_count_field";
44
45 @Override
46 protected void initializeAdditionalMappings(MapperService mapperService) throws IOException {
47 mapperService.merge("_doc",
48 new CompressedXContent("{\"properties\":{\"" + MY_FIELD + "\":{\"type\":\"text\" }," +
49 "\"fallback\":{\"type\":\"text\"}}}"),
50 MapperService.MergeReason.MAPPING_UPDATE);
51 }
52
53 @Override
54 protected boolean supportsBoost() {
55 return false;
56 }
57
58 @Override
59 protected boolean builderGeneratesCacheableQueries() {
60 return false;
61 }
62
63 @Override
64 protected boolean supportsQueryName() {
65
66
67
68
69
70
71
72
73
74 return false;
75 }
76
77 @Override
78 protected TokenCountRouterQueryBuilder doCreateTestQueryBuilder() {
79 TokenCountRouterQueryBuilder builder = new TokenCountRouterQueryBuilder();
80 builder.text(randomRealisticUnicodeOfCodepointLengthBetween(0, 100));
81 if (randomBoolean()) {
82
83 builder.field(MY_FIELD);
84 builder.fallback(new MatchNoneQueryBuilder());
85 } else {
86 builder.analyzer(randomAnalyzer());
87 builder.fallback(new MatchAllQueryBuilder());
88 }
89
90 for (int i = randomIntBetween(1, 10); i > 0; i--) {
91 AbstractRouterQueryBuilder.ConditionDefinition cond = randomFrom(AbstractRouterQueryBuilder.ConditionDefinition.values());
92 int value = randomInt(10);
93 builder.condition(cond, value, new TermQueryBuilder(MY_FIELD, value + ':' + cond.name()));
94 }
95
96 if (randomBoolean()) {
97 builder.discountOverlaps(randomBoolean());
98 }
99 return builder;
100 }
101
102 @Override
103 protected void doAssertLuceneQuery(TokenCountRouterQueryBuilder queryBuilder, Query query, QueryShardContext context) throws IOException {
104 Analyzer analyzer;
105 if (queryBuilder.field() != null) {
106 analyzer = context.getSearchAnalyzer(context.fieldMapper(queryBuilder.field()));
107 } else {
108 assertNotNull("field or analyzer must be set", queryBuilder.analyzer());
109 analyzer = context.getIndexAnalyzers().get(queryBuilder.analyzer());
110 }
111
112 int tokCount = TokenCountRouterQueryBuilder.countToken(analyzer, queryBuilder.text(), queryBuilder.discountOverlaps());
113
114 Optional<Condition> qb = queryBuilder.conditionStream()
115 .filter(x -> x.test(tokCount))
116 .findFirst();
117
118 query = rewrite(query);
119
120 if (qb.isPresent()) {
121 assertThat(query, instanceOf(TermQuery.class));
122 TermQuery tq = (TermQuery) query;
123 assertEquals(new Term(MY_FIELD, qb.get().value() + ':' + qb.get().definition().name()), tq.getTerm());
124 } else {
125 if (queryBuilder.field() != null) {
126 assertThat(query, instanceOf(MatchNoDocsQuery.class));
127 } else {
128 assertThat(query, instanceOf(MatchAllDocsQuery.class));
129 }
130 }
131 }
132
133 public void testRequiredFields() throws IOException {
134 final TokenCountRouterQueryBuilder builder = new TokenCountRouterQueryBuilder();
135 assertThat(expectThrows(ParsingException.class, () -> parseQuery(builder)).getMessage(),
136 containsString("No conditions defined"));
137 builder.condition(gt, 1, new MatchNoneQueryBuilder());
138
139 assertThat(expectThrows(ParsingException.class, () -> parseQuery(builder)).getMessage(),
140 containsString("No fallback query defined"));
141 builder.fallback(new MatchNoneQueryBuilder());
142
143 assertThat(expectThrows(ParsingException.class, () -> parseQuery(builder)).getMessage(),
144 containsString("No text provided"));
145 builder.text("test text");
146
147 assertThat(expectThrows(ParsingException.class, () -> parseQuery(builder)).getMessage(),
148 containsString("Missing field or analyzer definition"));
149
150 builder.field("test");
151
152 parseQuery(builder);
153 }
154
155 public void testParseDocExample() throws IOException {
156 String json = "{\"token_count_router\": {\n" +
157 " \"field\": \"text\",\n" +
158 " \"text\": \"input query\",\n" +
159 " \"conditions\" : [\n" +
160 " {\n" +
161 " \"gte\": 2,\n" +
162 " \"query\": {\n" +
163 " \"match_phrase\": {\n" +
164 " \"text\": \"input query\"\n" +
165 " }\n" +
166 " }\n" +
167 " }\n" +
168 " ],\n" +
169 " \"fallback\": {\n" +
170 " \"match_none\": {}\n" +
171 " }\n" +
172 "}}";
173 QueryBuilder builder = parseQuery(json);
174 assertThat(builder, instanceOf(TokenCountRouterQueryBuilder.class));
175 TokenCountRouterQueryBuilder tok = (TokenCountRouterQueryBuilder) builder;
176 assertEquals("text", tok.field());
177 assertEquals("input query", tok.text());
178 assertNull(tok.analyzer());
179 assertTrue(tok.discountOverlaps());
180 assertEquals(1, tok.conditionStream().count());
181 Condition cond = tok.conditionStream().findFirst().get();
182 assertEquals(gte, cond.definition());
183 assertEquals(2, cond.value());
184 assertThat(cond.query(), instanceOf(MatchPhraseQueryBuilder.class));
185 assertThat(tok.fallback(), instanceOf(MatchNoneQueryBuilder.class));
186
187 TokenCountRouterQueryBuilder expected = new TokenCountRouterQueryBuilder();
188 expected.field("text");
189 expected.text("input query");
190 expected.condition(gte, 2, QueryBuilders.matchPhraseQuery("text", "input query"));
191 expected.fallback(new MatchNoneQueryBuilder());
192 assertEquals(expected, tok);
193 }
194
195 public void testFailOnMultiplePredicate() throws IOException {
196 String json = "{\"token_count_router\": {\n" +
197 " \"field\": \"text\",\n" +
198 " \"text\": \"input query\",\n" +
199 " \"conditions\" : [\n" +
200 " {\n" +
201 " \"gte\": 2,\n" +
202 " \"gt\": 2,\n" +
203 " \"query\": {\n" +
204 " \"match_phrase\": {\n" +
205 " \"text\": \"input query\"\n" +
206 " }\n" +
207 " }\n" +
208 " }\n" +
209 " ],\n" +
210 " \"fallback\": {\n" +
211 " \"match_none\": {}\n" +
212 " }\n" +
213 "}}";
214 Throwable t = expectThrows(ParsingException.class, () -> parseQuery(json));
215 assertThat(t.getMessage(), endsWith("[token_count_router] failed to parse field [conditions]"));
216
217
218
219
220
221 }
222
223 @Override
224 public void testMustRewrite() throws IOException {
225 TokenCountRouterQueryBuilder builder = new TokenCountRouterQueryBuilder();
226 builder.text(randomAlphaOfLength(20));
227 builder.analyzer(randomAnalyzer());
228 QueryBuilder toRewrite = new TermQueryBuilder("fallback", "fallback");
229 builder.fallback(new WrapperQueryBuilder(toRewrite.toString()));
230 for (int i = randomIntBetween(1, 10); i > 0; i--) {
231 AbstractRouterQueryBuilder.ConditionDefinition cond = randomFrom(AbstractRouterQueryBuilder.ConditionDefinition.values());
232 int value = randomInt(10);
233 builder.condition(cond, value, new WrapperQueryBuilder(toRewrite.toString()));
234 }
235 QueryBuilder rewrittenBuilder = Rewriteable.rewrite(builder, createShardContext());
236 assertEquals(rewrittenBuilder, toRewrite);
237 }
238
239 public void testUnknownAnalyzer() {
240 TokenCountRouterQueryBuilder expected = new TokenCountRouterQueryBuilder();
241 expected.field("unknown_field");
242 expected.text("input query");
243 expected.condition(gte, 2, QueryBuilders.matchPhraseQuery("text", "input query"));
244 expected.fallback(new MatchNoneQueryBuilder());
245 QueryShardContext context = createShardContext();
246 assertThat(expectThrows(IllegalArgumentException.class, () -> Rewriteable.rewrite(expected, context)).getMessage(),
247 containsString("Unknown field [unknown_field]"));
248
249 expected.field(null);
250 expected.analyzer("unknown_analyzer");
251 assertThat(expectThrows(IllegalArgumentException.class, () -> Rewriteable.rewrite(expected, context)).getMessage(),
252 containsString("Unknown analyzer [unknown_analyzer]"));
253 }
254
255 @Override
256 protected Query rewrite(Query query) throws IOException {
257 if (query != null) {
258
259
260
261
262 MemoryIndex idx = new MemoryIndex();
263 return idx.createSearcher().rewrite(query);
264 }
265 return new MatchAllDocsQuery();
266 }
267 }