ReduceLROnPlateau¶
- class torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose='deprecated')[source]¶
Reduce learning rate when a metric has stopped improving. Models often benefit from reducing the learning rate by a factor of 2-10 once learning stagnates. This scheduler reads a metrics quantity and if no improvement is seen for a ‘patience’ number of epochs, the learning rate is reduced.
- Parameters
optimizer (Optimizer) – Wrapped optimizer.
mode (str) – One of min, max. In min mode, lr will be reduced when the quantity monitored has stopped decreasing; in max mode it will be reduced when the quantity monitored has stopped increasing. Default: ‘min’.
factor (float) – Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1.
patience (int) – The number of allowed epochs with no improvement after which the learning rate will be reduced. For example, consider the case of having no patience (patience = 0). In the first epoch, a baseline is established and is always considered good as there’s no previous baseline. In the second epoch, if the performance is worse than the baseline, we have what is considered an intolerable epoch. Since the count of intolerable epochs (1) is greater than the patience level (0), the learning rate is reduced at the end of this epoch. From the third epoch onwards, the learning rate continues to be reduced at the end of each epoch if the performance is worse than the baseline. If the performance improves or remains the same, the learning rate is not adjusted. Default: 10.
threshold (float) – Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4.
threshold_mode (str) – One of rel, abs. In rel mode, dynamic_threshold = best * ( 1 + threshold ) in ‘max’ mode or best * ( 1 - threshold ) in min mode. In abs mode, dynamic_threshold = best + threshold in max mode or best - threshold in min mode. Default: ‘rel’.
cooldown (int) – Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0.
min_lr (float or list) – A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0.
eps (float) – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8.
verbose (bool) –
If
True
, prints a message to stdout for each update. Default:False
.Deprecated since version 2.2:
verbose
is deprecated. Please useget_last_lr()
to access the learning rate.
Example
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) >>> scheduler = ReduceLROnPlateau(optimizer, 'min') >>> for epoch in range(10): >>> train(...) >>> val_loss = validate(...) >>> # Note that step should be called after validate() >>> scheduler.step(val_loss)
- get_last_lr()¶
Return last computed learning rate by current scheduler.
- print_lr(is_verbose, group, lr, epoch=None)¶
Display the current learning rate.