MultiheadAttention¶
- class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source]¶
Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need.
Multi-Head Attention is defined as:
where .
forward()
will use a special optimized implementation if all of the following conditions are met:self attention is being computed (i.e.,
query
,key
, andvalue
are the same tensor. This restriction will be loosened in the future.)Either autograd is disabled (using
torch.inference_mode
ortorch.no_grad
) or no tensor argumentrequires_grad
training is disabled (using
.eval()
)dropout is 0
add_bias_kv
isFalse
add_zero_attn
isFalse
batch_first
isTrue
and the input is batchedkdim
andvdim
are equal toembed_dim
at most one of
key_padding_mask
orattn_mask
is passedif a NestedTensor is passed, neither
key_padding_mask
norattn_mask
is passed
If the optimized implementation is in use, a NestedTensor can be passed for
query
/key
/value
to represent padding more efficiently than using a padding mask. In this case, a NestedTensor will be returned, and an additional speedup proportional to the fraction of the input that is padding can be expected.- Parameters:
embed_dim – Total dimension of the model.
num_heads – Number of parallel attention heads. Note that
embed_dim
will be split acrossnum_heads
(i.e. each head will have dimensionembed_dim // num_heads
).dropout – Dropout probability on
attn_output_weights
. Default:0.0
(no dropout).bias – If specified, adds bias to input / output projection layers. Default:
True
.add_bias_kv – If specified, adds bias to the key and value sequences at dim=0. Default:
False
.add_zero_attn – If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default:
False
.kdim – Total number of features for keys. Default:
None
(useskdim=embed_dim
).vdim – Total number of features for values. Default:
None
(usesvdim=embed_dim
).batch_first – If
True
, then the input and output tensors are provided as (batch, seq, feature). Default:False
(seq, batch, feature).
Examples:
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True)[source]¶
- Parameters:
query (Tensor) – Query embeddings of shape for unbatched input, when
batch_first=False
or whenbatch_first=True
, where is the target sequence length, is the batch size, and is the query embedding dimensionembed_dim
. Queries are compared against key-value pairs to produce the output. See “Attention Is All You Need” for more details.key (Tensor) – Key embeddings of shape for unbatched input, when
batch_first=False
or whenbatch_first=True
, where is the source sequence length, is the batch size, and is the key embedding dimensionkdim
. See “Attention Is All You Need” for more details.value (Tensor) – Value embeddings of shape for unbatched input, when
batch_first=False
or whenbatch_first=True
, where is the source sequence length, is the batch size, and is the value embedding dimensionvdim
. See “Attention Is All You Need” for more details.key_padding_mask (Optional[Tensor]) – If specified, a mask of shape indicating which elements within
key
to ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be . Binary and byte masks are supported. For a binary mask, aTrue
value indicates that the correspondingkey
value will be ignored for the purpose of attention. For a float mask, it will be directly added to the correspondingkey
value.need_weights (bool) – If specified, returns
attn_output_weights
in addition toattn_outputs
. Default:True
.attn_mask (Optional[Tensor]) – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape or , where is the batch size, is the target sequence length, and is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. For a binary mask, a
True
value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight.average_attn_weights (bool) – If true, indicates that the returned
attn_weights
should be averaged across heads. Otherwise,attn_weights
are provided separately per head. Note that this flag only has an effect whenneed_weights=True
. Default:True
(i.e. average weights across heads)
- Return type:
- Outputs:
attn_output - Attention outputs of shape when input is unbatched, when
batch_first=False
or whenbatch_first=True
, where is the target sequence length, is the batch size, and is the embedding dimensionembed_dim
.attn_output_weights - Only returned when
need_weights=True
. Ifaverage_attn_weights=True
, returns attention weights averaged across heads of shape when input is unbatched or , where is the batch size, is the target sequence length, and is the source sequence length. Ifaverage_attn_weights=False
, returns attention weights per head of shape when input is unbatched or .
Note
batch_first argument is ignored for unbatched inputs.