JsonRowSerializationSchema.java
package org.wikimedia.eventutilities.flink.formats.json;
import static java.lang.String.format;
import static java.time.format.DateTimeFormatter.ISO_LOCAL_DATE;
import static org.apache.flink.formats.common.TimeFormats.RFC3339_TIMESTAMP_FORMAT;
import static org.apache.flink.formats.common.TimeFormats.RFC3339_TIME_FORMAT;
import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;
import java.io.Serializable;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.ZoneOffset;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.serialization.SerializationSchema;
import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.MapTypeInfo;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.types.Row;
import org.apache.flink.util.WrappingRuntimeException;
import org.jetbrains.annotations.Nullable;
import org.wikimedia.eventutilities.core.event.JsonEventGenerator.EventNormalizer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import lombok.EqualsAndHashCode;
/**
* Serialization schema that serializes an object of Flink types into a JSON bytes.
*
* <p>Serializes the input Flink object into a JSON string and converts it into <code>byte[]</code>.
*
* <p>Result <code>byte[]</code> messages can be deserialized using {@link
* JsonRowDeserializationSchema}.
* (Copied from the flink code-base and adapted)
*/
// Suppress all spotbugs warnings for this class, as it was copy/pasted from upstream Flink.
@SuppressFBWarnings
@SuppressWarnings({"checkstyle:ClassFanoutComplexity", "checkstyle:CyclomaticComplexity"})
@EqualsAndHashCode
public final class JsonRowSerializationSchema implements SerializationSchema<Row> {
private static final long serialVersionUID = -2885556750743978636L;
/** Type information describing the input type. */
private final RowTypeInfo typeInfo;
/** Event generator responsible for fetching the schema and validating the event against it. */
private final Function<Consumer<ObjectNode>, ObjectNode> normalization;
private final SerializationRuntimeConverter runtimeConverter;
private final ObjectMapper mapper;
private JsonRowSerializationSchema(TypeInformation<Row> typeInfo, Function<Consumer<ObjectNode>, ObjectNode> normalization, ObjectMapper mapper) {
checkNotNull(typeInfo, "Type information");
checkArgument(
typeInfo instanceof RowTypeInfo, "Only RowTypeInfo is supported");
this.typeInfo = (RowTypeInfo) typeInfo;
this.runtimeConverter = createConverter(typeInfo);
this.normalization = normalization;
this.mapper = mapper;
}
/**
* The type information supported by this serializer.
*/
public RowTypeInfo getTypeInformation() {
return typeInfo;
}
/** Builder for {@link JsonRowSerializationSchema}. */
@PublicEvolving
public static final class Builder {
private RowTypeInfo typeInfo;
private Function<Consumer<ObjectNode>, ObjectNode> normalization;
private ObjectMapper mapper;
private Builder() {
}
/**
* Sets type information for JSON serialization schema.
*
* @param typeInfo Type information describing the result type. The field names of {@link
* Row} are used to parse the JSON properties.
*/
public Builder withTypeInfo(@Nonnull TypeInformation<Row> typeInfo) {
checkArgument(
typeInfo instanceof RowTypeInfo, "Only SchemaAwareRowTypeInfo is supported");
this.typeInfo = (RowTypeInfo) typeInfo;
return this;
}
/**
* The normalization function to apply on the provided event data.
*/
public Builder withNormalizationFunction(Function<Consumer<ObjectNode>, ObjectNode> normalization) {
this.normalization = checkNotNull(normalization);
return this;
}
/**
* Do not apply any normalization to serialized events.
* Should never be used for producing events to the WMF Event Platform.
*/
public Builder withoutNormalization() {
if (mapper == null) {
mapper = new ObjectMapper();
}
this.normalization = new NoopEventNormalizer(mapper);
return this;
}
public Builder withObjectMapper(ObjectMapper mapper) {
this.mapper = checkNotNull(mapper);
return this;
}
/**
* Finalizes the configuration and checks validity.
*
* @return Configured {@link JsonRowSerializationSchema}
*/
public JsonRowSerializationSchema build() {
checkNotNull(typeInfo, "typeInfo should be set.");
checkNotNull(normalization, "A normalization method must be explicitly set.");
return new JsonRowSerializationSchema(typeInfo, normalization, mapper != null ? mapper : new ObjectMapper());
}
private static final class NoopEventNormalizer implements EventNormalizer {
private final ObjectMapper mapper;
private NoopEventNormalizer(ObjectMapper mapper) {
this.mapper = mapper;
}
@Override
public ObjectNode generateEvent(Consumer<ObjectNode> consumer,
@Nullable Instant eventTime) {
ObjectNode node = mapper.createObjectNode();
consumer.accept(node);
return node;
}
@Override
public ObjectMapper getObjectMapper() {
return mapper;
}
}
}
/** Creates a builder for {@link JsonRowSerializationSchema.Builder}. */
public static Builder builder() {
return new Builder();
}
@SuppressWarnings("checkstyle:IllegalCatch")
@Override
public byte[] serialize(Row row) {
try {
ObjectNode event = normalization.apply(root -> runtimeConverter.convert(mapper, root, row));
return mapper.writeValueAsBytes(event);
} catch (Exception t) {
throw new RuntimeException(
"Could not serialize row '"
+ row
+ "'. "
+ "Make sure that the schema matches the input.",
t);
}
}
/*
Runtime converters
*/
/** Runtime converter that maps between Java objects and corresponding {@link JsonNode}s. */
@FunctionalInterface
private interface SerializationRuntimeConverter extends Serializable {
JsonNode convert(ObjectMapper mapper, JsonNode reuse, Object object);
}
private SerializationRuntimeConverter createConverter(TypeInformation<?> typeInfo) {
// -- BEGIN WMF MODIFICATION --
// dropped nullNode converter
return createConverterForSimpleType(typeInfo)
.orElseGet(
() ->
createContainerConverter(typeInfo)
.orElseGet(this::createFallbackConverter));
// -- END WMF MODIFICATION --
}
private Optional<SerializationRuntimeConverter> createContainerConverter(
TypeInformation<?> typeInfo) {
if (typeInfo instanceof RowTypeInfo) {
return Optional.of(createRowConverter((RowTypeInfo) typeInfo));
} else if (typeInfo instanceof ObjectArrayTypeInfo) {
return Optional.of(
createObjectArrayConverter(
((ObjectArrayTypeInfo) typeInfo).getComponentInfo()));
} else if (typeInfo instanceof BasicArrayTypeInfo) {
return Optional.of(
createObjectArrayConverter(((BasicArrayTypeInfo) typeInfo).getComponentInfo()));
} else if (isPrimitiveByteArray(typeInfo)) {
return Optional.of(
(mapper, reuse, object) -> mapper.getNodeFactory().binaryNode((byte[]) object));
} else if (typeInfo instanceof MapTypeInfo) {
return createMapConverter((MapTypeInfo<?, ?>) typeInfo);
} else {
// Should we fail?
return Optional.empty();
}
}
private Optional<SerializationRuntimeConverter> createMapConverter(MapTypeInfo<?, ?> typeInfo) {
if (!typeInfo.getKeyTypeInfo().getTypeClass().isAssignableFrom(String.class)) {
throw new IllegalArgumentException("Map types must have String keys");
}
SerializationRuntimeConverter mapValueConverter = createConverter(typeInfo.getValueTypeInfo());
return Optional.of((mapper, reuse, object) -> {
Map<String, ?> map = (Map<String, ?>) object;
ObjectNode node = reuse != null ? (ObjectNode) reuse : mapper.createObjectNode();
map.forEach((String k, Object mapValue) -> {
if (mapValue != null) {
node.set(k, mapValueConverter.convert(mapper, null, mapValue));
}
});
return node;
});
}
private boolean isPrimitiveByteArray(TypeInformation<?> typeInfo) {
return typeInfo instanceof PrimitiveArrayTypeInfo
&& ((PrimitiveArrayTypeInfo) typeInfo).getComponentType() == Types.BYTE;
}
private SerializationRuntimeConverter createObjectArrayConverter(
TypeInformation elementTypeInfo) {
SerializationRuntimeConverter elementConverter = createConverter(elementTypeInfo);
return assembleArrayConverter(elementConverter);
}
private SerializationRuntimeConverter createRowConverter(RowTypeInfo typeInfo) {
List<SerializationRuntimeConverter> fieldConverters =
Arrays.stream(typeInfo.getFieldTypes())
.map(this::createConverter)
.collect(Collectors.toList());
return assembleRowConverter(typeInfo.getFieldNames(), fieldConverters);
}
private SerializationRuntimeConverter createFallbackConverter() {
return (mapper, reuse, object) -> {
// for types that were specified without JSON schema
// e.g. POJOs
try {
return mapper.valueToTree(object);
} catch (IllegalArgumentException e) {
throw new WrappingRuntimeException(
format(Locale.ROOT, "Could not convert object: %s", object), e);
}
};
}
private Optional<SerializationRuntimeConverter> createConverterForSimpleType(
TypeInformation<?> simpleTypeInfo) {
if (simpleTypeInfo == Types.VOID) {
return Optional.of((mapper, reuse, object) -> mapper.getNodeFactory().nullNode());
} else if (simpleTypeInfo == Types.BOOLEAN) {
return Optional.of(
(mapper, reuse, object) ->
mapper.getNodeFactory().booleanNode((Boolean) object));
} else if (simpleTypeInfo == Types.STRING) {
return Optional.of(
(mapper, reuse, object) -> mapper.getNodeFactory().textNode((String) object));
} else if (simpleTypeInfo == Types.INT) {
return Optional.of(
(mapper, reuse, object) ->
mapper.getNodeFactory().numberNode((Integer) object));
} else if (simpleTypeInfo == Types.LONG) {
return Optional.of(
(mapper, reuse, object) -> mapper.getNodeFactory().numberNode((Long) object));
} else if (simpleTypeInfo == Types.DOUBLE) {
return Optional.of(
(mapper, reuse, object) -> mapper.getNodeFactory().numberNode((Double) object));
} else if (simpleTypeInfo == Types.FLOAT) {
return Optional.of(
(mapper, reuse, object) -> mapper.getNodeFactory().numberNode((Float) object));
} else if (simpleTypeInfo == Types.SHORT) {
return Optional.of(
(mapper, reuse, object) -> mapper.getNodeFactory().numberNode((Short) object));
} else if (simpleTypeInfo == Types.BYTE) {
return Optional.of(
(mapper, reuse, object) -> mapper.getNodeFactory().numberNode((Byte) object));
} else if (simpleTypeInfo == Types.BIG_DEC) {
return Optional.of(createBigDecimalConverter());
} else if (simpleTypeInfo == Types.BIG_INT) {
return Optional.of(createBigIntegerConverter());
} else if (simpleTypeInfo == Types.SQL_DATE) {
return Optional.of(this::convertDate);
} else if (simpleTypeInfo == Types.SQL_TIME) {
return Optional.of(this::convertTime);
} else if (simpleTypeInfo == Types.SQL_TIMESTAMP) {
return Optional.of(this::convertTimestamp);
} else if (simpleTypeInfo == Types.LOCAL_DATE) {
return Optional.of(this::convertLocalDate);
} else if (simpleTypeInfo == Types.LOCAL_TIME) {
return Optional.of(this::convertLocalTime);
} else if (simpleTypeInfo == Types.LOCAL_DATE_TIME) {
return Optional.of(this::convertLocalDateTime);
// -- BEGIN WMF MODIFICATION --
} else if (simpleTypeInfo == Types.INSTANT) {
return Optional.of(this::convertInstant);
// -- END WMF MODIFICATION --
} else {
return Optional.empty();
}
}
private JsonNode convertLocalDate(ObjectMapper mapper, JsonNode reuse, Object object) {
return mapper.getNodeFactory().textNode(ISO_LOCAL_DATE.format((LocalDate) object));
}
private JsonNode convertDate(ObjectMapper mapper, JsonNode reuse, Object object) {
Date date = (Date) object;
return convertLocalDate(mapper, reuse, date.toLocalDate());
}
private JsonNode convertLocalDateTime(ObjectMapper mapper, JsonNode reuse, Object object) {
return mapper.getNodeFactory()
.textNode(RFC3339_TIMESTAMP_FORMAT.format((LocalDateTime) object));
}
// -- BEGIN WMF MODIFICATION --
private JsonNode convertInstant(ObjectMapper mapper, JsonNode reuse, Object object) {
Instant instant = (Instant) object;
return convertLocalDateTime(mapper, reuse, instant.atZone(ZoneOffset.UTC).toLocalDateTime());
}
// -- END WMF MODIFICATION --
private JsonNode convertTimestamp(ObjectMapper mapper, JsonNode reuse, Object object) {
Timestamp timestamp = (Timestamp) object;
return convertLocalDateTime(mapper, reuse, timestamp.toLocalDateTime());
}
private JsonNode convertLocalTime(ObjectMapper mapper, JsonNode reuse, Object object) {
JsonNodeFactory nodeFactory = mapper.getNodeFactory();
return nodeFactory.textNode(RFC3339_TIME_FORMAT.format((LocalTime) object));
}
private JsonNode convertTime(ObjectMapper mapper, JsonNode reuse, Object object) {
final Time time = (Time) object;
return convertLocalTime(mapper, reuse, time.toLocalTime());
}
private SerializationRuntimeConverter createBigDecimalConverter() {
return (mapper, reuse, object) -> {
// convert decimal if necessary
JsonNodeFactory nodeFactory = mapper.getNodeFactory();
if (object instanceof BigDecimal) {
return nodeFactory.numberNode((BigDecimal) object);
}
return nodeFactory.numberNode(BigDecimal.valueOf(((Number) object).doubleValue()));
};
}
private SerializationRuntimeConverter createBigIntegerConverter() {
return (mapper, reuse, object) -> {
// convert decimal if necessary
JsonNodeFactory nodeFactory = mapper.getNodeFactory();
if (object instanceof BigInteger) {
return nodeFactory.numberNode((BigInteger) object);
}
return nodeFactory.numberNode(BigInteger.valueOf(((Number) object).longValue()));
};
}
private SerializationRuntimeConverter assembleRowConverter(
String[] fieldNames, List<SerializationRuntimeConverter> fieldConverters) {
return (mapper, reuse, object) -> {
ObjectNode node = reuse != null ? (ObjectNode) reuse : mapper.createObjectNode();
Row row = (Row) object;
for (int i = 0; i < fieldNames.length; i++) {
String fieldName = fieldNames[i];
Object rowValue = row.getField(i);
if (rowValue != null) {
node.set(fieldName,
fieldConverters.get(i).convert(mapper, null, rowValue));
}
}
return node;
};
}
private SerializationRuntimeConverter assembleArrayConverter(
SerializationRuntimeConverter elementConverter) {
return (mapper, reuse, object) -> {
ArrayNode node;
// reuse could be a NullNode if last record is null.
if (reuse == null || reuse.isNull()) {
node = mapper.createArrayNode();
} else {
node = (ArrayNode) reuse;
node.removeAll();
}
if (object instanceof List) {
object = ((List<?>) object).toArray();
}
Object[] array = (Object[]) object;
for (Object element : array) {
node.add(elementConverter.convert(mapper, null, element));
}
return node;
};
}
}