I want to create custom activation function in TF2. The math is like this:
def sqrt_activation(x):
if x >= 0:
return tf.math.sqrt(x)
else:
return -tf.math.sqrt(-x)
The problem is that I can't compare x
with 0 since x
is a tensor. How to achieve this functionality?
You can skip the comparison by doing,
def sqrt_activation(x):
return tf.math.sign(x)*tf.math.sqrt(tf.abs(x))