GlentControl.java

/*
 * Copyright (C) 2019 Glenbrook Networks
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package org.wikimedia.search.glent;

import static org.apache.spark.sql.functions.lit;

import java.util.Arrays;
import java.util.Map;
import java.util.Set;

import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.StructType;

import com.beust.jcommander.JCommander;
import com.google.common.collect.Sets;

/**
 *
 * @author Julia.Glen
 *
 */
public final class GlentControl {

    private GlentControl() {}

    public static void main(String...args) {
        Params params = new Params();

        JCommander.Builder builder = JCommander.newBuilder()
                .addObject(params);
        Map<String, Method> methods = Method.methods();
        for (Map.Entry<String, Method> entry : methods.entrySet()) {
            builder.addCommand(entry.getKey(), entry.getValue());
        }
        JCommander jCommander = builder.build();
        jCommander.parse(args);

        if (params.help) {
            jCommander.usage();
            System.exit(1);
        }

        methods.get(jCommander.getParsedCommand()).accept(initSpark());
    }

    /**
     * Initialize the spark session.
     * @return The spark session
     */
    private static SparkSession initSpark() {
        return  SparkSession.builder()
                .appName("glent")
                .enableHiveSupport()
                // TODO: Why? Something to do with writeDf and insert overwrite partition?
                .config("spark.hadoop.hive.exec.dynamic.partition", "true")
                // Required for insertInto to only overwrite the partitions being written to and not the whole table
                .config("spark.sql.sources.partitionOverwriteMode", "dynamic")
                // Nonstrict is required to allow partition selection on a per-row basis. We can't specify anything
                // more strict from park
                .config("hive.exec.dynamic.partition.mode", "nonstrict")
                .getOrCreate();
    }

    /**
     * Write dataframe to provided location.
     * @param df The dataset
     * @param dbAndTable Output data path or table name
     * @param partition Name of partition to write to in dbAndTable
     * @param maxPartitions number of partitions to save output as
     */
    static void writeDf(Dataset<Row> df, final String dbAndTable,
                        final Map<String, String> partition, int maxPartitions) {
        if (dbAndTable.contains("/")) {
            df
                .write()
                // .csv.gz is primarily for debugging purposes
                .option("compression", "gzip")
                .csv(dbAndTable);
        } else {
            if (maxPartitions > 0) {
                df = df.coalesce(maxPartitions);
            }
            prepareDatasetForInsertion(df, dbAndTable, partition)
                .write()
                .mode(SaveMode.Overwrite)
                .insertInto(dbAndTable);
        }
    }

    private static Dataset<Row> prepareDatasetForInsertion(Dataset<Row> df, String table, Map<String, String> partition) {
        StructType expectSchema = df.sparkSession().read().table(table).schema();
        Set<String> sourceColumns = Sets.newHashSet(df.columns());
        Set<String> targetColumns = Sets.newHashSet(expectSchema.names());

        for (Map.Entry<String, String> entry : partition.entrySet()) {
            String name = entry.getKey();
            if (sourceColumns.contains(name)) {
                throw new IllegalArgumentException("Partition key " + name + " overwriting dataframe column");
            }
            if (!targetColumns.contains(name)) {
                throw new IllegalArgumentException("Partition key " + name + " not found in target table");
            }
            df = df.withColumn(name, lit(entry.getValue()).cast(expectSchema.apply(name).dataType()));
        }
        // insertInto works off the order of columns, so we need to re-order df to match the schema
        return df.select(Arrays.stream(expectSchema.names()).map(functions::col).toArray(Column[]::new));
    }
}