pythontensorflowtensorflow-layers

How to use tf.contrib.model_pruning on MNIST?


I'm struggling to use Tensorflow's pruning library and haven't found many helpful examples so I'm looking for help to prune a simple model trained on the MNIST dataset. If anyone can either help fix my attempt or provide an example of how to use the library on MNIST I would be very grateful.

The first half of my code is pretty standard except my model has 2 hidden layers 300 units wide using layers.masked_fully_connected for pruning.

import tensorflow as tf
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.examples.tutorials.mnist import input_data

# Import dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# Define Placeholders
image = tf.placeholder(tf.float32, [None, 784])
label = tf.placeholder(tf.float32, [None, 10])

# Define the model
layer1 = layers.masked_fully_connected(image, 300)
layer2 = layers.masked_fully_connected(layer1, 300)
logits = tf.contrib.layers.fully_connected(layer2, 10, tf.nn.relu)

# Loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))

# Training op
train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)

# Accuracy ops
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

Then I attempt to define the necessary pruning operations but I get an error.

############ Pruning Operations ##############
# Create global step variable
global_step = tf.contrib.framework.get_or_create_global_step()

# Create a pruning object using the pruning specification
pruning_hparams = pruning.get_pruning_hparams()
p = pruning.Pruning(pruning_hparams, global_step=global_step)

# Mask Update op
mask_update_op = p.conditional_mask_update_op()

# Set up the specification for model pruning
prune_train = tf.contrib.model_pruning.train(train_op=train_op, logdir=None, mask_update_op=mask_update_op)

Error on this line:

prune_train = tf.contrib.model_pruning.train(train_op=train_op, logdir=None, mask_update_op=mask_update_op)

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder_1' with dtype float and shape [?,10] [[Node: Placeholder_1 = Placeholderdtype=DT_FLOAT, shape=[?,10], _device="/job:localhost/replica:0/task:0/device:GPU:0"]] [[Node: global_step/_57 = _Recv_start_time=0, client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_71_global_step", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

I assume it wants a different type of operation in place of train_op but I haven't found any adjustments that work.

Again if you have a different working example that prunes a model trained on MNIST I would consider that an answer.


Solution

  • The simplest pruning library example I could get working, figured I'd post it here in case it helps some other noobie who has a hard time with the documentation.

    import tensorflow as tf
    from tensorflow.contrib.model_pruning.python import pruning
    from tensorflow.contrib.model_pruning.python.layers import layers
    from tensorflow.examples.tutorials.mnist import input_data
    
    epochs = 250
    batch_size = 55000 # Entire training set
    
    # Import dataset
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    batches = int(len(mnist.train.images) / batch_size)
    
    # Define Placeholders
    image = tf.placeholder(tf.float32, [None, 784])
    label = tf.placeholder(tf.float32, [None, 10])
    
    # Define the model
    layer1 = layers.masked_fully_connected(image, 300)
    layer2 = layers.masked_fully_connected(layer1, 300)
    logits = layers.masked_fully_connected(layer2, 10)
    
    # Create global step variable (needed for pruning)
    global_step = tf.train.get_or_create_global_step()
    reset_global_step_op = tf.assign(global_step, 0)
    
    # Loss function
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))
    
    # Training op, the global step is critical here, make sure it matches the one used in pruning later
    # running this operation increments the global_step
    train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step)
    
    # Accuracy ops
    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    # Get, Print, and Edit Pruning Hyperparameters
    pruning_hparams = pruning.get_pruning_hparams()
    print("Pruning Hyperparameters:", pruning_hparams)
    
    # Change hyperparameters to meet our needs
    pruning_hparams.begin_pruning_step = 0
    pruning_hparams.end_pruning_step = 250
    pruning_hparams.pruning_frequency = 1
    pruning_hparams.sparsity_function_end_step = 250
    pruning_hparams.target_sparsity = .9
    
    # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam
    p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=.9)
    prune_op = p.conditional_mask_update_op()
    
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
    
        # Train the model before pruning (optional)
        for epoch in range(epochs):
            for batch in range(batches):
                batch_xs, batch_ys = mnist.train.next_batch(batch_size)
                sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})
    
            # Calculate Test Accuracy every 10 epochs
            if epoch % 10 == 0:
                acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
                print("Un-pruned model step %d test accuracy %g" % (epoch, acc_print))
    
        acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
        print("Pre-Pruning accuracy:", acc_print)
        print("Sparsity of layers (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))
    
        # Reset the global step counter and begin pruning
        sess.run(reset_global_step_op)
        for epoch in range(epochs):
            for batch in range(batches):
                batch_xs, batch_ys = mnist.train.next_batch(batch_size)
                # Prune and retrain
                sess.run(prune_op)
                sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})
    
            # Calculate Test Accuracy every 10 epochs
            if epoch % 10 == 0:
                acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
                print("Pruned model step %d test accuracy %g" % (epoch, acc_print))
                print("Weight sparsities:", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))
    
        # Print final accuracy
        acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
        print("Final accuracy:", acc_print)
        print("Final sparsity by layer (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))