torch.nn.attention.bias.CausalBias¶
- class torch.nn.attention.bias.CausalBias(variant, seq_len_q, seq_len_kv)[source]¶
A bias representing causal attention patterns. For an overview of the bias structure, see the
CausalVariant
enum.This class is used for defining causal (triangular) attention biases. For construing the bias, there exist two factory functions:
causal_upper_left()
andcausal_lower_right()
.Example:
from torch.nn.attention.bias import causal_lower_right bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8 # Create a lower-right causal bias attn_bias = causal_lower_right(seqlen_q, seqlen_kv) q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16) k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) out = F.scaled_dot_product_attention(q, k, v, attn_bias)
Warning
This class is a prototype and subject to change.