From 36bb73c57d2e7fab592db4b2159d4db9f1aa4118 Mon Sep 17 00:00:00 2001 From: AnishChandurkar Date: Thu, 18 Jun 2026 15:41:42 +0530 Subject: [PATCH] feat: add Betweenness Centrality graph algorithm for issue #801 --- .../function/BuildInSqlFunctionTable.java | 2 + .../dsl/udf/graph/BetweennessCentrality.java | 222 ++++++++++++++++++ .../dsl/runtime/query/GQLAlgorithmTest.java | 9 + .../gql_algorithm_betweenness_centrality.txt | 6 + .../gql_algorithm_betweenness_centrality.sql | 58 +++++ 5 files changed, 297 insertions(+) create mode 100644 geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/BetweennessCentrality.java create mode 100644 geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_betweenness_centrality.txt create mode 100644 geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_betweenness_centrality.sql diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java index 47addc84a..239f7e1ef 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java @@ -35,6 +35,7 @@ import org.apache.geaflow.dsl.planner.GQLJavaTypeFactory; import org.apache.geaflow.dsl.schema.GeaFlowFunction; import org.apache.geaflow.dsl.udf.graph.AllSourceShortestPath; +import org.apache.geaflow.dsl.udf.graph.BetweennessCentrality; import org.apache.geaflow.dsl.udf.graph.ClosenessCentrality; import org.apache.geaflow.dsl.udf.graph.ClusterCoefficient; import org.apache.geaflow.dsl.udf.graph.CommonNeighbors; @@ -231,6 +232,7 @@ public class BuildInSqlFunctionTable extends ListSqlOperatorTable { .add(GeaFlowFunction.of(IncrementalKCore.class)) .add(GeaFlowFunction.of(IncMinimumSpanningTree.class)) .add(GeaFlowFunction.of(ClosenessCentrality.class)) + .add(GeaFlowFunction.of(BetweennessCentrality.class)) .add(GeaFlowFunction.of(WeakConnectedComponents.class)) .add(GeaFlowFunction.of(TriangleCount.class)) .add(GeaFlowFunction.of(ClusterCoefficient.class)) diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/BetweennessCentrality.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/BetweennessCentrality.java new file mode 100644 index 000000000..5a865392d --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/BetweennessCentrality.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.apache.geaflow.common.type.primitive.DoubleType; +import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; +import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.RowEdge; +import org.apache.geaflow.dsl.common.data.RowVertex; +import org.apache.geaflow.dsl.common.data.impl.ObjectRow; +import org.apache.geaflow.dsl.common.function.Description; +import org.apache.geaflow.dsl.common.types.GraphSchema; +import org.apache.geaflow.dsl.common.types.StructType; +import org.apache.geaflow.dsl.common.types.TableField; +import org.apache.geaflow.dsl.common.util.TypeCastUtil; +import org.apache.geaflow.model.graph.edge.EdgeDirection; + +@Description(name = "betweenness_centrality", description = "built-in udga for BetweennessCentrality") +public class BetweennessCentrality implements AlgorithmUserFunction { + + private AlgorithmRuntimeContext context; + private int maxDiameter = 10; + + @Override + public void init(AlgorithmRuntimeContext context, Object[] parameters) { + this.context = context; + if (parameters.length > 1) { + throw new IllegalArgumentException( + "Only support zero or one arguments, usage: betweenness([maxDiameter])"); + } + if (parameters.length == 1) { + maxDiameter = Integer.parseInt(String.valueOf(parameters[0])); + } + } + + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + long iter = context.getCurrentIterationId(); + double score; + Map srcStateMap; + + if (iter == 1L) { + score = 0.0; + srcStateMap = new HashMap<>(); + // Each vertex is initially its own source with distance 0 and sigma 1 + srcStateMap.put(vertex.getId(), new double[]{0.0, 1.0, 0.0}); + broadcastForward(vertex.getId(), 0L, 1L); + context.updateVertexValue(ObjectRow.create(score, srcStateMap)); + return; + } + + if (updatedValues.isPresent()) { + Row row = updatedValues.get(); + score = (double) row.getField(0, DoubleType.INSTANCE); + srcStateMap = (Map) row.getField(1, null); + } else { + score = 0.0; + srcStateMap = new HashMap<>(); + } + + if (iter <= maxDiameter + 1) { + // Forward phase (iterations 2 to maxDiameter + 1) + Map> incoming = new HashMap<>(); + while (messages.hasNext()) { + Object[] msg = messages.next(); + if ("F".equals(msg[0])) { + Object srcId = msg[1]; + long mDist = (long) msg[2]; + long mSigma = (long) msg[3]; + incoming.computeIfAbsent(srcId, k -> new ArrayList<>()).add(new long[]{mDist, mSigma}); + } + } + + for (Map.Entry> entry : incoming.entrySet()) { + Object srcId = entry.getKey(); + long minDist = Long.MAX_VALUE; + for (long[] m : entry.getValue()) { + if (m[0] < minDist) { + minDist = m[0]; + } + } + + long sumSigma = 0; + for (long[] m : entry.getValue()) { + if (m[0] == minDist) { + sumSigma += m[1]; + } + } + + long newDist = minDist + 1; + double[] state = srcStateMap.get(srcId); + if (state == null) { + srcStateMap.put(srcId, new double[]{newDist, sumSigma, 0.0}); + broadcastForward(srcId, newDist, sumSigma); + } else { + if (newDist < (long) state[0]) { + state[0] = newDist; + state[1] = sumSigma; + broadcastForward(srcId, newDist, sumSigma); + } else if (newDist == (long) state[0]) { + state[1] += sumSigma; + broadcastForward(srcId, newDist, sumSigma); + } + } + } + context.updateVertexValue(ObjectRow.create(score, srcStateMap)); + + } else { + // Backward phase (iterations maxDiameter + 2 to 2 * maxDiameter + 2) + while (messages.hasNext()) { + Object[] msg = messages.next(); + if ("B".equals(msg[0])) { + Object srcId = msg[1]; + double wSigma = (double) msg[2]; + double wDelta = (double) msg[3]; + + double[] state = srcStateMap.get(srcId); + if (state != null) { + long myDist = (long) state[0]; + if (myDist == 2 * maxDiameter + 2 - iter) { + double mySigma = state[1]; + if (mySigma > 0 && wSigma > 0) { + state[2] += (mySigma / wSigma) * (1.0 + wDelta); + } + } + } + } + } + + long targetDist = 2 * maxDiameter + 2 - iter; + for (Map.Entry entry : srcStateMap.entrySet()) { + Object srcId = entry.getKey(); + double[] state = entry.getValue(); + long myDist = (long) state[0]; + if (myDist == targetDist && myDist > 0 && !srcId.equals(vertex.getId())) { + broadcastBackward(srcId, state[1], state[2]); + } + } + + if (iter == 2 * maxDiameter + 2) { + for (Map.Entry entry : srcStateMap.entrySet()) { + Object srcId = entry.getKey(); + if (!srcId.equals(vertex.getId())) { + score += entry.getValue()[2]; + } + } + } + context.updateVertexValue(ObjectRow.create(score, srcStateMap)); + } + } + + @Override + public void finish(RowVertex vertex, Optional newValue) { + double score = 0.0; + if (newValue.isPresent()) { + score = (double) newValue.get().getField(0, DoubleType.INSTANCE); + } + context.take(ObjectRow.create(vertex.getId(), score)); + } + + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("id", graphSchema.getIdType(), false), + new TableField("score", DoubleType.INSTANCE, false) + ); + } + + private void broadcastForward(Object srcId, long dist, long sigma) { + Object[] msg = new Object[]{"F", srcId, dist, sigma}; + List edges = context.loadEdges(EdgeDirection.OUT); + if (edges != null) { + Set targetIds = new HashSet<>(); + for (RowEdge e : edges) { + targetIds.add(e.getTargetId()); + } + for (Object targetId : targetIds) { + context.sendMessage(targetId, msg); + } + } + } + + private void broadcastBackward(Object srcId, double sigma, double delta) { + Object[] msg = new Object[]{"B", srcId, sigma, delta}; + List edges = context.loadEdges(EdgeDirection.IN); + if (edges != null) { + Set targetIds = new HashSet<>(); + for (RowEdge e : edges) { + targetIds.add(e.getTargetId()); + } + for (Object targetId : targetIds) { + context.sendMessage(targetId, msg); + } + } + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java index 8c75d6e78..cfa1000ef 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GQLAlgorithmTest.java @@ -124,6 +124,15 @@ public void testAlgorithmClosenessCentrality() throws Exception { .checkSinkResult(); } + @Test + public void testAlgorithmBetweennessCentrality() throws Exception { + QueryTester + .build() + .withQueryPath("/query/gql_algorithm_betweenness_centrality.sql") + .execute() + .checkSinkResult(); + } + @Test public void testAlgorithmWeakConnectedComponents() throws Exception { QueryTester diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_betweenness_centrality.txt b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_betweenness_centrality.txt new file mode 100644 index 000000000..462ed4678 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_betweenness_centrality.txt @@ -0,0 +1,6 @@ +1,5.0 +2,0.0 +3,4.5 +4,8.0 +5,0.5 +6,0.0 diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_betweenness_centrality.sql b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_betweenness_centrality.sql new file mode 100644 index 000000000..3b53b4510 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_betweenness_centrality.sql @@ -0,0 +1,58 @@ +set geaflow.dsl.window.size = -1; +set geaflow.dsl.ignore.exception = true; + +CREATE GRAPH IF NOT EXISTS g ( + Vertex v ( + vid varchar ID, + vvalue int + ), + Edge e ( + srcId varchar SOURCE ID, + targetId varchar DESTINATION ID + ) +) WITH ( + storeType='rocksdb', + shardCount = 1 +); + +CREATE TABLE IF NOT EXISTS v_source ( + v_id varchar, + v_value int, + ts varchar, + type varchar +) WITH ( + type='file', + geaflow.dsl.file.path = 'resource:///input/test_vertex' +); + +CREATE TABLE IF NOT EXISTS e_source ( + src_id varchar, + dst_id varchar +) WITH ( + type='file', + geaflow.dsl.file.path = 'resource:///input/test_edge' +); + +CREATE TABLE IF NOT EXISTS tbl_result ( + v_id varchar, + score double +) WITH ( + type='file', + geaflow.dsl.file.path = '${target}' +); + +USE GRAPH g; + +INSERT INTO g.v(vid, vvalue) +SELECT +v_id, v_value +FROM v_source; + +INSERT INTO g.e(srcId, targetId) +SELECT + src_id, dst_id +FROM e_source; + +INSERT INTO tbl_result(v_id, score) +CALL betweenness_centrality(10) YIELD (id, score) +RETURN id, score;