Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

python find connected components in a 3D graph / tuple with three elements?

I have a binary 3D numpy array, for which I would like to find connected components, i.d. neighbor elements with value 1.

data = np.random.binomial(1, 0.4, 1000)
data = data.reshape((10,10,10))

Alternatively I can get the coordinates for each element with value one and get a set of lists with three elements for which I could get neighboring clusters

coordinates = np.argwhere(data > 0)

connected_elements = []
for node in coordinates:
  neighbors = #Get possible neighbors of node
  if neighbors not in connected_elements:
    connected_elements.append(node)
  else:
    connected_elements.index(neighbor).extend(node)

How can I do this, or implement a 2D connected_components function for a 3D setting?

like image 450
hirschme Avatar asked Mar 15 '19 19:03

hirschme


2 Answers

Like suggested in the question, we first generate the data and find the coordinates.

Then we can use k-d tree cKDTree to find neighbours within a distance of 1 with query_pairs and use them as edges of the graph, which essentially reduces the problem to a standard graph connected component search.

Then we create the networkx graph from these edges with from_edgelist and run connected_components to find connected components.

And the last step is visualization.

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from scipy.spatial.ckdtree import cKDTree
from mpl_toolkits.mplot3d import Axes3D

# create data
data = np.random.binomial(1, 0.1, 1000)
data = data.reshape((10,10,10))

# find coordinates
cs = np.argwhere(data > 0)

# build k-d tree
kdt = cKDTree(cs)
edges = kdt.query_pairs(1)

# create graph
G = nx.from_edgelist(edges)

# find connected components
ccs = nx.connected_components(G)
node_component = {v:k for k,vs in enumerate(ccs) for v in vs}

# visualize
df = pd.DataFrame(cs, columns=['x','y','z'])
df['c'] = pd.Series(node_component)

# to include single-node connected components
# df.loc[df['c'].isna(), 'c'] = df.loc[df['c'].isna(), 'c'].isna().cumsum() + df['c'].max()

fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111, projection='3d')
cmhot = plt.get_cmap("hot")
ax.scatter(df['x'], df['y'], df['z'], c=df['c'], s=50, cmap=cmhot)

Output:

enter image description here

Notes:

  • I've reduced the probability in binomial distribution when generating the nodes from 0.4 to 0.1 to make the visualisation more 'readable'
  • I'm not showing connected components that contain only a single node. This can be done with uncommenting the line below the # to include single-node connected components comment
  • DataFrame df contains coordinates x, y and z and the connected component index c for each node:
print(df)

Output:

     x  y  z     c
0    0  0  3  20.0
1    0  1  8  21.0
2    0  2  1   6.0
3    0  2  3  22.0
4    0  3  0  23.0
...
  • Based on the DataFrame df we can also check some fun stuff, like the sizes of the biggest connected components found (along with the connected component number):
df['c'].value_counts().nlargest(5)

Output:

4.0    5
1.0    4
7.0    3
8.0    3
5.0    2
Name: c, dtype: int64
like image 181
perl Avatar answered Nov 03 '22 05:11

perl


DFS to find connected components

import queue
import itertools
n = 10

def DFS(data, v, x,y,z, component):
    q = queue.Queue()
    q.put((x,y,z))
    while not q.empty():
        x,y,z = q.get()
        v[x,y,z] = component

        l = [[x], [y], [z]]

        for i in range(3):
            if l[i][0] > 0:
                l[i].append(l[i][0]-1)
            if l[i][0] < v.shape[1]-1:
                l[i].append(l[i][0]+1)

        c = list(itertools.product(l[0], l[1], l[2]))
        for x,y,z in c:
            if v[x,y,z] == 0 and data[x,y,z] == 1:
                q.put((x,y,z))

data = np.random.binomial(1, 0.2, n*n*n)
data = data.reshape((n,n,n))

coordinates = np.argwhere(data > 0)
v = np.zeros_like(data)

component = 1
for x,y,z in coordinates:
    if v[x,y,z] != 0:
        continue
    DFS(data, v, x,y,z, component)
    component += 1

Main Algo:

  1. Set visited of each point = 0 (denoting that it is not part of any connected component yet)
  2. for all points whose value == 1
    1. If the point is not visited start a DFS starting form it

DFP:: It is the traditional DFS algorithm using Queue. The only difference for 3D case is given (x,y,z) we calculate all the valid neighbour of it using itertools.product. In 3D case every point will have 27 neighbour including itself (3 positions and 3 possible values - same, increment, decrement, so 27 ways).

The matrix v stores the connected components numbered starting from 1.

Testcase:

when data =

 [[[1 1 1]
  [1 1 1]
  [1 1 1]]

 [[0 0 0]
  [0 0 0]
  [0 0 0]]

 [[1 1 1]
  [1 1 1]
  [1 1 1]]]

Visualisation : enter image description here

the two opposite sides are the two different connected components

The algorithm returns v

[[[1 1 1]
  [1 1 1]
  [1 1 1]]

 [[0 0 0]
  [0 0 0]
  [0 0 0]]

 [[2 2 2]
  [2 2 2]
  [2 2 2]]]

which is correct.

Visualisation : enter image description here

As can see in the visualisation of v green color represent one connected component and blue color represent other connected component.

Visualization code

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def plot(data):
    fig = plt.figure(figsize=(10,10))
    ax = fig.gca(projection='3d')

    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            ax.scatter([i]*data.shape[0], [j]*data.shape[1], 
            [i for i in range(data.shape[2])], 
                   c=['r' if i == 0 else 'b' for i in data[i,j]], s=50)

plot(data)
plt.show()
plt.close('all')
plot(v)
plt.show()
like image 36
mujjiga Avatar answered Nov 03 '22 07:11

mujjiga