torch.nn.functional.scaled_dot_product_attention¶
- torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) Tensor: ¶
Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
# Efficient implementation equivalent to the following: scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p) return attn_weight @ V
Warning
This function is beta and subject to change.
Note
There are currently three supported implementations of scaled dot product attention:
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
A PyTorch implementation defined in C++ matching the above formulation
The function may call optimized kernels for improved performance when using the CUDA backend. For all other backends, the PyTorch implementation will be used.
All implementations are enabled by default. Scaled dot product attention attempts to automatically select the most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation is used, the following functions are provided for enabling and disabling implementations. The context manager is the preferred mechanism:
torch.backends.cuda.sdp_kernel()
: A context manager used to enable/disable any of the implementations.torch.backends.cuda.enable_flash_sdp()
: Enables or Disables FlashAttention.torch.backends.cuda.enable_mem_efficient_sdp()
: Enables or Disables Memory-Efficient Attention.torch.backends.cuda.enable_math_sdp()
: Enables or Disables the PyTorch C++ implementation.
Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation, disable the PyTorch C++ implementation using
torch.backends.cuda.sdp_kernel()
. In the event that a fused implementation is not available, an error will be raised with the reasons why the fused implementation cannot run.Due to the nature of fusing floating point operations, the output of this function may be different depending on what backend kernel is chosen. The c++ implementation supports torch.float64 and can be used when higher precision is required. For more information please see Numerical accuracy
Note
In some circumstances when given tensors on a CUDA device and using CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by setting
torch.backends.cudnn.deterministic = True
. See Reproducibility for more information.- Parameters
query (Tensor) – Query tensor; shape .
key (Tensor) – Key tensor; shape .
value (Tensor) – Value tensor; shape .
attn_mask (optional Tensor) – Attention mask; shape . Two types of masks are supported. A boolean mask where a value of True indicates that the element should take part in attention. A float mask of the same type as query, key, value that is added to the attention score.
dropout_p (float) – Dropout probability; if greater than 0.0, dropout is applied
is_causal (bool) – If true, assumes causal attention masking and errors if both attn_mask and is_causal are set.
scale (optional python:float) – Scaling factor applied prior to softmax. If None, the default value is set to .
- Returns
Attention output; shape .
- Return type
output (Tensor)
- Shape legend:
Examples:
>>> # Optionally use the context manager to ensure one of the fused kernels is run >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> with torch.backends.cuda.sdp_kernel(enable_math=False): >>> F.scaled_dot_product_attention(query,key,value)