View Javadoc
1   package org.wikimedia.search.extra.termfreq;
2   
3   import java.io.IOException;
4   import java.util.Objects;
5   import java.util.Set;
6   import java.util.function.IntPredicate;
7   
8   import org.apache.lucene.index.LeafReaderContext;
9   import org.apache.lucene.index.PostingsEnum;
10  import org.apache.lucene.index.Term;
11  import org.apache.lucene.search.DocIdSetIterator;
12  import org.apache.lucene.search.Explanation;
13  import org.apache.lucene.search.IndexSearcher;
14  import org.apache.lucene.search.Query;
15  import org.apache.lucene.search.ScoreMode;
16  import org.apache.lucene.search.Scorer;
17  import org.apache.lucene.search.TwoPhaseIterator;
18  import org.apache.lucene.search.Weight;
19  import org.wikimedia.search.extra.util.ConcreteIntPredicate;
20  
21  public class TermFreqFilterQuery extends Query {
22      private final Term term;
23      private final ConcreteIntPredicate predicate;
24  
25      public TermFreqFilterQuery(Term term, ConcreteIntPredicate predicate) {
26          this.term = term;
27          this.predicate = predicate;
28      }
29  
30      public Term getTerm() {
31          return term;
32      }
33  
34      public IntPredicate getPredicate() {
35          return predicate;
36      }
37  
38      @Override
39      public String toString(String field) {
40          StringBuilder buffer = new StringBuilder("term_freq(");
41          if (!this.term.field().equals(field)) {
42              buffer.append(this.term.field());
43              buffer.append(':');
44          }
45  
46          buffer.append(this.term.text());
47          buffer.append(',');
48          buffer.append(predicate);
49          buffer.append(')');
50          return buffer.toString();
51      }
52  
53      @Override
54      public boolean equals(Object o) {
55          if (this == o) return true;
56          if (o == null || getClass() != o.getClass()) return false;
57          TermFreqFilterQuery that = (TermFreqFilterQuery) o;
58          return Objects.equals(term, that.term) &&
59                  Objects.equals(predicate, that.predicate);
60      }
61  
62      @Override
63      public int hashCode() {
64          return Objects.hash(classHash(), term, predicate);
65      }
66  
67      @Override
68      public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) {
69          return new TermFreqFilterWeight(this, term, predicate);
70      }
71  
72      private static final class TermFreqFilterWeight extends Weight {
73          private final Term term;
74          private final IntPredicate predicate;
75  
76          private TermFreqFilterWeight(Query q, Term term, IntPredicate predicate) {
77              super(q);
78              this.term = term;
79              this.predicate = predicate;
80          }
81  
82          @Override
83          public void extractTerms(Set<Term> set) {
84              set.add(term);
85          }
86  
87          @Override
88          public Explanation explain(LeafReaderContext leafReaderContext, int i) throws IOException {
89              PostingsEnum postings = leafReaderContext.reader().postings(term);
90              if (postings != null && postings.advance(i) == i) {
91                  String desc = "freq:" + postings.freq() + " " + predicate + " (" + term + ")";
92                  if (predicate.test(postings.freq())) {
93                      return Explanation.match(postings.freq(), desc);
94                  } else {
95                      return Explanation.noMatch(desc);
96                  }
97              } else {
98                  return Explanation.noMatch("(" + term + ")");
99              }
100         }
101 
102         @Override
103         public Scorer scorer(LeafReaderContext leafReaderContext) throws IOException {
104             PostingsEnum innerDocs = leafReaderContext.reader().postings(term);
105             if (innerDocs == null) {
106                 return null;
107             }
108             TwoPhaseIterator iter = new TwoPhaseIterator(innerDocs) {
109                 @Override
110                 public boolean matches() throws IOException {
111                     return predicate.test(innerDocs.freq());
112                 }
113 
114                 @Override
115                 public float matchCost() {
116                     return 3;
117                 }
118             };
119 
120             return new Scorer(this) {
121                 @Override
122                 public int docID() {
123                     return innerDocs.docID();
124                 }
125 
126                 @Override
127                 public float score() throws IOException {
128                     return innerDocs.freq();
129                 }
130 
131                 @Override
132                 public DocIdSetIterator iterator() {
133                     return innerDocs;
134                 }
135 
136                 @Override
137                 public TwoPhaseIterator twoPhaseIterator() {
138                     return iter;
139                 }
140 
141                 @Override
142                 public float getMaxScore(int upTo) {
143                     return Float.MAX_VALUE;
144                 }
145             };
146         }
147 
148         @Override
149         public boolean isCacheable(LeafReaderContext leafReaderContext) {
150             return true;
151         }
152     }
153 }