GetMainSearchRequest.java
package org.wikimedia.search.glent.udf;
import static org.apache.spark.sql.functions.udf;
import java.io.Serializable;
import java.util.List;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import com.google.common.annotations.VisibleForTesting;
import scala.collection.JavaConverters;
import scala.collection.Seq;
public class GetMainSearchRequest implements Serializable {
private static final String FULL_TEXT = "full_text";
private final StructType reqType;
private final int indicesIdx;
private final int queryTypeIdx;
public GetMainSearchRequest(StructField reqArrayField) {
reqType = (StructType)((ArrayType)reqArrayField.dataType()).elementType();
indicesIdx = reqType.fieldIndex("indices");
queryTypeIdx = reqType.fieldIndex("querytype");
}
public Column apply(Column wiki, Column requests) {
UDF2<String, Seq<Row>, Row> udfMethod = this::getMainSearchReq;
return udf(udfMethod, reqType).apply(wiki, requests);
}
@VisibleForTesting
Row getMainSearchReq(String wiki, Seq<Row> requests) {
if (requests == null || requests.length() == 0) {
return null;
}
String prefix = wiki + "_";
for (Row request : JavaConverters.seqAsJavaListConverter(requests).asJava()) {
if (!FULL_TEXT.equals(request.getString(queryTypeIdx))) {
continue;
}
List<String> indices = request.getList(indicesIdx);
for (String index : indices) {
if (index != null && (index.equals(wiki) || index.startsWith(prefix))) {
return request;
}
}
}
return null;
}
}