torch.set_float32_matmul_precision¶
- torch.set_float32_matmul_precision(precision)[source]¶
Sets the internal precision of float32 matrix multiplications.
Running float32 matrix multiplications in lower precision may significantly increase performance, and in some programs the loss of precision has a negligible impact.
Supports three settings:
“highest”, float32 matrix multiplications use the float32 datatype for internal computations.
“high”, float32 matrix multiplications use the TensorFloat32 or bfloat16_3x datatypes for internal computations, if fast matrix multiplication algorithms using those datatypes internally are available. Otherwise float32 matrix multiplications are computed as if the precision is “highest”.
“medium”, float32 matrix multiplications use the bfloat16 datatype for internal computations, if a fast matrix multiplication algorithm using that datatype internally is available. Otherwise float32 matrix multiplications are computed as if the precision is “high”.
Note
This does not change the output dtype of float32 matrix multiplications, it controls how the internal computation of the matrix multiplication is performed.
Note
This does not change the precision of convolution operations. Other flags, like torch.backends.cudnn.allow_tf32, may control the precision of convolution operations.
Note
This flag currently only affects one native device type: CUDA. If “high” or “medium” are set then the TensorFloat32 datatype will be used when computing float32 matrix multiplications, equivalent to setting torch.backends.cuda.matmul.allow_tf32 = True. When “highest” (the default) is set then the float32 datatype is used for internal computations, equivalent to setting torch.backends.cuda.matmul.allow_tf32 = False.
- Parameters
precision (str) – can be set to “highest” (default), “high”, or “medium” (see above).