View Javadoc
1   package org.wikimedia.search.extra.router;
2   
3   import java.io.IOException;
4   import java.util.ArrayList;
5   import java.util.List;
6   import java.util.Objects;
7   import java.util.function.Predicate;
8   import java.util.stream.Stream;
9   
10  import javax.annotation.Nullable;
11  
12  import org.apache.lucene.search.Query;
13  import org.elasticsearch.common.ParseField;
14  import org.elasticsearch.common.ParsingException;
15  import org.elasticsearch.common.io.stream.StreamInput;
16  import org.elasticsearch.common.io.stream.StreamOutput;
17  import org.elasticsearch.common.io.stream.Writeable;
18  import org.elasticsearch.common.xcontent.ContextParser;
19  import org.elasticsearch.common.xcontent.ObjectParser;
20  import org.elasticsearch.common.xcontent.ToXContent;
21  import org.elasticsearch.common.xcontent.XContentBuilder;
22  import org.elasticsearch.common.xcontent.XContentParser;
23  import org.elasticsearch.index.query.AbstractQueryBuilder;
24  import org.elasticsearch.index.query.BoolQueryBuilder;
25  import org.elasticsearch.index.query.QueryBuilder;
26  import org.elasticsearch.index.query.QueryShardContext;
27  import org.wikimedia.search.extra.router.AbstractRouterQueryBuilder.Condition;
28  
29  import com.google.common.annotations.VisibleForTesting;
30  
31  import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
32  import lombok.AccessLevel;
33  import lombok.EqualsAndHashCode;
34  import lombok.Getter;
35  import lombok.Setter;
36  import lombok.experimental.Accessors;
37  
38  
39  /**
40   * Base class for "router" like queries.
41   */
42  @SuppressWarnings("checkstyle:classfanoutcomplexity") // TODO: refactor at some point
43  @Getter
44  @Setter
45  @Accessors(fluent = true, chain = true)
46  public abstract class AbstractRouterQueryBuilder<C extends Condition, QB extends AbstractRouterQueryBuilder<C, QB>> extends AbstractQueryBuilder<QB> {
47      private static final ParseField FALLBACK = new ParseField("fallback");
48      private static final ParseField CONDITIONS = new ParseField("conditions");
49      public static final ParseField QUERY = new ParseField("query");
50  
51      @Getter(AccessLevel.PRIVATE)
52      private List<C> conditions;
53  
54      @Nullable private QueryBuilder fallback;
55  
56      /**
57       * Empty ctor.
58       */
59      AbstractRouterQueryBuilder() {
60          this.conditions = new ArrayList<>();
61      }
62  
63      /**
64       * Build from a StreamInput.
65       */
66      AbstractRouterQueryBuilder(StreamInput in, Writeable.Reader<C> reader) throws IOException {
67          super(in);
68          conditions = in.readList(reader);
69          fallback = in.readNamedWriteable(QueryBuilder.class);
70      }
71  
72      /**
73       * Write to the StreamOuput.
74       */
75      protected void doWriteTo(StreamOutput out) throws IOException {
76          out.writeList(conditions);
77          out.writeNamedWriteable(fallback);
78      }
79  
80      /**
81       * Evaluates conditions and returns the associated QueryBuilder.
82       */
83      final QueryBuilder doRewrite(Predicate<C> condition) {
84          QueryBuilder qb = conditions.stream()
85                  .filter(condition)
86                  .findFirst()
87                  .map(Condition::query)
88                  .orElse(fallback);
89  
90          if (boost() != DEFAULT_BOOST || queryName() != null) {
91              // AbstractQueryBuilder#rewrite will copy non default boost/name
92              // to the rewritten query, we pass a fresh BoolQuery so we don't
93              // override the one on the rewritten query here
94              // Is this really useful?
95              return new BoolQueryBuilder().must(qb);
96          }
97          return qb;
98  
99      }
100 
101     @Override
102     protected boolean doEquals(QB other) {
103         AbstractRouterQueryBuilder<C, QB> qb = other;
104         return Objects.equals(fallback, qb.fallback) &&
105                 Objects.equals(conditions, qb.conditions);
106     }
107 
108     @Override
109     protected int doHashCode() {
110         return Objects.hash(fallback, conditions);
111     }
112 
113     @Override
114     @SuppressFBWarnings("ACEM_ABSTRACT_CLASS_EMPTY_METHODS")
115     protected Query doToQuery(QueryShardContext queryShardContext) {
116         throw new UnsupportedOperationException("This query must be rewritten.");
117     }
118 
119     /**
120      * Override when custom params need to be written to the ContentBuilder.
121      */
122     @SuppressFBWarnings("ACEM_ABSTRACT_CLASS_EMPTY_METHODS")
123     protected void addXContent(XContentBuilder builder, Params params) throws IOException {
124     }
125 
126     @Override
127     @SuppressFBWarnings("PRMC_POSSIBLY_REDUNDANT_METHOD_CALLS")
128     protected void doXContent(XContentBuilder builder, Params params) throws IOException {
129         builder.startObject(getWriteableName());
130         if (fallback() != null) {
131             builder.field(FALLBACK.getPreferredName(), fallback());
132         }
133         if (!conditions().isEmpty()) {
134             builder.startArray(CONDITIONS.getPreferredName());
135             for (C c : conditions()) {
136                 c.toXContent(builder, params);
137             }
138             builder.endArray();
139         }
140 
141         addXContent(builder, params);
142         printBoostAndQueryName(builder);
143         builder.endObject();
144     }
145 
146     /**
147      * Parse condition given a condition parser.
148      */
149     @SuppressFBWarnings(value = "OCP_OVERLY_CONCRETE_PARAMETER", justification = "No need to be generic in this case")
150     static <C extends Condition, CPS extends AbstractConditionParserState<C>> C parseCondition(
151             ObjectParser<CPS, Void> condParser, XContentParser parser
152     ) throws IOException {
153         CPS state = condParser.parse(parser, null);
154         state.checkValid();
155         return state.condition();
156     }
157 
158 
159     /**
160      * Parse and build the query.
161      */
162     @SuppressFBWarnings(value = "LEST_LOST_EXCEPTION_STACK_TRACE", justification = "The new exception contains all needed context")
163     static <QB extends AbstractRouterQueryBuilder<?, QB>> QB fromXContent(
164             ObjectParser<QB, Void> objectParser, XContentParser parser) throws IOException {
165         final QB builder;
166         try {
167             builder = objectParser.parse(parser, null);
168         } catch (IllegalArgumentException iae) {
169             throw new ParsingException(parser.getTokenLocation(), iae.getMessage());
170         }
171 
172         final AbstractRouterQueryBuilder<?, QB> qb = builder;
173         if (qb.conditions.isEmpty()) {
174             throw new ParsingException(parser.getTokenLocation(), "No conditions defined");
175         }
176 
177         if (qb.fallback == null) {
178             throw new ParsingException(parser.getTokenLocation(), "No fallback query defined");
179         }
180 
181         return builder;
182     }
183 
184     /**
185      * A condition.
186      */
187     @Getter
188     @Accessors(fluent = true, chain = true)
189     @EqualsAndHashCode
190     public static class Condition implements Writeable, ToXContent {
191         private final ConditionDefinition definition;
192         private final int value;
193         private final QueryBuilder query;
194 
195         @SuppressFBWarnings("NP_NULL_ON_SOME_PATH_FROM_RETURN_VALUE")
196         Condition(StreamInput in) throws IOException {
197             definition = ConditionDefinition.readFrom(in);
198             value = in.readVInt();
199             query = in.readNamedWriteable(QueryBuilder.class);
200         }
201 
202         Condition(ConditionDefinition definition, int value, QueryBuilder query) {
203             this.definition = Objects.requireNonNull(definition);
204             this.value = value;
205             this.query = Objects.requireNonNull(query);
206         }
207 
208         @Override
209         public void writeTo(StreamOutput out) throws IOException {
210             definition.writeTo(out);
211             out.writeVInt(value);
212             out.writeNamedWriteable(query);
213         }
214 
215         /**
216          * Test the condition.
217          */
218         public boolean test(long lhs) {
219             return definition.test(lhs, value);
220         }
221 
222         /**
223          * Build the XContent. Override when custom fields are needed.
224          */
225         @SuppressWarnings({"EmptyMethod"})
226         void addXContent(XContentBuilder builder, Params params) throws IOException {
227             // Empty implementation, but allow extending classes to add things.
228         }
229 
230         @Override
231         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
232             builder.startObject();
233             builder.field(definition.parseField.getPreferredName(), value);
234             builder.field(QUERY.getPreferredName(), query);
235             addXContent(builder, params);
236             return builder.endObject();
237         }
238     }
239 
240     /**
241      * Helper method to declare a field on the ObjectParser.
242      */
243     static <C extends Condition, QB extends AbstractRouterQueryBuilder<C, QB>> void declareRouterFields(
244             ObjectParser<QB, Void> parser,
245             ContextParser<Void, C> objectParser) {
246         parser.declareObjectArray(QB::conditions, objectParser, CONDITIONS);
247         parser.declareObject(QB::fallback,
248             (p, ctx) -> parseInnerQueryBuilder(p),
249             FALLBACK);
250     }
251 
252     /**
253      * Helper method to declare a condition field on the ObjectParser.
254      */
255     static <CPS extends AbstractConditionParserState<?>> void declareConditionFields(
256             ObjectParser<CPS, Void> parser) {
257         for (ConditionDefinition def : ConditionDefinition.values()) {
258             // gt: int, addPredicate will fail if a predicate has already been set
259             parser.declareInt((cps, value) -> cps.addPredicate(def, value), def.parseField);
260         }
261         // query: { }
262         parser.declareObject(CPS::setQuery,
263                 (p, ctx) -> parseInnerQueryBuilder(p),
264                 QUERY);
265     }
266 
267     /**
268      * Parser state for a Condition.
269      */
270     abstract static class AbstractConditionParserState<C extends Condition> {
271         @Nullable protected ConditionDefinition definition;
272         protected int value;
273         @Nullable protected QueryBuilder query;
274 
275         /**
276          * Add a new predicate to the state.
277          */
278         void addPredicate(ConditionDefinition def, int value) {
279             if (this.definition != null) {
280                 throw new IllegalArgumentException("Cannot set extra predicate [" + def.parseField + "] " +
281                         "on condition: [" + this.definition.parseField + "] already set");
282             }
283             this.definition = def;
284             this.value = value;
285         }
286 
287         /**
288          * Set the query associated to the condition.
289          */
290         protected void setQuery(QueryBuilder query) {
291             this.query = query;
292         }
293 
294         /**
295          * The parsed Condition.
296          */
297         abstract C condition();
298 
299         /**
300          * Check validity of the condition.
301          * @throws IllegalArgumentException if the condition is missing mandatory fields
302          */
303         void checkValid() {
304             if (query == null) {
305                 throw new IllegalArgumentException("Missing field [query] in condition");
306             }
307             if (definition == null) {
308                 throw new IllegalArgumentException("Missing condition predicate in condition");
309             }
310         }
311     }
312 
313     /**
314      * Simple parser state for conditions.
315      */
316     static class ConditionParserState extends AbstractConditionParserState<Condition> {
317         @Override
318         Condition condition() {
319             return new Condition(definition, value, query);
320         }
321     }
322 
323     /**
324      * A predicate that accepts two values.
325      */
326     @FunctionalInterface
327     interface BiLongPredicate {
328         /** test a versus b. */
329         boolean test(long a, long b);
330     }
331 
332     /**
333      * List of supported conditions and predicates.
334      */
335     public enum ConditionDefinition implements BiLongPredicate, Writeable {
336         eq((a, b) -> a == b),
337         neq((a, b) -> a != b),
338         lte((a, b) -> a <= b),
339         lt((a, b) -> a < b),
340         gte((a, b) -> a >= b),
341         gt((a, b) -> a > b);
342 
343         final ParseField parseField;
344         final BiLongPredicate predicate;
345 
346         ConditionDefinition(BiLongPredicate predicate) {
347             this.predicate = predicate;
348             this.parseField = new ParseField(name());
349         }
350 
351         /**
352          * Test the condition.
353          */
354         @Override
355         public boolean test(long a, long b) {
356             return predicate.test(a, b);
357         }
358 
359         @Override
360         public void writeTo(StreamOutput out) throws IOException {
361             out.writeVInt(ordinal());
362         }
363 
364         /**
365          * Read from a StreamInput.
366          */
367         static ConditionDefinition readFrom(StreamInput in) throws IOException {
368             int ord = in.readVInt();
369             if (ord < 0 || ord >= ConditionDefinition.values().length) {
370                 throw new IOException("Unknown ConditionDefinition ordinal [" + ord + "]");
371             }
372             return ConditionDefinition.values()[ord];
373         }
374 
375     }
376 
377     /**
378      * Simple stream view of the conditions.
379      */
380     @VisibleForTesting
381     Stream<C> conditionStream() {
382         return conditions.stream();
383     }
384 
385     /**
386      * Add a condition.
387      */
388     @VisibleForTesting
389     void condition(C condition) {
390         conditions.add(condition);
391     }
392 }