Logo Questions Linux Laravel Mysql Ubuntu Git Menu

How to do alpha matting in python

How to do alpha matting in python?

More specifically, how to extract the alpha channel of an image, given a trimap which marks pixels as either

  • 100% foreground (white)
  • 100% background (black)
  • or unknown (gray)

Input image

enter image description here

Input trimap

enter image description here

like image 923
John Avatar asked Mar 26 '19 09:03


People also ask

What is alpha matting?

Alpha matting is used to extract a foreground object with soft boundaries from a background image. This module is dedicated to computing alpha matte of objects in images from a given input image and a greyscale trimap image that contains information about the foreground, background and unknown pixels.

What is trimap in image matting?

Trimap problem mask – or timap – as an input. Basically, the trimap is a rough segmentation of an image into three region types: certain foreground, unknown, certain background.

What is a Trimap?

In order to fully extract meaningful foreground object, almost all the matting techniques rely on the user intervention, wherein the user segments the input image into three regions: defi- nite foreground, definite background, and unknown region. This three-level map is called as a trimap.

What is image matting?

Image Matting is the process of accurately estimating the foreground object in images and videos. It is a very important technique in image and video editing applications, particularly in film production for creating visual effects.

1 Answers

Using a library for alpha matting

Here are two options, both based on the paper "A Closed Form Solution to Natural Image Matting" by Levin and Lischinski.

  • https://github.com/MarcoForte/closed-form-matting
  • https://github.com/99991/matting

The math behind it

  1. Calculate the matrix L based on image I which describes how similar neighboring pixels are:

matrix L

  1. Solve the linear system for alpha using a constraint matrix D and vector b to fix the known alpha values:

solve for alpha

Implementation in python

import numpy as np
import numpy.linalg
import scipy.sparse
import scipy.sparse.linalg
from PIL import Image
from numba import njit

def main():
    # configure paths here
    image_path  = "cat_image.png"
    trimap_path = "cat_trimap.png"
    alpha_path  = "cat_alpha.png"
    cutout_path = "cat_cutout.png"

    # load and convert to [0, 1] range
    image  = np.array(Image.open( image_path).convert("RGB"))/255.0
    trimap = np.array(Image.open(trimap_path).convert(  "L"))/255.0

    # make matting laplacian
    i,j,v = closed_form_laplacian(image)
    h,w = trimap.shape
    L = scipy.sparse.csr_matrix((v, (i, j)), shape=(w*h, w*h))

    # build linear system
    A, b = make_system(L, trimap)

    # solve sparse linear system
    print("solving linear system...")
    alpha = scipy.sparse.linalg.spsolve(A, b).reshape(h, w)

    # stack rgb and alpha
    cutout = np.concatenate([image, alpha[:, :, np.newaxis]], axis=2)

    # clip and convert to uint8 for PIL
    cutout = np.clip(cutout*255, 0, 255).astype(np.uint8)
    alpha  = np.clip( alpha*255, 0, 255).astype(np.uint8)

    # save and show
    Image.fromarray(alpha ).save( alpha_path)
    Image.fromarray(alpha ).show()

def closed_form_laplacian(image, epsilon=1e-7, r=1):
    h,w = image.shape[:2]
    window_area = (2*r + 1)**2
    n_vals = (w - 2*r)*(h - 2*r)*window_area**2
    k = 0
    # data for matting laplacian in coordinate form
    i = np.empty(n_vals, dtype=np.int32)
    j = np.empty(n_vals, dtype=np.int32)
    v = np.empty(n_vals, dtype=np.float64)

    # for each pixel of image
    for y in range(r, h - r):
        for x in range(r, w - r):

            # gather neighbors of current pixel in 3x3 window
            n = image[y-r:y+r+1, x-r:x+r+1]
            u = np.zeros(3)
            for p in range(3):
                u[p] = n[:, :, p].mean()
            c = n - u

            # calculate covariance matrix over color channels
            cov = np.zeros((3, 3))
            for p in range(3):
                for q in range(3):
                    cov[p, q] = np.mean(c[:, :, p]*c[:, :, q])

            # calculate inverse covariance of window
            inv_cov = np.linalg.inv(cov + epsilon/window_area * np.eye(3))

            # for each pair ((xi, yi), (xj, yj)) in a 3x3 window
            for dyi in range(2*r + 1):
                for dxi in range(2*r + 1):
                    for dyj in range(2*r + 1):
                        for dxj in range(2*r + 1):
                            i[k] = (x + dxi - r) + (y + dyi - r)*w
                            j[k] = (x + dxj - r) + (y + dyj - r)*w
                            temp = c[dyi, dxi].dot(inv_cov).dot(c[dyj, dxj])
                            v[k] = (1.0 if (i[k] == j[k]) else 0.0) - (1 + temp)/window_area
                            k += 1
        print("generating matting laplacian", y - r + 1, "/", h - 2*r)

    return i, j, v

def make_system(L, trimap, constraint_factor=100.0):
    # split trimap into foreground, background, known and unknown masks
    is_fg = (trimap > 0.9).flatten()
    is_bg = (trimap < 0.1).flatten()
    is_known = is_fg | is_bg
    is_unknown = ~is_known

    # diagonal matrix to constrain known alpha values
    d = is_known.astype(np.float64)
    D = scipy.sparse.diags(d)

    # combine constraints and graph laplacian
    A = constraint_factor*D + L
    # constrained values of known alpha values
    b = constraint_factor*is_fg.astype(np.float64)

    return A, b

if __name__ == "__main__":

Output alpha

cat alpha

Output cutout

cat cutout

like image 95
John Avatar answered Sep 28 '22 19:09
