pythontensorflowmachine-learningcomputer-visionmnist

Augment MNIST dataset tensorflow


I am trying to augment the MNIST dataset. This is what I tried. Can't get any success.

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

X = mnist.train.images
y = mnist.train.labels

def flip_images(X_imgs):
    X_flip = []
    tf.reset_default_graph()
    X = tf.placeholder(tf.float32, shape = (28, 28, 1))
    input_d = tf.reshape(X_imgs, [-1, 28, 28, 1])
    tf_img1 = tf.image.flip_left_right(X)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for img in input_d:
            flipped_imgs = sess.run([tf_img1], feed_dict = {X: img})
            X_flip.extend(flipped_imgs)
    X_flip = np.array(X_flip, dtype = np.float32)
    return X_flip

flip = flip_images(X)

What am I doing wrong? I can't seem to figure out.

Error:

Line: for img in input_d:
raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable

Solution

  • First, note that your tf.reshape changes the type from an ndarray to a tensor. It will take an .eval() call to bring it back down. In that for loop, you are trying to iterate over a tensor (not a list or a true iterable), consider indexing numerically as in:

    X = mnist.train.images
    y = mnist.train.labels
    
    def flip_images(X_imgs):
    
        X_flip = []
        tf.reset_default_graph()
        X = tf.placeholder(tf.float32, shape = (28, 28, 1))
    
        input_d = tf.reshape(X_imgs, [-1, 28, 28, 1])
        tf_img1 = tf.image.flip_left_right(X)
    
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())        
            for img_ind in range(input_d.shape[0]):
                img = input_d[img_ind].eval()
                flipped_imgs = sess.run([tf_img1], feed_dict={X: img})
                X_flip.extend(flipped_imgs)
        X_flip = np.array(X_flip, dtype = np.float32)
        return X_flip
    
    flip = flip_images(X)
    

    Let me know if this resolves your issue! Might want to set the range to a small constant for testing, this could take a while if you don't have a GPU around.