This is what I did in each epoch
val_loss += loss
val_loss = val_loss / len(trainloader)
if val_loss < min_val_loss:
#Saving the model
if min_loss > loss.item():
min_loss = loss.item()
best_model = copy.deepcopy(loaded_model.state_dict())
print('Min loss %0.2f' % min_loss)
epochs_no_improve = 0
min_val_loss = val_loss
else:
epochs_no_improve += 1
# Check early stopping condition
if epochs_no_improve == n_epochs_stop:
print('Early stopping!' )
loaded_model.load_state_dict(best_model)
Donno how correct it is (I took most parts of this code from a post on another website, but forgot where, so I can't put the reference link. I have just modified it a bit), hope you find it useful, in case I'm wrong, kindly point out the mistake. Thank you