pythontensorflowimage-processingcomputer-vision

n-dimensional sliding window operation in python using gpu accelerated libraries, preferably Tensorflow?


I am looking for a gpu-accelerated n-dimensional sliding window operation implementation in Python using Tensorflow. You can post your implementation in Torch, Caffe or Theano, but I'll choose the Tensorflow implementation as the accepted answer. Please post working code snippet that performs a 2d median filter operation (hopefully, with no code change or minimal code change, can be applied to n-dimensional images)

With my limited knowledge on Tensorflow, I believe the 2 potential modules to start with are sliding_window_batch or extract_image_patches and then with some map,apply,reshape magic?

My failed attempt is posted below, for entertainment. Please note I have posted a similar question 2 years ago, asking for a Theano implementation, nowadays, most people are using tf/keras or torch.

import time
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import tensorflow as tf
from tensorflow.contrib.data.python.ops import sliding
from skimage import img_as_float, data
from scipy.signal import medfilt

imgs = img_as_float(data.camera())

### SCIPY median ###
stime = time.time()
scipysmoothed = medfilt(imgs,(9,9))
etime = time.time()
print('scipy smoothed: {:1.4f} seconds'.format(etime-stime))

### Failed attempt of TF median ###
method = 'Tensorflow'
stime = time.time()

window_func = lambda x: tf.contrib.distributions.percentile(x, 50.0)

# create TensorFlow Dataset object
data = tf.data.Dataset.from_tensor_slices(imgs)

# sliding window - only 1d is allowed?
window = 3
stride = 1
data = data.apply(sliding.sliding_window_batch(window, stride)).map(lambda x: window_func(x))

# create TensorFlow Iterator object
iterator =  tf.data.Iterator.from_structure(data.output_types)
next_element = iterator.get_next()

# create initialization ops 
init_op = iterator.make_initializer(data)
c=0
smoothed = np.zeros(imgs.shape)
with tf.Session() as sess:
    # initialize the iterator on the data
    sess.run(init_op)
    while True:
        try:
            elem = sess.run(next_element)
            smoothed[c,:]=elem
            # obviously WRONG.
            c+=1
        except tf.errors.OutOfRangeError:
            #print("End of dataset.")
            break
#print(c)
etime = time.time()
print('tf smoothed: {:1.4f} seconds'.format(etime-stime))


plt.figure(figsize=(20,20))
plt.subplot(131)
plt.imshow(imgs,cmap='gray',interpolation='none')
plt.title('original')
plt.subplot(132)
plt.imshow(smoothed,cmap='gray',interpolation='none')
plt.title('actual smoothed\nwith {}'.format(method))
plt.subplot(133)
plt.imshow(scipysmoothed,cmap='gray',interpolation='none')
_=plt.title('expected smoothed')

.

scipy smoothed: 1.1899 seconds
tf smoothed: 0.7485 seconds

Solution

  • Proposal 1: My attempt is the below, since it just uses tf.image.extract_image_patches and tf.extract_volume_patches, the implementation supports only 2d and 3d images.

    Proposal 2: one could just format the data as a preprocessing step (via tf.data.Dataset.map), however this also takes alot of time, I am not sure why yet ( example https://gist.github.com/pangyuteng/ca5cb07fe383ebe59b521c832f2e2918 ).

    Proposal 3: use convolutional blocks to parallelize processing, see "Hypercolumns for Object Segmentation and Fine-grained Localization" https://arxiv.org/abs/1411.5752 .

    --

    additional solutions added 2025-06-30:

    Solution 4: monai's sliding_window_inference utilities https://docs.monai.io/en/latest/inferers.html#slidingwindowinferer

    Solution 5: nnNet's sliding_window_prediction utilities
    https://github.com/MIC-DKFZ/nnUNet/blob/f1851fbaf2c53dcb51b079b60a01de528a7d0c17/nnunetv2/inference/predict_from_raw_data.py#L634

    --

    Proposal 1 code:

    import time
    import matplotlib.pyplot as plt
    %matplotlib inline
    import numpy as np
    import tensorflow as tf
    from tensorflow.contrib.data.python.ops import sliding
    from skimage import img_as_float, data
    from scipy.signal import medfilt
    
    dtype = 2
    if dtype==2:
        imgs = img_as_float(data.camera())
    elif dtype==3:
        imgs = np.random.rand(28,28,28)
    
    imgs = img_as_float(data.camera())
    
    ### SCIPY median ###
    stime = time.time()
    scipysmoothed = medfilt(imgs,(9,9))
    etime = time.time()
    print('scipy smoothed: {:1.4f} seconds'.format(etime-stime))
    
    ### TF median ###
    method = 'Tensorflow'
    imgs = np.expand_dims(imgs,axis=-1)
    imgs = np.expand_dims(imgs,axis=0)
    print('imgs.shape:{}'.format(imgs.shape))
    imgs = tf.cast(imgs,tf.float32)
    
    stime = time.time()
    
    if len(imgs.shape) == 4:
        kernel=(1,9,9,1)
        stride=(1,1,1,1)
        rates=(1,1,1,1)
        padding='SAME'
        patches=tf.image.extract_image_patches(
            imgs,kernel,stride,rates,padding,
        )    
        _,x,y,n = patches.shape
        _,sx,sy,_ = kernel
        window_func = lambda x: tf.contrib.distributions.percentile(x, 50.0)
        patches = tf.reshape(patches,[x*y,sx,sy])
        smoothed = tf.map_fn(lambda x: window_func(patches[x,:,:]), tf.range(x*y), dtype=tf.float32)
        smoothed = tf.reshape(smoothed,[x,y])
    
    elif len(imgs.shape) == 5:
        
        kernel=(1,12,12,12,1)
        stride=(1,1,1,1,1)    
        padding='SAME'
        patches=tf.extract_volume_patches(
            imgs,kernel,stride,padding,
        )
        _,x,y,z,n = patches.shape
        _,sx,sy,sz,_ = kernel
        window_func = lambda x: tf.contrib.distributions.percentile(x, 50.0)
        patches = tf.reshape(patches,[x*y*z,sx,sy,sz])
        smoothed = tf.map_fn(lambda x: window_func(patches[x,:,:]), tf.range(x*y*z), dtype=tf.float32)
        smoothed = tf.reshape(smoothed,[x,y,z])
        
    else:
        raise NotImplemented()
    
    with tf.Session() as sess:
        output = sess.run(smoothed)
        
    etime = time.time()
    print('tf smoothed: {:1.4f} seconds'.format(etime-stime))
    
    print(output.shape)
    
    plt.figure(figsize=(20,20))
    plt.subplot(131)
    imgs = img_as_float(data.camera())
    plt.imshow(imgs.squeeze(),cmap='gray',interpolation='none')
    plt.title('original')
    plt.subplot(132)
    plt.imshow(output.squeeze(),cmap='gray',interpolation='none')
    plt.title('actual smoothed\nwith {}'.format(method))
    plt.subplot(133)
    plt.imshow(scipysmoothed,cmap='gray',interpolation='none')
    _=plt.title('expected smoothed')