torch.nn.functional.ctc_loss¶
- torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)[source]¶
The Connectionist Temporal Classification loss.
See
CTCLoss
for details.Note
In some circumstances when given tensors on a CUDA device and using CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by setting
torch.backends.cudnn.deterministic = True
. See Reproducibility for more information.Note
This operation may produce nondeterministic gradients when given tensors on a CUDA device. See Reproducibility for more information.
- Parameters
log_probs (Tensor) – or where C = number of characters in alphabet including blank, T = input length, and N = batch size. The logarithmized probabilities of the outputs (e.g. obtained with
torch.nn.functional.log_softmax()
).targets (Tensor) – or (sum(target_lengths)). Targets cannot be blank. In the second form, the targets are assumed to be concatenated.
input_lengths (Tensor) – or . Lengths of the inputs (must each be )
target_lengths (Tensor) – or . Lengths of the targets
blank (int, optional) – Blank label. Default .
reduction (str, optional) – Specifies the reduction to apply to the output:
'none'
|'mean'
|'sum'
.'none'
: no reduction will be applied,'mean'
: the output losses will be divided by the target lengths and then the mean over the batch is taken,'sum'
: the output will be summed. Default:'mean'
zero_infinity (bool, optional) – Whether to zero infinite losses and the associated gradients. Default:
False
Infinite losses mainly occur when the inputs are too short to be aligned to the targets.
- Return type
Example:
>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_() >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long) >>> input_lengths = torch.full((16,), 50, dtype=torch.long) >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) >>> loss.backward()