tensorflowdeep-learning

How to stop gradient for some entry of a tensor in tensorflow


I am trying to implement an embedding layer. The embedding is going to be initialized using pre-trained glove embedding. For words that can be found in glove, it will be fixed. For those words that don't appear in glove, it will be initialized randomly, and will be trainable. How do I do it in tensorflow? I am aware that there is a tf.stop_gradient for a whole tensor, is there any kind of stop_gradient api for this kind of scenario? or, is there any workaround for this? any suggestion is appreciated


Solution

  • So the idea is to use mask and tf.stop_gradient to crack this problem:

    res_matrix = tf.stop_gradient(mask_h*E) + mask*E,

    where in matrix mask, 1 denotes to which entry I would like to apply gradient, 0 denotes to which entry I don't want to apply gradient(set gradient to 0), mask_h is the invese of mask (1 flip to 0, 0 flip to 1) .Then we can fetch from the res_matrix . here is the testing code:

    import tensorflow as tf
    import numpy as np
    
    def entry_stop_gradients(target, mask):
        mask_h = tf.abs(mask-1)
        return tf.stop_gradient(mask_h * target) + mask * target
    
    mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
    mask_h = np.abs(mask-1)
    
    emb = tf.constant(np.ones([10, 5]))
    
    matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))
    
    parm = np.random.randn(5, 1)
    t_parm = tf.constant(parm)
    
    loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
    grad1 = tf.gradients(loss, emb)
    grad2 = tf.gradients(loss, matrix)
    print matrix
    with tf.Session() as sess:
        print sess.run(loss)
        print sess.run([grad1, grad2])