kerasgenerative-adversarial-networkadversarial-machines

Generative Adversarial Networks (GANs) in Keras - creating the combined model


I'm trying to create a pretty simple GANs model, and not sure how to combine the generator and the discriminator for training the generator

from keras import optimizers
from keras.layers import Input, Dense
from keras.models import Sequential, Model
import numpy as np
def build_generator(input_dim=10, output_dim=40, hidden_dim=28):
     model = Sequential()
     model.add(Dense(hidden_dim, input_dim=input_dim, activation='sigmoid', kernel_initializer="random_uniform"))
     model.add(Dense(output_dim, activation='sigmoid', kernel_initializer="random_uniform"))
     return model

def build_discriminator(input_dim=40, hidden_dim=28, output_dim=50):
    input_d = Input(shape=(input_dim,))
    encoded = Dense(hidden_dim, activation='sigmoid', kernel_initializer="random_uniform")(input_d)
    decoded = Dense(output_dim, activation='sigmoid', kernel_initializer="random_uniform")(encoded)
    x = Dense(1, activation='relu')(encoded)
    y = Dense(1, activation='sigmoid')(encoded)
    model = Model(inputs=input_d, outputs=[decoded, x, y])
    return model

sgd = optimizers.SGD(lr=0.1) 
generator = build_generator(10, 100, 70)
discriminator = build_discriminator(100, 60, 80)
generator.compile(loss='mean_squared_error', optimizer=sgd)
discriminator.trainable = True
discriminator.compile(loss='mean_squared_error', optimizer=sgd)
discriminator.trainable = False

Now I'm not sure how to combine them both, so the discriminator will receive the generator output and than will pass the generator back propagation data


Solution

  • For this, the best to do is to use the functional Model API. This is suited for more complex models, accepting branches, concatenations, etc.

    (It's still possible, in this specific case to use the sequential models, but using the functional API always sounded better to me, for freedom and further experiments on the models)

    So, you may preserve your two sequential models. All you have to do is to build a third model that contains these two.

    generator = build_generator(....) #don't create a new generator, use the one you have. 
    discriminator = build_discriminator(....)
    

    Now, a functional API model has its input shape defined like this:

    inputTensor = Input(inputShape) #inputShape must be the same as in generator     
    

    And we work by passing inputs to layers and getting outputs:

    #Getting the output of the generator given our input tensor:
    genOut = generator(inputTensor) #you call a model just like you call a layer    
    
    #and we pass the generator's output to the discriminator, getting its output:
    discOut = discriminator(genOut)
    

    Finally, we create the actual model by defining its start and end points:

    GAN = Model(inputTensor, discOut)
    

    Use the model.layers[i].trainable parameter before compile to define which layers will be trainable or not in each of the models.