(PyTorch) Pitfalls in KL Divergence

 

Kullback-Leibler divergence (KL divergence) is a preferable alternative for cross-entropy with soft labels as the reason is explained in the previous post. The mathematical form of KL divergence is demonstrated in following equations:
\(D_{KL} = \sum \mathbf{g}(x) \log\frac{\mathbf{g}(x)}{\mathbf{f}(x)}\)
, where $\mathbf{f}(x)$ is predicted probabilities and $\mathbf{g}(x)$ is soft labels. Correct using of KL divergence in PyTorch (torch.nn.KLDivLoss()) is relatively important. In this post, we describe three pitfalls in using torch.nn.KLDivLoss(). For reference, the default signature is listed as follows.

torch.nn.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=False)

Reduction Should Be Batchmean Rather Than Mean

In default reduction mode mean, the losses are averaged for each minibatch over observations as well as over dimensions. batchmean mode gives the correct KL divergence where losses are averaged over batch dimension only. mean mode’s behavior will be changed to the same as batchmean in the next major release.

Log Target Should Be True if Working with Probabilities

According to the previous mathematical definition, the targets and input are probabilities, which implies that the summation is always one (either hard or soft labels). In this situation, log_target has to be True, whereas its default status is False.

import torch
import torch.nn as nn

# initialize tensors
loss_log = nn.KLDivLoss(reduction='batchmean', log_target=True)
loss_nonlog = nn.KLDivLoss(reduction='batchmean', log_target=False)

target = torch.tensor([[0.95,0.05]])
pred = target.clone()

print(f'log_target=True, loss is {loss_log(pred, target).item():.5f}')
print(f'log_target=False, loss is {loss_nonlog(pred, target).item():.5f}')

Output:

log_target=True, loss is 0.00000
log_target=False, loss is -1.10352

Input and Target Should Never Be Inverted

In the mathematical definition of KL divergence, $\mathbf{f}(x)$ can be arbitrarily interchanged with $\mathbf{g}(x)$. However, in PyTorch, input couldn’t be interchanged with targets, especially when log_target=False.

# initialize tensors
loss_nonlog = nn.KLDivLoss(reduction='batchmean', log_target=False)

init = torch.tensor([[0.95,0.05,0]])
pred = F.softmax(init, dim=-1)
target = F.log_softmax(init, dim=-1)

print(f'log_target=False, loss is {loss_nonlog(pred, target).item()}') # correct
print(f'log_target=False, loss is {loss_nonlog(target, pred).item()}') # incorrect
print(f'log_target=True, loss is {loss_log(pred, pred).item()}') # correct with log_target=True

Output:

log_target=False, loss is 0.0
log_target=False, loss is -2.7026482385394957e-08
log_target=True, loss is 0.0