1 package org.wikimedia.search.extra.router;
2
3 import java.io.IOException;
4 import java.util.ArrayList;
5 import java.util.List;
6 import java.util.Objects;
7 import java.util.function.Predicate;
8 import java.util.stream.Stream;
9
10 import javax.annotation.Nullable;
11
12 import org.apache.lucene.search.Query;
13 import org.elasticsearch.common.ParseField;
14 import org.elasticsearch.common.ParsingException;
15 import org.elasticsearch.common.io.stream.StreamInput;
16 import org.elasticsearch.common.io.stream.StreamOutput;
17 import org.elasticsearch.common.io.stream.Writeable;
18 import org.elasticsearch.common.xcontent.ContextParser;
19 import org.elasticsearch.common.xcontent.ObjectParser;
20 import org.elasticsearch.common.xcontent.ToXContent;
21 import org.elasticsearch.common.xcontent.XContentBuilder;
22 import org.elasticsearch.common.xcontent.XContentParser;
23 import org.elasticsearch.index.query.AbstractQueryBuilder;
24 import org.elasticsearch.index.query.BoolQueryBuilder;
25 import org.elasticsearch.index.query.QueryBuilder;
26 import org.elasticsearch.index.query.QueryShardContext;
27 import org.wikimedia.search.extra.router.AbstractRouterQueryBuilder.Condition;
28
29 import com.google.common.annotations.VisibleForTesting;
30
31 import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
32 import lombok.AccessLevel;
33 import lombok.EqualsAndHashCode;
34 import lombok.Getter;
35 import lombok.Setter;
36 import lombok.experimental.Accessors;
37
38
39
40
41
42 @SuppressWarnings("checkstyle:classfanoutcomplexity")
43 @Getter
44 @Setter
45 @Accessors(fluent = true, chain = true)
46 public abstract class AbstractRouterQueryBuilder<C extends Condition, QB extends AbstractRouterQueryBuilder<C, QB>> extends AbstractQueryBuilder<QB> {
47 private static final ParseField FALLBACK = new ParseField("fallback");
48 private static final ParseField CONDITIONS = new ParseField("conditions");
49 public static final ParseField QUERY = new ParseField("query");
50
51 @Getter(AccessLevel.PRIVATE)
52 private List<C> conditions;
53
54 @Nullable private QueryBuilder fallback;
55
56
57
58
59 AbstractRouterQueryBuilder() {
60 this.conditions = new ArrayList<>();
61 }
62
63
64
65
66 AbstractRouterQueryBuilder(StreamInput in, Writeable.Reader<C> reader) throws IOException {
67 super(in);
68 conditions = in.readList(reader);
69 fallback = in.readNamedWriteable(QueryBuilder.class);
70 }
71
72
73
74
75 protected void doWriteTo(StreamOutput out) throws IOException {
76 out.writeList(conditions);
77 out.writeNamedWriteable(fallback);
78 }
79
80
81
82
83 final QueryBuilder doRewrite(Predicate<C> condition) {
84 QueryBuilder qb = conditions.stream()
85 .filter(condition)
86 .findFirst()
87 .map(Condition::query)
88 .orElse(fallback);
89
90 if (boost() != DEFAULT_BOOST || queryName() != null) {
91
92
93
94
95 return new BoolQueryBuilder().must(qb);
96 }
97 return qb;
98
99 }
100
101 @Override
102 protected boolean doEquals(QB other) {
103 AbstractRouterQueryBuilder<C, QB> qb = other;
104 return Objects.equals(fallback, qb.fallback) &&
105 Objects.equals(conditions, qb.conditions);
106 }
107
108 @Override
109 protected int doHashCode() {
110 return Objects.hash(fallback, conditions);
111 }
112
113 @Override
114 @SuppressFBWarnings("ACEM_ABSTRACT_CLASS_EMPTY_METHODS")
115 protected Query doToQuery(QueryShardContext queryShardContext) {
116 throw new UnsupportedOperationException("This query must be rewritten.");
117 }
118
119
120
121
122 @SuppressFBWarnings("ACEM_ABSTRACT_CLASS_EMPTY_METHODS")
123 protected void addXContent(XContentBuilder builder, Params params) throws IOException {
124 }
125
126 @Override
127 @SuppressFBWarnings("PRMC_POSSIBLY_REDUNDANT_METHOD_CALLS")
128 protected void doXContent(XContentBuilder builder, Params params) throws IOException {
129 builder.startObject(getWriteableName());
130 if (fallback() != null) {
131 builder.field(FALLBACK.getPreferredName(), fallback());
132 }
133 if (!conditions().isEmpty()) {
134 builder.startArray(CONDITIONS.getPreferredName());
135 for (C c : conditions()) {
136 c.toXContent(builder, params);
137 }
138 builder.endArray();
139 }
140
141 addXContent(builder, params);
142 printBoostAndQueryName(builder);
143 builder.endObject();
144 }
145
146
147
148
149 @SuppressFBWarnings(value = "OCP_OVERLY_CONCRETE_PARAMETER", justification = "No need to be generic in this case")
150 static <C extends Condition, CPS extends AbstractConditionParserState<C>> C parseCondition(
151 ObjectParser<CPS, Void> condParser, XContentParser parser
152 ) throws IOException {
153 CPS state = condParser.parse(parser, null);
154 state.checkValid();
155 return state.condition();
156 }
157
158
159
160
161
162 @SuppressFBWarnings(value = "LEST_LOST_EXCEPTION_STACK_TRACE", justification = "The new exception contains all needed context")
163 static <QB extends AbstractRouterQueryBuilder<?, QB>> QB fromXContent(
164 ObjectParser<QB, Void> objectParser, XContentParser parser) throws IOException {
165 final QB builder;
166 try {
167 builder = objectParser.parse(parser, null);
168 } catch (IllegalArgumentException iae) {
169 throw new ParsingException(parser.getTokenLocation(), iae.getMessage());
170 }
171
172 final AbstractRouterQueryBuilder<?, QB> qb = builder;
173 if (qb.conditions.isEmpty()) {
174 throw new ParsingException(parser.getTokenLocation(), "No conditions defined");
175 }
176
177 if (qb.fallback == null) {
178 throw new ParsingException(parser.getTokenLocation(), "No fallback query defined");
179 }
180
181 return builder;
182 }
183
184
185
186
187 @Getter
188 @Accessors(fluent = true, chain = true)
189 @EqualsAndHashCode
190 public static class Condition implements Writeable, ToXContent {
191 private final ConditionDefinition definition;
192 private final int value;
193 private final QueryBuilder query;
194
195 @SuppressFBWarnings("NP_NULL_ON_SOME_PATH_FROM_RETURN_VALUE")
196 Condition(StreamInput in) throws IOException {
197 definition = ConditionDefinition.readFrom(in);
198 value = in.readVInt();
199 query = in.readNamedWriteable(QueryBuilder.class);
200 }
201
202 Condition(ConditionDefinition definition, int value, QueryBuilder query) {
203 this.definition = Objects.requireNonNull(definition);
204 this.value = value;
205 this.query = Objects.requireNonNull(query);
206 }
207
208 @Override
209 public void writeTo(StreamOutput out) throws IOException {
210 definition.writeTo(out);
211 out.writeVInt(value);
212 out.writeNamedWriteable(query);
213 }
214
215
216
217
218 public boolean test(long lhs) {
219 return definition.test(lhs, value);
220 }
221
222
223
224
225 @SuppressWarnings({"EmptyMethod"})
226 void addXContent(XContentBuilder builder, Params params) throws IOException {
227
228 }
229
230 @Override
231 public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
232 builder.startObject();
233 builder.field(definition.parseField.getPreferredName(), value);
234 builder.field(QUERY.getPreferredName(), query);
235 addXContent(builder, params);
236 return builder.endObject();
237 }
238 }
239
240
241
242
243 static <C extends Condition, QB extends AbstractRouterQueryBuilder<C, QB>> void declareRouterFields(
244 ObjectParser<QB, Void> parser,
245 ContextParser<Void, C> objectParser) {
246 parser.declareObjectArray(QB::conditions, objectParser, CONDITIONS);
247 parser.declareObject(QB::fallback,
248 (p, ctx) -> parseInnerQueryBuilder(p),
249 FALLBACK);
250 }
251
252
253
254
255 static <CPS extends AbstractConditionParserState<?>> void declareConditionFields(
256 ObjectParser<CPS, Void> parser) {
257 for (ConditionDefinition def : ConditionDefinition.values()) {
258
259 parser.declareInt((cps, value) -> cps.addPredicate(def, value), def.parseField);
260 }
261
262 parser.declareObject(CPS::setQuery,
263 (p, ctx) -> parseInnerQueryBuilder(p),
264 QUERY);
265 }
266
267
268
269
270 abstract static class AbstractConditionParserState<C extends Condition> {
271 @Nullable protected ConditionDefinition definition;
272 protected int value;
273 @Nullable protected QueryBuilder query;
274
275
276
277
278 void addPredicate(ConditionDefinition def, int value) {
279 if (this.definition != null) {
280 throw new IllegalArgumentException("Cannot set extra predicate [" + def.parseField + "] " +
281 "on condition: [" + this.definition.parseField + "] already set");
282 }
283 this.definition = def;
284 this.value = value;
285 }
286
287
288
289
290 protected void setQuery(QueryBuilder query) {
291 this.query = query;
292 }
293
294
295
296
297 abstract C condition();
298
299
300
301
302
303 void checkValid() {
304 if (query == null) {
305 throw new IllegalArgumentException("Missing field [query] in condition");
306 }
307 if (definition == null) {
308 throw new IllegalArgumentException("Missing condition predicate in condition");
309 }
310 }
311 }
312
313
314
315
316 static class ConditionParserState extends AbstractConditionParserState<Condition> {
317 @Override
318 Condition condition() {
319 return new Condition(definition, value, query);
320 }
321 }
322
323
324
325
326 @FunctionalInterface
327 interface BiLongPredicate {
328
329 boolean test(long a, long b);
330 }
331
332
333
334
335 public enum ConditionDefinition implements BiLongPredicate, Writeable {
336 eq((a, b) -> a == b),
337 neq((a, b) -> a != b),
338 lte((a, b) -> a <= b),
339 lt((a, b) -> a < b),
340 gte((a, b) -> a >= b),
341 gt((a, b) -> a > b);
342
343 final ParseField parseField;
344 final BiLongPredicate predicate;
345
346 ConditionDefinition(BiLongPredicate predicate) {
347 this.predicate = predicate;
348 this.parseField = new ParseField(name());
349 }
350
351
352
353
354 @Override
355 public boolean test(long a, long b) {
356 return predicate.test(a, b);
357 }
358
359 @Override
360 public void writeTo(StreamOutput out) throws IOException {
361 out.writeVInt(ordinal());
362 }
363
364
365
366
367 static ConditionDefinition readFrom(StreamInput in) throws IOException {
368 int ord = in.readVInt();
369 if (ord < 0 || ord >= ConditionDefinition.values().length) {
370 throw new IOException("Unknown ConditionDefinition ordinal [" + ord + "]");
371 }
372 return ConditionDefinition.values()[ord];
373 }
374
375 }
376
377
378
379
380 @VisibleForTesting
381 Stream<C> conditionStream() {
382 return conditions.stream();
383 }
384
385
386
387
388 @VisibleForTesting
389 void condition(C condition) {
390 conditions.add(condition);
391 }
392 }