QueryResultRecorder.scala
package org.wikidata.query.rdf.spark.metrics.queries
import com.google.common.hash.{HashFunction, Hashing}
import com.ibm.icu.text.{CollationKey, Collator}
import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.spark.sql.api.java.UDF1
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, Row, functions}
import org.eclipse.jetty.http.{HttpField, HttpHeader}
import org.openrdf.model.impl.ValueFactoryImpl
import org.openrdf.model.{BNode, Value, ValueFactory}
import org.openrdf.query.impl.{ListBindingSet, TupleQueryResultImpl}
import org.openrdf.query.{Binding, BindingSet, TupleQueryResult}
import org.openrdf.rio.ntriples.NTriplesUtil
import org.wikidata.query.rdf.tool.HttpClientUtils
import org.wikidata.query.rdf.tool.rdf.client.RdfClient
import java.io.Serializable
import java.net.URI
import java.nio.charset.StandardCharsets
import java.time.Duration
import java.util.Locale
import java.util.regex.Pattern
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import scala.language.postfixOps
import scala.util.{Failure, Success, Try}
/**
* QueryResultRecorder is a class usable as a spark UDF to run a SPARQL query against and record its output.
*
* The class can be used directly to obtain the QueryResult via the query(sparl String) method.
* Or via a udf using its own apply function:
* {{{
* val recorder = QueryResultRecorder.create("https://query.wikidata.org/sparql");
* val df = ...; // a dataframe with a string column named "query"
* val results = df.select(recorder(df("query")).alias("query_result"))
* }}}
* The dataframe results will have column name "query_result" of type struct with the schema [[QueryResultRecorder.outputStruct]]:
* <ul>
* <li>results: list of maps of var names to N3 value representation of the results</li>
* <li>success: a boolean indicating if the query succeeded or not</li>
* <li>error_msg: a message with the error if the query failed</li>
* <li>exactHash: a SHA256 of the results</li>
* <li>reorderedHash: a SHA256 of the reordered results</li>
* <li>resultSize: the size in number of solutions of the result</li>
* </ul>
*
* NOTE: extra care must be taken when executing this with spark to not overload a SPARQL endpoint.
*
* @param rdfClientBuilder the RDFClient supplier
*/
@SerialVersionUID(1)
class QueryResultRecorder(rdfClientBuilder: () => RdfClient) extends Serializable {
@transient
private lazy val hashFunction: HashFunction = Hashing.sha256()
@transient
private lazy val client: RdfClient = rdfClientBuilder()
@transient
private lazy val valueFactory: ValueFactory = new ValueFactoryImpl()
@transient
private lazy val udf: UserDefinedFunction = {
def udfFunc: UDF1[String,Row] = (query: String) => {
this.query(query) match {
case QueryResult(Some(results), true, msg, Some(size), Some(exactHash), Some(reorderedHash)) =>
Row(results, true, msg.orNull, size, exactHash, reorderedHash)
case QueryResult(results, false, Some(msg), size, exactHash, reorderedHash) =>
Row(results.orNull, false, msg, size.orNull, exactHash.orNull, reorderedHash.orNull)
case x: QueryResult => throw new IllegalArgumentException("Invalid QueryResult for query " + query + ": " + x)
}
}
functions.udf(udfFunc, QueryResultRecorder.outputStruct)
}
private lazy val collator: Collator = {
val collator = Collator.getInstance(Locale.ROOT)
// Set tertiary strength (same as blazegraph, for context see T233204)
collator.setStrength(Collator.TERTIARY)
collator
}
def apply(column: Column): Column = {
udf(column)
}
private def fromAsk(result: Boolean): QueryResult = {
val names = List("b").asJava
fromResult(new TupleQueryResultImpl(names, List(new ListBindingSet(names, valueFactory.createLiteral(result))).asJava))
}
def query(query: String): QueryResult = {
def execute(query: String): QueryResult = {
QueryResultRecorder.getQueryType(query) match {
case "SELECT" => fromResult(client.query(query))
case "ASK" => fromAsk(client.ask(query))
case "DESCRIBE" => fromResult(client.describe(query))
case "CONSTRUCT" => fromResult(client.construct(query))
}
}
Try(execute(query)) match {
case Success(result) => result
case Failure(e) =>
QueryResult(results = None, success = false,
Some(e.getMessage + ": " + ExceptionUtils.getStackTrace(e)), resultSize = None, exactHash = None, reorderedHash = None)
}
}
private def fromResult(result: TupleQueryResult) = {
val (bindings, solutions) = copyResult(result)
val sortedBindings = bindings.sorted
QueryResult(
results = Some(QueryResultRecorder.encodeQueryResults(solutions)),
success = true,
errorMessage = None,
resultSize = Some(solutions.length),
exactHash = Some(hashSolutions(bindings, solutions)),
reorderedHash = Some(hashSolutions(sortedBindings, solutions.sorted(sortBindingSet(sortedBindings))))
)
}
private def sortBindingSet(bindings: Array[String]): Ordering[BindingSet] = {
(x: BindingSet, y: BindingSet) => {
bindings map {
b => (Option(x.getBinding(b)), Option(y.getBinding(b)))
} map {
case (None, None) => 0
case (Some(_), None) => 1 // left|null > null
case (None, Some(_)) => -1 // null < right
case (Some(bvx), Some(bvy)) => // actual comparison
compareBinding(bvx, bvy)
} collectFirst {
case cmp: Int if cmp != 0 => cmp
} getOrElse 0
}
}
private def compareBinding(bvx: Binding, bvy: Binding) = {
(bvx.getValue, bvy.getValue) match {
// consider two blank nodes as equal, they're assigned by blazegraph while generating the results
// if we re-order there's no point in keeping the specific blank node id.
// Risk of missed positives seems low enough compared to the noise caused by false positives.
case (_: BNode, _: BNode) => 0
case (_: BNode, _: Value) => -1
case (_: Value, _: BNode) => 1
case (vx: Value, vy: Value) =>
vx.getClass.getName.compareTo(vy.getClass.getName) match {
case 0 =>
encodeValue(vx).compareTo(encodeValue(vy))
case x: Int => x
}
}
}
private def hashSolutions(bindings: Array[String], solutions: Iterable[BindingSet]): String = {
val hasher = hashFunction.newHasher()
bindings foreach {
hasher.putString(_, StandardCharsets.UTF_8)
}
solutions foreach { bs =>
bindings map bs.getBinding map (Option(_)) foreach {
case Some(b) =>
b.getValue match {
case _: BNode => hasher.putString("_BLANK")
case v: Value =>
hasher.putString(v.getClass.getName, StandardCharsets.UTF_8)
hasher.putBytes(encodeValue(v).toByteArray)
}
case None =>
hasher.putInt(-1)
}
}
hasher.hash().toString
}
private def encodeValue(v: Value): CollationKey = {
collator.getCollationKey(v.toString)
}
private def copyResult(result: TupleQueryResult): (Array[String], Array[BindingSet]) = {
val bindings = result.getBindingNames.toArray(Array[String]())
val solutions: ListBuffer[BindingSet] = ListBuffer()
while (result.hasNext) {
solutions += result.next()
}
(bindings, solutions.toArray)
}
}
@SerialVersionUID(1)
object QueryResultRecorder {
private val IRI_PATTERN = Pattern.compile("^<([^>]*)>*")
private val PREFIX_PATTERN = Pattern.compile("^prefix([^:]+):", Pattern.CASE_INSENSITIVE)
private val COMMENT_PATTERN = Pattern.compile("^(#.*((\r)?\n|(\r)?\n*))*")
private val QUERY_TYPE = Pattern.compile("^(SELECT|ASK|CONSTRUCT|DESCRIBE)", Pattern.CASE_INSENSITIVE)
def create(endpoint: String, uaSuffix: String): QueryResultRecorder = {
new QueryResultRecorder(() => {
val timeout = Duration.ofSeconds(65)
val httpClient = HttpClientUtils.buildHttpClient(None.orNull, None.orNull)
httpClient.setUserAgentField(new HttpField(HttpHeader.USER_AGENT, "QueryResultRecorder (org.wikidata.query.rdf:rdf-spark-tools) bot " + uaSuffix))
new RdfClient(httpClient, URI.create(endpoint), HttpClientUtils.buildHttpClientRetryer(), timeout, 16*1024*1024)
})
}
def encodeQueryResults(solutions: Array[BindingSet]): Array[Map[String, String]] = {
solutions map { s =>
s.asScala map { b => b.getName -> NTriplesUtil.toNTriplesString(b.getValue) } toMap
}
}
private val resultList: ArrayType = ArrayType(MapType(DataTypes.StringType, DataTypes.StringType, valueContainsNull = true))
/**
* The schema of the output of the UDF
*/
val outputStruct: StructType = new StructType(Array(
StructField(name = "results", dataType = resultList),
StructField(name = "success", dataType = DataTypes.BooleanType),
StructField(name = "error_msg", dataType = DataTypes.StringType),
StructField(name = "resultSize", dataType = DataTypes.IntegerType),
StructField(name = "exactHash", dataType = DataTypes.StringType),
StructField(name = "reorderedHash", dataType = DataTypes.StringType)
))
// scalastyle:off cyclomatic.complexity
def getQueryType(input: CharSequence): CharSequence = {
var i = 0
var restOfQuery: CharSequence = ""
while (i < input.length) {
val c = input.charAt(i)
c match {
case '#' =>
i += readComment(input, i)
case 'p' | 'P' =>
// read PREFIX
i += readPrefix(input, i)
case 'b' | 'B' =>
i += 4 // 4 for base keyword
case '<' =>
// read IRI
i += readIRI(input, i)
case _ =>
if (Character.isWhitespace(c)) {
i += 1
} else {
restOfQuery = input.subSequence(i, input.length)
i += restOfQuery.length
}
}
}
val m = QUERY_TYPE.matcher(restOfQuery)
if (m.find()) {
m.group().toUpperCase(Locale.ROOT)
} else {
"SELECT"
}
}
// scalastyle:on cyclomatic.complexity
/**
* Reads the first comment line from the input, and returns
* the comment line (including the line break character) without
* the leading "#".
*/
private def readComment(input: CharSequence, index: Int): Int = {
val matcher = COMMENT_PATTERN.matcher(input.subSequence(index, input.length()))
if (matcher.find()) {
matcher.end()
} else {
1
}
}
private def readPrefix(input: CharSequence, index: Int): Int = {
val matcher = PREFIX_PATTERN.matcher(input.subSequence(index, input.length()))
if (matcher.find()) {
matcher.end()
} else {
1
}
}
private def readIRI(input: CharSequence, index: Int): Int = {
val matcher = IRI_PATTERN.matcher(input.subSequence(index, input.length()))
if (matcher.find()) {
matcher.end()
} else {
1
}
}
}
case class QueryResult(results: Option[Array[Map[String, String]]],
success: Boolean,
errorMessage: Option[String],
resultSize: Option[Int],
exactHash: Option[String],
reorderedHash: Option[String]
)