SparkUtils.scala
package org.wikidata.query.rdf.spark.utils
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.spark.sql.SaveMode.Overwrite
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import java.util.regex.Pattern
import scala.util.matching.Regex
object SparkUtils {
private val partitionRegex: Regex = "^([\\w]+)=([\\w.-/]+)$".r
def renameSparkPartitions(source: String, format: String, filter: Pattern, flag: Option[String] = Some("_RENAMED"))(implicit spark: SparkSession): Unit = {
val fs = FileSystem.get(spark.sparkContext.hadoopConfiguration)
if (!Pattern.compile("%\\d*d").matcher(format).find()) {
throw new IllegalArgumentException(s"Invalid format $format: it should have a %d place-holder for the partition index.")
}
fs.listStatus(new Path(source)).filter { file =>
file.isFile && filter.matcher(file.getPath.getName).find()
}.zipWithIndex.foreach { case (file: FileStatus, idx: Int) =>
fs.rename(file.getPath, new Path(file.getPath.getParent, format.format(idx + 1)))
}
flag foreach { f =>
val out = fs.create(new Path(source, f))
out.writeUTF("renamed")
out.close()
}
}
def readTablePartition(tableAndPartitionSpecs: String)(implicit spark: SparkSession): DataFrame = {
applyTablePartitions[DataFrame, DataFrame](tableAndPartitionSpecs,
spark.read.table,
(column, value, df) => df.filter(df(column).equalTo(lit(value))),
(_, df) => df)
}
def insertIntoTablePartition(tableAndPartitionSpecs: String,
dataFrame: DataFrame,
saveMode: SaveMode = SaveMode.Overwrite,
format: Option[String] = None
)(implicit spark: SparkSession): Unit = {
def insertIntoFunction(table: String, df: DataFrame): Unit = {
// reorder the columns according to the target table because
// DataFrame.insertInto only care about column position
val dfw = df.select(spark.read.table(table).schema.fields.map(e => {
e.dataType match {
case t@(_: StringType | _: VarcharType | _: CharType | _: NumericType | _: BooleanType | _: DateType | _: TimestampType) => df(e.name).cast(t)
case _ => df(e.name)
}
}): _*)
.write.mode(saveMode)
format.foreach(dfw.format)
dfw.insertInto(table)
}
applyTablePartitions[DataFrame, Unit](tableAndPartitionSpecs,
_ => dataFrame,
(column, value, df) => df.withColumn(column, lit(value)),
insertIntoFunction)
}
private def applyTablePartitions[E, O](
tableAndPartitionSpecs: String,
tablePreOp: String => E,
partitionOp: (String, String, E) => E,
tablePostOp: (String, E) => O
): O = {
val tableAndPartitions: Array[String] = tableAndPartitionSpecs.split("/", 2)
tableAndPartitions match {
case Array(table, partition) => tablePostOp(table, applyPartitions(tablePreOp(table), partition, partitionOp))
case Array(table) => tablePostOp(table, tablePreOp(table))
case _ => throw new IllegalArgumentException("Invalid table or partition specifications: [" + tableAndPartitionSpecs + "]")
}
}
private def applyPartitions[E](input: E, partitionSpec: String, func: (String, String, E) => E): E = {
var df = input
partitionSpec.split("/") foreach {
partitionRegex.findFirstMatchIn(_) match {
case Some(m) =>
df = func(m.group(1), m.group(2), df)
case None =>
throw new IllegalArgumentException("Invalid partition specifications: [" + partitionSpec + "]")
}
}
df
}
def getSparkSession(appName: String): SparkSession = {
SparkSession
.builder()
// required because spark would fail with:
// Exception in thread "main" org.apache.spark.SparkException: Dynamic partition strict mode requires
// at least one static partition column. To turn this off set // hive.exec.dynamic.partition.mode=nonstrict
.config("hive.exec.dynamic.partition", value = true)
.config("hive.exec.dynamic.partition.mode", "non-strict")
// Allows overwriting the target partitions
.config("spark.sql.sources.partitionOverwriteMode", "dynamic")
.appName(appName)
.getOrCreate()
}
def saveTables(dataframeAndPathList: List[(DataFrame, String)])
(implicit spark: SparkSession): Unit = {
for ((df, path) <- dataframeAndPathList) {
insertIntoTablePartition(path, df, saveMode = Overwrite, format = Some("hive"))
}
}
}