Kullback-Leibler divergence (KL divergence) is a preferable alternative for cross-entropy with soft labels as the reason is explained in the previous post. The mathematical form of KL divergence is demonstrated in following equations:
\(D_{KL} = \sum \mathbf{g}(x) \log\frac{\mathbf{g}(x)}{\mathbf{f}(x)}\)
, where $\mathbf{f}(x)$ is predicted probabilities and $\mathbf{g}(x)$ is soft labels. Correct using of KL divergence in PyTorch (torch.nn.KLDivLoss()
) is relatively important. In this post, we describe three pitfalls in using torch.nn.KLDivLoss()
.
For reference, the default signature is listed as follows.
torch.nn.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=False)
Reduction Should Be Batchmean Rather Than Mean
In default reduction mode mean
, the losses are averaged for each minibatch over observations as well as over dimensions. batchmean
mode gives the correct KL divergence where losses are averaged over batch dimension only. mean
mode’s behavior will be changed to the same as batchmean
in the next major release.
Log Target Should Be True if Working with Probabilities
According to the previous mathematical definition, the targets and input are probabilities, which implies that the summation is always one (either hard or soft labels). In this situation, log_target
has to be True
, whereas its default status is False
.
import torch
import torch.nn as nn
# initialize tensors
loss_log = nn.KLDivLoss(reduction='batchmean', log_target=True)
loss_nonlog = nn.KLDivLoss(reduction='batchmean', log_target=False)
target = torch.tensor([[0.95,0.05]])
pred = target.clone()
print(f'log_target=True, loss is {loss_log(pred, target).item():.5f}')
print(f'log_target=False, loss is {loss_nonlog(pred, target).item():.5f}')
Output:
log_target=True, loss is 0.00000
log_target=False, loss is -1.10352
Input and Target Should Never Be Inverted
In the mathematical definition of KL divergence, $\mathbf{f}(x)$ can be arbitrarily interchanged with $\mathbf{g}(x)$. However, in PyTorch, input couldn’t be interchanged with targets, especially when log_target=False
.
# initialize tensors
loss_nonlog = nn.KLDivLoss(reduction='batchmean', log_target=False)
init = torch.tensor([[0.95,0.05,0]])
pred = F.softmax(init, dim=-1)
target = F.log_softmax(init, dim=-1)
print(f'log_target=False, loss is {loss_nonlog(pred, target).item()}') # correct
print(f'log_target=False, loss is {loss_nonlog(target, pred).item()}') # incorrect
print(f'log_target=True, loss is {loss_log(pred, pred).item()}') # correct with log_target=True
Output:
log_target=False, loss is 0.0
log_target=False, loss is -2.7026482385394957e-08
log_target=True, loss is 0.0