SubgraphRuleMapper.scala

package org.wikidata.query.rdf.spark.transform.structureddata.dumps

import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.openrdf.model.{BNode, Resource, Value}
import org.wikidata.query.rdf.common.uri.{Ontology, UrisScheme}
import org.wikidata.query.rdf.tool.subgraph.SubgraphRule.{Outcome, TriplePattern}
import org.wikidata.query.rdf.tool.subgraph.{SubgraphDefinition, SubgraphDefinitions, SubgraphRule}

import java.util
import scala.collection.JavaConverters._
import scala.language.postfixOps

@SuppressWarnings(Array("scala:S117"))
case class Entity(entity_uri: String)
@SuppressWarnings(Array("scala:S117"))
case class MappedRules(entity_uri: String, matched_rules: Array[String])

class SubgraphRuleMapper(urisScheme: UrisScheme, subgraphDefinition: SubgraphDefinitions, subgraphNames: List[String]) {
  private val ENTITY_URI_FIELD = "entity_uri"
  private val MATCHED_RULLES_FIELD = "matched_rules"

  val subgraphs: List[SubgraphDefinition] = subgraphNames.map(subgraphDefinition.getDefinitionByName)
  private val statementEncoder: StatementEncoder = new StatementEncoder()

  /** Build a list of (subgraphName, ruleIndex, ruleDefinition). The ruleIndex is relative to the graph. */
  private val rules: List[(String, Int, SubgraphRule)] = subgraphs map {
    s => (s.getName, s.getRules.asScala)
  } flatMap {
    case (name, rules) =>
      rules.zipWithIndex map {
        case (rule, idx) => (name, idx, rule)
      }
  }

  implicit private val entityEncoder: Encoder[Entity] = Encoders.product[Entity]
  implicit private val mappedRulesEncoder: Encoder[MappedRules] = Encoders.product[MappedRules]

  /**
   * Map subgraph to their corresponding entities by applying the set of rules declared in their respective definition.
   */
  def mapSubgraphs(baseTable: DataFrame): Map[SubgraphDefinition, Dataset[Entity]] = {
    val entityUrisFilter: Column = urisScheme.entityURIs().asScala map {
      uriPrefix => baseTable("context").startsWith("<" + uriPrefix)
    } reduce {
      _ or _
    }

    val entityUris: Dataset[Entity] = baseTable
      .filter(entityUrisFilter)
      .select(baseTable("context").as(ENTITY_URI_FIELD))
      .dropDuplicates()
      .as[Entity]
      .cache()

    val appliedRules: Dataset[MappedRules] = mapRules(baseTable).cache()
    val allEntities = entityUris
      .withColumn(MATCHED_RULLES_FIELD, lit(Array.empty[String]))
      .join(appliedRules, entityUris(ENTITY_URI_FIELD) === appliedRules(ENTITY_URI_FIELD), "left_anti")
      .select(ENTITY_URI_FIELD, MATCHED_RULLES_FIELD)
      .as[MappedRules]
      .union(appliedRules)

    subgraphs map { subgraph =>
      subgraph -> allEntities.filter(filterEntities(subgraph)).select(ENTITY_URI_FIELD).as[Entity]
    } toMap
  }

  private def filterEntities(subgraph: SubgraphDefinition): FilterFunction[MappedRules] = {
    val subgraphRules: List[(String, Outcome)] = rules filter {
      case (name, _, _) => subgraph.getName.equals(name)
    } map {
      case (name, idx, r) => (ruleName(name, idx), r.getOutcome)
    }
    (t: MappedRules) => {
      subgraphRules find {
        case (r, _) => t.matched_rules.contains(r)
      } map {
        case (_, outcome) => outcome == Outcome.pass
      } getOrElse (subgraph.getRuleDefault == Outcome.pass)
    }
  }

  def buildStubs(mappedSubgraphs: Map[SubgraphDefinition, Dataset[Entity]]): Map[SubgraphDefinition, Option[DataFrame]] = {
    mappedSubgraphs map { case (dest, entities) =>
      val stubsDf = mappedSubgraphs filter {
        case (source, _) => source.isStubsSource && !source.equals(dest)
      } map { case (source, sourceEntities) =>
        sourceEntities
          .withColumn("context", sourceEntities("entity_uri"))
          .withColumn("subject", sourceEntities("entity_uri"))
          .withColumn("predicate", lit(statementEncoder.encodeURI(Ontology.QueryService.SUBGRAPH)))
          .withColumn("object", lit(statementEncoder.encode(source.getSubgraphUri)))
          .drop("entity_uri")
      } reduceOption {
        _ union _
      } map (stubs => stubs.join(entities, entities("entity_uri") === stubs("context"), "left_anti"))
      dest -> stubsDf
    }
  }

  /**
   * Map rules to their corresponding entities.
   */
  private def mapRules(baseTable: DataFrame): Dataset[MappedRules] = {
    // build a chained "or" filters with all the rules
    val filteredTable = rules map {
      case (_, _, rule) => buildFilter(baseTable.apply, rule)
    } reduceOption {
      _ or _
    } map {
      baseTable.filter
    } getOrElse baseTable

    // reduce by grouping by entities and merge the array of matched rules
    addRuleColumnAndGroupByEntity(filteredTable, rules)
  }

  /**
   * Encode a ruleName as a string: "$subgraphName[$ruleIndex]".
   */
  private def ruleName(subgraphName: String, idx: Int): String = {
    s"$subgraphName[$idx]"
  }

  /**
   * Add an array column named matched_rules containing the list of rules (encoded using ruleName) that matched.
   * Then group by entity merging the matched_rules array.
   * The resulting dataset has a shape corresponding to MappedRules which is the list of entity URIs and the list of rules
   * that matched for this entity.
   */
  private def addRuleColumnAndGroupByEntity(dataFrame: DataFrame, rules: List[(String, Int, SubgraphRule)]): Dataset[MappedRules] = {
    var df = dataFrame
      .withColumn(MATCHED_RULLES_FIELD, lit(Array[String]()))

    rules foreach {
      case (name, idx, rule) =>
        val filter = buildFilter(dataFrame.apply, rule)
        val ruleId = ruleName(name, idx)
        // if filter matches
        //    return matched_rules + [ruleId]
        // else
        //    return matched_rules
        val appendRulesToMatchedRuleField = when(filter, array_union(df(MATCHED_RULLES_FIELD), lit(Array(ruleId)))).otherwise(df(MATCHED_RULLES_FIELD))
        df = df.withColumn(MATCHED_RULLES_FIELD, appendRulesToMatchedRuleField)
    }

    val flattenDistinct = array_distinct _ compose flatten
    df.withColumn(ENTITY_URI_FIELD, df("context"))
      .drop("context", "subject", "predicate", "object")
      .as[MappedRules]
      .groupBy(ENTITY_URI_FIELD)
      .agg(flattenDistinct(collect_list(MATCHED_RULLES_FIELD)).alias(MATCHED_RULLES_FIELD))
      .as[MappedRules]
  }

  /**
   * Build a filter from a subgraph rule.
   */
  private def buildFilter(colSupplier: String => Column, subgraphRule: SubgraphRule): Column = {
    filterFromTriplePattern(colSupplier, subgraphRule.getPattern)
  }

  /**
   * Build a filter from a triple pattern.
   */
  private def filterFromTriplePattern(colSupplier: String => Column, triplePattern: TriplePattern): Column = {
    filterFromValue(colSupplier("context"), colSupplier("subject"), triplePattern.getSubject, triplePattern.getBindings)
      .and(filterFromValue(colSupplier("context"), colSupplier("predicate"), triplePattern.getPredicate, triplePattern.getBindings))
      .and(filterFromValue(colSupplier("context"), colSupplier("object"), triplePattern.getObject, triplePattern.getBindings))
  }

  /**
   * Filter a rdf resource or literal based on the splitting rules convention.
   * - BNode("?entity"): matches the current entity
   * - BNone("wildcard"): matches any literal or resource
   * - anything else: matches the N3 representation using statementEncoder
   */
  private def filterFromValue(entityColumn: Column, col: Column, value: Value, bindings: util.Map[String, util.Collection[Resource]]): Column = {
    value match {
      case bnode: BNode =>
        if (bnode.getID.equals(SubgraphRule.TriplePattern.ENTITY_BINDING_NAME)) {
          col === entityColumn
        } else if (bnode.getID.equals(TriplePattern.WILDCARD_BNODE_LABEL)) {
          lit(true)
        } else {
          val bindingValues = bindings.get(bnode.getID)
          if (bindingValues == null) {
            throw new UnsupportedOperationException("Unknown binding: " + bnode.getID)
          }
          col.isInCollection(bindingValues.asScala.map(statementEncoder.encode(_)))
        }
      case _ =>
        col === lit(statementEncoder.encode(value))
    }
  }
}