Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Reusing models from grabcut in OpenCV

I used the interactive grabcut.py from the OpenCV samples to segment an image and saved the foreground and background models. Then I used these models to segment more images of the same kind, as I don't want to retrain the model each time.

After running the grabcut algorithm, the mask is all zeros (all background) and therefore it doesn't segment anything.

from matplotlib import pyplot as plt
import numpy as np
import cv2

img = cv2.imread('usimg1.jpg')
mask = np.zeros(img.shape[:2], np.uint8)
bgdModel = np.load('bgdmodel.npy')
fgdModel = np.load('fgdmodel.npy')

cv2.grabCut(img, mask, None, bgdModel, fgdModel, 5, cv2.GC_EVAL)

mask = np.where((mask==2) | (mask==0), 0, 1).astype('uint8') 
img = img * mask[:, :, np.newaxis]

plt.imshow(img)
plt.show()

I tried to initialize the algorithm with a mask or a rectangle but this produces an error because the models are not empty (which is what I actually want).

How do I have to pass the pre-trained models to the algorithm, such that they are not retrained from scratch each time I'm segmenting an image?

EDIT After rayryeng's comment I implemented following code:

cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 0, cv2.GC_INIT_WITH_RECT)
cv2.grabCut(img, mask, rect, bgdModel, fgdModel, 2, cv2.GC_EVAL)

It seems to work but the first call now changes my model. In the source code it calls learnGMMs without checking whether a pretrained model is provided.

like image 247
ipa Avatar asked Jul 19 '17 15:07

ipa


1 Answers

You have the correct line of thinking where you use cv2.GC_EVAL so that you only need to perform the segmentation without having to compute the models again. Unfortunately even when you use this flag, this is a limitation with the OpenCV source itself. If you look at the actual C++ implementation when you encounter the GC_EVAL condition, it does this towards the end of the cv::grabcut method. Note that the Python cv2.grabCut method is a wrapper for cv::grabcut:

if( mode == GC_EVAL )
    checkMask( img, mask );

const double gamma = 50;
const double lambda = 9*gamma;
const double beta = calcBeta( img );

Mat leftW, upleftW, upW, uprightW;
calcNWeights( img, leftW, upleftW, upW, uprightW, beta, gamma );

for( int i = 0; i < iterCount; i++ )
{
    GCGraph<double> graph;
    assignGMMsComponents( img, mask, bgdGMM, fgdGMM, compIdxs );
    learnGMMs( img, mask, compIdxs, bgdGMM, fgdGMM );
    constructGCGraph(img, mask, bgdGMM, fgdGMM, lambda, leftW, upleftW, upW, uprightW, graph );
    estimateSegmentation( graph, mask );
}

You'll see that GC_EVAL is only encountered once in the code and that's to check the validity of the inputs. The culprit is the learnGMMs function. Even though you specified the trained models, these get reset because the call to learnGMMs ignores the GC_EVAL flag, so this gets called regardless of whatever flag you specify as the input.

Inspired by this post: OpenCV - GrabCut with custom foreground/background models, what you can do is you'll have to modify the OpenCV source yourself and inside the loop you can place an if statement to check for the GC_EVAL flag prior to calling learnGMMs:

if( mode == GC_EVAL )
    checkMask( img, mask );

const double gamma = 50;
const double lambda = 9*gamma;
const double beta = calcBeta( img );

Mat leftW, upleftW, upW, uprightW;
calcNWeights( img, leftW, upleftW, upW, uprightW, beta, gamma );

for( int i = 0; i < iterCount; i++ )
{
    GCGraph<double> graph;
    assignGMMsComponents( img, mask, bgdGMM, fgdGMM, compIdxs );
    if (mode != GC_EVAL) // New
        learnGMMs( img, mask, compIdxs, bgdGMM, fgdGMM );
    constructGCGraph(img, mask, bgdGMM, fgdGMM, lambda, leftW, upleftW, upW, uprightW, graph );
    estimateSegmentation( graph, mask );
}

This should be able to use the pre-trained models without having to learn them all over again at each iteration. Once you make the change, you'll have to recompile the source again and that should hopefully be able to use your pre-trained models without clearing them when you use the cv2.GC_EVAL flag.

For the future, I have opened up a issue on the official repo for OpenCV. Hopefully they'll fix this when they have the time: https://github.com/opencv/opencv/issues/9191

like image 107
rayryeng Avatar answered Oct 01 '22 21:10

rayryeng