Loss function for class imbalanced binary classifier in Tensor flow

ClassificationTensorflow

Classification Problem Overview


I am trying to apply deep learning for a binary classification problem with high class imbalance between target classes (500k, 31K). I want to write a custom loss function which should be like: minimize(100-((predicted_smallerclass)/(total_smallerclass))*100)

Appreciate any pointers on how I can build this logic.

Classification Solutions


Solution 1 - Classification

You can add class weights to the loss function, by multiplying logits. Regular cross entropy loss is this:

loss(x, class) = -log(exp(x[class]) / (\sum_j exp(x[j])))
               = -x[class] + log(\sum_j exp(x[j]))

in weighted case:

loss(x, class) = weights[class] * -x[class] + log(\sum_j exp(weights[class] * x[j]))

So by multiplying logits, you are re-scaling predictions of each class by its class weight.

For example:

ratio = 31.0 / (500.0 + 31.0)
class_weight = tf.constant([ratio, 1.0 - ratio])
logits = ... # shape [batch_size, 2]
weighted_logits = tf.mul(logits, class_weight) # shape [batch_size, 2]
xent = tf.nn.softmax_cross_entropy_with_logits(
  weighted_logits, labels, name="xent_raw")

There is a standard losses function now that supports weights per batch:

tf.losses.sparse_softmax_cross_entropy(labels=label, logits=logits, weights=weights)

Where weights should be transformed from class weights to a weight per example (with shape [batch_size]). See documentation here.

Solution 2 - Classification

The code you proposed seems wrong to me. The loss should be multiplied by the weight, I agree.

But if you multiply the logit by the class weights, you end with:

weights[class] * -x[class] + log( \sum_j exp(x[j] * weights[class]) )

The second term is not equal to:

weights[class] * log(\sum_j exp(x[j]))

To show this, we can be rewrite the latter as:

log( (\sum_j exp(x[j]) ^ weights[class] )

So here is the code I'm proposing:

ratio = 31.0 / (500.0 + 31.0)
class_weight = tf.constant([[ratio, 1.0 - ratio]])
logits = ... # shape [batch_size, 2]

weight_per_label = tf.transpose( tf.matmul(labels
                           , tf.transpose(class_weight)) ) #shape [1, batch_size]
# this is the weight for each datapoint, depending on its label

xent = tf.mul(weight_per_label
         , tf.nn.softmax_cross_entropy_with_logits(logits, labels, name="xent_raw") #shape [1, batch_size]
loss = tf.reduce_mean(xent) #shape 1

Solution 3 - Classification

Use tf.nn.weighted_cross_entropy_with_logits() and set pos_weight to 1 / (expected ratio of positives).

Solution 4 - Classification

You can check the guides at tensorflow https://www.tensorflow.org/api_guides/python/contrib.losses

...

While specifying a scalar loss rescales the loss over the entire batch, we sometimes want to rescale the loss per batch sample. For example, if we have certain examples that matter more to us to get correctly, we might want to have a higher loss that other samples whose mistakes matter less. In this case, we can provide a weight vector of length batch_size which results in the loss for each sample in the batch being scaled by the corresponding weight element. For example, consider the case of a classification problem where we want to maximize our accuracy but we especially interested in obtaining high accuracy for a specific class:

inputs, labels = LoadData(batch_size=3)
logits = MyModelPredictions(inputs)

# Ensures that the loss for examples whose ground truth class is `3` is 5x
# higher than the loss for all other examples.
weight = tf.multiply(4, tf.cast(tf.equal(labels, 3), tf.float32)) + 1

onehot_labels = tf.one_hot(labels, num_classes=5)
tf.contrib.losses.softmax_cross_entropy(logits, onehot_labels, weight=weight)

Solution 5 - Classification

I had to work with a similar unbalanced dataset of multiple classes and this is how I worked through it, hope it will help somebody looking for a similar solution:

This goes inside your training module:

from sklearn.utils.class_weight import compute_sample_weight
#use class weights for handling unbalanced dataset
if mode == 'INFER' #test/dev mode, not weighing loss in test mode
   sample_weights = np.ones(labels.shape)
else:
   sample_weights = compute_sample_weight(class_weight='balanced', y=labels)

This goes inside your model class definition:

#an extra placeholder for sample weights
#assuming you already have batch_size tensor
self.sample_weight = tf.placeholder(dtype=tf.float32, shape=[None],
                       name='sample_weights')
cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                       labels=self.label, logits=logits, 
                       name='cross_entropy_loss')
cross_entropy_loss = tf.reduce_sum(cross_entropy_loss*self.sample_weight) / batch_size

Solution 6 - Classification

Did ops tf.nn.weighted_cross_entropy_with_logits() for two classes:

classes_weights = tf.constant([0.1, 1.0])
cross_entropy = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=labels, pos_weight=classes_weights)

Solution 7 - Classification

""" Weighted binary crossentropy between an output tensor and a target tensor.
# Arguments
    pos_weight: A coefficient to use on the positive examples.
# Returns
    A loss function supposed to be used in model.compile().
"""
def weighted_binary_crossentropy(pos_weight=1):
    def _to_tensor(x, dtype):
        """Convert the input `x` to a tensor of type `dtype`.
        # Arguments
            x: An object to be converted (numpy array, list, tensors).
            dtype: The destination type.
        # Returns
            A tensor.
        """
        return tf.convert_to_tensor(x, dtype=dtype)
  
  
    def _calculate_weighted_binary_crossentropy(target, output, from_logits=False):
        """Calculate weighted binary crossentropy between an output tensor and a target tensor.
        # Arguments
            target: A tensor with the same shape as `output`.
            output: A tensor.
            from_logits: Whether `output` is expected to be a logits tensor.
                By default, we consider that `output`
                encodes a probability distribution.
        # Returns
            A tensor.
        """
        # Note: tf.nn.sigmoid_cross_entropy_with_logits
        # expects logits, Keras expects probabilities.
        if not from_logits:
            # transform back to logits
            _epsilon = _to_tensor(K.epsilon(), output.dtype.base_dtype)
            output = tf.clip_by_value(output, _epsilon, 1 - _epsilon)
            output = log(output / (1 - output))
        target = tf.dtypes.cast(target, tf.float32)
        return tf.nn.weighted_cross_entropy_with_logits(labels=target, logits=output, pos_weight=pos_weight)


    def _weighted_binary_crossentropy(y_true, y_pred):
        return K.mean(_calculate_weighted_binary_crossentropy(y_true, y_pred), axis=-1)
    
    return _weighted_binary_crossentropy

For usage:

pos = #count of positive class
neg = #count of negative class
total = pos + neg
weight_for_0 = (1 / neg)*(total)/2.0 
weight_for_1 = (1 / pos)*(total)/2.0

class_weight = {0: weight_for_0, 1: weight_for_1}

model = <your model>

model.compile(
	optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=weighted_binary_crossentropy(weight_for_1),
    metrics=tf.keras.metrics.Precision(name='precision')
)

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
QuestionVenkata Dikshit PappuView Question on Stackoverflow
Solution 1 - ClassificationilblackdragonView Answer on Stackoverflow
Solution 2 - ClassificationJL MeunierView Answer on Stackoverflow
Solution 3 - ClassificationMalay HaldarView Answer on Stackoverflow
Solution 4 - ClassificationVictor Mondejar-GuerraView Answer on Stackoverflow
Solution 5 - ClassificationbitspersecondView Answer on Stackoverflow
Solution 6 - ClassificationDenis ShcheglovView Answer on Stackoverflow
Solution 7 - Classificationtttzof351View Answer on Stackoverflow