CirrusLogLoader.java

package org.wikimedia.search.glent;

import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.udf;
import static org.apache.spark.sql.functions.unix_timestamp;

import java.time.Instant;
import java.util.Objects;
import java.util.OptionalInt;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF8;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;

import scala.collection.JavaConverters;
import scala.collection.Seq;

/**
 * Reads both the CirrusSearchRequestSet avro schema, as "old", and eventgate
 * mediawiki_cirrussearch_request jsonschema, as "new". Returns a unified view
 * over requested timespan.
 *
 * Around June 2019 the logging schema changed. Massage away the difference
 * between these input formats and union together the two datasets so we
 * have a cohesive dataset. The output shape matches the old schema.
 */
public class CirrusLogLoader {
    /**
     * reassembles the `requests` array with expected names and types.
     */
    private final UserDefinedFunction fixer;

    public CirrusLogLoader() {
        DataType outputSchema = DataTypes.createArrayType(DataTypes.createStructType(new StructField[]{
                DataTypes.createStructField("query", DataTypes.StringType, true),
                DataTypes.createStructField("querytype", DataTypes.StringType, true),
                DataTypes.createStructField("indices", DataTypes.createArrayType(DataTypes.StringType), true),
                DataTypes.createStructField("limit", DataTypes.IntegerType, true),
                DataTypes.createStructField("hitstotal", DataTypes.IntegerType, true),
                DataTypes.createStructField("hitsoffset", DataTypes.IntegerType, true),
                DataTypes.createStructField("namespaces", DataTypes.createArrayType(DataTypes.IntegerType), true),
                DataTypes.createStructField("syntax", DataTypes.createArrayType(DataTypes.StringType), true),
        }));

        fixer = udf((UDF8<Seq<String>, Seq<String>, Seq<Seq<String>>, Seq<Number>, Seq<Number>,
                Seq<Number>, Seq<Seq<Number>>, Seq<Seq<String>>, Row[]>) CirrusLogLoader::assembleRows, outputSchema);
    }

    public Dataset<Row> load(SparkSession spark, String name, Instant from, Instant to) {
        return spark.read().table(name)
                .where(new HourlyPartitionSelector().apply(from, to))
                .select(unix_timestamp(col("meta.dt"), "yyyy-MM-dd'T'HH:mm:ss'Z'").alias("ts"),
                        col("database").alias("wikiid"), col("source"),
                        col("identity"), col("http.client_ip").alias("ip"),
                        fixNewRequests(col("elasticsearch_requests")).alias("requests"));
    }

    private Column fixNewRequests(Column in) {
        return fixer.apply(
                in.getItem("query"), in.getItem("query_type"), in.getItem("indices"),
                in.getItem("limit"), in.getItem("hits_total"), in.getItem("hits_offset"),
                in.getItem("namespaces"), in.getItem("syntax"));
    }

    // This is probably OO overkill for null safe type conversion..
    private static final Assembler<String, String> STRINGS = nullSafe(Seq::apply);
    private static final Assembler<Seq<String>, Seq<String>> STRING_SEQS = nullSafe(Seq::apply);
    private static final Assembler<Number, Integer> INTS = nullSafe((seq, i) -> {
        Number x = seq.apply(i);
        return x == null ? null : x.intValue();
    });
    private static final Assembler<Seq<Number>, int[]> INT_SEQS = nullSafe((seq, i) -> {
        Seq<Number> x = seq.apply(i);
        return x == null ? null : JavaConverters.seqAsJavaList(x)
                .stream()
                .mapToInt(Number::intValue)
                .toArray();
    });

    private static Row[] assembleRows(Seq<String> query, Seq<String> queryType, Seq<Seq<String>> indices,
                                      Seq<Number> limit, Seq<Number> hitsTotal,
                                      Seq<Number> hitsOffset, Seq<Seq<Number>> namespaces,
                                      Seq<Seq<String>> syntax) {
        // Because we are transforming this before any kind of filtering is applied we
        // have to deal with everything potentially being null.
        OptionalInt optLength = Stream.of(
                    query, queryType, indices, limit, hitsTotal, hitsOffset, namespaces)
                .filter(Objects::nonNull)
                .mapToInt(Seq::size)
                .findFirst();
        if (!optLength.isPresent()) {
            // All null, implies the source was null
            return null;
        }
        return IntStream.range(0, optLength.getAsInt())
                .mapToObj(i -> RowFactory.create(
                    STRINGS.valueOf(query, i), STRINGS.valueOf(queryType, i),
                    STRING_SEQS.valueOf(indices, i), INTS.valueOf(limit, i),
                    INTS.valueOf(hitsTotal, i), INTS.valueOf(hitsOffset, i),
                    INT_SEQS.valueOf(namespaces, i), STRING_SEQS.valueOf(syntax, i)))
                .toArray(Row[]::new);
    }

    private interface Assembler<T, U> {
        U valueOf(Seq<T> seq, int i);
    }

    private static <T, U> Assembler<T, U> nullSafe(Assembler<T, U> nested) {
        return (seq, i) -> seq == null ? null : nested.valueOf(seq, i);
    }
}