I am looking for a gpu-accelerated n-dimensional sliding window operation implementation in Python using Tensorflow. You can post your implementation in Torch, Caffe or Theano, but I'll choose the Tensorflow implementation as the accepted answer. Please post working code snippet that performs a 2d median filter operation (hopefully, with no code change or minimal code change, can be applied to n-dimensional images)
With my limited knowledge on Tensorflow, I believe the 2 potential modules to start with are sliding_window_batch
or extract_image_patches
and then with some map
,apply
,reshape
magic?
My failed attempt is posted below, for entertainment. Please note I have posted a similar question 2 years ago, asking for a Theano implementation, nowadays, most people are using tf/keras or torch.
import time
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import tensorflow as tf
from tensorflow.contrib.data.python.ops import sliding
from skimage import img_as_float, data
from scipy.signal import medfilt
imgs = img_as_float(data.camera())
### SCIPY median ###
stime = time.time()
scipysmoothed = medfilt(imgs,(9,9))
etime = time.time()
print('scipy smoothed: {:1.4f} seconds'.format(etime-stime))
### Failed attempt of TF median ###
method = 'Tensorflow'
stime = time.time()
window_func = lambda x: tf.contrib.distributions.percentile(x, 50.0)
# create TensorFlow Dataset object
data = tf.data.Dataset.from_tensor_slices(imgs)
# sliding window - only 1d is allowed?
window = 3
stride = 1
data = data.apply(sliding.sliding_window_batch(window, stride)).map(lambda x: window_func(x))
# create TensorFlow Iterator object
iterator = tf.data.Iterator.from_structure(data.output_types)
next_element = iterator.get_next()
# create initialization ops
init_op = iterator.make_initializer(data)
c=0
smoothed = np.zeros(imgs.shape)
with tf.Session() as sess:
# initialize the iterator on the data
sess.run(init_op)
while True:
try:
elem = sess.run(next_element)
smoothed[c,:]=elem
# obviously WRONG.
c+=1
except tf.errors.OutOfRangeError:
#print("End of dataset.")
break
#print(c)
etime = time.time()
print('tf smoothed: {:1.4f} seconds'.format(etime-stime))
plt.figure(figsize=(20,20))
plt.subplot(131)
plt.imshow(imgs,cmap='gray',interpolation='none')
plt.title('original')
plt.subplot(132)
plt.imshow(smoothed,cmap='gray',interpolation='none')
plt.title('actual smoothed\nwith {}'.format(method))
plt.subplot(133)
plt.imshow(scipysmoothed,cmap='gray',interpolation='none')
_=plt.title('expected smoothed')
.
scipy smoothed: 1.1899 seconds
tf smoothed: 0.7485 seconds
Proposal 1: My attempt is the below, since it just uses tf.image.extract_image_patches
and tf.extract_volume_patches
, the implementation supports only 2d and 3d images.
Proposal 2: one could just format the data as a preprocessing step (via tf.data.Dataset.map
), however this also takes alot of time, I am not sure why yet ( example https://gist.github.com/pangyuteng/ca5cb07fe383ebe59b521c832f2e2918 ).
Proposal 3: use convolutional blocks to parallelize processing, see "Hypercolumns for Object Segmentation and Fine-grained Localization" https://arxiv.org/abs/1411.5752 .
--
additional solutions added 2025-06-30:
Solution 4: monai's sliding_window_inference
utilities https://docs.monai.io/en/latest/inferers.html#slidingwindowinferer
Solution 5: nnNet's sliding_window_prediction
utilities
https://github.com/MIC-DKFZ/nnUNet/blob/f1851fbaf2c53dcb51b079b60a01de528a7d0c17/nnunetv2/inference/predict_from_raw_data.py#L634
--
Proposal 1 code:
import time
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import tensorflow as tf
from tensorflow.contrib.data.python.ops import sliding
from skimage import img_as_float, data
from scipy.signal import medfilt
dtype = 2
if dtype==2:
imgs = img_as_float(data.camera())
elif dtype==3:
imgs = np.random.rand(28,28,28)
imgs = img_as_float(data.camera())
### SCIPY median ###
stime = time.time()
scipysmoothed = medfilt(imgs,(9,9))
etime = time.time()
print('scipy smoothed: {:1.4f} seconds'.format(etime-stime))
### TF median ###
method = 'Tensorflow'
imgs = np.expand_dims(imgs,axis=-1)
imgs = np.expand_dims(imgs,axis=0)
print('imgs.shape:{}'.format(imgs.shape))
imgs = tf.cast(imgs,tf.float32)
stime = time.time()
if len(imgs.shape) == 4:
kernel=(1,9,9,1)
stride=(1,1,1,1)
rates=(1,1,1,1)
padding='SAME'
patches=tf.image.extract_image_patches(
imgs,kernel,stride,rates,padding,
)
_,x,y,n = patches.shape
_,sx,sy,_ = kernel
window_func = lambda x: tf.contrib.distributions.percentile(x, 50.0)
patches = tf.reshape(patches,[x*y,sx,sy])
smoothed = tf.map_fn(lambda x: window_func(patches[x,:,:]), tf.range(x*y), dtype=tf.float32)
smoothed = tf.reshape(smoothed,[x,y])
elif len(imgs.shape) == 5:
kernel=(1,12,12,12,1)
stride=(1,1,1,1,1)
padding='SAME'
patches=tf.extract_volume_patches(
imgs,kernel,stride,padding,
)
_,x,y,z,n = patches.shape
_,sx,sy,sz,_ = kernel
window_func = lambda x: tf.contrib.distributions.percentile(x, 50.0)
patches = tf.reshape(patches,[x*y*z,sx,sy,sz])
smoothed = tf.map_fn(lambda x: window_func(patches[x,:,:]), tf.range(x*y*z), dtype=tf.float32)
smoothed = tf.reshape(smoothed,[x,y,z])
else:
raise NotImplemented()
with tf.Session() as sess:
output = sess.run(smoothed)
etime = time.time()
print('tf smoothed: {:1.4f} seconds'.format(etime-stime))
print(output.shape)
plt.figure(figsize=(20,20))
plt.subplot(131)
imgs = img_as_float(data.camera())
plt.imshow(imgs.squeeze(),cmap='gray',interpolation='none')
plt.title('original')
plt.subplot(132)
plt.imshow(output.squeeze(),cmap='gray',interpolation='none')
plt.title('actual smoothed\nwith {}'.format(method))
plt.subplot(133)
plt.imshow(scipysmoothed,cmap='gray',interpolation='none')
_=plt.title('expected smoothed')