View Javadoc
1   package org.wikimedia.search.extra.router;
2   
3   import java.io.IOException;
4   import java.util.Objects;
5   
6   import javax.annotation.Nullable;
7   
8   import org.apache.lucene.analysis.Analyzer;
9   import org.apache.lucene.analysis.TokenStream;
10  import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
11  import org.elasticsearch.common.ParseField;
12  import org.elasticsearch.common.ParsingException;
13  import org.elasticsearch.common.io.stream.StreamInput;
14  import org.elasticsearch.common.io.stream.StreamOutput;
15  import org.elasticsearch.common.xcontent.ObjectParser;
16  import org.elasticsearch.common.xcontent.XContentBuilder;
17  import org.elasticsearch.common.xcontent.XContentParser;
18  import org.elasticsearch.index.mapper.MappedFieldType;
19  import org.elasticsearch.index.mapper.MapperService;
20  import org.elasticsearch.index.query.QueryBuilder;
21  import org.elasticsearch.index.query.QueryRewriteContext;
22  import org.elasticsearch.index.query.QueryShardContext;
23  import org.wikimedia.search.extra.router.AbstractRouterQueryBuilder.Condition;
24  
25  import com.google.common.annotations.VisibleForTesting;
26  
27  import lombok.Getter;
28  import lombok.Setter;
29  import lombok.experimental.Accessors;
30  
31  /**
32   * Builds a token_count_router query.
33   */
34  @Getter
35  @Setter
36  @Accessors(fluent = true, chain = true)
37  public class TokenCountRouterQueryBuilder extends AbstractRouterQueryBuilder<Condition, TokenCountRouterQueryBuilder> {
38      public static final ParseField NAME = new ParseField("token_count_router");
39      private static final ParseField TEXT = new ParseField("text");
40      private static final ParseField FIELD = new ParseField("field");
41      private static final ParseField ANALYZER = new ParseField("analyzer");
42      private static final ParseField DISCOUNT_OVERLAPS = new ParseField("discount_overlaps");
43      private static final boolean DEFAULT_DISCOUNT_OVERLAPS = true;
44  
45      private static final ObjectParser<TokenCountRouterQueryBuilder, Void> PARSER;
46      private static final ObjectParser<ConditionParserState, Void> COND_PARSER;
47  
48      static {
49          COND_PARSER = new ObjectParser<>("condition", ConditionParserState::new);
50          declareConditionFields(COND_PARSER);
51  
52          PARSER = new ObjectParser<>(NAME.getPreferredName(), TokenCountRouterQueryBuilder::new);
53          PARSER.declareString(TokenCountRouterQueryBuilder::text, TEXT);
54          PARSER.declareString(TokenCountRouterQueryBuilder::field, FIELD);
55          PARSER.declareString(TokenCountRouterQueryBuilder::analyzer, ANALYZER);
56          PARSER.declareBoolean(TokenCountRouterQueryBuilder::discountOverlaps, DISCOUNT_OVERLAPS);
57          declareRouterFields(PARSER, (p, pc) -> parseCondition(COND_PARSER, p));
58          declareStandardFields(PARSER);
59      }
60  
61  
62      @Nullable private String analyzer;
63      @Nullable private String field;
64      private boolean discountOverlaps = DEFAULT_DISCOUNT_OVERLAPS;
65      @Nullable private String text;
66  
67      public TokenCountRouterQueryBuilder() {
68          super();
69      }
70  
71      public TokenCountRouterQueryBuilder(StreamInput in) throws IOException {
72          super(in, Condition::new);
73          analyzer = in.readOptionalString();
74          field = in.readOptionalString();
75          discountOverlaps = in.readBoolean();
76          text = in.readString();
77      }
78  
79      @Override
80      protected void doWriteTo(StreamOutput out) throws IOException {
81          super.doWriteTo(out);
82          out.writeOptionalString(analyzer);
83          out.writeOptionalString(field);
84          out.writeBoolean(discountOverlaps);
85          out.writeString(text);
86      }
87  
88      @Override
89      public String getWriteableName() {
90          return NAME.getPreferredName();
91      }
92  
93      @Override
94      protected void addXContent(XContentBuilder builder, Params params) throws IOException {
95          if (analyzer != null) {
96              builder.field(ANALYZER.getPreferredName(), analyzer);
97          }
98          if (field != null) {
99              builder.field(FIELD.getPreferredName(), field);
100         }
101         if (discountOverlaps != DEFAULT_DISCOUNT_OVERLAPS) {
102             builder.field(DISCOUNT_OVERLAPS.getPreferredName(), discountOverlaps);
103         }
104         if (text != null) {
105             builder.field(TEXT.getPreferredName(), text);
106         }
107     }
108 
109     @SuppressWarnings("ResultOfMethodCallIgnored")
110     public static TokenCountRouterQueryBuilder fromXContent(XContentParser parser) throws IOException {
111         TokenCountRouterQueryBuilder builder = AbstractRouterQueryBuilder.fromXContent(PARSER, parser);
112 
113         if (builder.text == null) {
114             throw new ParsingException(parser.getTokenLocation(), "No text provided");
115         }
116 
117         if (builder.analyzer == null && builder.field == null) {
118             throw new ParsingException(parser.getTokenLocation(), "Missing field or analyzer definition");
119         }
120 
121         return builder;
122     }
123 
124     private Analyzer resolveAnalyzer(QueryShardContext context) {
125         final Analyzer luceneAnalyzer;
126         MapperService mapper = context.getMapperService();
127         if (analyzer != null) {
128             luceneAnalyzer = mapper.getIndexAnalyzers().get(analyzer);
129             if (luceneAnalyzer == null) {
130                 throw new IllegalArgumentException("Unknown analyzer [" + analyzer + "]");
131             }
132         } else if (field != null) {
133             MappedFieldType fieldMapper = mapper.fieldType(field);
134             if (fieldMapper == null) {
135                 throw new IllegalArgumentException("Unknown field [" + field + "]");
136             }
137 
138             luceneAnalyzer = context.getSearchAnalyzer(fieldMapper);
139         } else {
140             throw new IllegalArgumentException("field or analyzer must be set");
141         }
142         if (text == null) {
143             throw new IllegalArgumentException("text cannot be null");
144         }
145         return luceneAnalyzer;
146     }
147 
148     @Override
149     public QueryBuilder doRewrite(QueryRewriteContext context) throws IOException {
150         QueryShardContext shardContext = context.convertToShardContext();
151         if (shardContext != null) {
152             final Analyzer luceneAnalyzer = resolveAnalyzer(shardContext);
153             final int count = countToken(luceneAnalyzer, text, discountOverlaps);
154             return super.doRewrite((c) -> c.test(count));
155         }
156         return this;
157     }
158 
159     static int countToken(Analyzer analyzer, String text, boolean discountOverlaps) throws IOException {
160         try (TokenStream ts = analyzer.tokenStream("", text)) {
161             ts.reset();
162             int count = 0;
163             PositionIncrementAttribute posInc = ts.getAttribute(PositionIncrementAttribute.class);
164             while (ts.incrementToken()) {
165                 if (!discountOverlaps || posInc.getPositionIncrement() > 0) {
166                     count++;
167                 }
168             }
169             return count;
170         }
171     }
172 
173     @Override
174     protected boolean doEquals(TokenCountRouterQueryBuilder other) {
175         return super.doEquals(other) &&
176                 Objects.equals(text, other.text) &&
177                 Objects.equals(field, other.field) &&
178                 Objects.equals(analyzer, other.analyzer) &&
179                 Objects.equals(discountOverlaps, other.discountOverlaps);
180     }
181 
182     @Override
183     protected int doHashCode() {
184         return Objects.hash(text, field, analyzer, discountOverlaps, super.doHashCode());
185     }
186 
187     @VisibleForTesting
188     public TokenCountRouterQueryBuilder condition(ConditionDefinition def, int value, QueryBuilder qb) {
189         condition(new Condition(def, value, qb));
190         return this;
191     }
192 }