pythonimage-processinginterpolationrandom-forest

Random forest for image resize on python


I'm currently trying an experimental project, where I resize a low resolution image to the same size as its high resolution one using a random forest regression model, and then compare it with other resizing techniques than are staples image enhacement (basically interpolation techniques).

I, so far have the code provided here that runs smoothly and doesn't take long, but only returns one patch of the new image, when my actual goal would be to obtain the whole image, as I do with an interpolated technique, or if that requires too much effort and time, at least a recognizable fragment of the image in question. Although I'm following a scientific paper that used decision trees, obtaining good results, computation and quality wise (with metrics like PSNR and SSIM).

How can the provided code be modified to achieve what I exposed earlier?

from sklearn.ensemble import RandomForestRegressor
from matplotlib import pyplot as plt
from skimage.transform import downscale_local_mean, resize
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage import io, color
import numpy as np
import time


def safely_convert_to_gray(image):
    if len(image.shape) == 3 and image.shape[2] == 3:  # Imagen RGB
        return color.rgb2gray(image)
    elif len(image.shape) == 2:  # Imagen en escala de grises
        return image
    else:
        raise ValueError("La imagen no es RGB ni escala de grises en formato reconocido")

def create_patches(lr_images, hr_images, patch_size, scale):
    lr_patches = []
    hr_patches = []
    for lr_img, hr_img in zip(lr_images, hr_images):
        # Asegúrate de que las imágenes estén en escala de grises y reescaladas adecuadamente
        for i in range(0, lr_img.shape[0] - patch_size + 1, patch_size):
            for j in range(0, lr_img.shape[1] - patch_size + 1, patch_size):
                # Extrae parches de la imagen de baja resolución
                lr_patch = lr_img[i:i + patch_size, j:j + patch_size]
                # Asegura que el parche de HR tenga el tamaño correcto, teniendo en cuenta el factor de escala
                hr_patch = hr_img[i*scale:(i+patch_size)*scale, j*scale:(j+patch_size)*scale]
                if lr_patch.shape == (patch_size, patch_size) and hr_patch.shape == (patch_size*scale, patch_size*scale):
                    lr_patches.append(lr_patch.flatten())
                    hr_patches.append(hr_patch.flatten())
    return np.array(lr_patches), np.array(hr_patches)


def load_images_and_features(patch_size, scale):
    lr_image_url = 'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/camera.png'
    hr_image_url = 'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/astronaut.png'
    lr_image = safely_convert_to_gray(io.imread(lr_image_url))
    hr_image = safely_convert_to_gray(io.imread(hr_image_url))
    lr_images = [lr_image]
    hr_images = [hr_image]
    features, labels = create_patches(lr_images, hr_images, patch_size, scale)
    return features, labels

patch_size = 8  # El tamaño de los parches extraídos de la imagen de baja resolución
scale = 2       # El factor de escala entre las imágenes de baja y alta resolución

features, labels = load_images_and_features(patch_size, scale)

# Entrenamiento del modelo
rf = RandomForestRegressor(n_estimators=10, random_state=42)
rf.fit(features, labels)

def upscale_image(lr_img, model, patch_size, scale):
    upscaled_img = resize(lr_img, (lr_img.shape[0] * scale, lr_img.shape[1] * scale), anti_aliasing=True)
    for i in range(0, upscaled_img.shape[0] - patch_size * scale + 1, scale):
        for j in range(0, upscaled_img.shape[1] - patch_size * scale + 1, scale):
            lr_patch = upscaled_img[i//scale:(i//scale)+patch_size, j//scale:(j//scale)+patch_size].flatten()
            hr_patch_predicted = model.predict([lr_patch])
            upscaled_img[i:i + patch_size * scale, j:j + patch_size * scale] = hr_patch_predicted.reshape(patch_size * scale, patch_size * scale)
    return upscaled_img

# Prueba de escalamiento de imagen
lr_test_image = features[0].reshape(patch_size, patch_size)
upscaled_image = upscale_image(lr_test_image, rf, patch_size, scale)


# tiempo
tiempo=time.process_time()
print("tiempo: ", round(tiempo,2))

# Visualización de la imagen original y la escalada
plt.title("Upscaled Image")
plt.imshow(upscaled_image)
plt.show()

As explained earlier, the current code does run, but does not meet my needs, which would be a full image of the new resized image or at the very least, a recognizable fragment of the image, and not just a patch. So I can, later on compare it to the original HR version.

The picture's from the current outcome.

enter image description here


Solution

  • In the original code it seems as though the low-res and high-res patches are from different images, rather than being from the same image.

    The code below is an attempt at the upsampling task you described. Images are first selected for a training set and a validation set. The images are converted to patches, where the input data X are the low-resolution patches and the targets y are the original patches:

    enter image description here

    The model is trained on the data, and evaluated on the validation image patches:

    enter image description here

    from sklearn.linear_model import LinearRegression
    from sklearn.feature_extraction.image import extract_patches_2d, reconstruct_from_patches_2d
    from matplotlib import pyplot as plt
    from skimage.transform import rescale
    from skimage import io, color
    import numpy as np
    
    def safely_convert_to_gray(image):
        if len(image.shape) == 3 and image.shape[2] == 3:  # Imagen RGB
            return color.rgb2gray(image)
        elif len(image.shape) == 2:  # Imagen en escala de grises
            return image
        else:
            raise ValueError("La imagen no es RGB ni escala de grises en formato reconocido")
    
    
    def image_to_Xy(image, patch_size, scale):
        """
        Returns: (X, y) tuple
        where X: (n_patches, patch_size**2 / scale**2) array of low-res patches
              y: (n_patches, patch_size**2) array of hi-res patches
        """
        
        hires_patches = extract_patches_2d(image, [patch_size] * 2)    
        lowres_patches = np.array(
            [rescale(patch, 1 / scale, anti_aliasing=True) for patch in hires_patches]
        )
        
        hires_patches = hires_patches.reshape(-1, patch_size ** 2)
        lowres_patches = lowres_patches.reshape(-1, (patch_size // scale)**2)
        return lowres_patches.astype(np.float32), hires_patches.astype(np.float32)
    
    def images_to_Xy(images, patch_size, scale, shuffle=True, random_state=None):
        """
        Returns (X, y)
        where X: (patches of all images, patch_size**2 / scale**2) array of low-res patches
        where y: (patches of all images, patch_size**2) array of hi-res patches
        """
        Xy_perimage = [image_to_Xy(image, patch_size, scale) for image in images]
        
        X_arr = np.concatenate([X for X, y in Xy_perimage], axis=0)
        y_arr = np.concatenate([y for X, y in Xy_perimage], axis=0)
        
        if shuffle:
            ixs = np.random.default_rng(random_state).permutation(range(len(X_arr)))
            X_arr, y_arr = [arr[ixs] for arr in (X_arr, y_arr)]
            
        return X_arr, y_arr
    
    def hires_patches_to_image(patches, patch_size, image_size):
        """
        Reconstructs image from patches.
        Only for the hi-res original & hi-res prediction.
        Doesn't work for the scaled patches: use lowres_patches_to_image()
        """
        patches_unflat = patches.reshape(-1, patch_size, patch_size)
        return reconstruct_from_patches_2d(patches_unflat, image_size).astype(np.float32)
    
    def lowres_patches_to_image(patches, patch_size, scale, original_image_shape):
        """
        Low-res patches assembled into the corresponding low-res image
        """
        lr_patch_size = patch_size // scale
        
        patches_unflat = patches.reshape(-1, lr_patch_size, lr_patch_size)
    
        lr_shape = [original_image_shape[i] - lr_patch_size for i in [0, 1]]
        data = np.zeros(lr_shape)
        counts = np.zeros_like(data) + 1e-10
    
        for r in range(lr_shape[0] - lr_patch_size + 1):
            for c in range(lr_shape[1] - lr_patch_size + 1):
                patch_idx = r * (original_image_shape[1] - patch_size + 1) + c
                data[r:r + lr_patch_size, c:c + lr_patch_size] += patches_unflat[patch_idx]
                counts[r:r + lr_patch_size, c:c + lr_patch_size] += np.ones([lr_patch_size] * 2)
        
        return (data / counts).astype(np.float32)
    
    #
    # Load train and validation images
    #
    image_urls = [
        'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/camera.png',
        'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/chelsea.png',
        'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/coffee.png',
        'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/brick.png',
        'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/coins.png',
        'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/grass.png',
    ]
    images = [safely_convert_to_gray(io.imread(url)) for url in image_urls]
    images = [rescale(image, 0.15) for image in images] #let's work with smaller images due to RAM
    
    
    train_images = images[:2]  #camera.png and cat png for training
    val_images = images[2:]    #coffee.png & grass.png for validation
                               #Don't mix patches from the same
                               # image across sets (data leakage)
    patch_size = 8
    scale = 2
    
    trn_features, trn_labels = images_to_Xy(train_images, patch_size, scale, shuffle=True, random_state=0)
    
    #
    #View some samples
    #
    f, axs = plt.subplots(4, 10, figsize=(9, 5))
    axs = axs.flatten()
    for i in range(0, 40, 2):
        ax = axs[i]
        ax.imshow(trn_features[i].reshape(patch_size // scale, patch_size // scale), cmap='gray')
        ax.axis('off')
        ax.set_title(' ' * 20 + f'X[{i}], y[{i}]', fontsize=8)
        
        ax = axs[i + 1]
        ax.imshow(trn_labels[i].reshape(patch_size, patch_size), cmap='gray')
        ax.axis('off')
    f.suptitle('Training set samples')
    plt.show()
    
    #
    #Fit model on train data
    # Can limit number of samples for RAM
    #
    trn_limit = None
    model = LinearRegression()
    model.fit(trn_features[:trn_limit], trn_labels[:trn_limit])
    
    #
    #Assess on validation set
    #
    for i, image in enumerate(val_images):
        val_features, _ = image_to_Xy(image, patch_size, scale)
        val_predictions = model.predict(val_features).astype(np.float32)
        pred_image = hires_patches_to_image(val_predictions, patch_size, image.shape)
    
        f, axs = plt.subplots(1, 3, figsize=(8, 2.8))
        f.suptitle(f'Validation image {i} results')
    
        ax = axs[0]
        ax.imshow(image, cmap='gray')
        ax.set_title('original', fontsize=9)
    
        ax = axs[1]
        ax.imshow(lowres_patches_to_image(val_features, patch_size, scale, image.shape), cmap='gray')
        ax.set_title(f'low-res ({1/scale}) patches input', fontsize=9)
    
        ax = axs[2]
        ax.imshow(pred_image, cmap='gray')
        ax.set_title('upsampled output', fontsize=9)
    
        [ax.axis('off') for ax in axs]