1 package org.wikimedia.search.extra.router;
2
3 import java.io.IOException;
4 import java.util.Objects;
5
6 import javax.annotation.Nullable;
7
8 import org.apache.lucene.analysis.Analyzer;
9 import org.apache.lucene.analysis.TokenStream;
10 import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
11 import org.elasticsearch.common.ParseField;
12 import org.elasticsearch.common.ParsingException;
13 import org.elasticsearch.common.io.stream.StreamInput;
14 import org.elasticsearch.common.io.stream.StreamOutput;
15 import org.elasticsearch.common.xcontent.ObjectParser;
16 import org.elasticsearch.common.xcontent.XContentBuilder;
17 import org.elasticsearch.common.xcontent.XContentParser;
18 import org.elasticsearch.index.mapper.MappedFieldType;
19 import org.elasticsearch.index.mapper.MapperService;
20 import org.elasticsearch.index.query.QueryBuilder;
21 import org.elasticsearch.index.query.QueryRewriteContext;
22 import org.elasticsearch.index.query.QueryShardContext;
23 import org.wikimedia.search.extra.router.AbstractRouterQueryBuilder.Condition;
24
25 import com.google.common.annotations.VisibleForTesting;
26
27 import lombok.Getter;
28 import lombok.Setter;
29 import lombok.experimental.Accessors;
30
31
32
33
34 @Getter
35 @Setter
36 @Accessors(fluent = true, chain = true)
37 public class TokenCountRouterQueryBuilder extends AbstractRouterQueryBuilder<Condition, TokenCountRouterQueryBuilder> {
38 public static final ParseField NAME = new ParseField("token_count_router");
39 private static final ParseField TEXT = new ParseField("text");
40 private static final ParseField FIELD = new ParseField("field");
41 private static final ParseField ANALYZER = new ParseField("analyzer");
42 private static final ParseField DISCOUNT_OVERLAPS = new ParseField("discount_overlaps");
43 private static final boolean DEFAULT_DISCOUNT_OVERLAPS = true;
44
45 private static final ObjectParser<TokenCountRouterQueryBuilder, Void> PARSER;
46 private static final ObjectParser<ConditionParserState, Void> COND_PARSER;
47
48 static {
49 COND_PARSER = new ObjectParser<>("condition", ConditionParserState::new);
50 declareConditionFields(COND_PARSER);
51
52 PARSER = new ObjectParser<>(NAME.getPreferredName(), TokenCountRouterQueryBuilder::new);
53 PARSER.declareString(TokenCountRouterQueryBuilder::text, TEXT);
54 PARSER.declareString(TokenCountRouterQueryBuilder::field, FIELD);
55 PARSER.declareString(TokenCountRouterQueryBuilder::analyzer, ANALYZER);
56 PARSER.declareBoolean(TokenCountRouterQueryBuilder::discountOverlaps, DISCOUNT_OVERLAPS);
57 declareRouterFields(PARSER, (p, pc) -> parseCondition(COND_PARSER, p));
58 declareStandardFields(PARSER);
59 }
60
61
62 @Nullable private String analyzer;
63 @Nullable private String field;
64 private boolean discountOverlaps = DEFAULT_DISCOUNT_OVERLAPS;
65 @Nullable private String text;
66
67 public TokenCountRouterQueryBuilder() {
68 super();
69 }
70
71 public TokenCountRouterQueryBuilder(StreamInput in) throws IOException {
72 super(in, Condition::new);
73 analyzer = in.readOptionalString();
74 field = in.readOptionalString();
75 discountOverlaps = in.readBoolean();
76 text = in.readString();
77 }
78
79 @Override
80 protected void doWriteTo(StreamOutput out) throws IOException {
81 super.doWriteTo(out);
82 out.writeOptionalString(analyzer);
83 out.writeOptionalString(field);
84 out.writeBoolean(discountOverlaps);
85 out.writeString(text);
86 }
87
88 @Override
89 public String getWriteableName() {
90 return NAME.getPreferredName();
91 }
92
93 @Override
94 protected void addXContent(XContentBuilder builder, Params params) throws IOException {
95 if (analyzer != null) {
96 builder.field(ANALYZER.getPreferredName(), analyzer);
97 }
98 if (field != null) {
99 builder.field(FIELD.getPreferredName(), field);
100 }
101 if (discountOverlaps != DEFAULT_DISCOUNT_OVERLAPS) {
102 builder.field(DISCOUNT_OVERLAPS.getPreferredName(), discountOverlaps);
103 }
104 if (text != null) {
105 builder.field(TEXT.getPreferredName(), text);
106 }
107 }
108
109 @SuppressWarnings("ResultOfMethodCallIgnored")
110 public static TokenCountRouterQueryBuilder fromXContent(XContentParser parser) throws IOException {
111 TokenCountRouterQueryBuilder builder = AbstractRouterQueryBuilder.fromXContent(PARSER, parser);
112
113 if (builder.text == null) {
114 throw new ParsingException(parser.getTokenLocation(), "No text provided");
115 }
116
117 if (builder.analyzer == null && builder.field == null) {
118 throw new ParsingException(parser.getTokenLocation(), "Missing field or analyzer definition");
119 }
120
121 return builder;
122 }
123
124 private Analyzer resolveAnalyzer(QueryShardContext context) {
125 final Analyzer luceneAnalyzer;
126 MapperService mapper = context.getMapperService();
127 if (analyzer != null) {
128 luceneAnalyzer = mapper.getIndexAnalyzers().get(analyzer);
129 if (luceneAnalyzer == null) {
130 throw new IllegalArgumentException("Unknown analyzer [" + analyzer + "]");
131 }
132 } else if (field != null) {
133 MappedFieldType fieldMapper = mapper.fieldType(field);
134 if (fieldMapper == null) {
135 throw new IllegalArgumentException("Unknown field [" + field + "]");
136 }
137
138 luceneAnalyzer = context.getSearchAnalyzer(fieldMapper);
139 } else {
140 throw new IllegalArgumentException("field or analyzer must be set");
141 }
142 if (text == null) {
143 throw new IllegalArgumentException("text cannot be null");
144 }
145 return luceneAnalyzer;
146 }
147
148 @Override
149 public QueryBuilder doRewrite(QueryRewriteContext context) throws IOException {
150 QueryShardContext shardContext = context.convertToShardContext();
151 if (shardContext != null) {
152 final Analyzer luceneAnalyzer = resolveAnalyzer(shardContext);
153 final int count = countToken(luceneAnalyzer, text, discountOverlaps);
154 return super.doRewrite((c) -> c.test(count));
155 }
156 return this;
157 }
158
159 static int countToken(Analyzer analyzer, String text, boolean discountOverlaps) throws IOException {
160 try (TokenStream ts = analyzer.tokenStream("", text)) {
161 ts.reset();
162 int count = 0;
163 PositionIncrementAttribute posInc = ts.getAttribute(PositionIncrementAttribute.class);
164 while (ts.incrementToken()) {
165 if (!discountOverlaps || posInc.getPositionIncrement() > 0) {
166 count++;
167 }
168 }
169 return count;
170 }
171 }
172
173 @Override
174 protected boolean doEquals(TokenCountRouterQueryBuilder other) {
175 return super.doEquals(other) &&
176 Objects.equals(text, other.text) &&
177 Objects.equals(field, other.field) &&
178 Objects.equals(analyzer, other.analyzer) &&
179 Objects.equals(discountOverlaps, other.discountOverlaps);
180 }
181
182 @Override
183 protected int doHashCode() {
184 return Objects.hash(text, field, analyzer, discountOverlaps, super.doHashCode());
185 }
186
187 @VisibleForTesting
188 public TokenCountRouterQueryBuilder condition(ConditionDefinition def, int value, QueryBuilder qb) {
189 condition(new Condition(def, value, qb));
190 return this;
191 }
192 }