EventJsonFormatFactory.java

package org.wikimedia.eventutilities.flink.formats.json;

import static org.apache.flink.formats.json.JsonFormatOptions.ENCODE_DECIMAL_AS_PLAIN_NUMBER;
import static org.apache.flink.formats.json.JsonFormatOptions.MAP_NULL_KEY_LITERAL;
import static org.wikimedia.eventutilities.flink.formats.json.EventJsonFormatOptions.EVENT_SCHEMAS_BASE_URIS;
import static org.wikimedia.eventutilities.flink.formats.json.EventJsonFormatOptions.EVENT_STREAM_CONFIG_URI;
import static org.wikimedia.eventutilities.flink.formats.json.EventJsonFormatOptions.FORWARD_OPTIONS;
import static org.wikimedia.eventutilities.flink.formats.json.EventJsonFormatOptions.HTTP_ROUTES;
import static org.wikimedia.eventutilities.flink.formats.json.EventJsonFormatOptions.SCHEMA_VERSION;
import static org.wikimedia.eventutilities.flink.formats.json.EventJsonFormatOptions.STREAM_NAME;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.flink.api.common.serialization.SerializationSchema;
import org.apache.flink.configuration.ConfigOption;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.formats.common.TimestampFormat;
import org.apache.flink.formats.json.JsonFormatFactory;
import org.apache.flink.formats.json.JsonFormatOptions;
import org.apache.flink.formats.json.JsonFormatOptionsUtil;
import org.apache.flink.table.connector.ChangelogMode;
import org.apache.flink.table.connector.format.EncodingFormat;
import org.apache.flink.table.connector.sink.DynamicTableSink;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.factories.DynamicTableFactory;
import org.apache.flink.table.factories.FactoryUtil;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.RowType;
import org.wikimedia.eventutilities.core.event.EventStream;
import org.wikimedia.eventutilities.core.event.EventStreamFactory;
import org.wikimedia.eventutilities.core.event.JsonEventGenerator;

public class EventJsonFormatFactory extends JsonFormatFactory {

    public static final String IDENTIFIER = "event-json";

    @Override
    public String factoryIdentifier() {
        return IDENTIFIER;
    }

    @Override
    public Set<ConfigOption<?>> requiredOptions() {
        Set<ConfigOption<?>> superOptions = super.requiredOptions();
        Set<ConfigOption<?>> options = new HashSet<>(superOptions);

        options.add(STREAM_NAME);
        options.add(SCHEMA_VERSION);
        options.add(EVENT_SCHEMAS_BASE_URIS);
        options.add(EVENT_STREAM_CONFIG_URI);

        return options;
    }

    @Override
    public Set<ConfigOption<?>> optionalOptions() {
        Set<ConfigOption<?>> superOptions = super.optionalOptions();
        Set<ConfigOption<?>> options = new HashSet<>(superOptions);

        options.add(HTTP_ROUTES);
        return options;
    }

    @Override
    public Set<ConfigOption<?>> forwardOptions() {
        Set<ConfigOption<?>> options = super.forwardOptions();
        options.addAll(FORWARD_OPTIONS);
        return options;
    }

    @Override
    public EncodingFormat<SerializationSchema<RowData>> createEncodingFormat(
        DynamicTableFactory.Context context, ReadableConfig formatOptions
    ) {
        FactoryUtil.validateFactoryOptions(this, formatOptions);
        JsonFormatOptionsUtil.validateEncodingFormatOptions(formatOptions);

        TimestampFormat timestampOption = JsonFormatOptionsUtil.getTimestampFormat(formatOptions);
        JsonFormatOptions.MapNullKeyMode mapNullKeyMode =
            JsonFormatOptionsUtil.getMapNullKeyMode(formatOptions);
        String mapNullKeyLiteral = formatOptions.get(MAP_NULL_KEY_LITERAL);

        String streamName = formatOptions.get(STREAM_NAME);
        String schemaVersion = formatOptions.get(SCHEMA_VERSION);
        List<String> eventSchemaBaseUris = formatOptions.get(EVENT_SCHEMAS_BASE_URIS);
        String eventStreamConfigUri = formatOptions.get(EVENT_STREAM_CONFIG_URI);
        Map<String, String> httpRoutes = formatOptions.getOptional(HTTP_ROUTES).orElse(null);

        final boolean encodeDecimalAsPlainNumber =
            formatOptions.get(ENCODE_DECIMAL_AS_PLAIN_NUMBER);

        return new EncodingFormat<SerializationSchema<RowData>>() {
            @Override
            public SerializationSchema<RowData> createRuntimeEncoder(
                DynamicTableSink.Context context,
                DataType consumedDataType
            ) {
                final RowType rowType = (RowType) consumedDataType.getLogicalType();

                final EventStreamFactory eventStreamFactory = EventStreamFactory.from(
                    eventSchemaBaseUris, eventStreamConfigUri, httpRoutes
                );
                final JsonEventGenerator eventGenerator = JsonEventGenerator.builder()
                    .eventStreamConfig(eventStreamFactory.getEventStreamConfig())
                    .schemaLoader(eventStreamFactory.getEventSchemaLoader())
                    .build();
                final EventStream eventStream = eventStreamFactory.createEventStream(streamName);
                final JsonEventGenerator.EventNormalizer generator = eventGenerator
                    .createEventStreamEventGenerator(streamName, eventStream.schemaUri(schemaVersion).toString());

                return new JsonRowDataSerializationSchema(
                    rowType,
                    generator,
                    generator.getObjectMapper(),
                    timestampOption,
                    mapNullKeyMode,
                    mapNullKeyLiteral,
                    encodeDecimalAsPlainNumber);
            }

            @Override
            public ChangelogMode getChangelogMode() {
                // We might want to allow other changelog modes in the future
                return ChangelogMode.insertOnly();
            }
        };
    }
}