
Why to use combined loss function for segmentation and classification

I tried to modify the U-NET model for one-dimensional data by providing an additional branch, attached to the last encoding block, whose purpose is to classify data into eighteen areas of the cortex. (inspiration based on this paper:

The part that is responsible for semantic segmentation performs delineation of data (cortex characteristics in one-dimensional form) into six layers.

The example input shape is (15000, 256, 1), where 15000 is the number of 1D cortical profiles (or vectors), 256 is a depth of single profile, and 1 is a cortex characteristic or channel (e.g intensity or estimated size of cells). For segmentation, target is an one-hot (15000, 256, 6) where six is the number of layers and for areas classification (15000, 18).

What puzzles me, however, is this:

I tried both approaches and observed that, when I join loss functions, the obtained accuracy and loss for classification are much lower than when using separate loss functions. Which brings me to the next questions:

Separate loss functions:
883/883 ━━━━━━━━━━━━━━━━━━━━ 63s 71ms/step - classification_accuracy: 0.8877 - classification_loss: 0.2729 - loss: 0.5143 - segmentation_accuracy: 0.8440 - segmentation_loss: 0.2413

Jointed loss functions:
883/883 ━━━━━━━━━━━━━━━━━━━━ 56s 63ms/step - classification_accuracy: 0.0556 - classification_loss: 2.0994e-07 - loss: 0.3952 - segmentation_accuracy: 0.6453 - segmentation_loss: 0.3952

The function that I use to joint both, classification and segmentation loss is:

def joint_loss(y_true, y_pred):

    # Unpack the true values
    y_true_segmentation = y_true[0]
    y_true_classification = y_true[1]
    y_pred_segmentation = y_pred[0]
    y_pred_classification = y_pred[1]
    # segmentation loss
    segmentation_loss = tf.keras.losses.categorical_crossentropy(
    # classification loss 
    classification_loss = tf.keras.losses.categorical_crossentropy(

    segmentation_loss = tf.reduce_mean(segmentation_loss)
    classification_loss = tf.reduce_mean(classification_loss)

    alpha = 1.0  # weight for segmentation
    beta = 1.0   # weight for classification
    total_loss = alpha * segmentation_loss + beta * classification_loss

    return total_loss

then I compile the model:

loss = {"segmentation": "categorical_crossentropy",
"classification": "categorical_crossentropy"}

metrics = {"segmentation": "accuracy",
"classification": "accuracy"}


where loss states for separately used categorical_crossentropy for classification and segmentation.

Is it possible that the problem occurs due to model modification (additional output branch)?


  • There is no such thing as "separate loss functions" in tensorflow. When you ask it to minimize a multidimensional object (e.g. 2 numbers) it simply adds them. So there is no real choice between "joint or separate". The confusing thing about tensorflow is that it adds these number implicitly, without telling you.

    What are the real options here?

    1. A single loss function, potentially with weight alpha (so alpha*l1 + (1-alpha)*l2. The perks are: relative simplicity, only one added hyperparameter, and in general - this is the most common approach in all of machine learning. Note that you don't need "beta" at all. It is meaningless here. Just use alpha and 1-alpha.

    2. Separate optimization procedures (not just losses!). This can take many forms, you can have 2 separate adams, that you call opne after the other in each iteration. This gives more flexibility, allows separate statistics per loss, etc. but ... it is hard to do. You end up with a huge design space, multiple hyperparameters to choose. So unless you are an experienced researcher who knows exactly what they are doing you are unlikely to really benefit from it. This also includes things like null space projections etc. or using one loss as pretraining, and then separate training procedure that train s the main one.

    When you mention "problem" you probably mean that the main loss started behaving badly. In general adding an auxiliary loss has no guarantee to help you. Typically, people just tune the weights. So if you have alpha*main loss + (1-alpha)*new loss, just start with alpha=1, things should behave exactly as without new loss. If they don't - you have a bug. Once this is tested, slowly decrease alpha, check 0.999, 0.99, 0.9, etc. plot the behaviour and decide what works for you.

    If you want to learn more about various challenges of auxiliary losses take a look at which has toy problems showing why it can sometimes help, and sometimes do the opposite.