pythonimage-processingcomputer-visionmathematical-optimization

How to do Image Alpha Matting in Python with a Given Trimap (Based on Closed Form Solution to Natural Image Matting)


How to implement alpha matting algorithm in Python?

More specifically, how to extract the alpha channel of an image, given a trimap which marks pixels as either

Input image

enter image description here

Input trimap

enter image description here


Solution

  • Using a library for alpha matting

    Here are two options, both based on the paper "A Closed Form Solution to Natural Image Matting" by Levin and Lischinski.

    The math behind it

    1. Calculate the matrix L based on image I which describes how similar neighboring pixels are:

    matrix L

    1. Solve the linear system for alpha using a constraint matrix D and vector b to fix the known alpha values:

    solve for alpha

    Implementation in python

    import numpy as np
    import numpy.linalg
    import scipy.sparse
    import scipy.sparse.linalg
    from PIL import Image
    from numba import njit
    
    def main():
        # configure paths here
        image_path  = "cat_image.png"
        trimap_path = "cat_trimap.png"
        alpha_path  = "cat_alpha.png"
        cutout_path = "cat_cutout.png"
    
        # load and convert to [0, 1] range
        image  = np.array(Image.open( image_path).convert("RGB"))/255.0
        trimap = np.array(Image.open(trimap_path).convert(  "L"))/255.0
    
        # make matting laplacian
        i,j,v = closed_form_laplacian(image)
        h,w = trimap.shape
        L = scipy.sparse.csr_matrix((v, (i, j)), shape=(w*h, w*h))
    
        # build linear system
        A, b = make_system(L, trimap)
    
        # solve sparse linear system
        print("solving linear system...")
        alpha = scipy.sparse.linalg.spsolve(A, b).reshape(h, w)
    
        # stack rgb and alpha
        cutout = np.concatenate([image, alpha[:, :, np.newaxis]], axis=2)
    
        # clip and convert to uint8 for PIL
        cutout = np.clip(cutout*255, 0, 255).astype(np.uint8)
        alpha  = np.clip( alpha*255, 0, 255).astype(np.uint8)
    
        # save and show
        Image.fromarray(alpha ).save( alpha_path)
        Image.fromarray(cutout).save(cutout_path)
        Image.fromarray(alpha ).show()
        Image.fromarray(cutout).show()
    
    @njit
    def closed_form_laplacian(image, epsilon=1e-7, r=1):
        h,w = image.shape[:2]
        window_area = (2*r + 1)**2
        n_vals = (w - 2*r)*(h - 2*r)*window_area**2
        k = 0
        # data for matting laplacian in coordinate form
        i = np.empty(n_vals, dtype=np.int32)
        j = np.empty(n_vals, dtype=np.int32)
        v = np.empty(n_vals, dtype=np.float64)
    
        # for each pixel of image
        for y in range(r, h - r):
            for x in range(r, w - r):
    
                # gather neighbors of current pixel in 3x3 window
                n = image[y-r:y+r+1, x-r:x+r+1]
                u = np.zeros(3)
                for p in range(3):
                    u[p] = n[:, :, p].mean()
                c = n - u
    
                # calculate covariance matrix over color channels
                cov = np.zeros((3, 3))
                for p in range(3):
                    for q in range(3):
                        cov[p, q] = np.mean(c[:, :, p]*c[:, :, q])
    
                # calculate inverse covariance of window
                inv_cov = np.linalg.inv(cov + epsilon/window_area * np.eye(3))
    
                # for each pair ((xi, yi), (xj, yj)) in a 3x3 window
                for dyi in range(2*r + 1):
                    for dxi in range(2*r + 1):
                        for dyj in range(2*r + 1):
                            for dxj in range(2*r + 1):
                                i[k] = (x + dxi - r) + (y + dyi - r)*w
                                j[k] = (x + dxj - r) + (y + dyj - r)*w
                                temp = c[dyi, dxi].dot(inv_cov).dot(c[dyj, dxj])
                                v[k] = (1.0 if (i[k] == j[k]) else 0.0) - (1 + temp)/window_area
                                k += 1
            print("generating matting laplacian", y - r + 1, "/", h - 2*r)
    
        return i, j, v
    
    def make_system(L, trimap, constraint_factor=100.0):
        # split trimap into foreground, background, known and unknown masks
        is_fg = (trimap > 0.9).flatten()
        is_bg = (trimap < 0.1).flatten()
        is_known = is_fg | is_bg
        is_unknown = ~is_known
    
        # diagonal matrix to constrain known alpha values
        d = is_known.astype(np.float64)
        D = scipy.sparse.diags(d)
    
        # combine constraints and graph laplacian
        A = constraint_factor*D + L
        # constrained values of known alpha values
        b = constraint_factor*is_fg.astype(np.float64)
    
        return A, b
    
    if __name__ == "__main__":
        main()
    

    Output alpha

    cat alpha

    Output cutout

    cat cutout