diff --git a/collector/src/main/java/org/apache/hertzbeat/collector/collect/nebulagraph/NebulaTemplate.java b/collector/src/main/java/org/apache/hertzbeat/collector/collect/nebulagraph/NebulaTemplate.java index ba706c7e312..fb7cfcc79b7 100644 --- a/collector/src/main/java/org/apache/hertzbeat/collector/collect/nebulagraph/NebulaTemplate.java +++ b/collector/src/main/java/org/apache/hertzbeat/collector/collect/nebulagraph/NebulaTemplate.java @@ -30,7 +30,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import lombok.Getter; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; @@ -42,27 +41,9 @@ @Slf4j public class NebulaTemplate { - private final String userName; - private final String password; - private final HostAddress hostAddress; - private final String spaceName; + private String spaceName; private Session session; NebulaPool pool; - @Getter - private boolean initSuccess; - - private final Integer timeout; - - public NebulaTemplate(NgqlProtocol protocol) { - this.userName = protocol.getUsername(); - this.password = protocol.getPassword(); - this.hostAddress = new HostAddress(protocol.getHost(), Integer.parseInt(protocol.getPort())); - this.spaceName = protocol.getSpaceName(); - this.timeout = Integer.valueOf(protocol.getTimeout()); - initSession(); - - } - public void closeSessionAndPool() { if (session != null && session.ping()) { @@ -74,22 +55,22 @@ public void closeSessionAndPool() { } @SneakyThrows - private void initSession() { + public boolean initSession(NgqlProtocol protocol) { + HostAddress hostAddress = new HostAddress(protocol.getHost(), Integer.parseInt(protocol.getPort())); + this.spaceName = protocol.getSpaceName(); pool = new NebulaPool(); - try { - NebulaPoolConfig nebulaPoolConfig = new NebulaPoolConfig(); - nebulaPoolConfig.setMaxConnSize(100); - nebulaPoolConfig.setTimeout(timeout); - boolean initResult = pool - .init(Collections.singletonList(hostAddress), nebulaPoolConfig); - initSuccess = initResult; - if (!initResult) { - log.error("pool init failed."); - } - session = pool.getSession(userName, password, false); - } catch (Exception e) { - log.error("初始化失败"); - } + + NebulaPoolConfig nebulaPoolConfig = new NebulaPoolConfig(); + nebulaPoolConfig.setMaxConnSize(100); + nebulaPoolConfig.setTimeout(Integer.parseInt(protocol.getTimeout())); + boolean initResult = pool + .init(Collections.singletonList(hostAddress), nebulaPoolConfig); + if (!initResult) { + log.error("pool init failed."); + return false; + } + session = pool.getSession(protocol.getUsername(), protocol.getPassword(), false); + return true; } private ResultSet execute(String ngql) { diff --git a/collector/src/main/java/org/apache/hertzbeat/collector/collect/nebulagraph/NgqlCollectImpl.java b/collector/src/main/java/org/apache/hertzbeat/collector/collect/nebulagraph/NgqlCollectImpl.java index 586193fdd78..9c3e6a27d62 100644 --- a/collector/src/main/java/org/apache/hertzbeat/collector/collect/nebulagraph/NgqlCollectImpl.java +++ b/collector/src/main/java/org/apache/hertzbeat/collector/collect/nebulagraph/NgqlCollectImpl.java @@ -63,12 +63,20 @@ public void collect(Builder builder, long monitorId, String app, Metrics metrics NgqlProtocol ngql = metrics.getNgql(); StopWatch stopWatch = new StopWatch(); stopWatch.start(); - NebulaTemplate nebulaTemplate = new NebulaTemplate(metrics.getNgql()); - if (!nebulaTemplate.isInitSuccess()) { + NebulaTemplate nebulaTemplate = new NebulaTemplate(); + try { + boolean initSuccess = nebulaTemplate.initSession(ngql); + if (!initSuccess) { + builder.setCode(CollectRep.Code.FAIL); + builder.setMsg("Failed to connect Nebula Graph"); + return; + } + } catch (Exception e) { builder.setCode(CollectRep.Code.FAIL); - builder.setMsg("Failed to connect Nebula Graph"); + builder.setMsg(e.getMessage()); return; } + stopWatch.stop(); long responseTime = stopWatch.getTotalTimeMillis(); try { @@ -77,7 +85,8 @@ public void collect(Builder builder, long monitorId, String app, Metrics metrics case PARSE_TYPE_ONE_ROW -> queryOneRow(nebulaTemplate, ngql, metrics.getAliasFields(), builder, responseTime); case PARSE_TYPE_MULTI_ROW -> queryMultiRow(nebulaTemplate, ngql.getCommands(), metrics.getAliasFields(), builder, responseTime); case PARSE_TYPE_COLUMNS -> queryColumns(nebulaTemplate, ngql.getCommands(), metrics.getAliasFields(), builder, responseTime); - default -> {} + default -> { + } } } finally { nebulaTemplate.closeSessionAndPool(); diff --git a/collector/src/test/java/org/apache/hertzbeat/collector/collect/nebulagraph/NgqlCollectImplTest.java b/collector/src/test/java/org/apache/hertzbeat/collector/collect/nebulagraph/NgqlCollectImplTest.java new file mode 100644 index 00000000000..cb808896f8d --- /dev/null +++ b/collector/src/test/java/org/apache/hertzbeat/collector/collect/nebulagraph/NgqlCollectImplTest.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hertzbeat.collector.collect.nebulagraph; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.apache.hertzbeat.common.entity.job.Metrics; +import org.apache.hertzbeat.common.entity.job.protocol.NgqlProtocol; +import org.apache.hertzbeat.common.entity.message.CollectRep; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.MockedConstruction; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; + +/** + * Test case for {@link NgqlCollectImpl} + */ +@ExtendWith(MockitoExtension.class) +class NgqlCollectImplTest { + + @InjectMocks + private NgqlCollectImpl ngqlCollect; + + private NgqlProtocol ngqlProtocol; + + @BeforeEach + public void init() { + ngqlProtocol = NgqlProtocol.builder() + .host("127.0.0.1") + .port("9669") + .password("123456") + .username("root") + .timeout("60000").build(); + } + + @Test + void testOneRowCollect() { + String ngql = "SHOW COLLATION;"; + String charset = "utf8"; + String collation = "utf8_bin"; + CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder(); + ngqlProtocol.setCommands(Collections.singletonList(ngql)); + ngqlProtocol.setParseType("oneRow"); + List aliasField = Arrays.asList("Collation", "Charset"); + + List> result = new ArrayList<>(); + Map data = new HashMap<>(); + data.put("Collation", "utf8_bin"); + data.put("Charset", "utf8"); + result.add(data); + + MockedConstruction mocked = + Mockito.mockConstruction(NebulaTemplate.class, (template, context) -> { + Mockito.doNothing().when(template).closeSessionAndPool(); + Mockito.when(template.initSession(ngqlProtocol)).thenReturn(true); + Mockito.when(template.executeCommand(ngql)).thenReturn(result); + }); + + Metrics metrics = new Metrics(); + metrics.setNgql(ngqlProtocol); + metrics.setAliasFields(aliasField); + ngqlCollect.collect(builder, 1L, "test", metrics); + assertEquals(builder.getValuesCount(), 1); + assertEquals(builder.getValues(0).getColumns(0), collation); + assertEquals(builder.getValues(0).getColumns(1), charset); + mocked.close(); + } + + @Test + void testFilterCountCollect() { + String command = "offline#SHOW HOSTS#Status#OFFLINE"; + CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder(); + ngqlProtocol.setCommands(Collections.singletonList(command)); + ngqlProtocol.setParseType("filterCount"); + List aliasField = Collections.singletonList("offline"); + + List> result = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + Map data = new HashMap<>(); + data.put("Host", "graph" + 0); + data.put("Port", "9669"); + data.put("Status", i == 0 ? "OFFLINE" : "ONLINE"); + result.add(data); + } + MockedConstruction mocked = + Mockito.mockConstruction(NebulaTemplate.class, (template, context) -> { + Mockito.doNothing().when(template).closeSessionAndPool(); + Mockito.when(template.initSession(ngqlProtocol)).thenReturn(true); + Mockito.when(template.executeCommand("SHOW HOSTS")).thenReturn(result); + }); + + Metrics metrics = new Metrics(); + metrics.setNgql(ngqlProtocol); + metrics.setAliasFields(aliasField); + ngqlCollect.collect(builder, 1L, "test", metrics); + assertEquals(1, builder.getValuesCount()); + assertEquals("1", builder.getValues(0).getColumns(0)); + mocked.close(); + } + + @Test + void testMultiRowCollect() { + String command = "SHOW HOSTS"; + CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder(); + ngqlProtocol.setCommands(Collections.singletonList(command)); + ngqlProtocol.setParseType("multiRow"); + List aliasField = Arrays.asList("Host", "Port", "Status"); + + List> result = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + Map data = new LinkedHashMap<>(); + data.put("Host", "graph" + i); + data.put("Port", "9669"); + data.put("Status", i == 0 ? "OFFLINE" : "ONLINE"); + result.add(data); + } + MockedConstruction mocked = + Mockito.mockConstruction(NebulaTemplate.class, (template, context) -> { + Mockito.doNothing().when(template).closeSessionAndPool(); + Mockito.when(template.initSession(ngqlProtocol)).thenReturn(true); + Mockito.when(template.executeCommand(command)).thenReturn(result); + }); + + Metrics metrics = new Metrics(); + metrics.setNgql(ngqlProtocol); + metrics.setAliasFields(aliasField); + ngqlCollect.collect(builder, 1L, "test", metrics); + assertEquals(3, builder.getValuesCount()); + for (int i = 0; i < result.size(); i++) { + List> list = new ArrayList<>(result.get(i).entrySet()); + for (int j = 0; j < list.size(); j++) { + assertEquals(list.get(j).getValue().toString(), builder.getValues(i).getColumns(j)); + } + } + mocked.close(); + } + + @Test + void testColumnsCollect() { + String command = "SHOW HOSTS"; + CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder(); + ngqlProtocol.setCommands(Collections.singletonList(command)); + ngqlProtocol.setParseType("columns"); + List aliasField = Arrays.asList("graph0", "graph1", "graph2"); + + List> result = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + Map data = new LinkedHashMap<>(); + data.put("Host", "graph" + i); + data.put("Port", "9669" + i); + data.put("Status", i == 0 ? "OFFLINE" : "ONLINE"); + result.add(data); + } + MockedConstruction mocked = + Mockito.mockConstruction(NebulaTemplate.class, (template, context) -> { + Mockito.doNothing().when(template).closeSessionAndPool(); + Mockito.when(template.initSession(ngqlProtocol)).thenReturn(true); + Mockito.when(template.executeCommand(command)).thenReturn(result); + }); + + Metrics metrics = new Metrics(); + metrics.setNgql(ngqlProtocol); + metrics.setAliasFields(aliasField); + ngqlCollect.collect(builder, 1L, "test", metrics); + assertEquals(1, builder.getValuesCount()); + for (int i = 0; i < 3; i++) { + assertEquals("9669" + i, builder.getValues(0).getColumns(i)); + } + mocked.close(); + } + +}