SubgraphMapper.scala
package org.wikidata.query.rdf.spark.transform.structureddata.subgraphs
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions.{broadcast, col, lit, rand}
import org.wikidata.query.rdf.common.uri.{DefaultUrisScheme, PropertyType, UrisSchemeFactory}
import org.wikidata.query.rdf.spark.utils.{SparkUtils, SubgraphUtils}
/**
* Maps items and triples to subgraphs. Here `item` means any entity in Wikidata.
* An item is part a subgraph if it is instance of (P31) the subgraph entity.
* All triples (direct or statements) originating from those items are the subgraph triples.
*
* @param wikidataTriples expected columns: context, subject, predicate, object
*/
class SubgraphMapper(wikidataTriples: DataFrame) {
val scheme: DefaultUrisScheme = UrisSchemeFactory.WIKIDATA
val p31: String = scheme.property(PropertyType.DIRECT) + "P31"
/**
* Lists all subgraphs in Wikidata
*
* @return spark dataframe containing list of all subgraphs with columns: subgraph, count
*/
def getAllSubgraphs(): DataFrame = {
wikidataTriples
.filter(s"predicate='<$p31>'")
.selectExpr("object as subgraph")
.groupBy("subgraph")
.count()
}
/**
* Performs a right join when the left side is skewed. The right side must be small
* enough to be collected and broadcast to executors.
*
* This cannot use a standard broadcast join. In a broadcast right join the left
* dataset needs to be broadcasted, but in our case the right side is the one small
* enough to be broadcasted.
*
* Instead we perform an inner join which can operate off a broadcast hash join, and
* an anti join after pruning the left down to only the set of distinct subgraphs.
*
* This is specialized to the exact dataframes being used here, and is not a generic
* right-join with skew implementation.
*/
def rightSkewJoin(left: DataFrame, right: DataFrame): DataFrame = {
// Using a broadcast join will avoid skew problems here, left wont even need to be shuffled.
val inner = left.join(broadcast(right), Seq("subgraph"), "inner")
// Then we need to union in the rows in the right dataset that don't match the left dataset.
// Thankfully the right dataset is ~1MB and easy to deal with.
// First we prepare the set of subgraphs that exist in the left dataset. We do a two-pass
// distinct with a salt to deal with the skew, avoiding sending the largest subgraphs
// to a single executor.
val salt = (rand(0) * 10).cast("int").alias("salt")
val leftSubgraphs = left
.select(col("subgraph"), salt)
.distinct()
.drop("salt")
.distinct()
// Then perform an anti join from the right side. This will give us all subgraphs in right that
// are not referenced in left. The anti join doesn't add in the extra null columns, so we manually
// add the "item" column that would have been set to null in a standard right join.
// scalastyle:off null
val rightNotLeft = right.join(leftSubgraphs, right("subgraph").equalTo(leftSubgraphs("subgraph")), "anti")
.withColumn("item", lit(null).cast(left.schema("item").dataType))
// scalastyle:on null
inner.union(rightNotLeft)
}
/**
* Maps all items to one or more of the top subgraphs
*
* @param topSubgraphs expected columns: subgraph, count
* @param minItems the minimum number of items a subgraph should have to be called a top_subgraph
* @return spark dataframes with columns: subgraph, item
*/
def getTopSubgraphItems(allSubgraphs: DataFrame, minItems: Long): DataFrame = {
val topSubgraphs = allSubgraphs
.filter(col("count") >= minItems)
.drop("count")
val filteredTriples = wikidataTriples
.filter(s"predicate='<$p31>'")
.selectExpr("object as subgraph", "subject as item")
rightSkewJoin(filteredTriples, topSubgraphs)
}
/**
* Maps all triples to one or more the top subgraphs. Does this by listing all triples under the items
* that were identified as being part of a subgraph. Here predicate_code means the last part of the
* predicate URI. For wikidata predicates, it would be the P-id.
* See [[SubgraphUtils.extractItem]] for the extraction process.
*
* @param topSubgraphItems expected columns: subgraph, item
* @return spark dataframes with columns: subgraph, subject, predicate, object, predicate_code
*/
def getTopSubgraphTriples(topSubgraphItems: DataFrame): DataFrame = {
wikidataTriples
.select("context", "subject", "predicate", "object")
.join(topSubgraphItems, wikidataTriples("context") === topSubgraphItems("item"), "inner")
.drop("context")
.withColumn("predicate_code", SubgraphUtils.extractItem(col("predicate"), lit("/")))
}
}
object SubgraphMapper {
/**
* When the same data is referenced multiple times in a pipeline spark will recompute it
* each time. Used to dataframes that need to both write to disk and be used in later
* computation, so they are only computed once.
*/
private def saveAndReload(df: DataFrame, path: String, numPartitions: Int)(implicit spark: SparkSession): DataFrame = {
SparkUtils.saveTables((df.repartition(numPartitions), path) :: Nil)
SparkUtils.readTablePartition(path)
}
/**
* Reads input table, calls getSubgraphMapping(...) to extract subgraph mapping, and saves output tables
*/
def extractAndSaveSubgraphMapping(wikidataTriplesPath: String,
minItems: Long,
allSubgraphsPath: String,
topSubgraphItemsPath: String,
topSubgraphTriplesPath: String): Unit = {
implicit val spark: SparkSession = SparkUtils.getSparkSession("SubgraphMapper")
val subgraphMapper = new SubgraphMapper(SparkUtils.readTablePartition(wikidataTriplesPath))
// 1 partition since data is ~1mb
val allSubgraphs = saveAndReload(subgraphMapper.getAllSubgraphs(), allSubgraphsPath, 1)
// data is ~800mb
val topSubgraphItems = saveAndReload(
subgraphMapper
.getTopSubgraphItems(allSubgraphs, minItems),
topSubgraphItemsPath, 8)
// data is ~340gb
val topSubgraphTriples = subgraphMapper
.getTopSubgraphTriples(topSubgraphItems)
.repartition(3000)
SparkUtils.saveTables((topSubgraphTriples, topSubgraphTriplesPath) :: Nil)
}
}