/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hertzbeat.collector.collect.common.ssh;

import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.github.benmanes.caffeine.cache.RemovalCause;
import com.github.benmanes.caffeine.cache.Scheduler;
import java.io.IOException;
import java.net.ServerSocket;
import java.security.GeneralSecurityException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Predicate;
import org.apache.hertzbeat.collector.collect.common.ssh.SshHelper;
import org.apache.hertzbeat.common.entity.job.SshTunnel;
import org.apache.sshd.client.session.ClientSession;
import org.apache.sshd.client.session.forward.ExplicitPortForwardingTracker;
import org.apache.sshd.common.util.net.SshdSocketAddress;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

public class SshTunnelHelper {
    private static final Logger log = LoggerFactory.getLogger(SshTunnelHelper.class);
    private static final long DEFAULT_CACHE_TIMEOUT = 500000L;
    private static final Cache<SshClientSessionWrapper, LocalPortForwardingWrapper> TRACKER_CACHE = Caffeine.newBuilder().initialCapacity(1).maximumSize(1000L).expireAfterAccess(Duration.ofMillis(500000L)).scheduler(Scheduler.systemScheduler()).removalListener((key, value, cause) -> {
        if (cause == RemovalCause.REPLACED) {
            return;
        }
        if (key != null && value != null) {
            SshClientSessionWrapper clientSessionWrapper = (SshClientSessionWrapper)key;
            LocalPortForwardingWrapper wrapper = (LocalPortForwardingWrapper)value;
            wrapper.remove(clientSessionWrapper.getClientSession());
            if (!clientSessionWrapper.isShareConnection()) {
                try {
                    clientSessionWrapper.close();
                    log.info("[SSH Tunnel] close unshared ssh connection, {}", (Object)clientSessionWrapper);
                }
                catch (IOException e) {
                    log.error("[SSH Tunnel] close unshared ssh connection error", (Throwable)e);
                }
            }
        }
    }).build();

    public static void checkTunnelParam(SshTunnel sshTunnel) {
        if (sshTunnel == null || !Boolean.parseBoolean(sshTunnel.getEnable())) {
            return;
        }
        if (!StringUtils.hasText((String)sshTunnel.getHost())) {
            throw new IllegalArgumentException("ssh tunnel must has ssh host param");
        }
        if (!StringUtils.hasText((String)sshTunnel.getPort())) {
            throw new IllegalArgumentException("ssh tunnel must has ssh port param");
        }
        if (!StringUtils.hasText((String)sshTunnel.getUsername())) {
            throw new IllegalArgumentException("ssh tunnel must has ssh username param");
        }
    }

    public static int localPortForward(SshTunnel sshTunnel, String remoteHost, String remotePort) throws GeneralSecurityException, IOException {
        int localPort;
        boolean shareConnection = Boolean.parseBoolean(sshTunnel.getShareConnection());
        ClientSession session = SshHelper.getConnectSession(sshTunnel.getHost(), sshTunnel.getPort(), sshTunnel.getUsername(), sshTunnel.getPassword(), sshTunnel.getPrivateKey(), sshTunnel.getPrivateKeyPassphrase(), Integer.parseInt(sshTunnel.getTimeout()), shareConnection);
        SshClientSessionWrapper sessionWrapper = new SshClientSessionWrapper(session, shareConnection);
        LocalPortForwardingWrapper forwardingWrapper = SshTunnelHelper.selectWrapper((LocalPortForwardingWrapper)TRACKER_CACHE.getIfPresent((Object)sessionWrapper), sessionWrapper, remoteHost, remotePort);
        if (forwardingWrapper == null) {
            localPort = SshTunnelHelper.getRandomPort();
            LocalPortForwardingWrapper newForwardingWrapper = sessionWrapper.createLocalPortForwardingTracker(localPort, remoteHost, Integer.parseInt(remotePort));
            if (TRACKER_CACHE.getIfPresent((Object)sessionWrapper) == null) {
                TRACKER_CACHE.put((Object)sessionWrapper, (Object)newForwardingWrapper);
            }
            log.info("[SSH Tunnel] created ssh forwarding tracker ssh:{}, remote:{}, localPort:{}", new Object[]{sshTunnel.getHost() + ":" + sshTunnel.getPort(), remoteHost + ":" + remotePort, localPort});
        } else {
            localPort = forwardingWrapper.getTracker().getLocalAddress().getPort();
        }
        return localPort;
    }

    private static LocalPortForwardingWrapper selectWrapper(LocalPortForwardingWrapper wrapper, SshClientSessionWrapper sessionWrapper, String remoteHost, String remotePort) {
        if (wrapper == null) {
            return null;
        }
        List<LocalPortForwardingWrapper> selectList = wrapper.select(sessionWrapper.getClientSession(), localPortForwardWrapper -> {
            if (!localPortForwardWrapper.isOpen()) {
                return false;
            }
            ExplicitPortForwardingTracker tracker = localPortForwardWrapper.getTracker();
            SshdSocketAddress remoteAddress = tracker.getRemoteAddress();
            return Objects.equals(remoteAddress.getHostName(), remoteHost) && Objects.equals(remoteAddress.getPort(), Integer.parseInt(remotePort));
        });
        if (selectList.isEmpty()) {
            return null;
        }
        LocalPortForwardingWrapper selected = selectList.size() == 1 ? selectList.get(0) : selectList.stream().min(Comparator.comparing(LocalPortForwardingWrapper::getLastAccessTime)).get();
        selected.setLastAccessTime(System.currentTimeMillis());
        return selected;
    }

    private static int getRandomPort() throws IOException {
        try (ServerSocket serverSocket = new ServerSocket(0);){
            int n = serverSocket.getLocalPort();
            return n;
        }
    }

    private static class SshClientSessionWrapper {
        private ClientSession clientSession;
        private boolean shareConnection;

        public SshClientSessionWrapper(ClientSession clientSession, boolean shareConnection) {
            this.clientSession = clientSession;
            this.shareConnection = shareConnection;
        }

        public LocalPortForwardingWrapper createLocalPortForwardingTracker(Integer localPort, String remoteHost, Integer remotePort) throws IOException {
            SshdSocketAddress remoteAddress = new SshdSocketAddress(remoteHost, remotePort.intValue());
            SshdSocketAddress localAddress = new SshdSocketAddress("localhost", localPort.intValue());
            ExplicitPortForwardingTracker tracker = this.clientSession.createLocalPortForwardingTracker(localAddress, remoteAddress);
            return new LocalPortForwardingWrapper(tracker);
        }

        public void close() throws IOException {
            this.clientSession.close();
        }

        public String toString() {
            return "{ ssh:%s, shareConnection:%b }".formatted(this.clientSession, this.shareConnection);
        }

        public ClientSession getClientSession() {
            return this.clientSession;
        }

        public boolean isShareConnection() {
            return this.shareConnection;
        }

        public void setClientSession(ClientSession clientSession) {
            this.clientSession = clientSession;
        }

        public void setShareConnection(boolean shareConnection) {
            this.shareConnection = shareConnection;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof SshClientSessionWrapper)) {
                return false;
            }
            SshClientSessionWrapper other = (SshClientSessionWrapper)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.isShareConnection() != other.isShareConnection()) {
                return false;
            }
            ClientSession this$clientSession = this.getClientSession();
            ClientSession other$clientSession = other.getClientSession();
            return !(this$clientSession == null ? other$clientSession != null : !this$clientSession.equals(other$clientSession));
        }

        protected boolean canEqual(Object other) {
            return other instanceof SshClientSessionWrapper;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + (this.isShareConnection() ? 79 : 97);
            ClientSession $clientSession = this.getClientSession();
            result = result * 59 + ($clientSession == null ? 43 : $clientSession.hashCode());
            return result;
        }
    }

    private static class LocalPortForwardingWrapper {
        private static Map<ClientSession, List<LocalPortForwardingWrapper>> map = new ConcurrentHashMap<ClientSession, List<LocalPortForwardingWrapper>>();
        private ExplicitPortForwardingTracker tracker;
        private Long lastAccessTime;

        public LocalPortForwardingWrapper(ExplicitPortForwardingTracker tracker) {
            this.tracker = tracker;
            this.lastAccessTime = System.currentTimeMillis();
            map.computeIfAbsent(tracker.getClientSession(), key -> new ArrayList()).add(this);
        }

        public List<LocalPortForwardingWrapper> select(ClientSession session, Predicate<LocalPortForwardingWrapper> predicate) {
            List<LocalPortForwardingWrapper> trackerList = map.get(session);
            if (CollectionUtils.isEmpty(trackerList)) {
                return trackerList;
            }
            ArrayList<LocalPortForwardingWrapper> list = new ArrayList<LocalPortForwardingWrapper>();
            long currentTimeMillis = System.currentTimeMillis();
            Iterator<LocalPortForwardingWrapper> iterator = trackerList.iterator();
            while (iterator.hasNext()) {
                LocalPortForwardingWrapper wrapper = iterator.next();
                if (currentTimeMillis - wrapper.getLastAccessTime() > 500000L) {
                    try {
                        wrapper.getTracker().close();
                        iterator.remove();
                        log.info("[SSH Tunnel] Lazy Remove ssh local port forwarding {}", (Object)wrapper);
                    }
                    catch (IOException e) {
                        log.warn("[SSH Tunnel] Lazy Remove ssh local port forwarding  Error", (Throwable)e);
                    }
                    continue;
                }
                if (predicate != null && !predicate.test(wrapper)) continue;
                list.add(wrapper);
            }
            return list;
        }

        public void remove(ClientSession session) {
            List<LocalPortForwardingWrapper> trackerList = map.get(session);
            if (CollectionUtils.isEmpty(trackerList)) {
                return;
            }
            Iterator<LocalPortForwardingWrapper> iterator = trackerList.iterator();
            while (iterator.hasNext()) {
                try {
                    LocalPortForwardingWrapper next = iterator.next();
                    next.close();
                    iterator.remove();
                    log.info("[SSH Tunnel] Remove ssh local port forwarding, {}", (Object)next);
                }
                catch (IOException e) {
                    log.error("[SSH Tunnel] Remove ssh session local port forwarding  error", (Throwable)e);
                }
            }
        }

        public void close() throws IOException {
            this.tracker.close();
        }

        public boolean isOpen() {
            return this.tracker.isOpen();
        }

        public String toString() {
            return "{ ssh:%s, remote:%s, localPort:%d }".formatted(this.tracker.getSession().getConnectAddress(), this.tracker.getRemoteAddress(), this.tracker.getLocalAddress().getPort());
        }

        public ExplicitPortForwardingTracker getTracker() {
            return this.tracker;
        }

        public Long getLastAccessTime() {
            return this.lastAccessTime;
        }

        public void setTracker(ExplicitPortForwardingTracker tracker) {
            this.tracker = tracker;
        }

        public void setLastAccessTime(Long lastAccessTime) {
            this.lastAccessTime = lastAccessTime;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof LocalPortForwardingWrapper)) {
                return false;
            }
            LocalPortForwardingWrapper other = (LocalPortForwardingWrapper)o;
            if (!other.canEqual(this)) {
                return false;
            }
            Long this$lastAccessTime = this.getLastAccessTime();
            Long other$lastAccessTime = other.getLastAccessTime();
            if (this$lastAccessTime == null ? other$lastAccessTime != null : !((Object)this$lastAccessTime).equals(other$lastAccessTime)) {
                return false;
            }
            ExplicitPortForwardingTracker this$tracker = this.getTracker();
            ExplicitPortForwardingTracker other$tracker = other.getTracker();
            return !(this$tracker == null ? other$tracker != null : !this$tracker.equals(other$tracker));
        }

        protected boolean canEqual(Object other) {
            return other instanceof LocalPortForwardingWrapper;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            Long $lastAccessTime = this.getLastAccessTime();
            result = result * 59 + ($lastAccessTime == null ? 43 : ((Object)$lastAccessTime).hashCode());
            ExplicitPortForwardingTracker $tracker = this.getTracker();
            result = result * 59 + ($tracker == null ? 43 : $tracker.hashCode());
            return result;
        }
    }
}

