GetMainSearchRequest.java

package org.wikimedia.search.glent.udf;

import static org.apache.spark.sql.functions.udf;

import java.io.Serializable;
import java.util.List;

import org.apache.spark.sql.Column;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import com.google.common.annotations.VisibleForTesting;

import scala.collection.JavaConverters;
import scala.collection.Seq;

public class GetMainSearchRequest implements Serializable {
    private static final String FULL_TEXT = "full_text";

    private final StructType reqType;
    private final int indicesIdx;
    private final int queryTypeIdx;

    public GetMainSearchRequest(StructField reqArrayField) {
        reqType = (StructType)((ArrayType)reqArrayField.dataType()).elementType();
        indicesIdx = reqType.fieldIndex("indices");
        queryTypeIdx = reqType.fieldIndex("querytype");
    }

    public Column apply(Column wiki, Column requests) {
        UDF2<String, Seq<Row>, Row> udfMethod = this::getMainSearchReq;
        return udf(udfMethod, reqType).apply(wiki, requests);
    }

    @VisibleForTesting
    Row getMainSearchReq(String wiki, Seq<Row> requests) {
        if (requests == null || requests.length() == 0) {
            return null;
        }
        String prefix = wiki + "_";
        for (Row request : JavaConverters.seqAsJavaListConverter(requests).asJava()) {
            if (!FULL_TEXT.equals(request.getString(queryTypeIdx))) {
                continue;
            }
            List<String> indices = request.getList(indicesIdx);
            for (String index : indices) {
                if (index != null && (index.equals(wiki) || index.startsWith(prefix))) {
                    return request;
                }
            }
        }
        return null;
    }
}