RMSNorm¶
- class torch.nn.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[source]¶
Applies Root Mean Square Layer Normalization over a mini-batch of inputs.
This layer implements the operation as described in the paper Root Mean Square Layer Normalization
The root mean squared norm is taken over the last
D
dimensions, whereD
is the dimension ofnormalized_shape
. For example, ifnormalized_shape
is(3, 5)
(a 2-dimensional shape), the rms norm is computed over the last 2 dimensions of the input.- Parameters
normalized_shape (int or list or torch.Size) –
input shape from an expected input of size
If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size.
eps (Optional[float]) – a value added to the denominator for numerical stability. Default:
torch.finfo(x.dtype).eps()
elementwise_affine (bool) – a boolean value that when set to
True
, this module has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases). Default:True
.
- Shape:
Input:
Output: (same shape as input)
Examples:
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input)