Shortcuts

MultiheadAttention

class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source]

Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need.

Multi-Head Attention is defined as:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O

where headi=Attention(QWiQ,KWiK,VWiV)head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V).

forward() will use a special optimized implementation if all of the following conditions are met:

  • self attention is being computed (i.e., query, key, and value are the same tensor. This restriction will be loosened in the future.)

  • Either autograd is disabled (using torch.inference_mode or torch.no_grad) or no tensor argument requires_grad

  • training is disabled (using .eval())

  • dropout is 0

  • add_bias_kv is False

  • add_zero_attn is False

  • batch_first is True and the input is batched

  • kdim and vdim are equal to embed_dim

  • at most one of key_padding_mask or attn_mask is passed

  • if a NestedTensor is passed, neither key_padding_mask nor attn_mask is passed

If the optimized implementation is in use, a NestedTensor can be passed for query/key/value to represent padding more efficiently than using a padding mask. In this case, a NestedTensor will be returned, and an additional speedup proportional to the fraction of the input that is padding can be expected.

Parameters
  • embed_dim – Total dimension of the model.

  • num_heads – Number of parallel attention heads. Note that embed_dim will be split across num_heads (i.e. each head will have dimension embed_dim // num_heads).

  • dropout – Dropout probability on attn_output_weights. Default: 0.0 (no dropout).

  • bias – If specified, adds bias to input / output projection layers. Default: True.

  • add_bias_kv – If specified, adds bias to the key and value sequences at dim=0. Default: False.

  • add_zero_attn – If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False.

  • kdim – Total number of features for keys. Default: None (uses kdim=embed_dim).

  • vdim – Total number of features for values. Default: None (uses vdim=embed_dim).

  • batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

Examples:

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True)[source]
Parameters
  • query – Query embeddings of shape (L,Eq)(L, E_q) for unbatched input, (L,N,Eq)(L, N, E_q) when batch_first=False or (N,L,Eq)(N, L, E_q) when batch_first=True, where LL is the target sequence length, NN is the batch size, and EqE_q is the query embedding dimension embed_dim. Queries are compared against key-value pairs to produce the output. See “Attention Is All You Need” for more details.

  • key – Key embeddings of shape (S,Ek)(S, E_k) for unbatched input, (S,N,Ek)(S, N, E_k) when batch_first=False or (N,S,Ek)(N, S, E_k) when batch_first=True, where SS is the source sequence length, NN is the batch size, and EkE_k is the key embedding dimension kdim. See “Attention Is All You Need” for more details.

  • value – Value embeddings of shape (S,Ev)(S, E_v) for unbatched input, (S,N,Ev)(S, N, E_v) when batch_first=False or (N,S,Ev)(N, S, E_v) when batch_first=True, where SS is the source sequence length, NN is the batch size, and EvE_v is the value embedding dimension vdim. See “Attention Is All You Need” for more details.

  • key_padding_mask – If specified, a mask of shape (N,S)(N, S) indicating which elements within key to ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be (S)(S). Binary and byte masks are supported. For a binary mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding key value will be ignored.

  • need_weights – If specified, returns attn_output_weights in addition to attn_outputs. Default: True.

  • attn_mask – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape (L,S)(L, S) or (Nnum_heads,L,S)(N\cdot\text{num\_heads}, L, S), where NN is the batch size, LL is the target sequence length, and SS is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. For a binary mask, a True value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight.

  • average_attn_weights – If true, indicates that the returned attn_weights should be averaged across heads. Otherwise, attn_weights are provided separately per head. Note that this flag only has an effect when need_weights=True. Default: True (i.e. average weights across heads)

Outputs:
  • attn_output - Attention outputs of shape (L,E)(L, E) when input is unbatched, (L,N,E)(L, N, E) when batch_first=False or (N,L,E)(N, L, E) when batch_first=True, where LL is the target sequence length, NN is the batch size, and EE is the embedding dimension embed_dim.

  • attn_output_weights - Only returned when need_weights=True. If average_attn_weights=True, returns attention weights averaged across heads of shape (L,S)(L, S) when input is unbatched or (N,L,S)(N, L, S), where NN is the batch size, LL is the target sequence length, and SS is the source sequence length. If average_weights=False, returns attention weights per head of shape (num_heads,L,S)(\text{num\_heads}, L, S) when input is unbatched or (N,num_heads,L,S)(N, \text{num\_heads}, L, S).

Note

batch_first argument is ignored for unbatched inputs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources