/*
 * Decompiled with CFR 0.152.
 */
package org.keycloak.spi.infinispan.impl.embedded;

import jakarta.persistence.EntityManager;
import java.lang.invoke.MethodHandles;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import javax.net.ssl.KeyManager;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.TrustManager;
import org.infinispan.commons.configuration.attributes.Attribute;
import org.infinispan.configuration.global.GlobalConfiguration;
import org.infinispan.configuration.global.TransportConfiguration;
import org.infinispan.configuration.global.TransportConfigurationBuilder;
import org.infinispan.configuration.parsing.ConfigurationBuilderHolder;
import org.infinispan.remoting.transport.jgroups.EmbeddedJGroupsChannelConfigurator;
import org.jboss.logging.Logger;
import org.jgroups.Address;
import org.jgroups.JChannel;
import org.jgroups.conf.ClassConfigurator;
import org.jgroups.conf.ProtocolConfiguration;
import org.jgroups.protocols.TCP;
import org.jgroups.protocols.TCP_NIO2;
import org.jgroups.protocols.UDP;
import org.jgroups.stack.Protocol;
import org.jgroups.util.DefaultSocketFactory;
import org.jgroups.util.ExtendedUUID;
import org.jgroups.util.SocketFactory;
import org.jgroups.util.UUID;
import org.keycloak.Config;
import org.keycloak.common.util.Retry;
import org.keycloak.config.CachingOptions;
import org.keycloak.config.Option;
import org.keycloak.connections.jpa.JpaConnectionProvider;
import org.keycloak.connections.jpa.JpaConnectionProviderFactory;
import org.keycloak.connections.jpa.util.JpaUtils;
import org.keycloak.infinispan.util.InfinispanUtils;
import org.keycloak.jgroups.header.TracerHeader;
import org.keycloak.jgroups.protocol.KEYCLOAK_JDBC_PING2;
import org.keycloak.jgroups.protocol.OPEN_TELEMETRY;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.provider.ProviderConfigurationBuilder;
import org.keycloak.spi.infinispan.JGroupsCertificateProvider;
import org.keycloak.spi.infinispan.impl.Util;
import org.keycloak.storage.configuration.ServerConfigStorageProvider;

public final class JGroupsConfigurator {
    private static final Logger logger = Logger.getLogger(MethodHandles.lookup().lookupClass());
    private static final String TLS_PROTOCOL_VERSION = "TLSv1.3";
    private static final String TLS_PROTOCOL = "TLS";
    public static final String JGROUPS_ADDRESS_SEQUENCE = "JGROUPS_ADDRESS_SEQUENCE";

    private JGroupsConfigurator() {
    }

    public static boolean isLocal(ConfigurationBuilderHolder holder) {
        return JGroupsConfigurator.transportOf(holder).getTransport() == null;
    }

    public static boolean isClustered(ConfigurationBuilderHolder holder) {
        return JGroupsConfigurator.transportOf(holder).getTransport() != null;
    }

    public static void configureJGroups(Config.Scope config, ConfigurationBuilderHolder holder, KeycloakSession session) {
        if (JGroupsConfigurator.isLocal(holder)) {
            return;
        }
        String stack = config.get("stack");
        if (stack != null) {
            JGroupsConfigurator.transportOf(holder).stack(stack);
        }
        JGroupsConfigurator.configureTransport(config);
        boolean tracingEnabled = config.getBoolean("tracingEnabled", Boolean.valueOf(false));
        JGroupsConfigurator.configureDiscovery(holder, session, tracingEnabled);
        JGroupsConfigurator.configureTls(holder, session);
        JGroupsConfigurator.warnDeprecatedStack(holder);
    }

    public static void configureTopology(Config.Scope config, ConfigurationBuilderHolder holder) {
        String siteName;
        if (System.getProperty("jboss.site.name") != null) {
            throw new IllegalArgumentException(String.format("System property %s is in use. Use --spi-cache-embedded-%s-site-name config option instead", "jboss.site.name", "default"));
        }
        if (System.getProperty("jboss.node.name") != null) {
            throw new IllegalArgumentException(String.format("System property %s is in use. Use --spi-cache-embedded-%s-node-name config option instead", "jboss.node.name", "default"));
        }
        TransportConfigurationBuilder transport = JGroupsConfigurator.transportOf(holder);
        String legacySiteName = Config.scope((String[])new String[]{"connectionsInfinispan", "quarkus"}).get("site-name");
        if (legacySiteName != null) {
            logger.warn((Object)"--spi-connections-infinispan-quarkus-site-name is deprecated and may be removed in the future. Use --spi-cache-embedded-%s-site-name".formatted("default"));
        }
        if ((siteName = config.get("siteName", legacySiteName)) != null && !siteName.isEmpty()) {
            transport.siteId(siteName);
        }
        JGroupsConfigurator.readConfigAndSet(config, "rackName", arg_0 -> ((TransportConfigurationBuilder)transport).rackId(arg_0));
        JGroupsConfigurator.readConfigAndSet(config, "machineName", arg_0 -> ((TransportConfigurationBuilder)transport).machineId(arg_0));
        JGroupsConfigurator.readConfigAndSet(config, "nodeName", arg_0 -> ((TransportConfigurationBuilder)transport).nodeName(arg_0));
    }

    static void createJGroupsProperties(ProviderConfigurationBuilder builder) {
        Util.copyFromOption(builder, SystemProperties.BIND_ADDRESS.configKey, "address", "String", CachingOptions.CACHE_EMBEDDED_NETWORK_BIND_ADDRESS, false);
        Util.copyFromOption(builder, SystemProperties.BIND_PORT.configKey, "port", "Integer", CachingOptions.CACHE_EMBEDDED_NETWORK_BIND_PORT, false);
        Util.copyFromOption(builder, SystemProperties.EXTERNAL_ADDRESS.configKey, "address", "String", CachingOptions.CACHE_EMBEDDED_NETWORK_EXTERNAL_ADDRESS, false);
        Util.copyFromOption(builder, SystemProperties.EXTERNAL_PORT.configKey, "port", "Integer", CachingOptions.CACHE_EMBEDDED_NETWORK_EXTERNAL_PORT, false);
    }

    private static void configureTransport(Config.Scope config) {
        Arrays.stream(SystemProperties.values()).forEach(p -> p.set(config));
    }

    private static void configureTls(ConfigurationBuilderHolder holder, KeycloakSession session) {
        JGroupsCertificateProvider provider = (JGroupsCertificateProvider)session.getProvider(JGroupsCertificateProvider.class);
        if (provider == null || !provider.isEnabled()) {
            return;
        }
        SocketFactory factory = JGroupsConfigurator.createSocketFactory(provider);
        JGroupsConfigurator.transportOf(holder).addProperty("socketFactory", (Object)factory);
        JGroupsConfigurator.validateTlsAvailable(holder);
        logger.info((Object)"JGroups Encryption enabled (mTLS).");
    }

    private static SocketFactory createSocketFactory(JGroupsCertificateProvider provider) {
        try {
            SSLContext sslContext = SSLContext.getInstance(TLS_PROTOCOL);
            sslContext.init(new KeyManager[]{provider.keyManager()}, new TrustManager[]{provider.trustManager()}, null);
            return JGroupsConfigurator.createFromContext(sslContext);
        }
        catch (KeyManagementException | NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }

    private static SocketFactory createFromContext(SSLContext context) {
        DefaultSocketFactory socketFactory = new DefaultSocketFactory(context);
        SSLParameters serverParameters = new SSLParameters();
        serverParameters.setProtocols(new String[]{TLS_PROTOCOL_VERSION});
        serverParameters.setNeedClientAuth(true);
        socketFactory.setServerSocketConfigurator(socket -> ((SSLServerSocket)socket).setSSLParameters(serverParameters));
        return socketFactory;
    }

    private static void configureDiscovery(ConfigurationBuilderHolder holder, KeycloakSession session, boolean tracingEnabled) {
        Attribute<String> stackXmlAttribute = JGroupsConfigurator.transportStackOf(holder);
        if (stackXmlAttribute.isModified() && !JGroupsConfigurator.isJdbcPingStack((String)stackXmlAttribute.get())) {
            logger.debugf("Custom stack configured (%s). JDBC_PING discovery disabled.", stackXmlAttribute.get());
            return;
        }
        logger.debug((Object)"JDBC_PING discovery enabled.");
        if (!stackXmlAttribute.isModified()) {
            JGroupsConfigurator.transportOf(holder).stack("jdbc-ping");
        }
        EntityManager em = ((JpaConnectionProvider)session.getProvider(JpaConnectionProvider.class)).getEntityManager();
        String stackName = (String)JGroupsConfigurator.transportStackOf(holder).get();
        boolean isUdp = stackName.endsWith("udp");
        String tableName = JpaUtils.getTableNameForNativeQuery((String)"JGROUPS_PING", (EntityManager)em);
        List<ProtocolConfiguration> stack = JGroupsConfigurator.getProtocolConfigurations(tableName, isUdp, tracingEnabled);
        JpaConnectionProviderFactory connectionFactory = (JpaConnectionProviderFactory)session.getKeycloakSessionFactory().getProviderFactory(JpaConnectionProvider.class);
        String clusterName = (String)JGroupsConfigurator.transportOf(holder).attributes().attribute(TransportConfiguration.CLUSTER_NAME).get();
        Address address = (Address)Retry.call(ignored -> (Address)KeycloakModelUtils.runJobInTransactionWithResult((KeycloakSessionFactory)session.getKeycloakSessionFactory(), s -> JGroupsConfigurator.prepareJGroupsAddress(s, clusterName)), (int)50, (long)10L);
        holder.addJGroupsStack((EmbeddedJGroupsChannelConfigurator)new JpaFactoryAwareJGroupsChannelConfigurator(stackName, stack, connectionFactory, isUdp, address), null);
        JGroupsConfigurator.transportOf(holder).stack(stackName);
        logger.info((Object)"JGroups JDBC_PING discovery enabled.");
    }

    private static Address prepareJGroupsAddress(KeycloakSession session, String clusterName) {
        JpaConnectionProvider cp = (JpaConnectionProvider)session.getProvider(JpaConnectionProvider.class);
        String tableName = JpaUtils.getTableNameForNativeQuery((String)"JGROUPS_PING", (EntityManager)cp.getEntityManager());
        long highestSequence = JGroupsConfigurator.findHighestSequenceInTable(cp, clusterName, tableName);
        long mySequence = JGroupsConfigurator.getNextSequence(session.getKeycloakSessionFactory(), highestSequence);
        return JGroupsConfigurator.insertSequenceInTable(cp, clusterName, tableName, mySequence);
    }

    private static long findHighestSequenceInTable(JpaConnectionProvider cp, String clusterName, String tableName) {
        return (Long)cp.getEntityManager().callWithConnection(con -> {
            long maxSequence = -1L;
            try (PreparedStatement s = con.prepareStatement("SELECT address FROM %s WHERE cluster_name=?".formatted(tableName));){
                s.setString(1, clusterName);
                try (ResultSet resultSet = s.executeQuery();){
                    while (resultSet.next()) {
                        UUID uuidAddr;
                        String uuid = resultSet.getString(1);
                        Address addr = org.jgroups.util.Util.addressFromString((String)uuid);
                        if (!(addr instanceof UUID) || (uuidAddr = (UUID)addr).getMostSignificantBits() != 0L || uuidAddr.getLeastSignificantBits() <= maxSequence) continue;
                        maxSequence = uuidAddr.getLeastSignificantBits();
                    }
                }
            }
            return maxSequence;
        });
    }

    private static long getNextSequence(KeycloakSessionFactory sf, long highestSequence) {
        return (Long)KeycloakModelUtils.runJobInTransactionWithResult((KeycloakSessionFactory)sf, session -> {
            ServerConfigStorageProvider storage = (ServerConfigStorageProvider)session.getProvider(ServerConfigStorageProvider.class);
            String seq = storage.loadOrCreate(JGROUPS_ADDRESS_SEQUENCE, () -> "0");
            long value = Math.max(highestSequence + 1L, Long.parseLong(seq) + 1L);
            storage.replace(JGROUPS_ADDRESS_SEQUENCE, seq, Long.toString(value));
            return value;
        });
    }

    private static Address insertSequenceInTable(JpaConnectionProvider cp, String clusterName, String tableName, long mySequence) {
        ExtendedUUID address = new ExtendedUUID(0L, mySequence);
        cp.getEntityManager().runWithConnection(con -> {
            try (PreparedStatement s = con.prepareStatement("INSERT INTO %s values (?, ?, ?, ?, ?)".formatted(tableName));){
                s.setString(1, org.jgroups.util.Util.addressToString((Address)new UUID(address.getMostSignificantBits(), address.getLeastSignificantBits())));
                s.setString(2, "(starting)");
                s.setString(3, clusterName);
                s.setString(4, "127.0.0.1:0");
                s.setBoolean(5, false);
                s.execute();
            }
        });
        return address;
    }

    private static List<ProtocolConfiguration> getProtocolConfigurations(String tableName, boolean udp, boolean tracingEnabled) {
        ArrayList<ProtocolConfiguration> list = new ArrayList<ProtocolConfiguration>(udp ? 1 : 2);
        list.add(new ProtocolConfiguration(KEYCLOAK_JDBC_PING2.class.getName(), Map.of("initialize_sql", "", "clear_sql", String.format("DELETE from %s WHERE cluster_name=?", tableName), "delete_single_sql", String.format("DELETE from %s WHERE address=?", tableName), "insert_single_sql", String.format("INSERT INTO %s values (?, ?, ?, ?, ?)", tableName), "select_all_pingdata_sql", String.format("SELECT address, name, ip, coord FROM %s WHERE cluster_name=?", tableName), "remove_all_data_on_view_change", "true", "write_data_on_find", "true", "register_shutdown_hook", "false", "stack.combine", "REPLACE", "stack.position", udp ? "PING" : "MPING")));
        if (!udp && InfinispanUtils.isVirtualThreadsEnabled()) {
            list.add(new ProtocolConfiguration(TCP.class.getSimpleName(), Map.of("bundler_type", "per-destination")));
        }
        if (tracingEnabled) {
            list.add(new ProtocolConfiguration(OPEN_TELEMETRY.class.getName(), Map.of("stack.combine", "INSERT_ABOVE", "stack.position", udp ? "UDP" : "TCP")));
        }
        return list;
    }

    private static void warnDeprecatedStack(ConfigurationBuilderHolder holder) {
        String stackName;
        switch (stackName = (String)JGroupsConfigurator.transportStackOf(holder).get()) {
            case "jdbc-ping-udp": 
            case "tcp": 
            case "udp": 
            case "azure": 
            case "ec2": 
            case "google": {
                logger.warnf("Stack '%s' is deprecated. We recommend to use 'jdbc-ping' instead", (Object)stackName);
            }
        }
    }

    private static TransportConfigurationBuilder transportOf(ConfigurationBuilderHolder holder) {
        return holder.getGlobalConfigurationBuilder().transport();
    }

    private static Attribute<String> transportStackOf(ConfigurationBuilderHolder holder) {
        TransportConfigurationBuilder transport = JGroupsConfigurator.transportOf(holder);
        assert (transport != null);
        return transport.attributes().attribute(TransportConfiguration.STACK);
    }

    private static void validateTlsAvailable(ConfigurationBuilderHolder holder) {
        String stackName = (String)JGroupsConfigurator.transportStackOf(holder).get();
        if (stackName == null) {
            return;
        }
        GlobalConfiguration config = JGroupsConfigurator.transportOf(holder).build();
        for (ProtocolConfiguration protocol : config.transport().jgroups().configurator(stackName).getProtocolStack()) {
            String name = protocol.getProtocolName();
            if (!name.equals(UDP.class.getSimpleName()) && !name.equals(UDP.class.getName()) && !name.equals(TCP_NIO2.class.getSimpleName()) && !name.equals(TCP_NIO2.class.getName())) continue;
            throw new RuntimeException("Cache TLS is not available with protocol " + name);
        }
    }

    private static boolean isJdbcPingStack(String stackName) {
        return "jdbc-ping".equals(stackName) || "jdbc-ping-udp".equals(stackName);
    }

    private static void readConfigAndSet(Config.Scope scope, String configKey, Consumer<String> consumer) {
        String value = scope.get(configKey);
        if (value != null && !value.isEmpty()) {
            consumer.accept(value);
        }
    }

    static {
        ClassConfigurator.addProtocol((short)1025, KEYCLOAK_JDBC_PING2.class);
        ClassConfigurator.addProtocol((short)1026, OPEN_TELEMETRY.class);
        ClassConfigurator.add((short)1050, TracerHeader.class);
    }

    private static enum SystemProperties {
        BIND_ADDRESS(CachingOptions.CACHE_EMBEDDED_NETWORK_BIND_ADDRESS, "jgroups.bind_addr", "jgroups.bind.address"),
        BIND_PORT(CachingOptions.CACHE_EMBEDDED_NETWORK_BIND_PORT, "jgroups.bind_port", "jgroups.bind.port"),
        EXTERNAL_ADDRESS(CachingOptions.CACHE_EMBEDDED_NETWORK_EXTERNAL_ADDRESS, "jgroups.external_addr"),
        EXTERNAL_PORT(CachingOptions.CACHE_EMBEDDED_NETWORK_EXTERNAL_PORT, "jgroups.external_port");

        final Option<?> option;
        final String property;
        final String altProperty;
        final String configKey;

        private SystemProperties(Option<?> option, String property) {
            this(option, property, null);
        }

        private SystemProperties(Option<?> option, String property, String altProperty) {
            this.option = option;
            this.property = property;
            this.altProperty = altProperty;
            this.configKey = this.configKey();
        }

        void set(Config.Scope config) {
            String userConfig = this.fromConfig(config);
            if (userConfig == null) {
                return;
            }
            this.checkPropertyAlreadySet(userConfig, this.property);
            if (this.altProperty != null) {
                this.checkPropertyAlreadySet(userConfig, this.altProperty);
            }
            System.setProperty(this.property, userConfig);
        }

        void checkPropertyAlreadySet(String userValue, String property) {
            String userProp = System.getProperty(property);
            if (userProp != null) {
                logger.warnf("Conflicting system property '%s' and CLI arg '%s' set, utilising CLI value '%s'", (Object)property, (Object)this.option.getKey(), (Object)userValue);
                System.clearProperty(property);
            }
        }

        String fromConfig(Config.Scope config) {
            if (this.option.getType() == Integer.class) {
                Integer val = config.getInt(this.configKey);
                return val == null ? null : val.toString();
            }
            return config.get(this.configKey);
        }

        String configKey() {
            String key = this.option.getKey().substring("cache-embedded".length() + 1);
            StringBuilder sb = new StringBuilder(key);
            for (int i = 0; i < sb.length(); ++i) {
                if (sb.charAt(i) != '-') continue;
                sb.deleteCharAt(i);
                sb.replace(i, i + 1, String.valueOf(Character.toUpperCase(sb.charAt(i))));
            }
            return sb.toString();
        }
    }

    private static class JpaFactoryAwareJGroupsChannelConfigurator
    extends EmbeddedJGroupsChannelConfigurator {
        private final JpaConnectionProviderFactory factory;
        private final Address address;

        public JpaFactoryAwareJGroupsChannelConfigurator(String name, List<ProtocolConfiguration> stack, JpaConnectionProviderFactory factory, boolean isUdp, Address address) {
            super(name, stack, null, isUdp ? "udp" : "tcp");
            this.factory = Objects.requireNonNull(factory);
            this.address = address;
        }

        protected JChannel amendChannel(JChannel channel) {
            channel.addAddressGenerator(() -> this.address);
            return super.amendChannel(channel);
        }

        public void afterCreation(Protocol protocol) {
            super.afterCreation(protocol);
            if (protocol instanceof KEYCLOAK_JDBC_PING2) {
                KEYCLOAK_JDBC_PING2 kcPing = (KEYCLOAK_JDBC_PING2)protocol;
                kcPing.setJpaConnectionProviderFactory(this.factory);
            }
        }
    }
}

