/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.shardingsphere.proxy.backend.handler.distsql.ral.updatable.label;

import org.apache.shardingsphere.distsql.handler.engine.update.DistSQLUpdateExecutor;
import org.apache.shardingsphere.distsql.statement.type.ral.updatable.LabelComputeNodeStatement;
import org.apache.shardingsphere.infra.instance.ClusterInstanceRegistry;
import org.apache.shardingsphere.infra.instance.ComputeNodeInstance;
import org.apache.shardingsphere.infra.instance.ComputeNodeInstanceContext;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.mode.manager.ContextManager;
import org.apache.shardingsphere.mode.manager.cluster.persist.facade.ClusterPersistServiceFacade;
import org.apache.shardingsphere.mode.manager.cluster.persist.service.ClusterComputeNodePersistService;
import org.apache.shardingsphere.mode.persist.PersistServiceFacade;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.Collections;
import java.util.Optional;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

class LabelComputeNodeExecutorTest {
    
    private final LabelComputeNodeExecutor executor = (LabelComputeNodeExecutor) TypedSPILoader.getService(DistSQLUpdateExecutor.class, LabelComputeNodeStatement.class);
    
    @Test
    void assertExecuteUpdateWhenInstanceAbsent() {
        ContextManager contextManager = mock(ContextManager.class);
        ComputeNodeInstanceContext instanceContext = mock(ComputeNodeInstanceContext.class);
        ClusterInstanceRegistry clusterInstanceRegistry = mock(ClusterInstanceRegistry.class);
        when(contextManager.getComputeNodeInstanceContext()).thenReturn(instanceContext);
        when(instanceContext.getClusterInstanceRegistry()).thenReturn(clusterInstanceRegistry);
        when(clusterInstanceRegistry.find("instance-id")).thenReturn(Optional.empty());
        assertDoesNotThrow(() -> executor.executeUpdate(new LabelComputeNodeStatement(true, "instance-id", Collections.singletonList("label_a")), contextManager));
    }
    
    @Test
    void assertExecuteUpdateWhenOverwriteLabels() {
        ContextManager contextManager = mock(ContextManager.class);
        ComputeNodeInstance computeNodeInstance = mock(ComputeNodeInstance.class);
        ClusterComputeNodePersistService computeNodeService = mockContextManager(contextManager, computeNodeInstance);
        ComputeNodeInstanceContext instanceContext = contextManager.getComputeNodeInstanceContext();
        when(instanceContext.getClusterInstanceRegistry().find("instance-id")).thenReturn(Optional.of(computeNodeInstance));
        executor.executeUpdate(new LabelComputeNodeStatement(true, "instance-id", Arrays.asList("label_a", "label_b")), contextManager);
        verify(computeNodeService).persistLabels("instance-id", Arrays.asList("label_a", "label_b"));
    }
    
    @Test
    void assertExecuteUpdateWhenNotOverwriteLabels() {
        ContextManager contextManager = mock(ContextManager.class);
        ComputeNodeInstance computeNodeInstance = mock(ComputeNodeInstance.class);
        when(computeNodeInstance.getLabels()).thenReturn(Collections.singletonList("origin_label"));
        ClusterComputeNodePersistService computeNodeService = mockContextManager(contextManager, computeNodeInstance);
        ComputeNodeInstanceContext instanceContext = contextManager.getComputeNodeInstanceContext();
        when(instanceContext.getClusterInstanceRegistry().find("instance-id")).thenReturn(Optional.of(computeNodeInstance));
        executor.executeUpdate(new LabelComputeNodeStatement(false, "instance-id", Collections.singletonList("new_label")), contextManager);
        verify(computeNodeService).persistLabels("instance-id", Arrays.asList("new_label", "origin_label"));
    }
    
    private ClusterComputeNodePersistService mockContextManager(final ContextManager contextManager, final ComputeNodeInstance computeNodeInstance) {
        ComputeNodeInstanceContext instanceContext = mock(ComputeNodeInstanceContext.class, RETURNS_DEEP_STUBS);
        when(instanceContext.getClusterInstanceRegistry().find("instance-id")).thenReturn(Optional.of(computeNodeInstance));
        when(contextManager.getComputeNodeInstanceContext()).thenReturn(instanceContext);
        PersistServiceFacade persistServiceFacade = mock(PersistServiceFacade.class);
        ClusterPersistServiceFacade clusterPersistServiceFacade = mock(ClusterPersistServiceFacade.class);
        ClusterComputeNodePersistService result = mock(ClusterComputeNodePersistService.class);
        when(clusterPersistServiceFacade.getComputeNodeService()).thenReturn(result);
        when(persistServiceFacade.getModeFacade()).thenReturn(clusterPersistServiceFacade);
        when(contextManager.getPersistServiceFacade()).thenReturn(persistServiceFacade);
        return result;
    }
}
