1 package org.wikimedia.search.extra.levenshtein;
2
3 import java.io.IOException;
4 import java.util.Objects;
5
6 import javax.annotation.Nullable;
7
8 import org.elasticsearch.common.ParseField;
9 import org.elasticsearch.common.ParsingException;
10 import org.elasticsearch.common.io.stream.StreamInput;
11 import org.elasticsearch.common.io.stream.StreamOutput;
12 import org.elasticsearch.common.lucene.search.function.ScoreFunction;
13 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
14 import org.elasticsearch.common.xcontent.XContentBuilder;
15 import org.elasticsearch.common.xcontent.XContentParser;
16 import org.elasticsearch.index.mapper.MappedFieldType;
17 import org.elasticsearch.index.query.QueryShardContext;
18 import org.elasticsearch.index.query.QueryShardException;
19 import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
20
21 import lombok.Setter;
22 import lombok.experimental.Accessors;
23
24
25
26
27 @Accessors(chain = true, fluent = true)
28 public class LevenshteinDistanceScoreBuilder extends ScoreFunctionBuilder<LevenshteinDistanceScoreBuilder> {
29 public static final ParseField NAME = new ParseField("levenshtein_distance_score", "levenshteinDistanceScore");
30 public static final ParseField FIELD = new ParseField("field");
31 public static final ParseField TEXT = new ParseField("text");
32 public static final ParseField MISSING = new ParseField("missing");
33
34 private final String field;
35 private final String text;
36 @Nullable @Setter private String missing;
37
38 private static final ConstructingObjectParser<LevenshteinDistanceScoreBuilder, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(),
39 params -> new LevenshteinDistanceScoreBuilder((String) params[0], (String) params[1]));
40
41 static {
42 PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
43 PARSER.declareString(ConstructingObjectParser.constructorArg(), TEXT);
44 PARSER.declareString((b, s) -> b.missing = s, MISSING);
45 }
46 public LevenshteinDistanceScoreBuilder(String field, String text) {
47 this.field = field;
48 this.text = text;
49 }
50
51 public LevenshteinDistanceScoreBuilder(StreamInput in) throws IOException {
52 super(in);
53 field = in.readString();
54 text = in.readString();
55 missing = in.readOptionalString();
56 }
57
58 @Override
59 public String getName() {
60 return NAME.getPreferredName();
61 }
62
63 @Override
64 protected void doXContent(XContentBuilder builder, Params params) throws IOException {
65 builder.startObject(getName());
66 builder.field(FIELD.getPreferredName(), field);
67 builder.field(TEXT.getPreferredName(), text);
68 if (missing != null) {
69 builder.field(MISSING.getPreferredName(), missing);
70 }
71 builder.endObject();
72 }
73
74 @Override
75 protected void doWriteTo(StreamOutput out) throws IOException {
76 out.writeString(field);
77 out.writeString(text);
78 out.writeOptionalString(missing);
79 }
80
81 @Override
82 protected boolean doEquals(LevenshteinDistanceScoreBuilder other) {
83 return Objects.equals(field, other.field)
84 && Objects.equals(text, other.text)
85 && Objects.equals(missing, other.missing);
86 }
87
88 @Override
89 protected int doHashCode() {
90 return Objects.hash(field, text, missing);
91 }
92
93 @Override
94 protected ScoreFunction doToFunction(QueryShardContext context) {
95 MappedFieldType fieldType = context.getMapperService().fieldType(field);
96
97 if (fieldType == null) {
98 throw new QueryShardException(context, "Unable to load field type for field {}", field);
99 }
100 return new LevenshteinDistanceScore(context.lookup(), fieldType, text, missing);
101 }
102
103 public static LevenshteinDistanceScoreBuilder fromXContent(XContentParser parser) throws IOException, ParsingException {
104 return PARSER.parse(parser, null);
105 }
106 }