/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.dag.app;

import com.google.common.collect.Maps;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URISyntaxException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.ipc.ProtocolSignature;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.ipc.Server;
import org.apache.hadoop.ipc.VersionedProtocol;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.authorize.PolicyProvider;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.LocalResource;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.tez.common.ContainerContext;
import org.apache.tez.common.ContainerTask;
import org.apache.tez.common.TezConverterUtils;
import org.apache.tez.common.TezLocalResource;
import org.apache.tez.common.TezTaskUmbilicalProtocol;
import org.apache.tez.common.TezUtils;
import org.apache.tez.common.security.JobTokenIdentifier;
import org.apache.tez.common.security.JobTokenSecretManager;
import org.apache.tez.common.security.TokenCache;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.dag.app.security.authorize.TezAMPolicyProvider;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.api.impl.TaskSpec;
import org.apache.tez.runtime.api.impl.TezHeartbeatRequest;
import org.apache.tez.runtime.api.impl.TezHeartbeatResponse;
import org.apache.tez.serviceplugins.api.ContainerEndReason;
import org.apache.tez.serviceplugins.api.TaskAttemptEndReason;
import org.apache.tez.serviceplugins.api.TaskCommunicator;
import org.apache.tez.serviceplugins.api.TaskCommunicatorContext;
import org.apache.tez.serviceplugins.api.TaskHeartbeatRequest;
import org.apache.tez.serviceplugins.api.TaskHeartbeatResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InterfaceAudience.Private
public class TezTaskCommunicatorImpl
extends TaskCommunicator {
    private static final Logger LOG = LoggerFactory.getLogger(TezTaskCommunicatorImpl.class);
    private static final ContainerTask TASK_FOR_INVALID_JVM = new ContainerTask(null, true, null, null, false);
    private final TezTaskUmbilicalProtocol taskUmbilical;
    protected final ConcurrentMap<ContainerId, ContainerInfo> registeredContainers = new ConcurrentHashMap<ContainerId, ContainerInfo>();
    protected final ConcurrentMap<TezTaskAttemptID, ContainerId> attemptToContainerMap = new ConcurrentHashMap<TezTaskAttemptID, ContainerId>();
    protected final String tokenIdentifier;
    protected final Token<JobTokenIdentifier> sessionToken;
    protected final Configuration conf;
    protected InetSocketAddress address;
    protected volatile Server server;

    public TezTaskCommunicatorImpl(TaskCommunicatorContext taskCommunicatorContext) {
        super(taskCommunicatorContext);
        this.taskUmbilical = new TezTaskUmbilicalProtocolImpl();
        this.tokenIdentifier = taskCommunicatorContext.getApplicationAttemptId().getApplicationId().toString();
        this.sessionToken = TokenCache.getSessionToken((Credentials)taskCommunicatorContext.getAMCredentials());
        try {
            this.conf = TezUtils.createConfFromUserPayload((UserPayload)this.getContext().getInitialUserPayload());
        }
        catch (IOException e) {
            throw new TezUncheckedException("Unable to parse user payload for " + TezTaskCommunicatorImpl.class.getSimpleName(), (Throwable)e);
        }
    }

    @Override
    public void start() {
        this.startRpcServer();
    }

    @Override
    public void shutdown() {
        this.stopRpcServer();
    }

    protected void startRpcServer() {
        try {
            JobTokenSecretManager jobTokenSecretManager = new JobTokenSecretManager(this.conf);
            jobTokenSecretManager.addTokenForJob(this.tokenIdentifier, this.sessionToken);
            this.server = new RPC.Builder(this.conf).setProtocol(TezTaskUmbilicalProtocol.class).setBindAddress("0.0.0.0").setPort(0).setInstance((Object)this.taskUmbilical).setNumHandlers(this.conf.getInt("tez.am.task.listener.thread-count", 30)).setPortRangeConfig("tez.am.task.am.port-range").setSecretManager((SecretManager)jobTokenSecretManager).build();
            if (this.conf.getBoolean("hadoop.security.authorization", false)) {
                this.refreshServiceAcls(this.conf, new TezAMPolicyProvider());
            }
            this.server.start();
            InetSocketAddress serverBindAddress = NetUtils.getConnectAddress((Server)this.server);
            this.address = NetUtils.createSocketAddrForHost((String)serverBindAddress.getAddress().getCanonicalHostName(), (int)serverBindAddress.getPort());
            LOG.info("Instantiated TezTaskCommunicator RPC at " + this.address);
        }
        catch (IOException e) {
            throw new TezUncheckedException((Throwable)e);
        }
    }

    protected void stopRpcServer() {
        if (this.server != null) {
            this.server.stop();
            this.server = null;
        }
    }

    protected Configuration getConf() {
        return this.conf;
    }

    private void refreshServiceAcls(Configuration configuration, PolicyProvider policyProvider) {
        this.server.refreshServiceAcl(configuration, policyProvider);
    }

    @Override
    public void registerRunningContainer(ContainerId containerId, String host, int port) {
        ContainerInfo oldInfo = this.registeredContainers.putIfAbsent(containerId, new ContainerInfo(containerId, host, port));
        if (oldInfo != null) {
            throw new TezUncheckedException("Multiple registrations for containerId: " + containerId);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void registerContainerEnd(ContainerId containerId, ContainerEndReason endReason, String diagnostics) {
        ContainerInfo containerInfo = (ContainerInfo)this.registeredContainers.remove(containerId);
        if (containerInfo != null) {
            ContainerInfo containerInfo2 = containerInfo;
            synchronized (containerInfo2) {
                if (containerInfo.taskSpec != null && containerInfo.taskSpec.getTaskAttemptID() != null) {
                    this.attemptToContainerMap.remove(containerInfo.taskSpec.getTaskAttemptID());
                }
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void registerRunningTaskAttempt(ContainerId containerId, TaskSpec taskSpec, Map<String, LocalResource> additionalResources, Credentials credentials, boolean credentialsChanged, int priority) {
        ContainerInfo containerInfo = (ContainerInfo)this.registeredContainers.get(containerId);
        Objects.requireNonNull(containerInfo, String.format("Cannot register task attempt %s to unknown container %s", taskSpec.getTaskAttemptID(), containerId));
        ContainerInfo containerInfo2 = containerInfo;
        synchronized (containerInfo2) {
            if (containerInfo.taskSpec != null) {
                throw new TezUncheckedException("Cannot register task: " + taskSpec.getTaskAttemptID() + " to container: " + containerId + " , with pre-existing assignment: " + containerInfo.taskSpec.getTaskAttemptID());
            }
            containerInfo.taskSpec = taskSpec;
            containerInfo.additionalLRs = additionalResources;
            containerInfo.credentials = credentials;
            containerInfo.credentialsChanged = credentialsChanged;
            containerInfo.taskPulled = false;
            ContainerId oldId = this.attemptToContainerMap.putIfAbsent(taskSpec.getTaskAttemptID(), containerId);
            if (oldId != null) {
                throw new TezUncheckedException("Attempting to register an already registered taskAttempt with id: " + taskSpec.getTaskAttemptID() + " to containerId: " + containerId + ". Already registered to containerId: " + oldId);
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void unregisterRunningTaskAttempt(TezTaskAttemptID taskAttemptID, TaskAttemptEndReason endReason, String diagnostics) {
        ContainerId containerId = (ContainerId)this.attemptToContainerMap.remove(taskAttemptID);
        if (containerId == null) {
            LOG.warn("Unregister task attempt: " + taskAttemptID + " from unknown container");
            return;
        }
        ContainerInfo containerInfo = (ContainerInfo)this.registeredContainers.get(containerId);
        if (containerInfo == null) {
            LOG.warn("Unregister task attempt: " + taskAttemptID + " from non-registered container: " + containerId);
            return;
        }
        ContainerInfo containerInfo2 = containerInfo;
        synchronized (containerInfo2) {
            containerInfo.reset();
            this.attemptToContainerMap.remove(taskAttemptID);
        }
    }

    @Override
    public InetSocketAddress getAddress() {
        return this.address;
    }

    @Override
    public void onVertexStateUpdated(VertexStateUpdate stateUpdate) {
    }

    @Override
    public void dagComplete(int dagIdentifier) {
    }

    @Override
    public Object getMetaInfo() {
        return this.address;
    }

    protected String getTokenIdentifier() {
        return this.tokenIdentifier;
    }

    protected Token<JobTokenIdentifier> getSessionToken() {
        return this.sessionToken;
    }

    public TezTaskUmbilicalProtocol getUmbilical() {
        return this.taskUmbilical;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private ContainerTask getContainerTask(ContainerId containerId) throws IOException {
        ContainerTask task;
        ContainerInfo containerInfo = (ContainerInfo)this.registeredContainers.get(containerId);
        if (containerInfo == null) {
            if (this.getContext().isKnownContainer(containerId)) {
                LOG.info("Container with id: " + containerId + " is valid, but no longer registered, and will be killed");
            } else {
                LOG.info("Container with id: " + containerId + " is invalid and will be killed");
            }
            task = TASK_FOR_INVALID_JVM;
        } else {
            ContainerInfo containerInfo2 = containerInfo;
            synchronized (containerInfo2) {
                this.getContext().containerAlive(containerId);
                if (containerInfo.taskSpec != null) {
                    if (!containerInfo.taskPulled) {
                        containerInfo.taskPulled = true;
                        task = this.constructContainerTask(containerInfo);
                    } else {
                        if (LOG.isDebugEnabled()) {
                            LOG.debug("Task " + containerInfo.taskSpec.getTaskAttemptID() + " already sent to container: " + containerId);
                        }
                        task = null;
                    }
                } else {
                    task = null;
                    LOG.debug("No task assigned yet for running container: {}", (Object)containerId);
                }
            }
        }
        return task;
    }

    private ContainerTask constructContainerTask(ContainerInfo containerInfo) throws IOException {
        return new ContainerTask(containerInfo.taskSpec, false, this.convertLocalResourceMap(containerInfo.additionalLRs), containerInfo.credentials, containerInfo.credentialsChanged);
    }

    private Map<String, TezLocalResource> convertLocalResourceMap(Map<String, LocalResource> ylrs) throws IOException {
        HashMap tlrs = Maps.newHashMap();
        if (ylrs != null) {
            for (Map.Entry<String, LocalResource> ylrEntry : ylrs.entrySet()) {
                TezLocalResource tlr;
                try {
                    tlr = TezConverterUtils.convertYarnLocalResourceToTez((LocalResource)ylrEntry.getValue());
                }
                catch (URISyntaxException e) {
                    throw new IOException(e);
                }
                tlrs.put(ylrEntry.getKey(), tlr);
            }
        }
        return tlrs;
    }

    protected ContainerInfo getContainerInfo(ContainerId containerId) {
        return (ContainerInfo)this.registeredContainers.get(containerId);
    }

    protected ContainerId getContainerForAttempt(TezTaskAttemptID taskAttemptId) {
        return (ContainerId)this.attemptToContainerMap.get(taskAttemptId);
    }

    @Override
    public long getTotalUsedMemory() {
        return this.registeredContainers.values().stream().mapToLong(c -> c.usedMemory).sum();
    }

    private class TezTaskUmbilicalProtocolImpl
    implements TezTaskUmbilicalProtocol {
        private TezTaskUmbilicalProtocolImpl() {
        }

        public ContainerTask getTask(ContainerContext containerContext) throws IOException {
            ContainerTask task = null;
            if (containerContext == null || containerContext.getContainerIdentifier() == null) {
                LOG.info("Invalid task request with an empty containerContext or containerId");
                task = TASK_FOR_INVALID_JVM;
            } else {
                ContainerId containerId = ConverterUtils.toContainerId((String)containerContext.getContainerIdentifier());
                LOG.debug("Container with id: {} asked for a task", (Object)containerId);
                task = TezTaskCommunicatorImpl.this.getContainerTask(containerId);
                if (task != null && !task.shouldDie()) {
                    TezTaskCommunicatorImpl.this.getContext().taskSubmitted(task.getTaskSpec().getTaskAttemptID(), containerId);
                    TezTaskCommunicatorImpl.this.getContext().taskStartedRemotely(task.getTaskSpec().getTaskAttemptID());
                }
            }
            LOG.debug("getTask returning task: {}", (Object)task);
            return task;
        }

        public boolean canCommit(TezTaskAttemptID taskAttemptId) throws IOException {
            return TezTaskCommunicatorImpl.this.getContext().canCommit(taskAttemptId);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public TezHeartbeatResponse heartbeat(TezHeartbeatRequest request) throws IOException, TezException {
            ContainerId containerId = ConverterUtils.toContainerId((String)request.getContainerIdentifier());
            long requestId = request.getRequestId();
            LOG.debug("Received heartbeat from container, request={}", (Object)request);
            ContainerInfo containerInfo = (ContainerInfo)TezTaskCommunicatorImpl.this.registeredContainers.get(containerId);
            if (containerInfo == null) {
                LOG.warn("Received task heartbeat from unknown container with id: " + containerId + ", asking it to die");
                TezHeartbeatResponse response = new TezHeartbeatResponse();
                response.setLastRequestId(requestId);
                response.setShouldDie();
                return response;
            }
            ContainerInfo response = containerInfo;
            synchronized (response) {
                if (containerInfo.lastRequestId == requestId) {
                    LOG.warn("Old sequenceId received: " + requestId + ", Re-sending last response to client");
                    return containerInfo.lastResponse;
                }
            }
            response = new TezHeartbeatResponse();
            TezTaskAttemptID taskAttemptID = request.getCurrentTaskAttemptID();
            if (taskAttemptID != null) {
                ContainerInfo containerInfo2 = containerInfo;
                synchronized (containerInfo2) {
                    ContainerId containerIdFromMap = (ContainerId)TezTaskCommunicatorImpl.this.attemptToContainerMap.get(taskAttemptID);
                    if (containerIdFromMap == null || !containerIdFromMap.equals((Object)containerId)) {
                        throw new TezException("Attempt " + taskAttemptID + " is not recognized for heartbeat");
                    }
                    if (containerInfo.lastRequestId + 1L != requestId) {
                        throw new TezException("Container " + containerId + " has invalid request id. Expected: " + containerInfo.lastRequestId + 1 + " and actual: " + requestId);
                    }
                }
                TaskHeartbeatRequest tRequest = new TaskHeartbeatRequest(request.getContainerIdentifier(), request.getCurrentTaskAttemptID(), request.getEvents(), request.getStartIndex(), request.getPreRoutedStartIndex(), request.getMaxEvents());
                TaskHeartbeatResponse tResponse = TezTaskCommunicatorImpl.this.getContext().heartbeat(tRequest);
                response.setEvents(tResponse.getEvents());
                response.setNextFromEventId(tResponse.getNextFromEventId());
                response.setNextPreRoutedEventId(tResponse.getNextPreRoutedEventId());
            }
            response.setLastRequestId(requestId);
            containerInfo.lastRequestId = requestId;
            containerInfo.lastResponse = response;
            containerInfo.usedMemory = request.getUsedMemory();
            return response;
        }

        public long getProtocolVersion(String protocol, long clientVersion) throws IOException {
            return 19L;
        }

        public ProtocolSignature getProtocolSignature(String protocol, long clientVersion, int clientMethodsHash) throws IOException {
            return ProtocolSignature.getProtocolSignature((VersionedProtocol)this, (String)protocol, (long)clientVersion, (int)clientMethodsHash);
        }
    }

    public static final class ContainerInfo {
        final ContainerId containerId;
        public final String host;
        public final int port;
        TezHeartbeatResponse lastResponse = null;
        TaskSpec taskSpec = null;
        long lastRequestId = 0L;
        Map<String, LocalResource> additionalLRs = null;
        Credentials credentials = null;
        boolean credentialsChanged = false;
        boolean taskPulled = false;
        long usedMemory = 0L;

        ContainerInfo(ContainerId containerId, String host, int port) {
            this.containerId = containerId;
            this.host = host;
            this.port = port;
        }

        void reset() {
            this.taskSpec = null;
            this.additionalLRs = null;
            this.credentials = null;
            this.credentialsChanged = false;
            this.taskPulled = false;
        }
    }
}

