1 package org.wikimedia.search.extra.termfreq; 2 3 import java.io.IOException; 4 import java.util.Objects; 5 import java.util.Set; 6 import java.util.function.IntPredicate; 7 8 import org.apache.lucene.index.LeafReaderContext; 9 import org.apache.lucene.index.PostingsEnum; 10 import org.apache.lucene.index.Term; 11 import org.apache.lucene.search.DocIdSetIterator; 12 import org.apache.lucene.search.Explanation; 13 import org.apache.lucene.search.IndexSearcher; 14 import org.apache.lucene.search.Query; 15 import org.apache.lucene.search.ScoreMode; 16 import org.apache.lucene.search.Scorer; 17 import org.apache.lucene.search.TwoPhaseIterator; 18 import org.apache.lucene.search.Weight; 19 import org.wikimedia.search.extra.util.ConcreteIntPredicate; 20 21 public class TermFreqFilterQuery extends Query { 22 private final Term term; 23 private final ConcreteIntPredicate predicate; 24 25 public TermFreqFilterQuery(Term term, ConcreteIntPredicate predicate) { 26 this.term = term; 27 this.predicate = predicate; 28 } 29 30 public Term getTerm() { 31 return term; 32 } 33 34 public IntPredicate getPredicate() { 35 return predicate; 36 } 37 38 @Override 39 public String toString(String field) { 40 StringBuilder buffer = new StringBuilder("term_freq("); 41 if (!this.term.field().equals(field)) { 42 buffer.append(this.term.field()); 43 buffer.append(':'); 44 } 45 46 buffer.append(this.term.text()); 47 buffer.append(','); 48 buffer.append(predicate); 49 buffer.append(')'); 50 return buffer.toString(); 51 } 52 53 @Override 54 public boolean equals(Object o) { 55 if (this == o) return true; 56 if (o == null || getClass() != o.getClass()) return false; 57 TermFreqFilterQuery that = (TermFreqFilterQuery) o; 58 return Objects.equals(term, that.term) && 59 Objects.equals(predicate, that.predicate); 60 } 61 62 @Override 63 public int hashCode() { 64 return Objects.hash(classHash(), term, predicate); 65 } 66 67 @Override 68 public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) { 69 return new TermFreqFilterWeight(this, term, predicate); 70 } 71 72 private static final class TermFreqFilterWeight extends Weight { 73 private final Term term; 74 private final IntPredicate predicate; 75 76 private TermFreqFilterWeight(Query q, Term term, IntPredicate predicate) { 77 super(q); 78 this.term = term; 79 this.predicate = predicate; 80 } 81 82 @Override 83 public void extractTerms(Set<Term> set) { 84 set.add(term); 85 } 86 87 @Override 88 public Explanation explain(LeafReaderContext leafReaderContext, int i) throws IOException { 89 PostingsEnum postings = leafReaderContext.reader().postings(term); 90 if (postings != null && postings.advance(i) == i) { 91 String desc = "freq:" + postings.freq() + " " + predicate + " (" + term + ")"; 92 if (predicate.test(postings.freq())) { 93 return Explanation.match(postings.freq(), desc); 94 } else { 95 return Explanation.noMatch(desc); 96 } 97 } else { 98 return Explanation.noMatch("(" + term + ")"); 99 } 100 } 101 102 @Override 103 public Scorer scorer(LeafReaderContext leafReaderContext) throws IOException { 104 PostingsEnum innerDocs = leafReaderContext.reader().postings(term); 105 if (innerDocs == null) { 106 return null; 107 } 108 TwoPhaseIterator iter = new TwoPhaseIterator(innerDocs) { 109 @Override 110 public boolean matches() throws IOException { 111 return predicate.test(innerDocs.freq()); 112 } 113 114 @Override 115 public float matchCost() { 116 return 3; 117 } 118 }; 119 120 return new Scorer(this) { 121 @Override 122 public int docID() { 123 return innerDocs.docID(); 124 } 125 126 @Override 127 public float score() throws IOException { 128 return innerDocs.freq(); 129 } 130 131 @Override 132 public DocIdSetIterator iterator() { 133 return innerDocs; 134 } 135 136 @Override 137 public TwoPhaseIterator twoPhaseIterator() { 138 return iter; 139 } 140 141 @Override 142 public float getMaxScore(int upTo) { 143 return Float.MAX_VALUE; 144 } 145 }; 146 } 147 148 @Override 149 public boolean isCacheable(LeafReaderContext leafReaderContext) { 150 return true; 151 } 152 } 153 }