I got 81.4% computation correct up to 15 decimal places. Most likely the exp() and the log() are giving a little precision problems, but I don't know exactly how.
I'm a bit confused by the cross entropy loss in PyTorch.
Considering this example:
import torch
import torch.nn as nn
from torch.autograd import Variable
output = Variable(torch.FloatTensor([0,0,0,1])).view(1, -1)
target = Variable(torch.LongTensor([3]))
criterion = nn.CrossEntropyLoss()
loss = criterion(output, target)
print(loss)
I would expect the loss to be 0. But I get:
Variable containing:
0.7437
[torch.FloatTensor of size 1]
As far as I know cross entropy can be calculated like this:
But shouldn't be the result then 1*log(1) = 0 ?
I tried different inputs like one-hot encodings, but this doesn't work at all, so it seems the input shape of the loss function is okay.
I would be really grateful if someone could help me out and tell me where my mistake is.
Thanks in advance!
Here I give the full formula to manually compute pytorch's CrossEntropyLoss. There is a little precision problem you will see later; do post an answer if you know the exact reason.
First, understand how NLLLoss works. Then CrossEntropyLoss is very similar, except it is NLLLoss with Softmax inside.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def compute_nllloss_manual(x,y0):
"""
x is the vector with shape (batch_size,C)
Note: official example uses log softmax(some vector) as x, so it becomes CELoss.
y0 shape is the same (batch_size), whose entries are integers from 0 to C-1
Furthermore, for C>1 classes, the other classes are ignored (see below
"""
loss = 0.
n_batch, n_class = x.shape
# print(n_class)
for x1,y1 in zip(x,y0):
class_index = int(y1.item())
loss = loss + x1[class_index] # other class terms, ignore.
loss = - loss/n_batch
return loss
We see from the formula that it is NOT like the standard prescribed NLLLoss because the "other class" terms are ignored (see the comment in the code). Also, remember that Pytorch often processes things in batches. In the following code, we randomly initiate 1000 batches to verify that the formula is correct up to 15 decimal places.
torch.manual_seed(0)
precision = 15
batch_size=10
C = 10
N_iter = 1000
n_correct_nll = 0
criterion = nn.NLLLoss()
for i in range(N_iter):
x = torch.rand(size=(batch_size,C)).to(torch.float)
y0 = torch.randint(0,C,size=(batch_size,))
nll_loss = criterion(x,y0)
manual_nll_loss = compute_nllloss_manual(x,y0)
if i==0:
print('NLLLoss:')
print('module:%s'%(str(nll_loss)))
print('manual:%s'%(str(manual_nll_loss)))
nll_loss_check = np.abs((nll_loss- manual_nll_loss).item())<10**-precision
if nll_loss_check: n_correct_nll+=1
print('percentage NLLLoss correctly computed:%s'%(str(n_correct_nll/N_iter*100)))
I got output like:
NLLLoss:
module:tensor(-0.4783)
manual:tensor(-0.4783)
percentage NLLLoss correctly computed:100.0
So far so good, 100% of the computations are correct. Now let us compute CrossEntropyLoss manually with the following.
def compute_crossentropyloss_manual(x,y0):
"""
x is the vector with shape (batch_size,C)
y0 shape is the same (batch_size), whose entries are integers from 0 to C-1
"""
loss = 0.
n_batch, n_class = x.shape
# print(n_class)
for x1,y1 in zip(x,y0):
class_index = int(y1.item())
loss = loss + torch.log(torch.exp(x1[class_index])/(torch.exp(x1).sum()))
loss = - loss/n_batch
return loss
And then repeat the procedure for 1000 randomly initiated batches.
torch.manual_seed(0)
precision = 15
batch_size=10
C = 10
N_iter = 1000
n_correct_CE = 0
criterion2 = nn.CrossEntropyLoss()
for i in range(N_iter):
x = torch.rand(size=(batch_size,C)).to(torch.float)
y0 = torch.randint(0,C,size=(batch_size,))
CEloss = criterion2(x,y0)
manual_CEloss = compute_crossentropyloss_manual(x,y0)
if i==0:
print('CrossEntropyLoss:')
print('module:%s'%(str(CEloss)))
print('manual:%s'%(str(manual_CEloss)))
CE_loss_check = np.abs((CEloss- manual_CEloss).item())<10**-precision
if CE_loss_check: n_correct_CE+=1
print('percentage CELoss correctly computed :%s'%(str(n_correct_CE/N_iter*100)))
the result is
CrossEntropyLoss:
module:tensor(2.3528)
manual:tensor(2.3528)
percentage CELoss correctly computed :81.39999999999999
I got 81.4% computation correct up to 15 decimal places. Most likely the exp() and the log() are giving a little precision problems, but I don't know exactly how.