/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.plugin.flink;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.DriverChangedException;
import org.apache.celeborn.common.exception.PartitionUnRetryAbleException;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.plugin.flink.RemoteBufferStreamReader;
import org.apache.celeborn.plugin.flink.RemoteShuffleDescriptor;
import org.apache.celeborn.plugin.flink.RemoteShuffleResource;
import org.apache.celeborn.plugin.flink.ShuffleResourceDescriptor;
import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
import org.apache.celeborn.plugin.flink.buffer.TransferBufferPool;
import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
import org.apache.celeborn.plugin.flink.utils.BufferUtils;
import org.apache.celeborn.plugin.flink.utils.Utils;
import org.apache.celeborn.shaded.org.apache.commons.lang3.tuple.Pair;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.event.AbstractEvent;
import org.apache.flink.runtime.io.AvailabilityProvider;
import org.apache.flink.runtime.io.network.api.EndOfData;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferDecompressor;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.function.SupplierWithException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RemoteShuffleInputGateDelegation {
    private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleInputGateDelegation.class);
    private Object lock = new Object();
    private String taskName;
    private int gateIndex;
    private InputGateDeploymentDescriptor gateDescriptor;
    private SupplierWithException<BufferPool, IOException> bufferPoolFactory;
    private BufferPool bufferPool;
    private TransferBufferPool transferBufferPool = new TransferBufferPool(Collections.emptySet());
    private List<RemoteBufferStreamReader> bufferReaders = new ArrayList<RemoteBufferStreamReader>();
    private List<InputChannelInfo> channelsInfo;
    private int[] channelIndexToReaderIndex;
    private int[] readerIndexToChannelIndex;
    private int[] numSubPartitionsNotConsumed;
    private long numUnconsumedSubpartitions;
    private Queue<Pair<Buffer, InputChannelInfo>> receivedBuffers = new LinkedList<Pair<Buffer, InputChannelInfo>>();
    private Throwable cause;
    private boolean closed;
    private boolean initialChannelsOpened;
    private long pendingEndOfDataEvents;
    private int numConcurrentReading;
    private boolean shouldDrainOnEndOfData = true;
    private BufferDecompressor bufferDecompressor;
    private FlinkShuffleClientImpl shuffleClient;
    private int numOpenedReaders = 0;
    private AvailabilityProvider.AvailabilityHelper availabilityHelper;
    private int startSubIndex;
    private int endSubIndex;

    public RemoteShuffleInputGateDelegation(CelebornConf celebornConf, String taskName, int gateIndex, InputGateDeploymentDescriptor gateDescriptor, SupplierWithException<BufferPool, IOException> bufferPoolFactory, BufferDecompressor bufferDecompressor, int numConcurrentReading, AvailabilityProvider.AvailabilityHelper availabilityHelper, int startSubIndex, int endSubIndex) {
        this.taskName = taskName;
        this.gateIndex = gateIndex;
        this.gateDescriptor = gateDescriptor;
        this.bufferPoolFactory = bufferPoolFactory;
        int numChannels = gateDescriptor.getShuffleDescriptors().length;
        this.channelIndexToReaderIndex = new int[numChannels];
        this.readerIndexToChannelIndex = new int[numChannels];
        this.numSubPartitionsNotConsumed = new int[numChannels];
        this.bufferDecompressor = bufferDecompressor;
        RemoteShuffleDescriptor remoteShuffleDescriptor = (RemoteShuffleDescriptor)gateDescriptor.getShuffleDescriptors()[0];
        RemoteShuffleResource shuffleResource = remoteShuffleDescriptor.getShuffleResource();
        try {
            String appUniqueId = ((RemoteShuffleDescriptor)gateDescriptor.getShuffleDescriptors()[0]).getCelebornAppId();
            this.shuffleClient = FlinkShuffleClientImpl.get(appUniqueId, shuffleResource.getLifecycleManagerHost(), shuffleResource.getLifecycleManagerPort(), shuffleResource.getLifecycleManagerTimestamp(), celebornConf, new UserIdentifier("default", "default"));
        }
        catch (DriverChangedException e) {
            throw new RuntimeException(e.getMessage());
        }
        this.startSubIndex = startSubIndex;
        this.endSubIndex = endSubIndex;
        this.initShuffleReadClients();
        this.channelsInfo = this.createChannelInfos();
        this.numConcurrentReading = numConcurrentReading;
        this.availabilityHelper = availabilityHelper;
        LOG.debug("Initial input gate with numConcurrentReading {}", (Object)this.numConcurrentReading);
    }

    private void initShuffleReadClients() {
        int numSubpartitionsPerChannel = this.endSubIndex - this.startSubIndex + 1;
        long numUnconsumedSubpartitions = 0L;
        List descriptors = IntStream.range(0, this.gateDescriptor.getShuffleDescriptors().length).mapToObj(i -> Pair.of(i, this.gateDescriptor.getShuffleDescriptors()[i])).collect(Collectors.toList());
        int readerIndex = 0;
        for (Pair descriptor : descriptors) {
            RemoteShuffleDescriptor remoteDescriptor = (RemoteShuffleDescriptor)descriptor.getRight();
            ShuffleResourceDescriptor shuffleDescriptor = remoteDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor();
            LOG.debug("create shuffle reader for descriptor {}", (Object)shuffleDescriptor);
            RemoteBufferStreamReader reader = new RemoteBufferStreamReader(this.shuffleClient, shuffleDescriptor, this.startSubIndex, this.endSubIndex, this.transferBufferPool, this.getDataListener((Integer)descriptor.getLeft()), this.getFailureListener(remoteDescriptor.getResultPartitionID()));
            this.bufferReaders.add(reader);
            this.numSubPartitionsNotConsumed[((Integer)descriptor.getLeft()).intValue()] = numSubpartitionsPerChannel;
            numUnconsumedSubpartitions += (long)numSubpartitionsPerChannel;
            this.channelIndexToReaderIndex[((Integer)descriptor.getLeft()).intValue()] = readerIndex;
            this.readerIndexToChannelIndex[readerIndex] = (Integer)descriptor.getLeft();
            ++readerIndex;
        }
        this.numUnconsumedSubpartitions = numUnconsumedSubpartitions;
        this.pendingEndOfDataEvents = numUnconsumedSubpartitions;
    }

    private List<InputChannelInfo> createChannelInfos() {
        return IntStream.range(0, this.gateDescriptor.getShuffleDescriptors().length).mapToObj(i -> new InputChannelInfo(this.gateIndex, i)).collect(Collectors.toList());
    }

    private Consumer<ByteBuf> getDataListener(int channelIdx) {
        return byteBuf -> {
            Queue<Buffer> unpackedBuffers = null;
            try {
                unpackedBuffers = BufferPacker.unpack(byteBuf);
                while (!unpackedBuffers.isEmpty()) {
                    this.onBuffer(unpackedBuffers.poll(), channelIdx);
                }
            }
            catch (Throwable throwable) {
                Object object = this.lock;
                synchronized (object) {
                    this.cause = this.cause == null ? throwable : this.cause;
                    this.availabilityHelper.getUnavailableToResetAvailable().complete(null);
                }
                if (unpackedBuffers != null) {
                    unpackedBuffers.forEach(Buffer::recycleBuffer);
                }
                LOG.error("Failed to process the received buffer.", throwable);
            }
        };
    }

    private Consumer<Throwable> getFailureListener(ResultPartitionID rpID) {
        return throwable -> {
            Object object = this.lock;
            synchronized (object) {
                if (this.cause != null) {
                    return;
                }
                Class<PartitionUnRetryAbleException> clazz = PartitionUnRetryAbleException.class;
                this.cause = throwable.getMessage() != null && throwable.getMessage().contains(clazz.getName()) ? new PartitionNotFoundException(rpID) : throwable;
                this.availabilityHelper.getUnavailableToResetAvailable().complete(null);
            }
        };
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void onBuffer(Buffer buffer, int channelIdx) {
        Object object = this.lock;
        synchronized (object) {
            if (this.closed || this.cause != null) {
                buffer.recycleBuffer();
                throw new IllegalStateException("Input gate already closed or failed.");
            }
            boolean needRecycle = true;
            try {
                boolean wasEmpty = this.receivedBuffers.isEmpty();
                InputChannelInfo channelInfo = this.channelsInfo.get(channelIdx);
                Utils.checkState(channelInfo.getInputChannelIdx() == channelIdx, "Illegal channel index.");
                LOG.debug("ReceivedBuffers is adding buffer {} on {}", (Object)buffer, (Object)channelInfo);
                this.receivedBuffers.add(Pair.of(buffer, channelInfo));
                needRecycle = false;
                if (wasEmpty) {
                    this.availabilityHelper.getUnavailableToResetAvailable().complete(null);
                }
            }
            catch (Throwable throwable) {
                if (needRecycle) {
                    buffer.recycleBuffer();
                }
                throw throwable;
            }
        }
    }

    public void setup() throws IOException {
        long startTime = System.nanoTime();
        this.bufferPool = (BufferPool)this.bufferPoolFactory.get();
        BufferUtils.reserveNumRequiredBuffers(this.bufferPool, 16);
        this.tryRequestBuffers();
        this.availabilityHelper.getUnavailableToResetAvailable().complete(null);
        LOG.info("Set up read gate by {} ms.", (Object)TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime));
    }

    private void tryRequestBuffers() {
        Buffer buffer;
        Utils.checkState(this.bufferPool != null, "Not initialized yet.");
        ArrayList<ByteBuf> buffers = new ArrayList<ByteBuf>();
        while ((buffer = this.bufferPool.requestBuffer()) != null) {
            buffers.add(buffer.asByteBuf());
        }
        if (!buffers.isEmpty()) {
            this.transferBufferPool.addBuffers(buffers);
        }
    }

    private Buffer decompressBufferIfNeeded(Buffer buffer) throws IOException {
        if (buffer.isCompressed()) {
            try {
                Utils.checkState(this.bufferDecompressor != null, "Buffer decompressor not set.");
                Buffer buffer2 = this.bufferDecompressor.decompressToIntermediateBuffer(buffer);
                return buffer2;
            }
            catch (Throwable t) {
                throw new IOException("Decompress failure", t);
            }
            finally {
                buffer.recycleBuffer();
            }
        }
        return buffer;
    }

    public List<InputChannelInfo> getChannelsInfo() {
        return this.channelsInfo;
    }

    public List<RemoteBufferStreamReader> getBufferReaders() {
        return this.bufferReaders;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void tryOpenSomeChannels() throws IOException {
        if (this.bufferReaders.size() == this.numOpenedReaders) {
            return;
        }
        ArrayList<RemoteBufferStreamReader> readersToOpen = new ArrayList<RemoteBufferStreamReader>();
        Iterator iterator = this.lock;
        synchronized (iterator) {
            if (this.closed) {
                throw new IOException("Input gate already closed.");
            }
            LOG.debug("Try open some partition readers.");
            int numOnGoing = 0;
            for (int i = 0; i < this.bufferReaders.size(); ++i) {
                RemoteBufferStreamReader bufferStreamReader = this.bufferReaders.get(i);
                LOG.debug("Trying reader: {}, isOpened={}, numSubPartitionsHasNotConsumed={}.", new Object[]{bufferStreamReader, bufferStreamReader.isOpened(), this.numSubPartitionsNotConsumed[this.readerIndexToChannelIndex[i]]});
                if (numOnGoing >= this.numConcurrentReading) break;
                if (bufferStreamReader.isOpened() && this.numSubPartitionsNotConsumed[this.readerIndexToChannelIndex[i]] > 0) {
                    ++numOnGoing;
                    continue;
                }
                if (bufferStreamReader.isOpened()) continue;
                readersToOpen.add(bufferStreamReader);
                ++numOnGoing;
            }
        }
        for (RemoteBufferStreamReader reader : readersToOpen) {
            reader.open(0);
            ++this.numOpenedReaders;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private Pair<Buffer, InputChannelInfo> getReceived() throws IOException {
        Object object = this.lock;
        synchronized (object) {
            this.healthCheck();
            if (!this.receivedBuffers.isEmpty()) {
                return this.receivedBuffers.poll();
            }
            if (!this.allReadersEOF()) {
                this.availabilityHelper.resetUnavailable();
            }
            return null;
        }
    }

    private void healthCheck() throws IOException {
        if (this.closed) {
            throw new IOException("Input gate already closed.");
        }
        if (this.cause != null) {
            if (this.cause instanceof IOException) {
                throw (IOException)this.cause;
            }
            throw new IOException(this.cause);
        }
    }

    private boolean allReadersEOF() {
        return this.numUnconsumedSubpartitions <= 0L;
    }

    private Optional<BufferOrEvent> transformBuffer(Buffer buf, InputChannelInfo info) throws IOException {
        return Optional.of(new BufferOrEvent(this.decompressBufferIfNeeded(buf), info, !this.isFinished(), false));
    }

    private Optional<BufferOrEvent> transformEvent(Buffer buffer, InputChannelInfo channelInfo) throws IOException {
        AbstractEvent event;
        try {
            event = EventSerializer.fromBuffer((Buffer)buffer, (ClassLoader)this.getClass().getClassLoader());
        }
        catch (Throwable t) {
            throw new IOException("Deserialize failure.", t);
        }
        finally {
            buffer.recycleBuffer();
        }
        if (event.getClass() == EndOfPartitionEvent.class) {
            Utils.checkState(this.numSubPartitionsNotConsumed[channelInfo.getInputChannelIdx()] > 0, "BUG -- EndOfPartitionEvent received repeatedly.");
            int n = channelInfo.getInputChannelIdx();
            this.numSubPartitionsNotConsumed[n] = this.numSubPartitionsNotConsumed[n] - 1;
            --this.numUnconsumedSubpartitions;
            if (this.numSubPartitionsNotConsumed[channelInfo.getInputChannelIdx()] != 0) {
                LOG.debug("numSubPartitionsNotConsumed: {}", (Object)this.numSubPartitionsNotConsumed[channelInfo.getInputChannelIdx()]);
            } else {
                this.bufferReaders.get(this.channelIndexToReaderIndex[channelInfo.getInputChannelIdx()]).close();
                this.tryOpenSomeChannels();
                if (this.allReadersEOF()) {
                    this.availabilityHelper.getUnavailableToResetAvailable().complete(null);
                }
            }
        } else if (event.getClass() == EndOfData.class) {
            Utils.checkState(!this.hasReceivedEndOfData(), "Too many EndOfData event.");
            --this.pendingEndOfDataEvents;
        }
        return Optional.of(new BufferOrEvent(event, buffer.getDataType().hasPriority(), channelInfo, !this.isFinished(), buffer.getSize(), false));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public boolean isFinished() {
        Object object = this.lock;
        synchronized (object) {
            return this.allReadersEOF() && this.receivedBuffers.isEmpty();
        }
    }

    public boolean hasReceivedEndOfData() {
        return this.pendingEndOfDataEvents <= 0L;
    }

    public int getGateIndex() {
        return this.gateIndex;
    }

    public Optional<BufferOrEvent> pollNext() throws IOException {
        if (!this.initialChannelsOpened) {
            this.tryOpenSomeChannels();
            this.initialChannelsOpened = true;
        }
        Pair<Buffer, InputChannelInfo> pair = this.getReceived();
        Optional<BufferOrEvent> bufferOrEvent = Optional.empty();
        LOG.debug("pollNext called with pair null {}", (Object)(pair == null ? 1 : 0));
        while (pair != null) {
            Buffer buffer = pair.getLeft();
            InputChannelInfo channelInfo = pair.getRight();
            if (buffer.isBuffer()) {
                bufferOrEvent = this.transformBuffer(buffer, channelInfo);
            } else {
                bufferOrEvent = this.transformEvent(buffer, channelInfo);
                LOG.debug("received event: {}.", bufferOrEvent.isPresent() ? bufferOrEvent.get().getEvent().getClass().getName() : Optional.empty());
            }
            if (bufferOrEvent.isPresent()) break;
            pair = this.getReceived();
        }
        this.tryRequestBuffers();
        return bufferOrEvent;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void close() throws Exception {
        List<Buffer> buffersToRecycle;
        Throwable closeException = null;
        for (RemoteBufferStreamReader reader : this.bufferReaders) {
            try {
                reader.close();
            }
            catch (Throwable throwable) {
                closeException = closeException == null ? throwable : closeException;
                LOG.error("Failed to close shuffle read client.", throwable);
            }
        }
        Object object = this.lock;
        synchronized (object) {
            buffersToRecycle = this.receivedBuffers.stream().map(Pair::getLeft).collect(Collectors.toList());
            this.receivedBuffers.clear();
            this.closed = true;
        }
        try {
            buffersToRecycle.forEach(Buffer::recycleBuffer);
        }
        catch (Throwable throwable) {
            closeException = closeException == null ? throwable : closeException;
            LOG.error("Failed to recycle buffers.", throwable);
        }
        try {
            this.transferBufferPool.destroy();
        }
        catch (Throwable throwable) {
            closeException = closeException == null ? throwable : closeException;
            LOG.error("Failed to close transfer buffer pool.", throwable);
        }
        try {
            if (this.bufferPool != null) {
                this.bufferPool.lazyDestroy();
            }
        }
        catch (Throwable throwable) {
            closeException = closeException == null ? throwable : closeException;
            LOG.error("Failed to close local buffer pool.", throwable);
        }
        if (closeException != null) {
            ExceptionUtils.rethrowException((Throwable)closeException);
        }
    }

    public String getTaskName() {
        return this.taskName;
    }

    public InputGateDeploymentDescriptor getGateDescriptor() {
        return this.gateDescriptor;
    }

    public long getPendingEndOfDataEvents() {
        return this.pendingEndOfDataEvents;
    }
}

