I am attempting to use a CNN to classify medical images in python using keras. These medical images also include textual information such as age and gender that can influence the decision of the model. How can I train a CNN that can train using both the images and the real world information so that it can make classifications base of both?
There are a couple of possibilities that I can think of off the type of my head, but the simplest is to extract some features from the medical images with a CNN, then flatten the result of the CNN, and concatenate the non-image data. Here is an idea supposing you have 512x512 images and 10 classes. This is the functional API which allows you to have multiple inputs.
import tensorflow as tf
import numpy as np
num_classes = 10
H,W = 512, 512
# Define inputs with their shapes
imgs = tf.keras.Input((H,W,3), dtype = tf.float32)
genders = tf.keras.Input(1, dtype = tf.float32)
ages = tf.keras.Input(1, dtype = tf.float32)
# Extract image features
features = tf.keras.layers.Conv2D(64, 4, strides = 4, activation = 'relu')(imgs)
features = tf.keras.layers.MaxPooling2D()(features)
features = tf.keras.layers.Conv2D(128,3, strides = 2, activation = 'relu')(features)
features = tf.keras.layers.MaxPooling2D()(features)
features = tf.keras.layers.Conv2D(256, 3, strides = 2, activation = 'relu')(features)
features = tf.keras.layers.Conv2D(512, 3, strides = 2, activation = 'relu')(features)
# #Flatten output
flat_features = tf.keras.layers.Flatten()(features)
#Concatenate gender and age
flat_features = tf.concat([flat_features, genders, ages], -1)
# Downsample
xx = tf.keras.layers.Dense(2048, activation = 'relu')(flat_features)
xx = tf.keras.layers.Dense(1024, activation = 'relu')(xx)
xx = tf.keras.layers.Dense(512, activation = 'relu')(xx)
#Calculate probabilities for each class
logits = tf.keras.layers.Dense(num_classes)(xx)
probs = tf.keras.layers.Softmax()(logits)
model = tf.keras.Model(inputs = [imgs, genders, ages], outputs = probs)
model.summary()
This architecture is not especially standard, and you might want to make the decoder deeper and/or decrease the number of parameters in the CNN encoder.