SuggestionAggregator.java

package org.wikimedia.search.glent;

import static org.apache.spark.sql.functions.array_contains;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.collect_set;
import static org.apache.spark.sql.functions.lit;
import static org.apache.spark.sql.functions.log10;
import static org.apache.spark.sql.functions.max;
import static org.apache.spark.sql.functions.row_number;
import static org.apache.spark.sql.functions.sum;
import static org.apache.spark.sql.functions.when;

import java.util.function.Function;

import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.Window;
import org.apache.spark.sql.expressions.WindowSpec;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class SuggestionAggregator implements Function<Dataset<Row>, Dataset<Row>> {
    static final StructType SCHEMA_OUT = DataTypes.createStructType(new StructField[] {
            DataTypes.createStructField("method",
                    DataTypes.createArrayType(DataTypes.StringType, false), false),
            DataTypes.createStructField("wiki", DataTypes.StringType, false),
            DataTypes.createStructField("suggestion_score", DataTypes.FloatType, false),
            DataTypes.createStructField("query", DataTypes.StringType, false),
            DataTypes.createStructField("dym", DataTypes.StringType, false)
    });

    private static class SchemaInconsistencyException extends RuntimeException {
        SchemaInconsistencyException(String message, StructType schema) {
            super(message + "\n" + schema.treeString());
        }

    }

    public Dataset<Row> apply(Dataset<Row> df) {
        df = mergeAlgorithms(df);
        df = withSuggestionScore(df);
        df = filterToBestSuggestions(df);
        // Rename to match CirrusSearch template
        df = df.select(
                col("algos").alias("method"),
                col("wikiid").alias("wiki"),
                col("suggestion_score"),
                col("query"), col("dym"));
        StructType schema = df.schema();
        if (!schema.sameType(SCHEMA_OUT)) {
            throw new SchemaInconsistencyException(
                    "Output does not match expected schema!", schema);
        }
        return df;
    }

    /**
     * Merge results from multiple algorithms so we have a single row per
     * (query, suggestion) pair.
     */
    private Dataset<Row> mergeAlgorithms(Dataset<Row> df) {
        return df
                .groupBy("query", "dym", "q1q2EditDist", "wikiid", "lang")
                .agg(
                        collect_set("algo").alias("algos"),
                        max("queryHitsTotal").alias("queryHitsTotal"),
                        max("dymHitsTotal").alias("dymHitsTotal"),
                        sum("suggCount").alias("suggCount"));
    }

    private Dataset<Row> withSuggestionScore(Dataset<Row> df) {
        return df.withColumn("suggestion_score",
                score("algos", "q1q2EditDist", "dymHitsTotal", "suggCount"));
    }
    /**
     * Choose best suggestion for all (wiki, query) sets.
     */
    private Dataset<Row> filterToBestSuggestions(Dataset<Row> df) {
        WindowSpec w = Window
                .partitionBy("wikiid", "query")
                .orderBy(col("suggestion_score").desc());
        return df
        // row_number is 1-indexed. Go figure.
                .withColumn("rn", row_number().over(w))
                .where(col("rn").equalTo(1))
                .drop("rn");
    }

    private Column frhedScore(Column suggCount, Column suggHits, Column editDist) {
        // adjHits = suggHits < 10_000 ? suggHits : 10_000 + suggHits/10_000;
        Column tenK = lit(10_000F);
        Column adjHits = when(suggHits.lt(tenK), suggHits)
                        .otherwise(tenK.plus(suggHits.divide(tenK)));

        // logGeoMean = log(∛(suggCount^2 * adjHits)) = (2*log(suggCount) + log(adjHits))/3
        Column logGeoMean = log10(suggCount).multiply(lit(2)).plus(log10(adjHits)).divide(lit(3));

        // frhedScore = log(suggCount^2 * adjHits)/3 - editDist
        return logGeoMean.minus(editDist);
    }

    private Column score(String algos, String dist, String hitsTotal, String suggCount) {
        return score(col(algos), col(dist), col(hitsTotal), col(suggCount));
    }

    private Column score(Column algos, Column dist, Column hitsTotal, Column suggCount) {
        // make sure M0 beats M1; M2 should be acting alone

        // Convert M2 to 10 - dist; for M0, M1, calculate frhedScore
        Column adjustedDist = when(array_contains(algos, Method.M2Run.ALGO),
            lit(10F).minus(dist)).otherwise(frhedScore(suggCount, hitsTotal, dist));

        // M0 has higher precision than M1, give it a boost
        Column m0runBoost = when(array_contains(algos, Method.M0Run.ALGO), lit(20F))
            .otherwise(lit(0F));

        return adjustedDist.cast(DataTypes.FloatType).plus(m0runBoost);
    }
}