Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

border/edge operations on numpy arrays

Suppose I have a 3D numpy array of nonzero values and "background" = 0. As an example I will take a sphere of random values:

array = np.random.randint(1, 5, size = (100,100,100))
z,y,x = np.ogrid[-50:50, -50:50, -50:50]
mask = x**2 + y**2 + z**2<= 20**2
array[np.invert(mask)] = 0

First, I would like to find the "border voxels" (all nonzero values that have a zero within their 3x3x3 neigbourhood). Second, I would like to replace all border voxels with the mean of their nonzero neighbours. So far I tried to use scipy's generic filter in the following way:

Function to apply at each element:

def borderCheck(values):
    #check if the footprint center is on a nonzero value
    if values[13] != 0:
        #replace border voxels with the mean of nonzero neighbours
        if 0 in values:
            return np.sum(values)/np.count_nonzero(values)
        else:
            return values[13]
    else:
        return 0

Generic filter:

from scipy import ndimage
result = ndimage.generic_filter(array, borderCheck, footprint = np.ones((3,3,3)))

Is this a proper way to handle this problem? I feel that I am trying to reinvent the wheel here and that there must be a shorter, nicer way to achieve the result. Are there any other suitable (numpy, scipy ) functions that I can use?

EDIT

I messed one thing up: I would like to replace all border voxels with the mean of their nonzero AND non-border neighbours. For this, I tried to clean up the neighbours from ali_m's code (2D case):

#for each neighbour voxel, check whether it also appears in the border/edges
non_border_neighbours = []
for each in neighbours:
    non_border_neighbours.append([i for i in each if nonzero_idx[i] not in edge_idx])

Now I can't figure out why non_border_neighbours comes back empty?

Furthermore, correct me if I am wrong but doesn't tree.query_ball_point with radius 1 adress only the 6 next neighbours (euclidean distance 1)? Should I set sqrt(3) (3D case) as radius to get the 26-neighbourhood?

like image 696
a.smiet Avatar asked Dec 24 '22 12:12

a.smiet


1 Answers

I think it's best to start out with the 2D case first, since it can be visualized much more easily:

import numpy as np
from matplotlib import pyplot as plt

A = np.random.randint(1, 5, size=(100, 100)).astype(np.double)
y, x = np.ogrid[-50:50, -50:50]
mask = x**2 + y**2 <= 30**2
A[~mask] = 0

To find the edge pixels you could perform binary erosion on your mask, then XOR the result with your mask

# rank 2 structure with full connectivity
struct = ndimage.generate_binary_structure(2, 2)
erode = ndimage.binary_erosion(mask, struct)
edges = mask ^ erode

One approach to find the nearest non-zero neighbours of each edge pixel would be to use a scipy.spatial.cKDTree:

from scipy.spatial import cKDTree

# the indices of the non-zero locations and their corresponding values
nonzero_idx = np.vstack(np.where(mask)).T
nonzero_vals = A[mask]

# build a k-D tree
tree = cKDTree(nonzero_idx)

# use it to find the indices of all non-zero values that are at most 1 pixel
# away from each edge pixel
edge_idx = np.vstack(np.where(edges)).T
neighbours = tree.query_ball_point(edge_idx, r=1, p=np.inf)

# take the average value for each set of neighbours
new_vals = np.hstack(np.mean(nonzero_vals[n]) for n in neighbours)

# use these to replace the values of the edge pixels
A_new = A.astype(np.double, copy=True)
A_new[edges] = new_vals

Some visualisation:

fig, ax = plt.subplots(1, 3, figsize=(10, 4), sharex=True, sharey=True)
norm = plt.Normalize(0, A.max())
ax[0].imshow(A, norm=norm)
ax[0].set_title('Original', fontsize='x-large')
ax[1].imshow(edges)
ax[1].set_title('Edges', fontsize='x-large')
ax[2].imshow(A_new, norm=norm)
ax[2].set_title('Averaged', fontsize='x-large')
for aa in ax:
    aa.set_axis_off()
ax[0].set_xlim(20, 50)
ax[0].set_ylim(50, 80)
fig.tight_layout()
plt.show()

enter image description here

This approach will also generalize to the 3D case:

B = np.random.randint(1, 5, size=(100, 100, 100)).astype(np.double)
z, y, x = np.ogrid[-50:50, -50:50, -50:50]
mask = x**2 + y**2 + z**2 <= 20**2
B[~mask] = 0

struct = ndimage.generate_binary_structure(3, 3)
erode = ndimage.binary_erosion(mask, struct)
edges = mask ^ erode

nonzero_idx = np.vstack(np.where(mask)).T
nonzero_vals = B[mask]

tree = cKDTree(nonzero_idx)

edge_idx = np.vstack(np.where(edges)).T
neighbours = tree.query_ball_point(edge_idx, r=1, p=np.inf)

new_vals = np.hstack(np.mean(nonzero_vals[n]) for n in neighbours)

B_new = B.astype(np.double, copy=True)
B_new[edges] = new_vals

Test against your version:

def borderCheck(values):
    #check if the footprint center is on a nonzero value
    if values[13] != 0:
        #replace border voxels with the mean of nonzero neighbours
        if 0 in values:
            return np.sum(values)/np.count_nonzero(values)
        else:
            return values[13]
    else:
        return 0

result = ndimage.generic_filter(B, borderCheck, footprint=np.ones((3, 3, 3)))

print(np.allclose(B_new, result))
# True

I'm sure this isn't the most efficient way to do it, but it will still be significantly faster than using generic_filter.


Update

The performance could be further improved by reducing the number of points that are considered as candidate neighbours of the edge pixels/voxels:

# ...

# the edge pixels/voxels plus their immediate non-zero neighbours
erode2 = ndimage.binary_erosion(erode, struct)
candidate_neighbours = mask ^ erode2

nonzero_idx = np.vstack(np.where(candidate_neighbours)).T
nonzero_vals = B[candidate_neighbours]

# ...
like image 149
ali_m Avatar answered Dec 29 '22 03:12

ali_m