/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.plugin.flink;

import java.io.IOException;
import java.util.Optional;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.DriverChangedException;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.plugin.flink.RemoteShuffleDescriptor;
import org.apache.celeborn.plugin.flink.buffer.BufferHeader;
import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
import org.apache.celeborn.plugin.flink.client.FlinkShuffleClientImpl;
import org.apache.celeborn.plugin.flink.utils.BufferUtils;
import org.apache.celeborn.plugin.flink.utils.Utils;
import org.apache.celeborn.shaded.io.netty.buffer.Unpooled;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.metrics.groups.ShuffleIOMetricGroup;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.flink.util.function.SupplierWithException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RemoteShuffleOutputGate {
    private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleOutputGate.class);
    private final RemoteShuffleDescriptor shuffleDesc;
    protected final int numSubs;
    protected FlinkShuffleClientImpl flinkShuffleClient;
    protected final SupplierWithException<BufferPool, IOException> bufferPoolFactory;
    protected BufferPool bufferPool;
    private CelebornConf celebornConf;
    private final int numMappers;
    private PartitionLocation partitionLocation;
    private int currentRegionIndex = 0;
    private int bufferSize;
    private BufferPacker bufferPacker;
    private String applicationId;
    private int shuffleId;
    private int mapId;
    private int attemptId;
    private int partitionId;
    private String lifecycleManagerHost;
    private int lifecycleManagerPort;
    private long lifecycleManagerTimestamp;
    private UserIdentifier userIdentifier;
    private boolean isRegisterShuffle = false;
    private int maxReviveTimes;
    private boolean hasSentHandshake = false;
    protected final ShuffleIOMetricGroup shuffleIOMetricGroup;

    public RemoteShuffleOutputGate(RemoteShuffleDescriptor shuffleDesc, int numSubs, int bufferSize, SupplierWithException<BufferPool, IOException> bufferPoolFactory, CelebornConf celebornConf, int numMappers, ShuffleIOMetricGroup shuffleIOMetricGroup) {
        this(shuffleDesc, numSubs, bufferSize, bufferPoolFactory, celebornConf, numMappers, shuffleIOMetricGroup, null);
    }

    public RemoteShuffleOutputGate(RemoteShuffleDescriptor shuffleDesc, int numSubs, int bufferSize, SupplierWithException<BufferPool, IOException> bufferPoolFactory, CelebornConf celebornConf, int numMappers, ShuffleIOMetricGroup shuffleIOMetricGroup, FlinkShuffleClientImpl flinkShuffleClient) {
        this.shuffleDesc = shuffleDesc;
        this.numSubs = numSubs;
        this.bufferPoolFactory = bufferPoolFactory;
        this.bufferPacker = new BufferPacker(this::write);
        this.celebornConf = celebornConf;
        this.numMappers = numMappers;
        this.bufferSize = bufferSize;
        this.applicationId = shuffleDesc.getCelebornAppId();
        this.shuffleId = shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getShuffleId();
        this.mapId = shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getMapId();
        this.attemptId = shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getAttemptId();
        this.partitionId = shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getPartitionId();
        this.lifecycleManagerHost = shuffleDesc.getShuffleResource().getLifecycleManagerHost();
        this.lifecycleManagerPort = shuffleDesc.getShuffleResource().getLifecycleManagerPort();
        this.lifecycleManagerTimestamp = shuffleDesc.getShuffleResource().getLifecycleManagerTimestamp();
        this.flinkShuffleClient = flinkShuffleClient == null ? this.getShuffleClient() : flinkShuffleClient;
        this.maxReviveTimes = celebornConf.clientPushMaxReviveTimes();
        this.shuffleIOMetricGroup = shuffleIOMetricGroup;
    }

    public void setup() throws IOException, InterruptedException {
        this.bufferPool = Utils.checkNotNull((BufferPool)this.bufferPoolFactory.get());
        Utils.checkArgument(this.bufferPool.getNumberOfRequiredMemorySegments() >= 2, "Too few buffers for transfer, the minimum valid required size is 2.");
        BufferUtils.reserveNumRequiredBuffers(this.bufferPool, 1);
    }

    public BufferPool getBufferPool() {
        return this.bufferPool;
    }

    public void write(Buffer buffer, int subIdx) throws InterruptedException {
        this.bufferPacker.process(buffer, subIdx);
    }

    public void regionStart(boolean isBroadcast) {
        try {
            this.registerShuffle();
            this.handshake();
            this.regionStartWithRevive(isBroadcast);
        }
        catch (IOException e) {
            Utils.rethrowAsRuntimeException(e);
        }
    }

    public void regionFinish() throws InterruptedException {
        this.bufferPacker.drain();
        try {
            this.flinkShuffleClient.regionFinish(this.shuffleId, this.mapId, this.attemptId, this.partitionLocation);
            ++this.currentRegionIndex;
        }
        catch (IOException e) {
            Utils.rethrowAsRuntimeException(e);
        }
    }

    public void finish() throws InterruptedException, IOException {
        this.flinkShuffleClient.mapPartitionMapperEnd(this.shuffleId, this.mapId, this.attemptId, this.numMappers, this.partitionLocation.getId());
    }

    public void close() throws IOException {
        if (this.bufferPool != null) {
            this.bufferPool.lazyDestroy();
        }
        this.bufferPacker.close();
        this.flinkShuffleClient.cleanup(this.shuffleId, this.mapId, this.attemptId);
    }

    public RemoteShuffleDescriptor getShuffleDesc() {
        return this.shuffleDesc;
    }

    @VisibleForTesting
    FlinkShuffleClientImpl getShuffleClient() {
        try {
            return FlinkShuffleClientImpl.get(this.applicationId, this.lifecycleManagerHost, this.lifecycleManagerPort, this.lifecycleManagerTimestamp, this.celebornConf, this.userIdentifier);
        }
        catch (DriverChangedException e) {
            throw new RuntimeException(e.getMessage());
        }
    }

    public void write(ByteBuf byteBuf, BufferHeader bufferHeader) {
        try {
            int bytesWritten = this.flinkShuffleClient.pushDataToLocation(this.shuffleId, this.mapId, this.attemptId, bufferHeader.getSubPartitionId(), Unpooled.wrappedBuffer(byteBuf.nioBuffer()), this.partitionLocation, () -> ((ByteBuf)byteBuf).release());
            this.shuffleIOMetricGroup.getNumBytesOut().inc((long)bytesWritten);
        }
        catch (IOException e) {
            Utils.rethrowAsRuntimeException(e);
        }
    }

    public void registerShuffle() throws IOException {
        if (!this.isRegisterShuffle) {
            this.partitionLocation = this.flinkShuffleClient.registerMapPartitionTask(this.shuffleId, this.numMappers, this.mapId, this.attemptId, this.partitionId);
            Utils.checkNotNull(this.partitionLocation);
            this.currentRegionIndex = 0;
            this.isRegisterShuffle = true;
        }
    }

    public void regionStartWithRevive(boolean isBroadcast) {
        try {
            int remainingReviveTimes = this.maxReviveTimes;
            boolean hasSentRegionStart = false;
            while (remainingReviveTimes-- > 0 && !hasSentRegionStart) {
                Optional<PartitionLocation> revivePartition = this.flinkShuffleClient.regionStart(this.shuffleId, this.mapId, this.attemptId, this.partitionLocation, this.currentRegionIndex, isBroadcast);
                if (revivePartition.isPresent()) {
                    LOG.info("Revive at regionStart, currentTimes:{}, totalTimes:{} for shuffleId:{}, mapId:{}, attempId:{}, currentRegionIndex:{}, isBroadcast:{}, newPartition:{}, oldPartition:{}", new Object[]{remainingReviveTimes, this.maxReviveTimes, this.shuffleId, this.mapId, this.attemptId, this.currentRegionIndex, isBroadcast, revivePartition, this.partitionLocation});
                    this.partitionLocation = revivePartition.get();
                    hasSentRegionStart = false;
                    this.hasSentHandshake = false;
                    this.handshake();
                    continue;
                }
                hasSentRegionStart = true;
            }
            if (remainingReviveTimes == 0 && !hasSentRegionStart) {
                throw new RuntimeException("After retry " + this.maxReviveTimes + " times, still failed to send regionStart");
            }
        }
        catch (IOException e) {
            Utils.rethrowAsRuntimeException(e);
        }
    }

    public void handshake() {
        try {
            int remainingReviveTimes = this.maxReviveTimes;
            while (remainingReviveTimes-- > 0 && !this.hasSentHandshake) {
                Optional<PartitionLocation> revivePartition = this.flinkShuffleClient.pushDataHandShake(this.shuffleId, this.mapId, this.attemptId, this.numSubs, this.bufferSize, this.partitionLocation);
                if (revivePartition.isPresent() && remainingReviveTimes > 0) {
                    LOG.info("Revive at handshake, currentTimes:{}, totalTimes:{} for shuffleId:{}, mapId:{}, attempId:{}, currentRegionIndex:{}, newPartition:{}, oldPartition:{}", new Object[]{remainingReviveTimes, this.maxReviveTimes, this.shuffleId, this.mapId, this.attemptId, this.currentRegionIndex, revivePartition, this.partitionLocation});
                    this.partitionLocation = revivePartition.get();
                    this.hasSentHandshake = false;
                    continue;
                }
                this.hasSentHandshake = true;
            }
            if (remainingReviveTimes == 0 && !this.hasSentHandshake) {
                throw new RuntimeException("After retry " + this.maxReviveTimes + " times, still failed to send handshake");
            }
        }
        catch (IOException e) {
            Utils.rethrowAsRuntimeException(e);
        }
    }
}

