I'm working on a GAN with generator and discriminator.
@tf.function
def train_step(input_image, target, step):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
gen_output = generator(input_image, training=True)
disc_real_output = discriminator([input_image, target], training=True)
disc_generated_output = discriminator([input_image, gen_output], training=True)
gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
generator_gradients = gen_tape.gradient(gen_total_loss,
generator.trainable_variables)
discriminator_gradients = disc_tape.gradient(disc_loss,
discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(generator_gradients,
generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
discriminator.trainable_variables))
with summary_writer.as_default():
tf.summary.scalar('gen_total_loss', gen_total_loss, step=step//1000)
tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=step//1000)
tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=step//1000)
tf.summary.scalar('disc_loss', disc_loss, step=step//1000)
This function throws an error:
TypeError: in user code:
File "/tmp/ipykernel_34/3224399777.py", line 9, in train_step *
gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
File "/tmp/ipykernel_34/3072633757.py", line 5, in generator_loss *
l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
TypeError: Input 'y' of 'Sub' Op has type float32 that does not match type uint8 of argument 'x'.
But I try to subtract it manually, it works just fine, they are both float32
target - gen_output
<tf.Tensor: shape=(1, 256, 256, 3), dtype=float32, numpy=
array([[[[185.98402 , 151.92749 , 81.13361 ],
[186.15788 , 151.78894 , 80.930176],
[185.86765 , 151.81358 , 80.65687 ],
...,
[183.64613 , 151.91382 , 87.36469 ],
[183.17218 , 152.08833 , 86.43396 ],
[183.51439 , 152.04149 , 87.40147 ]],
...
Just convert target to float32 from the beginning.
target.asarray(tf.float32)