pythonpytorchtensorblendingtiling

How can I make this PyTorch tensor (B, C, H, W) tiling & blending code simpler and more efficient?


So, I wrote the code below many months ago, and it's worked pretty well. Though I am struggling on how I can simplify it and make it more efficient.

The functions below split an image tensor (B, C, H, W) into equal sized tiles (B, C, H, W) and then you can do stuff individually to the tiles in order to save memory. Then when rebuilding the tensor from the tiles, it uses masks to ensure that the tiles are seamlessly blended back together. The 'special masks' in the masking function handle when tiles in the right most column or tiles in the bottom row can't use the same overlap as the other tiles. This means that the right edge tiles and the bottom tiles may sometimes have almost none of their content visible. This is done to ensure that the tiles are always the exact specified size, regardless of the original image/tensor's size (important for visualization/DeepDream, neural style transfer, etc...). The adjacent row/column to the edge row/column also has special masks as well for where they overlap with the edge row/column.

There are 8 possible masks for every tile, and 4 of those masks can be used at once. The 4 possible masks are left, right, top, and bottom, with a special version for each mask.

# Improved version of: https://github.com/ProGamerGov/neural-dream/blob/master/neural_dream/dream_tile.py
import torch


# Apply blend masks to tiles
def mask_tile(tile, overlap, side='bottom'):
    c, h, w = tile.size(1), tile.size(2), tile.size(3)
    top_overlap, bottom_overlap, right_overlap, left_overlap = overlap[0], overlap[1], overlap[2], overlap[3]

    base_mask = torch.ones_like(tile)

    if 'left' in side and 'left-special' not in side:
        lin_mask_left = torch.linspace(0,1,left_overlap, device=tile.device).repeat(h,1).repeat(c,1,1).unsqueeze(0)
        base_mask[:,:,:,:left_overlap] = base_mask[:,:,:,:left_overlap] * lin_mask_left
    if 'right' in side and 'right-special' not in side:
        lin_mask_right = torch.linspace(1,0,right_overlap, device=tile.device).repeat(h,1).repeat(c,1,1).unsqueeze(0)
        base_mask[:,:,:,w-right_overlap:] = base_mask[:,:,:,w-right_overlap:] * lin_mask_right
    if 'top' in side and 'top-special' not in side:
        lin_mask_top = torch.linspace(0,1,top_overlap, device=tile.device).repeat(w,1).rot90(3).repeat(c,1,1).unsqueeze(0)
        base_mask[:,:,:top_overlap,:] = base_mask[:,:,:top_overlap,:] * lin_mask_top
    if 'bottom' in side and 'bottom-special' not in side:
        lin_mask_bottom = torch.linspace(1,0,bottom_overlap, device=tile.device).repeat(w,1).rot90(3).repeat(c,1,1).unsqueeze(0)
        base_mask[:,:,h-bottom_overlap:,:] = base_mask[:,:,h-bottom_overlap:,:] * lin_mask_bottom

    if 'left-special' in side:
        lin_mask_left = torch.linspace(0,1,left_overlap, device=tile.device)
        zeros_mask = torch.zeros(w-(left_overlap*2), device=tile.device)
        ones_mask = torch.ones(left_overlap, device=tile.device)
        lin_mask_left = torch.cat([zeros_mask, lin_mask_left, ones_mask], 0).repeat(h,1).repeat(c,1,1).unsqueeze(0)
        base_mask = base_mask * lin_mask_left
    if 'right-special' in side:
        lin_mask_right = torch.linspace(1,0,right_overlap, device=tile.device)
        ones_mask = torch.ones(w-right_overlap, device=tile.device)
        lin_mask_right = torch.cat([ones_mask, lin_mask_right], 0).repeat(h,1).repeat(c,1,1).unsqueeze(0)
        base_mask = base_mask * lin_mask_right
    if 'top-special' in side:
        lin_mask_top = torch.linspace(0,1,top_overlap, device=tile.device)
        zeros_mask = torch.zeros(h-(top_overlap*2), device=tile.device)
        ones_mask = torch.ones(top_overlap, device=tile.device)
        lin_mask_top = torch.cat([zeros_mask, lin_mask_top, ones_mask], 0).repeat(w,1).rot90(3).repeat(c,1,1).unsqueeze(0)
        base_mask = base_mask * lin_mask_top
    if 'bottom-special' in side:
        lin_mask_bottom = torch.linspace(1,0,bottom_overlap, device=tile.device)
        ones_mask = torch.ones(h-bottom_overlap, device=tile.device)
        lin_mask_bottom = torch.cat([ones_mask, lin_mask_bottom], 0).repeat(w,1).rot90(3).repeat(c,1,1).unsqueeze(0)
        base_mask = base_mask * lin_mask_bottom
        
    # Apply mask to tile and return masked tile
    return tile * base_mask


def add_tiles(tiles, base_img, tile_coords, tile_size, overlap):

    # Check for any tiles that need different overlap values
    r, c = len(tile_coords[0]), len(tile_coords[1])
    f_ovlp = (tile_coords[0][r-1] - tile_coords[0][r-2], tile_coords[1][c-1] - tile_coords[1][c-2])

    h, w = tiles[0].size(2), tiles[0].size(3)
    t=0
    column, row, = 0, 0
    for y in tile_coords[0]:
        for x in tile_coords[1]:
            mask_sides=''
            c_overlap = overlap.copy()
            if row == 0:
                if row == len(tile_coords[0]) - 2:
                    mask_sides += 'bottom-special'
                    c_overlap[1] = f_ovlp[0] # Change bottom overlap
                else:
                    mask_sides += 'bottom'
            elif row > 0 and row < len(tile_coords[0]) -2:
                mask_sides += 'bottom,top'
            elif row == len(tile_coords[0]) - 2:
                if f_ovlp[0] > 0:
                    mask_sides += 'bottom-special,top'
                    c_overlap[1] = f_ovlp[0] # Change bottom overlap
                elif f_ovlp[0] <= 0:
                    mask_sides += 'bottom,top'
            elif row == len(tile_coords[0]) -1:
                if f_ovlp[0] > 0:
                    mask_sides += 'top-special'
                    c_overlap[0] = f_ovlp[0] # Change top overlap
                elif f_ovlp[0] <= 0:
                    mask_sides += 'top'

            if column == 0:
                if column == len(tile_coords[1]) -2:
                    mask_sides += ',right-special'
                    c_overlap[2] = f_ovlp[1] # Change right overlap
                else:
                    mask_sides += ',right'
            elif column > 0 and column < len(tile_coords[1]) -2:
                mask_sides += ',right,left'
            elif column == len(tile_coords[1]) -2:
                if f_ovlp[1] > 0:
                    mask_sides += ',right-special,left'
                    c_overlap[2] = f_ovlp[1] # Change right overlap
                elif f_ovlp[1] <= 0:
                    mask_sides += ',right,left'
            elif column == len(tile_coords[1]) -1:
                if f_ovlp[1] > 0:
                    mask_sides += ',left-special'
                    c_overlap[3] = f_ovlp[1] # Change left overlap
                elif f_ovlp[1] <= 0:
                    mask_sides += ',left'

            tile = mask_tile(tiles[t], c_overlap, side=mask_sides)
            base_img[:, :, y:y+tile_size[0], x:x+tile_size[1]] = base_img[:, :, y:y+tile_size[0], x:x+tile_size[1]] + tile
            t+=1
            column+=1
        row+=1
        column=0
    return base_img


# Calculate the coordinates for tiles
def get_tile_coords(d, tile_dim, overlap=0):
    move = int(tile_dim * (1-overlap))
    c, tile_start, coords = 1, 0, [0]
    while tile_start + tile_dim < d:
        tile_start = move * c
        if tile_start + tile_dim >= d:
            coords.append(d - tile_dim)
        else:
            coords.append(tile_start)
        c += 1
    return coords


# Calculates info required for tiling
def tile_setup(tile_size, overlap_percent, base_size):
    if type(tile_size) is not tuple and type(tile_size) is not list:
        tile_size = (tile_size, tile_size)
    if type(overlap_percent) is not tuple and type(overlap_percent) is not list:
        overlap_percent = (overlap_percent, overlap_percent)
    x_coords = get_tile_coords(base_size[1], tile_size[1], overlap_percent[1])
    y_coords = get_tile_coords(base_size[0], tile_size[0], overlap_percent[0])
    y_ovlp, x_ovlp = int(tile_size[0] * overlap_percent[0]), int(tile_size[1] * overlap_percent[1])
    return (y_coords, x_coords), tile_size, [y_ovlp, y_ovlp, x_ovlp, x_ovlp]


# Split tensor into tiles
def tile_image(img, tile_size, overlap_percent, info_only=False):
    tile_coords, tile_size, _ = tile_setup(tile_size, overlap_percent, (img.size(2), img.size(3)))

    # Cut out tiles
    tile_list = []
    for y in tile_coords[0]:
        for x in tile_coords[1]:
            tile = img[:, :, y:y + tile_size[0], x:x + tile_size[1]]
            tile_list.append(tile)
    return tile_list


# Put tiles back into the original tensor
def rebuild_image(tiles, image_size, tile_size, overlap_percent):
    base_img = torch.zeros(image_size, device=tiles[0].device)
    tile_coords, tile_size, overlap = tile_setup(tile_size, overlap_percent, (base_img.size(2), base_img.size(3)))
    return add_tiles(tiles, base_img, tile_coords, tile_size, overlap)

The above code can be tested with the code below:

import torchvision.transforms as transforms
from PIL import Image
import random

# Load image
def preprocess_simple(image_name, image_size):
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    image = Image.open(image_name).convert('RGB')
    return Loader(image).unsqueeze(0)
    
# Save image   
def deprocess_simple(output_tensor, output_name):
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.squeeze(0))
    image.save(output_name)    

test_input = preprocess_simple('tubingen.jpg', (1024,1024))
tile_size=260
overlap_percent=0.5

img_tiles = tile_image(test_input, tile_size=tile_size, overlap_percent=overlap_percent)

random.shuffle(img_tiles) # Comment this out to not randomize tile positions

output_tensor = rebuild_image(img_tiles, test_input.size(), tile_size=tile_size, overlap_percent=overlap_percent)
deprocess_simple(output_tensor, 'tiled_image.jpg')

I've included an example of what it does below (top is the original image, and the bottom is when I place the tiles back randomly to show off the blending system):

Original Image Tiled Image with random tile placement


Solution

  • I was able to remove all the bugs and simplify the code here: https://github.com/ProGamerGov/dream-creator/blob/master/utils/tile_utils.py

    The special masks were really only needed for 2 situations, and their were bugs in rebuild_tensor that I had to fix. Overlap percentages should be equal to or less than 50%.