top of page
Search
Writer's pictureAnastasia Karavdina

Your loss function does not look like the one in a textbook. Should you panic?



In your Machine Learning course, you have seen a graphic of a loss function, which most probably looked like this:



But in reality, loss curves can be quite challenging to interpret.

1. My Model Won't Train!

If your loss function looks like this:


What does it mean? Your model doesn't converge and you definitely should find out why.

Try these debugging steps:

  1. Validate your data. Do all data points match the data schema? Do you have outliers and/or NaNs?

  2. Do the features encode predictive signals? Using correlation matrices, you can find linear correlations between individual features and labels.

  3. Correlation matrices will not detect nonlinear correlations between features and labels. In this case, you can choose 10 examples from your dataset from which your model can easily learn. Then, ensure your model can achieve a very small loss on these 10 easily-learnable examples. Then, continue debugging your model on the full dataset.

  4. Try to reduce your learning rate to prevent the model from bouncing around in parameter space.

  5. Simplify your model and ensure the model outperforms your baseline. Then incrementally add complexity to the model.


2. My Loss Exploded!

The explosion of the loss function looks like this:



A large increase in loss is typically caused by anomalous values in input data. Possible causes are:

  • NaNs in input data.

  • Exploding gradient due to anomalous data.

  • Division by zero.

  • Logarithm of zero or negative numbers.

To fix an exploding loss, check for anomalous data in your batches. If the anomaly appears problematic, then investigate the cause. Otherwise, if the anomaly looks like outlying data, ensure the outliers are evenly distributed between batches by shuffling your data.


3. My Metrics are Contradictory!

Sometimes your loss function looks just fine (finally!), but recall is exactly 0 as on the plot below:



Recall is stuck at 0 because your examples' classification probability is never higher than the threshold for positive classification. This situation often occurs in problems with a large class imbalance. Remember that ML libraries, such as TF Keras, typically use a default threshold 0.5 to calculate classification metrics.

Try these steps:

  • Lower your classification threshold.

  • Check threshold-invariant metrics, such as AUC.

Also, it's a good idea to think twice before you use imbalanced data for ML model training. Your model is learning dominating class much better than minor class(es). Are you taking this into account in the model application?


4. Testing Loss is Too Damn High!

For some reason, your loss function looks much higher in the test than in training:



Your model is overfitting to the training data and you need to fix it. 

Try these steps:

  • Reduce model capacity.

  • Add regularization.

  • Check that the training and test splits are statistically equivalent.


5. My Model Gets Stuck

Sometimes your loss might show repetitive, step-like behavior:


The input data seen by your model probably is itself exhibiting repetitive behavior. Try shuffling the data to remove repetitive behavior during the training.


If your loss curve does not look like the one in a textbook, you should not panic. Consider it as occasion to learn more about the model and dataset. Any issue you meet on your way should only make you stronger expert in Machine Learning! Happy debugging ;)

15 views0 comments

Recent Posts

See All

Comments


bottom of page