QBoard » Artificial Intelligence & ML » AI and ML - Python » Loss function for class imbalanced binary classifier in Tensor flow

Loss function for class imbalanced binary classifier in Tensor flow

  • I am trying to apply deep learning for a binary classification problem with high class imbalance between target classes (500k, 31K). I want to write a custom loss function which should be like: minimize(100-((predicted_smallerclass)/(total_smallerclass))*100)

    Appreciate any pointers on how I can build this logic.

      July 31, 2020 4:17 PM IST
    0
  • Use tf.nn.weighted_cross_entropy_with_logits() and set pos_weight to 1 / (expected ratio of positives).
      July 31, 2020 4:19 PM IST
    1
  • You can add class weights to the loss function, by multiplying logits. Regular cross entropy loss is this:

    loss(x, class) = -log(exp(x[class]) / (\sum_j exp(x[j])))
                   = -x[class] + log(\sum_j exp(x[j]))

    in weighted case:

    loss(x, class) = weights[class] * -x[class] + log(\sum_j exp(weights[class] * x[j]))

    So by multiplying logits, you are re-scaling predictions of each class by its class weight.

    For example:

    ratio = 31.0 / (500.0 + 31.0)
    class_weight = tf.constant([ratio, 1.0 - ratio])
    logits = ... # shape [batch_size, 2]
    weighted_logits = tf.mul(logits, class_weight) # shape [batch_size, 2]
    xent = tf.nn.softmax_cross_entropy_with_logits(
      weighted_logits, labels, name="xent_raw")

    There is a standard losses function now that supports weights per batch:

    tf.losses.sparse_softmax_cross_entropy(labels=label, logits=logits, weights=weights)

    Where weights should be transformed from class weights to a weight per example (with shape [batch_size]). See documentation here.

     
      September 12, 2020 4:57 PM IST
    1
  • You can check the guides at tensorflow https://www.tensorflow.org/api_guides/python/contrib.losses

    While specifying a scalar loss rescales the loss over the entire batch, we sometimes want to rescale the loss per batch sample. For example, if we have certain examples that matter more to us to get correctly, we might want to have a higher loss that other samples whose mistakes matter less. In this case, we can provide a weight vector of length batch_size which results in the loss for each sample in the batch being scaled by the corresponding weight element. For example, consider the case of a classification problem where we want to maximize our accuracy but we especially interested in obtaining high accuracy for a specific class:

    inputs, labels = LoadData(batch_size=3)
    logits = MyModelPredictions(inputs)
    
    # Ensures that the loss for examples whose ground truth class is `3` is 5x
    # higher than the loss for all other examples.
    weight = tf.multiply(4, tf.cast(tf.equal(labels, 3), tf.float32)) + 1
    
    onehot_labels = tf.one_hot(labels, num_classes=5)
    tf.contrib.losses.softmax_cross_entropy(logits, onehot_labels, weight=weight)
     
      September 12, 2020 5:00 PM IST
    0
  • I had to work with a similar unbalanced dataset of multiple classes and this is how I worked through it, hope it will help somebody looking for a similar solution:

    This goes inside your training module:

    from sklearn.utils.class_weight import compute_sample_weight
    #use class weights for handling unbalanced dataset
    if mode == 'INFER' #test/dev mode, not weighing loss in test mode
       sample_weights = np.ones(labels.shape)
    else:
       sample_weights = compute_sample_weight(class_weight='balanced', y=labels)

    This goes inside your model class definition:

    #an extra placeholder for sample weights
    #assuming you already have batch_size tensor
    self.sample_weight = tf.placeholder(dtype=tf.float32, shape=[None],
                           name='sample_weights')
    cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                           labels=self.label, logits=logits, 
                           name='cross_entropy_loss')
    cross_entropy_loss = tf.reduce_sum(cross_entropy_loss*self.sample_weight) / batch_size
      September 12, 2020 5:09 PM IST
    0