Extending Bfloat16 Support to GemLite
Google's Gemma3 models, trained on TPUs using bfloat16 precision (bfp16), often have activations that exceed the range of float16 (fp16). Converting the weights and activations to fp16 has resulted in reported accuracy issues (alternative analysis).
Our fast low-bit matrix-multiplication library, Gemlite, did not support bfloat16 until recently. This was because Triton's atomic_add function only supported float16 and float32.
In this blog post, we explore various methods to incorporate bfloat16 support into GemLite. We discuss the limitations of each approach and provide a performance comparison in vLLM.
Furthermore, we are releasing a 4-bit bfloat16 HQQ-quantized version of Gemma3, which offers strong performance and compatibility across a wider range of hardware environments!
Introduction
Supporting bfloat16 in matrix multiplication might seem like a straightforward task. However, it proves to be a challenging one for Triton kernels that use atomic addition, including all gemv kernels and Split-K kernels in GemLite.
Let's assume the kernel performs C = matmul(A, B)
, where A
is an M x K
matrix and B
is a K x N
matrix. Atomic addition is typically used to accumulate the partial sums of the dot product along the K
dimension. For example, in the Split-K GEMM implementation, the K
dimension is split into SPLIT_K
accumulations, which improves performance for smaller batch sizes compared to standard GEMM.
Triton's atomic_add currently supports only float16 and float32. While native atomic addition for bfloat16 is available starting with sm_90 (Hopper or newer), as specified in the PTX manual, Triton does not yet expose this capability. This limitation remains an open issue on Triton's GitHub repository.
These limitations prompted us to explore various ways to adapt our kernels to support bfloat16. In the next section, we present several methods and discuss their advantages and disadvantages.
Approaches
We explored several approaches to efficiently support bfloat16 kernels. Below, we detail these methods, highlighting their unique trade-offs, and ultimately explain the reasoning behind our final choice.
Compare-And-Swap
Initially, we attempted to imulate atomic additions using Atomic Compare-And-Swap (CAS). This method employs a spinlock mechanism to prevent race conditions, ensuring correct computation, accumulation, and storage of partial sums. Without this synchronization, outputs could become inconsistent due to concurrent access.
A notable advantage of simulating atomic addition with CAS is its flexibility, enabling the casting to various data types beyond float16 and float32. However, the high overhead associated with spinlock loops significantly hampers performance, rendering this approach impractical for real-world kernels.
@triton.jit
def atomic_add_cas(ptr, value, Lock, mask, sem: tl.constexpr):
while tl.atomic_cas(Lock, 0, 1, sem=sem) == 1:
pass
tl.store(ptr, tl.load(ptr, mask=mask) + value, mask=mask)
tl.debug_barrier()
tl.atomic_xchg(Lock, 0)
Custom PTX
Our next consideration involved utilizing custom inlined PTX instructions. Although potentially efficient, this solution is constrained by hardware compatibility. Specifically, the instruction atom.add.noftz.bf16
is only available on NVIDIA Hopper GPUs and newer architectures. Since GemLite requires broader device support across both consumer and data center GPUs, this approach was not viable due to limited compatibility.
In-Kernel Casting
We further investigated an approach leveraging in-kernel casting, particularly within Split-K implementations. This strategy utilizes three buffers:
c
: The final output buffer as bfloat16c_tmp
: An intermediate buffer as float32 for atomic addition of the partial sumspid_acc
: A small buffer to track the process IDs (PIDs)
Partial accumulations increment counters in the pid_acc
buffer. Once the accumulated count reaches the predefined SPLIT_K
, all partial sums are complete, and results can be safely stored in the bfloat16 output buffer.
Initializing the pid_acc
buffer posed some challenges, mainly because direct access to tuning parameters BLOCK_SIZE_M
and BLOCK_SIZE_N
isn't possible during autotuning. Additionally, dimensions M
and N
might not align with block sizes. To mitigate these issues, we defaulted to the minimum autotuning values (BLOCK_SIZE_M=16
, BLOCK_SIZE_N=32
) and rounded dimensions M
and N
up to the nearest power of two, ensuring buffer coverage across the entire computation grid.
if(SPLIT_K > 1):
tl.atomic_add(c_ptrs, acc, mask=mask, sem=atomic_mode)
else:
tl.store(c_ptrs, acc, mask=mask)
if(not native_atomic):
done = tl.atomic_add(pid_acc + pid, 1)
if (done == SPLIT_K - 1):
tl.store(c_tmp_ptrs, tl.load(c_ptrs, mask=mask), mask=mask)
While effective, this method introduces overhead due to additional buffer management and complexity when supporting multiple configurations, particularly with varying data types and conditions like SPLIT_K >= 1
. For instance, pre-hooks are needed to initialize buffers, but data type usage (bfloat16 vs. float16) isn't always predetermined, complicating buffer setup.
Post-Casting (Final Solution)
Ultimately, we settled on performing atomic addition with float32, followed by post-casting the accumulated results to bfloat16. Although float16 offers superior speed, its limited dynamic range and precision pose significant risks of rounding errors during accumulation. Using float32 ensures enhanced numerical stability, significantly reducing the risk of accuracy loss. Nonetheless, float32 atomic additions and the final casting introduce performance penalties, typically manifesting as a worst-case slowdown of 5-10%, depending on specific model parameters, devices, and backend implementations.
native_atomic = output_dtype in [torch.float16, torch.float32]
dtype = output_dtype if native_atomic else torch.float32
output = torch.empty((M, N), device=device, dtype=dtype)
kernel_call(input, output)
if(not native_atomic):
output = output.to(output_dtype)
Benchmarks
To evaluate the practical impact of our chosen post-casting approach, we benchmarked the performance of vLLM integrated with GemLite and HQQ. We conducted tests on both consumer GPUs (A16W4 quantization on the RTX 4090) and data center GPUs (A8W8 dynamic quantization on the H100 SXM4).
As expected, the performance difference between float16 and bfloat16 varies depending on the specific device and the batch-size. Our results indicate that the post-casting method introduces a slowdown in the range of 5-10% in the worst-case scenarios compared to native float16.
The script used to run vLLM with GemLite for these benchmarks can be found here.
Conclusion
In this work, we explored various methods for adding bfloat16 support to GemLite. We discussed how each approach can be implemented, along with the advantages and disadvantages of each. We hope this blog post provides value to anyone interested in using bfloat16 for atomic addition in Triton kernels!
Citation
@misc{badri2025bfp16,
title = {Extending Bfloat16 Support to GemLite,
url = {https://mobiusml.github.io/gemlite_bfp16_blogpost/},
author = {Hicham Badri},
month = {March},
year = {2025}
}
Please feel free to contact us..