/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ratis.netty.server;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ratis.client.RaftClientConfigKeys;
import org.apache.ratis.client.impl.ClientProtoUtils;
import org.apache.ratis.client.impl.DataStreamClientImpl;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer;
import org.apache.ratis.datastream.impl.DataStreamRequestByteBuf;
import org.apache.ratis.io.StandardWriteOption;
import org.apache.ratis.io.WriteOption;
import org.apache.ratis.metrics.Timekeeper;
import org.apache.ratis.netty.metrics.NettyServerStreamRpcMetrics;
import org.apache.ratis.netty.server.ChannelMap;
import org.apache.ratis.netty.server.StreamMap;
import org.apache.ratis.proto.RaftProtos;
import org.apache.ratis.protocol.ClientId;
import org.apache.ratis.protocol.ClientInvocationId;
import org.apache.ratis.protocol.DataStreamPacket;
import org.apache.ratis.protocol.DataStreamReply;
import org.apache.ratis.protocol.RaftClientMessage;
import org.apache.ratis.protocol.RaftClientReply;
import org.apache.ratis.protocol.RaftClientRequest;
import org.apache.ratis.protocol.RaftGroupId;
import org.apache.ratis.protocol.RaftPeer;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.protocol.RoutingTable;
import org.apache.ratis.protocol.exceptions.AlreadyExistsException;
import org.apache.ratis.protocol.exceptions.DataStreamException;
import org.apache.ratis.protocol.exceptions.RaftException;
import org.apache.ratis.server.RaftConfiguration;
import org.apache.ratis.server.RaftServer;
import org.apache.ratis.server.RaftServerConfigKeys;
import org.apache.ratis.statemachine.StateMachine;
import org.apache.ratis.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelId;
import org.apache.ratis.util.ConcurrentUtils;
import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.MemoizedSupplier;
import org.apache.ratis.util.Preconditions;
import org.apache.ratis.util.ReferenceCountedObject;
import org.apache.ratis.util.TimeDuration;
import org.apache.ratis.util.TimeoutExecutor;
import org.apache.ratis.util.function.CheckedBiFunction;
import org.apache.ratis.util.function.UncheckedAutoCloseableSupplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataStreamManagement {
    public static final Logger LOG = LoggerFactory.getLogger(DataStreamManagement.class);
    private final RaftServer server;
    private final String name;
    private final StreamMap<StreamInfo> streams = new StreamMap();
    private final ChannelMap channels;
    private final ExecutorService requestExecutor;
    private final ExecutorService writeExecutor;
    private final TimeDuration requestTimeout;
    private final NettyServerStreamRpcMetrics nettyServerStreamRpcMetrics;

    DataStreamManagement(RaftServer server, NettyServerStreamRpcMetrics metrics) {
        this.server = server;
        this.name = server.getId() + "-" + JavaUtils.getClassSimpleName(this.getClass());
        this.channels = new ChannelMap();
        RaftProperties properties = server.getProperties();
        boolean useCachedThreadPool = RaftServerConfigKeys.DataStream.asyncRequestThreadPoolCached((RaftProperties)properties);
        this.requestExecutor = ConcurrentUtils.newThreadPoolWithMax((boolean)useCachedThreadPool, (int)RaftServerConfigKeys.DataStream.asyncRequestThreadPoolSize((RaftProperties)properties), (String)(this.name + "-request-"));
        this.writeExecutor = ConcurrentUtils.newThreadPoolWithMax((boolean)useCachedThreadPool, (int)RaftServerConfigKeys.DataStream.asyncWriteThreadPoolSize((RaftProperties)properties), (String)(this.name + "-write-"));
        this.requestTimeout = RaftClientConfigKeys.DataStream.requestTimeout((RaftProperties)server.getProperties());
        this.nettyServerStreamRpcMetrics = metrics;
    }

    void shutdown() {
        ConcurrentUtils.shutdownAndWait((TimeDuration)TimeDuration.ONE_SECOND, (ExecutorService)this.requestExecutor, timeout -> LOG.warn("{}: requestExecutor shutdown timeout in {}", (Object)this, timeout));
        ConcurrentUtils.shutdownAndWait((TimeDuration)TimeDuration.ONE_SECOND, (ExecutorService)this.writeExecutor, timeout -> LOG.warn("{}: writeExecutor shutdown timeout in {}", (Object)this, timeout));
    }

    private CompletableFuture<StateMachine.DataStream> stream(RaftClientRequest request, StateMachine stateMachine) {
        NettyServerStreamRpcMetrics.RequestMetrics metrics = this.getMetrics().newRequestMetrics(NettyServerStreamRpcMetrics.RequestType.STATE_MACHINE_STREAM);
        Timekeeper.Context context = metrics.start();
        return stateMachine.data().stream(request).whenComplete((r, e) -> metrics.stop(context, e == null));
    }

    private CompletableFuture<StateMachine.DataStream> computeDataStreamIfAbsent(RaftClientRequest request) throws IOException {
        RaftServer.Division division = this.server.getDivision(request.getRaftGroupId());
        ClientInvocationId invocationId = ClientInvocationId.valueOf((RaftClientMessage)request);
        CompletableFuture<StateMachine.DataStream> created = new CompletableFuture<StateMachine.DataStream>();
        CompletableFuture returned = division.getDataStreamMap().computeIfAbsent(invocationId, key -> created);
        if (returned != created) {
            throw new AlreadyExistsException("A DataStream already exists for " + invocationId);
        }
        this.stream(request, division.getStateMachine()).whenComplete(JavaUtils.asBiConsumer(created));
        return created;
    }

    private StreamInfo newStreamInfo(ByteBuf buf, CheckedBiFunction<RaftClientRequest, Set<RaftPeer>, Set<DataStreamClientImpl.DataStreamOutputImpl>, IOException> getStreams) {
        try {
            RaftClientRequest request = ClientProtoUtils.toRaftClientRequest((RaftProtos.RaftClientRequestProto)RaftProtos.RaftClientRequestProto.parseFrom((ByteBuffer)buf.nioBuffer()));
            boolean isPrimary = this.server.getId().equals((Object)request.getServerId());
            RaftServer.Division division = this.server.getDivision(request.getRaftGroupId());
            return new StreamInfo(request, isPrimary, this.computeDataStreamIfAbsent(request), division, getStreams, this.getMetrics()::newRequestMetrics);
        }
        catch (Throwable e) {
            throw new CompletionException(e);
        }
    }

    static <T> CompletableFuture<T> composeAsync(AtomicReference<CompletableFuture<T>> future, Executor executor, Function<T, CompletableFuture<T>> function) {
        return future.updateAndGet(previous -> previous.thenComposeAsync(function, executor));
    }

    static CompletableFuture<Long> writeToAsync(ByteBuf buf, Iterable<WriteOption> options, StateMachine.DataStream stream, Executor defaultExecutor) {
        Executor e = Optional.ofNullable(stream.getExecutor()).orElse(defaultExecutor);
        return CompletableFuture.supplyAsync(() -> DataStreamManagement.writeTo(buf, options, stream), e);
    }

    static long writeTo(ByteBuf buf, Iterable<WriteOption> options, StateMachine.DataStream stream) {
        StateMachine.DataChannel channel = stream.getDataChannel();
        long byteWritten = 0L;
        for (ByteBuffer buffer : buf.nioBuffers()) {
            if (buffer.remaining() == 0) continue;
            ReferenceCountedObject wrapped = ReferenceCountedObject.wrap((Object)buffer, () -> ((ByteBuf)buf).retain(), ignored -> buf.release());
            try (UncheckedAutoCloseableSupplier ignore = wrapped.retainAndReleaseOnClose();){
                byteWritten += (long)channel.write(wrapped);
            }
            catch (Throwable t) {
                throw new CompletionException(t);
            }
        }
        if (WriteOption.containsOption(options, (WriteOption)StandardWriteOption.SYNC)) {
            try {
                channel.force(false);
            }
            catch (IOException e) {
                throw new CompletionException(e);
            }
        }
        if (WriteOption.containsOption(options, (WriteOption)StandardWriteOption.CLOSE)) {
            DataStreamManagement.close(stream);
        }
        return byteWritten;
    }

    static void close(StateMachine.DataStream stream) {
        try {
            stream.getDataChannel().close();
        }
        catch (IOException e) {
            throw new CompletionException("Failed to close " + stream, e);
        }
    }

    static DataStreamReplyByteBuffer newDataStreamReplyByteBuffer(DataStreamRequestByteBuf request, RaftClientReply reply) {
        ByteBuffer buffer = ClientProtoUtils.toRaftClientReplyProto((RaftClientReply)reply).toByteString().asReadOnlyByteBuffer();
        return DataStreamReplyByteBuffer.newBuilder().setDataStreamPacket((DataStreamPacket)request).setBuffer(buffer).setSuccess(reply.isSuccess()).build();
    }

    private void sendReply(List<CompletableFuture<DataStreamReply>> remoteWrites, DataStreamRequestByteBuf request, long bytesWritten, Collection<RaftProtos.CommitInfoProto> commitInfos, ChannelHandlerContext ctx) {
        boolean success = this.checkSuccessRemoteWrite(remoteWrites, bytesWritten, request);
        DataStreamReplyByteBuffer.Builder builder = DataStreamReplyByteBuffer.newBuilder().setDataStreamPacket((DataStreamPacket)request).setSuccess(success).setCommitInfos(commitInfos);
        if (success) {
            builder.setBytesWritten(bytesWritten);
        }
        ctx.writeAndFlush((Object)builder.build());
    }

    static void replyDataStreamException(RaftServer server, Throwable cause, RaftClientRequest raftClientRequest, DataStreamRequestByteBuf request, ChannelHandlerContext ctx) {
        RaftClientReply reply = RaftClientReply.newBuilder().setRequest(raftClientRequest).setException((RaftException)new DataStreamException(server.getId(), cause)).build();
        DataStreamManagement.sendDataStreamException(cause, request, reply, ctx);
    }

    void replyDataStreamException(Throwable cause, DataStreamRequestByteBuf request, ChannelHandlerContext ctx) {
        RaftClientReply reply = RaftClientReply.newBuilder().setClientId(ClientId.emptyClientId()).setServerId(this.server.getId()).setGroupId(RaftGroupId.emptyGroupId()).setException((RaftException)new DataStreamException(this.server.getId(), cause)).build();
        DataStreamManagement.sendDataStreamException(cause, request, reply, ctx);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static void sendDataStreamException(Throwable throwable, DataStreamRequestByteBuf request, RaftClientReply reply, ChannelHandlerContext ctx) {
        LOG.warn("Failed to process {}", (Object)request, (Object)throwable);
        try {
            ctx.writeAndFlush((Object)DataStreamManagement.newDataStreamReplyByteBuffer(request, reply));
        }
        catch (Throwable t) {
            LOG.warn("Failed to sendDataStreamException {} for {}", new Object[]{throwable, request, t});
        }
        finally {
            request.release();
        }
    }

    void cleanUp(Set<ClientInvocationId> ids) {
        for (ClientInvocationId clientInvocationId : ids) {
            this.removeDataStream(clientInvocationId);
        }
    }

    void cleanUpOnChannelInactive(ChannelId channelId, TimeDuration channelInactiveGracePeriod) {
        Optional.ofNullable(this.channels.remove(channelId)).ifPresent(ids -> {
            LOG.info("Channel {} is inactive, cleanup clientInvocationIds={}", (Object)channelId, ids);
            TimeoutExecutor.getInstance().onTimeout(channelInactiveGracePeriod, () -> this.cleanUp((Set<ClientInvocationId>)ids), LOG, () -> "Timeout check failed, clientInvocationIds=" + ids);
        });
    }

    void read(DataStreamRequestByteBuf request, ChannelHandlerContext ctx, CheckedBiFunction<RaftClientRequest, Set<RaftPeer>, Set<DataStreamClientImpl.DataStreamOutputImpl>, IOException> getStreams) {
        LOG.debug("{}: read {}", (Object)this, (Object)request);
        try {
            this.readImpl(request, ctx, getStreams);
        }
        catch (Throwable t) {
            this.replyDataStreamException(t, request, ctx);
            this.removeDataStream(ClientInvocationId.valueOf((ClientId)request.getClientId(), (long)request.getStreamId()));
        }
    }

    private StreamInfo removeDataStream(ClientInvocationId invocationId) {
        StreamInfo removed = this.streams.remove(invocationId);
        if (removed != null) {
            removed.cleanUp(invocationId);
        }
        return removed;
    }

    private void readImpl(DataStreamRequestByteBuf request, ChannelHandlerContext ctx, CheckedBiFunction<RaftClientRequest, Set<RaftPeer>, Set<DataStreamClientImpl.DataStreamOutputImpl>, IOException> getStreams) {
        List<Object> remoteWrites;
        CompletableFuture<Long> localWrite;
        StreamInfo info;
        boolean close = request.getWriteOptionList().contains(StandardWriteOption.CLOSE);
        ClientInvocationId key = ClientInvocationId.valueOf((ClientId)request.getClientId(), (long)request.getStreamId());
        ChannelId channelId = ctx.channel().id();
        this.channels.add(channelId, key);
        if (request.getType() == RaftProtos.DataStreamPacketHeaderProto.Type.STREAM_HEADER) {
            MemoizedSupplier supplier = JavaUtils.memoize(() -> this.newStreamInfo(request.slice(), getStreams));
            info = this.streams.computeIfAbsent(key, id -> (StreamInfo)supplier.get());
            if (!supplier.isInitialized()) {
                throw new IllegalStateException("Failed to create a new stream for " + request + " since a stream already exists Key: " + key + " StreamInfo:" + info);
            }
            this.getMetrics().onRequestCreate(NettyServerStreamRpcMetrics.RequestType.HEADER);
        } else {
            info = close ? Optional.ofNullable(this.streams.remove(key)).orElseThrow(() -> new IllegalStateException("Failed to remove StreamInfo for " + request)) : Optional.ofNullable(this.streams.get(key)).orElseThrow(() -> new IllegalStateException("Failed to get StreamInfo for " + request));
        }
        if (request.getType() == RaftProtos.DataStreamPacketHeaderProto.Type.STREAM_HEADER) {
            localWrite = CompletableFuture.completedFuture(0L);
            remoteWrites = Collections.emptyList();
        } else if (request.getType() == RaftProtos.DataStreamPacketHeaderProto.Type.STREAM_DATA) {
            localWrite = info.getLocal().write(request.slice(), request.getWriteOptionList(), this.writeExecutor);
            remoteWrites = info.applyToRemotes(out -> out.write(request, this.requestExecutor));
        } else {
            throw new IllegalStateException(this + ": Unexpected type " + request.getType() + ", request=" + request);
        }
        DataStreamManagement.composeAsync(info.getPrevious(), this.requestExecutor, n -> JavaUtils.allOf((Collection)remoteWrites).thenCombineAsync(localWrite, (v, bytesWritten) -> {
            if (request.getType() != RaftProtos.DataStreamPacketHeaderProto.Type.STREAM_HEADER && request.getType() != RaftProtos.DataStreamPacketHeaderProto.Type.STREAM_DATA && !close) {
                throw new IllegalStateException(this + ": Unexpected type " + request.getType() + ", request=" + request);
            }
            this.sendReply(remoteWrites, request, (long)bytesWritten, info.getCommitInfos(), ctx);
            return null;
        }, (Executor)this.requestExecutor)).whenComplete((v, exception) -> {
            try {
                if (exception != null) {
                    DataStreamManagement.replyDataStreamException(this.server, exception, info.getRequest(), request, ctx);
                    StreamInfo removed = this.removeDataStream(key);
                    if (removed != null) {
                        Preconditions.assertSame((Object)info, (Object)removed, (String)"removed");
                    } else {
                        info.cleanUp(key);
                    }
                } else if (close) {
                    info.applyToRemotes(remote -> ((RemoteStream)remote).out.closeAsync());
                }
            }
            finally {
                request.release();
                this.channels.remove(channelId, key);
            }
        });
    }

    static void assertReplyCorrespondingToRequest(DataStreamRequestByteBuf request, DataStreamReply reply) {
        Preconditions.assertTrue((boolean)request.getClientId().equals((Object)reply.getClientId()));
        Preconditions.assertTrue((request.getType() == reply.getType() ? 1 : 0) != 0);
        Preconditions.assertTrue((request.getStreamId() == reply.getStreamId() ? 1 : 0) != 0);
        Preconditions.assertTrue((request.getStreamOffset() == reply.getStreamOffset() ? 1 : 0) != 0);
    }

    private boolean checkSuccessRemoteWrite(List<CompletableFuture<DataStreamReply>> replyFutures, long bytesWritten, DataStreamRequestByteBuf request) {
        for (CompletableFuture<DataStreamReply> replyFuture : replyFutures) {
            DataStreamReply reply;
            try {
                reply = replyFuture.get(this.requestTimeout.getDuration(), this.requestTimeout.getUnit());
            }
            catch (Exception e) {
                throw new CompletionException("Failed to get reply for bytesWritten=" + bytesWritten + ", " + request, e);
            }
            DataStreamManagement.assertReplyCorrespondingToRequest(request, reply);
            if (!reply.isSuccess()) {
                LOG.warn("reply is not success, request: {}", (Object)request);
                return false;
            }
            if (reply.getBytesWritten() == bytesWritten) continue;
            LOG.warn("reply written bytes not match, local size: {} remote size: {} request: {}", new Object[]{bytesWritten, reply.getBytesWritten(), request});
            return false;
        }
        return true;
    }

    NettyServerStreamRpcMetrics getMetrics() {
        return this.nettyServerStreamRpcMetrics;
    }

    public String toString() {
        return this.name;
    }

    static class StreamInfo {
        private final RaftClientRequest request;
        private final boolean primary;
        private final LocalStream local;
        private final Set<RemoteStream> remotes;
        private final RaftServer.Division division;
        private final AtomicReference<CompletableFuture<Void>> previous = new AtomicReference<CompletableFuture<Object>>(CompletableFuture.completedFuture(null));

        StreamInfo(RaftClientRequest request, boolean primary, CompletableFuture<StateMachine.DataStream> stream, RaftServer.Division division, CheckedBiFunction<RaftClientRequest, Set<RaftPeer>, Set<DataStreamClientImpl.DataStreamOutputImpl>, IOException> getStreams, Function<NettyServerStreamRpcMetrics.RequestType, NettyServerStreamRpcMetrics.RequestMetrics> metricsConstructor) throws IOException {
            this.request = request;
            this.primary = primary;
            this.local = new LocalStream(stream, metricsConstructor.apply(NettyServerStreamRpcMetrics.RequestType.LOCAL_WRITE));
            this.division = division;
            Set<RaftPeer> successors = this.getSuccessors(division.getId());
            Set outs = (Set)getStreams.apply((Object)request, successors);
            this.remotes = outs.stream().map(o -> new RemoteStream((DataStreamClientImpl.DataStreamOutputImpl)o, (NettyServerStreamRpcMetrics.RequestMetrics)metricsConstructor.apply(NettyServerStreamRpcMetrics.RequestType.REMOTE_WRITE))).collect(Collectors.toSet());
        }

        AtomicReference<CompletableFuture<Void>> getPrevious() {
            return this.previous;
        }

        RaftClientRequest getRequest() {
            return this.request;
        }

        RaftServer.Division getDivision() {
            return this.division;
        }

        Collection<RaftProtos.CommitInfoProto> getCommitInfos() {
            return this.getDivision().getCommitInfos();
        }

        boolean isPrimary() {
            return this.primary;
        }

        LocalStream getLocal() {
            return this.local;
        }

        <T> List<T> applyToRemotes(Function<RemoteStream, T> function) {
            return this.remotes.isEmpty() ? Collections.emptyList() : this.remotes.stream().map(function).collect(Collectors.toList());
        }

        public String toString() {
            return JavaUtils.getClassSimpleName(this.getClass()) + ":" + this.request;
        }

        private Set<RaftPeer> getSuccessors(RaftPeerId peerId) {
            RaftConfiguration conf = this.getDivision().getRaftConf();
            RoutingTable routingTable = this.request.getRoutingTable();
            if (routingTable != null) {
                return routingTable.getSuccessors(peerId).stream().map(x$0 -> conf.getPeer(x$0, new RaftProtos.RaftPeerRole[0])).collect(Collectors.toSet());
            }
            if (this.isPrimary()) {
                return conf.getCurrentPeers().stream().filter(p -> !p.getId().equals((Object)this.division.getId())).collect(Collectors.toSet());
            }
            return Collections.emptySet();
        }

        void cleanUp(ClientInvocationId invocationId) {
            this.getDivision().getDataStreamMap().remove(invocationId);
            this.getLocal().cleanUp();
            this.applyToRemotes(remote -> ((RemoteStream)remote).out.closeAsync());
        }
    }

    static class RemoteStream {
        private final DataStreamClientImpl.DataStreamOutputImpl out;
        private final AtomicReference<CompletableFuture<DataStreamReply>> sendFuture = new AtomicReference<CompletableFuture<Object>>(CompletableFuture.completedFuture(null));
        private final NettyServerStreamRpcMetrics.RequestMetrics metrics;

        RemoteStream(DataStreamClientImpl.DataStreamOutputImpl out, NettyServerStreamRpcMetrics.RequestMetrics metrics) {
            this.metrics = metrics;
            this.out = out;
        }

        static Iterable<WriteOption> addFlush(List<WriteOption> original) {
            if (original.contains(StandardWriteOption.FLUSH)) {
                return original;
            }
            return Stream.concat(Stream.of(StandardWriteOption.FLUSH), original.stream()).collect(Collectors.toList());
        }

        CompletableFuture<DataStreamReply> write(DataStreamRequestByteBuf request, Executor executor) {
            Timekeeper.Context context = this.metrics.start();
            return DataStreamManagement.composeAsync(this.sendFuture, executor, n -> this.out.writeAsync(request.slice().retain(), RemoteStream.addFlush(request.getWriteOptionList())).whenComplete((l, e) -> this.metrics.stop(context, e == null)));
        }
    }

    static class LocalStream {
        private final CompletableFuture<StateMachine.DataStream> streamFuture;
        private final AtomicReference<CompletableFuture<Long>> writeFuture;
        private final NettyServerStreamRpcMetrics.RequestMetrics metrics;

        LocalStream(CompletableFuture<StateMachine.DataStream> streamFuture, NettyServerStreamRpcMetrics.RequestMetrics metrics) {
            this.streamFuture = streamFuture;
            this.writeFuture = new AtomicReference<CompletionStage>(streamFuture.thenApply(s -> 0L));
            this.metrics = metrics;
        }

        CompletableFuture<Long> write(ByteBuf buf, Iterable<WriteOption> options, Executor executor) {
            Timekeeper.Context context = this.metrics.start();
            return DataStreamManagement.composeAsync(this.writeFuture, executor, n -> this.streamFuture.thenCompose(stream -> DataStreamManagement.writeToAsync(buf, options, stream, executor).whenComplete((l, e) -> this.metrics.stop(context, e == null))));
        }

        void cleanUp() {
            this.streamFuture.thenAccept(StateMachine.DataStream::cleanUp);
        }
    }
}

