EventRowSerializer.java

package org.wikimedia.eventutilities.flink;

import static java.util.Objects.requireNonNull;
import static org.apache.flink.api.java.typeutils.runtime.MaskUtils.readIntoAndCopyMask;
import static org.apache.flink.api.java.typeutils.runtime.MaskUtils.readIntoMask;
import static org.apache.flink.api.java.typeutils.runtime.MaskUtils.writeMask;
import static org.apache.flink.util.Preconditions.checkNotNull;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;

import javax.annotation.Nullable;
import javax.annotation.ParametersAreNonnullByDefault;
import javax.annotation.concurrent.NotThreadSafe;

import org.apache.flink.api.common.typeutils.CompositeTypeSerializerUtil;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
import org.apache.flink.api.common.typeutils.TypeSerializerUtils;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;
import org.apache.flink.types.RowUtils;

import com.google.common.collect.Sets;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;

/**
 * Row serializer that supports simple schema evolution as defined by the WMF event platform.
 *
 * <br>
 * General notes about flink Serializers:
 * <ul>
 *     <li>They must be {@link java.io.Serializable} because they are constructed when the stream graph is built and
 *     thus must be transferred over the wire to the taskmanager nodes. This serialization form is not stored durably
 *     and so we do not have to bother about reading another version of this same class.</li>
 *     <li>At runtime the Rows transferred from one operator to another will be serialized using
 *     {@link #serialize(Row, DataOutputView)} and deserialized using {@link #deserialize(DataInputView)}</li>
 *     <li>When Rows have to be stored in a durable state the shape of this serializer must be stored in the header of
 *     this state. This is done by first extracting the {@link #snapshotConfiguration()} of this serializer and then
 *     storing it using {@link EventRowSerializerSnapshot#writeSnapshot(DataOutputView)}.</li>
 *     <li>When restoring a state (possibly generated by a previous version of this serializer) flink will read
 *     back the {@link EventRowSerializerSnapshot} that was previously stored and determine the level
 *     of compatibility of the previous version with the runtime version of the serializer using
 *     {@link EventRowSerializerSnapshot#resolveSchemaCompatibility(TypeSerializer)}.</li>
 * </ul>
 * <br>
 *
 * This serializer is meant to only support rows that have named positions, for instance rows created from
 * {@link EventRowTypeInfo#createEmptyRow()}. This is the main difference with the upstream RowSerializer that
 * supports either position based or named based row, position based rows cannot support migration with in
 * compatibleAfterMigration mode and named based rows are not usable everywhere (the JSON serializer expects positions).
 *
 *  The nature of aliased partition keys carried by {@link EventRowTypeInfo#keyTypeInfo()} does not affect how events
 *  are serialized. Keys do not have to be taken into consideration when evaluating if stored events
 *  are compatible with a newer version of this class or its underlying schema. They will be omitted
 *  from {@link EventRowSerializerSnapshot} configurations.
 * <br>
 *
 * Compatibility levels:
 * <ul>
 *     <li>
 *        if the fields are exactly the same and are at the same position compatibility is delegated to the
 *        worse of the field's serializers compatibility, from worse to best: incompatible,
 *        compatibleAfterMigration, compatibleWithReconfiguredSerializer, compatibleAsIs.
 *     </li>
 *     <li>
 *        if new fields are present and/or if the field positions changed we migrate the state returning
 *        compatibleAfterMigration (or incompatible if one of the field's serializer is incompatible)
 *     </li>
 *     <li>
 *        in any other case we return incompatible (e.g. a field is missing)
 *     </li>
 * </ul>
 *
 * Note on the migration: migrating a state (compatibleAfterMigration) is done when restoring the state, all the state
 * is then read with the old version of the serializer and written back using {@link #serialize(Row, DataOutputView)}.
 * In other words this means that the (@link #serialize} function must accept Rows created from a different serializer
 * (with possibly missing fields and/or fields at different position) but the serialization format must be the same as
 * if a properly constructed Row was passed.
 */
@NotThreadSafe
@SuppressFBWarnings(value = {"PL_PARALLEL_LISTS", "SE_NO_SERIALVERSIONID"},
        justification = "reusableArray is mutable fieldSerializers is not, " +
                "never serialized in app state")
@ParametersAreNonnullByDefault
public class EventRowSerializer extends TypeSerializer<Row> {
    private final TypeSerializer<Object>[] fieldSerializers;

    private final LinkedHashMap<String, Integer> positionByName;

    /**
     * Transient array used as a buffer when serializing/deserializing Rows.
     * It ends being stored as a bitset prior to the fields data and holds the nullability of the fields
     * (it's more space efficient to store a bitset than a boolean byte is_null before every field).
     */
    private transient boolean[] mask;

    /**
     * Transient array used as a buffer when serializing.
     * Mainly here for optimization purposes as we have to loop over the fields multiple times. And
     * since field names have to be checked to support state schema evolution it permits to do a
     * single pass when using names (slower) and subsequent passes (nullability check, writing) are done
     * iterating on this array.
     */
    private transient Object[] reusableArray;

    @SuppressWarnings("unchecked")
    public EventRowSerializer(
            TypeSerializer<?>[] fieldSerializers,
            LinkedHashMap<String, Integer> positionByName) {
        this.fieldSerializers = (TypeSerializer<Object>[]) checkNotNull(fieldSerializers);
        this.positionByName = checkNotNull(positionByName);
        this.mask = new boolean[fieldSerializers.length];
        this.reusableArray = new Object[fieldSerializers.length];
    }

    @Override
    public boolean isImmutableType() {
        return false;
    }

    @Override
    public TypeSerializer<Row> duplicate() {
        TypeSerializer<?>[] duplicateFieldSerializers = new TypeSerializer[fieldSerializers.length];
        for (int i = 0; i < fieldSerializers.length; i++) {
            duplicateFieldSerializers[i] = fieldSerializers[i].duplicate();
        }
        return new EventRowSerializer(duplicateFieldSerializers, positionByName);
    }

    @Override
    public Row createInstance() {
        return RowUtils.createRowWithNamedPositions(
                RowKind.INSERT, new Object[fieldSerializers.length], positionByName);
    }

    @Override
    public Row copy(Row from) {
        Object[] values = new Object[fieldSerializers.length];
        if (from.getFieldNames(true) == null) {
            throw new IllegalArgumentException("Copied Row should support named based access");
        }
        for (Map.Entry<String, Integer> en : positionByName.entrySet()) {
            Object value = from.getField(en.getKey());
            values[en.getValue()] = value != null ? fieldSerializers[en.getValue()].copy(value) : null;
        }
        return RowUtils.createRowWithNamedPositions(
                from.getKind(), values, positionByName);
    }

    @Override
    public Row copy(Row from, @Nullable Row reuse) {
        // cannot reuse, do a non-reuse copy
        if (reuse == null) {
            return copy(from);
        }

        reuse.setKind(from.getKind());
        for (Map.Entry<String, Integer> en : positionByName.entrySet()) {
            Object fieldReuse = reuse.getField(en.getValue());
            Object value = from.getField(en.getKey());
            reuse.setField(en.getValue(), value != null ? fieldSerializers[en.getValue()].copy(value, fieldReuse) : null);
        }
        return reuse;
    }

    @Override
    public int getLength() {
        return -1;
    }

    @Override
    public void serialize(Row row, DataOutputView target) throws IOException {
        // when migrating we will receive events from a previous hopefully compatible schema
        // we have to ignore missing fields in the old schema
        Set<String> recordFieldNames = row.getFieldNames(true);
        if (recordFieldNames == null) {
            throw new IllegalArgumentException("record must support named based access");
        }

        target.writeByte(row.getKind().toByteValue());
        Arrays.fill(reusableArray, null);
        for (Map.Entry<String, Integer> en : positionByName.entrySet()) {
            if (recordFieldNames.contains(en.getKey())) {
                reusableArray[en.getValue()] = row.getField(en.getKey());
            }
        }

        for (int i = 0; i < reusableArray.length; i++) {
            mask[i] = reusableArray[i] == null;
        }
        writeMask(mask, target);
        for (int i = 0; i < reusableArray.length; i++) {
            Object fieldValue = reusableArray[i];
            if (fieldValue != null) {
                fieldSerializers[i].serialize(fieldValue, target);
            }
        }
    }

    @Override
    public Row deserialize(DataInputView source) throws IOException {
        // read row kind
        final RowKind kind;
        kind = RowKind.fromByteValue(source.readByte());

        // read bitmask
        readIntoMask(source, mask);
        // deserialize fields
        final int length = fieldSerializers.length;
        final Object[] fieldByPosition = new Object[length];
        for (int fieldPos = 0; fieldPos < length; fieldPos++) {
            if (!mask[fieldPos]) {
                fieldByPosition[fieldPos] = fieldSerializers[fieldPos].deserialize(source);
            }
        }

        return RowUtils.createRowWithNamedPositions(kind, fieldByPosition, positionByName);
    }

    @Override
    public Row deserialize(@Nullable Row reuse, DataInputView source) throws IOException {
        // reuse uses name-based field mode, do a non-reuse deserialize
        if (reuse == null || reuse.getFieldNames(false) != null) {
            return deserialize(source);
        }
        final int length = fieldSerializers.length;

        if (reuse.getArity() != length) {
            throw new RuntimeException(
                    "Row arity of reuse ("
                            + reuse.getArity()
                            + ") does not match "
                            + "this serializer's field length ("
                            + length
                            + ").");
        }

        reuse.setKind(RowKind.fromByteValue(source.readByte()));
        // read bitmask
        readIntoMask(source, mask);

        // deserialize fields
        for (int fieldPos = 0; fieldPos < length; fieldPos++) {
            if (mask[fieldPos]) {
                reuse.setField(fieldPos, null);
            } else {
                Object reuseField = reuse.getField(fieldPos);
                if (reuseField != null) {
                    reuse.setField(
                            fieldPos, fieldSerializers[fieldPos].deserialize(reuseField, source));
                } else {
                    reuse.setField(fieldPos, fieldSerializers[fieldPos].deserialize(source));
                }
            }
        }

        return reuse;
    }

    @Override
    public void copy(DataInputView source, DataOutputView target) throws IOException {
        int len = fieldSerializers.length;
        target.writeByte(source.readByte());
        // copy bitmask
        readIntoAndCopyMask(source, target, mask);
        // copy row kind
        // copy non-null fields
        for (int fieldPos = 0; fieldPos < len; fieldPos++) {
            if (!mask[fieldPos]) {
                fieldSerializers[fieldPos].copy(source, target);
            }
        }
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || getClass() != o.getClass()) {
            return false;
        }
        EventRowSerializer that = (EventRowSerializer) o;
        return Arrays.equals(fieldSerializers, that.fieldSerializers) && positionByName.equals(that.positionByName);
    }

    @Override
    public int hashCode() {
        return 31 + Arrays.hashCode(fieldSerializers) + positionByName.hashCode();
    }

    /**
     * Called during <a href="https://docs.oracle.com/javase/7/docs/api/java/io/ObjectInputStream.html">java deserialization</a>.
     * We need this because we have two transient fields, constructor and field initial values are not applied when
     * restoring an object and thus these two fields would remain null if we did not explicitly set them when restoring
     * this object state.
     */
    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        this.mask = new boolean[fieldSerializers.length];
        this.reusableArray = new Object[fieldSerializers.length];
    }

    @Override
    public TypeSerializerSnapshot<Row> snapshotConfiguration() {
        return new EventRowSerializerSnapshot(this.fieldSerializers, this.positionByName);
    }

    LinkedHashMap<String, Integer> getPositionByName() {
        return this.positionByName;
    }

    /**
     * The serialiazer snapshot, this object is meant to represent the state of a particular {@link EventRowSerializer}.
     * It is able to durably store this state (reason why it allows specifying a version).
     * When storing
     * {@link TypeSerializerSnapshotSerializationUtil#writeSerializerSnapshot(DataOutputView, TypeSerializerSnapshot)}
     * is used and will store
     * <ul>
     *     <li>the classname</li>
     *     <li>its version with {@link #getCurrentVersion()}</li>
     *     <li>its data with {@link #writeSnapshot(DataOutputView)}</li>
     * </ul>
     * When restoring using
     * {@link TypeSerializerSnapshotSerializationUtil#readSerializerSnapshot(DataInputView, ClassLoader)}
     * it will:
     * <ul>
     *     <li>instantiate the {@link TypeSerializerSnapshot} using the classname stored (using the empty ctor)</li>
     *     <li>read the version</li>
     *     <li>restore the state {@link #readVersionedSnapshot(DataInputView, ClassLoader)}</li>
     * </ul>
     *
     * This class must then be able to restore the serializer able to deserialize the data that it had previously
     * written with {@link #restoreSerializer()}.
     * And also determine its level of compatibility with the runtime serializer with {@link #resolveSchemaCompatibility(TypeSerializerSnapshot)}.
     */
    public static class EventRowSerializerSnapshot implements TypeSerializerSnapshot<Row> {
        static final int VERSION = 1;
        private TypeSerializerSnapshot<?>[] fieldSnapshots;
        private LinkedHashMap<String, Integer> positionByNames;

        public EventRowSerializerSnapshot() {}

        public EventRowSerializerSnapshot(TypeSerializer<?>[] fieldSerializers, LinkedHashMap<String, Integer> positionByNames) {
            this.fieldSnapshots = Arrays.stream(requireNonNull(fieldSerializers))
                    .map(TypeSerializer::snapshotConfiguration)
                    .toArray(TypeSerializerSnapshot[]::new);
            this.positionByNames = requireNonNull(positionByNames);
            if (fieldSerializers.length != positionByNames.size()) {
                throw new IllegalArgumentException("fieldSerializers and positionByNames must have the same size");
            }
        }
        @Override
        public int getCurrentVersion() {
            return VERSION;
        }

        @Override
        public void writeSnapshot(DataOutputView out) throws IOException {
            out.writeInt(positionByNames.size());
            for (Map.Entry<String, Integer> en : positionByNames.entrySet()) {
                out.writeUTF(en.getKey());
                out.writeInt(en.getValue());
            }
            for (TypeSerializerSnapshot<?> snap : fieldSnapshots) {
                TypeSerializerSnapshot.writeVersionedSnapshot(out, snap);
            }
        }

        @Override
        public void readSnapshot(int readVersion, DataInputView in, ClassLoader userCodeClassLoader) throws IOException {
            switch (readVersion) {
                case 1:
                    readV1(in, userCodeClassLoader);
                    break;
                default:
                    throw new IllegalStateException("Unsupported version " + readVersion);
            }
        }
        public void readV1(DataInputView in, ClassLoader userCodeClassLoader) throws IOException {
            positionByNames = new LinkedHashMap<>();
            int size = in.readInt();
            for (int i = 0; i < size; i++) {
                positionByNames.put(in.readUTF(), in.readInt());
            }
            fieldSnapshots = new TypeSerializerSnapshot<?>[size];
            for (int i = 0; i < fieldSnapshots.length; i++) {
                fieldSnapshots[i] = TypeSerializerSnapshot.readVersionedSnapshot(in, userCodeClassLoader);
            }
        }

        @Override
        public TypeSerializer<Row> restoreSerializer() {
            return new EventRowSerializer(
                    Arrays.stream(fieldSnapshots)
                            .map(TypeSerializerSnapshot::restoreSerializer)
                            .toArray(TypeSerializer<?>[]::new),
                    positionByNames);
        }

        @Override
        // Bound wildcards to Row to avoid clashing methods with the same erasure.
        public TypeSerializerSchemaCompatibility<Row> resolveSchemaCompatibility(TypeSerializerSnapshot<Row> oldSerializerSnapshot) {
            EventRowSerializer oldSer = (EventRowSerializer) oldSerializerSnapshot.restoreSerializer();
            if (!oldSer.getClass().equals(EventRowSerializer.class)) {
                return TypeSerializerSchemaCompatibility.incompatible();
            }
            LinkedHashMap<String, Integer> oldPositionByName = oldSer.getPositionByName();

            if (oldSer.fieldSerializers.length == fieldSnapshots.length && positionByNames.equals(oldPositionByName)) {
                return resolveCompatWithSameFields(TypeSerializerUtils.snapshot(oldSer.fieldSerializers));
            } else {
                return resolveCompatWithDifferentFields(oldSer.fieldSerializers, oldPositionByName);
            }
        }

        private TypeSerializerSchemaCompatibility<Row> resolveCompatWithDifferentFields(TypeSerializer<?>[] oldFieldSerializers,
                                                                                        LinkedHashMap<String, Integer> oldPositionByNames) {
            Sets.SetView<String> removedFields = Sets.difference(oldPositionByNames.keySet(), positionByNames.keySet());
            if (!removedFields.isEmpty()) {
                // Fields were removed, this is not considered a valid schema upgrade, schema utilities should not have
                // allowed this or the caller is mixing up incompatible schemas.
                return TypeSerializerSchemaCompatibility.incompatible();
            }
            for (Map.Entry<String, Integer> en : oldPositionByNames.entrySet()) {
                Integer newPosition = positionByNames.get(en.getKey());
                if (newPosition == null) {
                    throw new IllegalStateException("positionByNames or oldPositionByName must have changed, all entries in " +
                            "positionByName should be present in oldPositionByNames");
                }
                TypeSerializerSnapshot<?> newFieldSerSnapshot = fieldSnapshots[newPosition];
                TypeSerializerSnapshot<?> oldFieldSerSnapshot = oldFieldSerializers[en.getValue()].snapshotConfiguration();
                TypeSerializerSchemaCompatibility<?> compatibility = resolveSubFieldCompat(newFieldSerSnapshot, oldFieldSerSnapshot);
                if (compatibility.isIncompatible()) {
                    return TypeSerializerSchemaCompatibility.incompatible();
                }
            }
            return TypeSerializerSchemaCompatibility.compatibleAfterMigration();
        }

        private TypeSerializerSchemaCompatibility<Row> resolveCompatWithSameFields(TypeSerializerSnapshot<?>[] oldFieldSnapshots) {
            CompositeTypeSerializerUtil.IntermediateCompatibilityResult<Object> intermediateCompatibilityResult = CompositeTypeSerializerUtil
                    .constructIntermediateCompatibilityResult(fieldSnapshots, oldFieldSnapshots);
            if (intermediateCompatibilityResult.isCompatibleAsIs()) {
                return TypeSerializerSchemaCompatibility.compatibleAsIs();
            }
            if (intermediateCompatibilityResult.isCompatibleAfterMigration()) {
                return TypeSerializerSchemaCompatibility.compatibleAfterMigration();
            }
            if (intermediateCompatibilityResult.isCompatibleWithReconfiguredSerializer()) {
                return TypeSerializerSchemaCompatibility.compatibleWithReconfiguredSerializer(
                        new EventRowSerializer(intermediateCompatibilityResult.getNestedSerializers(), positionByNames));
            }
            // very weird...
            return TypeSerializerSchemaCompatibility.incompatible();
        }

        public <E> TypeSerializerSchemaCompatibility<E> resolveSubFieldCompat(TypeSerializerSnapshot<?> newFieldSnapshot,
                                                                              TypeSerializerSnapshot<?> oldFieldSnapshot) {
            TypeSerializerSnapshot<E> typedOldFieldSnapshot = (TypeSerializerSnapshot<E>) oldFieldSnapshot;
            TypeSerializerSnapshot<E> typedNewFieldSnapshot = (TypeSerializerSnapshot<E>) newFieldSnapshot;

            return typedNewFieldSnapshot.resolveSchemaCompatibility(typedOldFieldSnapshot);
        }
    }
}