Faster and Smaller Whisper: A Deep Dive into Quantization and Torch Compilation
Jilt Sebastian, Husein Zolkepli, Hicham Badri, Appu Shaji
Mobius Labs GmbH and Mesolitica
Introduction
Recent years have witnessed remarkable advancements in artificial intelligence, propelling rapid growth in automatic speech recognition (ASR) technologies. Soon after its release, OpenAI's Whisper model quickly gained prominence due to its open licensing, competitive performance against proprietary models, and strong generalization capabilities. Despite being in the field for over two years, Whisper models continue to be highly relevant and are the go-to workhorse for many large-scale ASR systems deployed worldwide.
In this blog post, we explain the techniques we used to enhance the performance of the PyTorch based Whisper models. By leveraging transformers, implementing a static cache, and utilizing torch.compile, we significantly accelerated the model's inference speed. Additionally, we employed HQQ to quantize the Whisper models to 4 bits, preserving transcription quality with minimal degradation, as evaluated by Word Error Rate (WER) benchmarks. Our optimizations resulted in a 4.5x speedup for non-quantized models and an impressive 6x speedup for quantized models. Moreover, we report detailed quantization results and found that Whisper can be run in extremely low-bit configurations.
Furthermore, we provide benchmarks across various Automatic Speech Recognition (ASR) datasets to demonstrate the effectiveness of our optimizations. This post will delve into the methods and processes behind these improvements, providing insights into the power of optimized kernels and quantization.
Table of Contents
Speed Optimization
torch.compile compiles PyTorch models into optimized Triton kernels and can often result in significant speedups for various PyTorch-based models. One of the prerequisites and critical steps toward achieving this is managing a static KV-cache, as detailed in https://pytorch.org/blog/accelerating-generative-ai-2/. During inference, the KV-cache stores intermediate outputs of the model in the attention modules. However, due to varying prompts and generation lengths, these caches are implemented via dynamic allocation, which creates overhead. Instead, torch.compile requires the KV-cache to be maintained statically by pre-allocating the maximum size.
The first step we took was to implement a static cache for the Whisper model to make it fully compatible with torch.compile. A gist of the implementation is available here. As far as we know, this is the first implementation of a static cache for encoder-decoder models in the Transformers library, and hopefully, it will be useful for other encoder-decoder architectures.
Quantization
We further quantized models to 4 bits using Half-Quadratic Quantization (HQQ). HQQ is a fast and accurate model quantizer that does not need any calibration data. Additionally, it fully supports torch.compile. Though quantization techniques can reduce VRAM requirements, they require optimized fused kernels to run faster. Recent developments have introduced low-bit matmul kernels that offer significant speed-ups such as TorchAO's tiny_gemm, and HQQ can leverage these kernels).
Our previous experience with quantizing LLMs has shown that there is only a very marginal drop in quality (measured as word error rate) when quantizing to 4 bits, while achieving further speedup compared to vanilla PyTorch compiled models. Moreover, the GPU VRAM requirements for the decoder can be reduced by a factor of 3 to 4.
We only quantize the decoder linear layers of Whisper for the following reason: the encoder linear layers involve mainly matrix-matrix computations (GEMM), whereas the decoder layers primarily involve matrix-vector multiplications (GEMV) during the decoding phase. The TorchAO kernel is optimized to speed up GEMV operations with 4-bit quantized weights.
Even if the encoder performance is optimized, its impact on the overall runtime is marginal compared to the decoder. The encoder is only run once, whereas the decoder must be run many times, once per token generated.
Benchmarks
We conducted two sets of benchmarking experiments to validate the speed-up using torch.compile and HQQ. The Open ASR eval dataset is used in the first experiment to measure the speed improvement and effect on the WER, and a test set of audio extracts from YouTube videos is used in the second benchmarking. We find that in real-life use cases, transcripts are generally long, hence the second dataset is included. We used the 'large-v2' model with a batch size of 1 and greedy decoding (no sampling) for the experiments.
The Real Time Factor (RTF) metric, commonly used in open-source benchmarks like Open ASR Leaderboard evaluation, measures the speed of offline ASR systems by comparing total processing time to audio duration. We use 1/RTF in our benchmarking experiments as it is more interpretable once the speed becomes faster than the real-time speed, which is a common offline ASR requirement. We used EnglishTextNormalizer()
from Open ASR Eval as the text normalizer to compute WER. Experiments use torch nightly build as of 28 May 2024. The experiments are run on a RTX 4090 GPU.
Short-Form Audio
The first dataset represents short-form audio as the average duration of each sample in the dataset is less than 10 seconds. Four Open ASR eval datasets with audios less than 30 seconds are used for the experiments. Baseline system is the current implementation in transformers, that does not use torch.compile operation.
Dataset | Avg. Duration | Baseline | Torch Compile | Torch Compile + HQQ | |||
---|---|---|---|---|---|---|---|
Speed | WER | Speed | WER | Speed | WER | ||
TEDLIUM | 8 secs | 11.6x | 4.06% | 35.3x | 4.06% | 37.3x | 4.0% |
Voxpopuli | 10 secs | 14.2x | 7.84% | 42.6x | 7.84% | 47.9x | 7.91% |
Earnings22 | 7 secs | 12.8x | 12.13% | 36.4x | 12.15% | 42.9x | 12.83% |
AMI | 2 secs | 9.8x | 16.67% | 20.8x | 16.71% | 20.4x | 16.69% |
Real-Word Scenario: Long Audios
In order to check the effect of these decoding methods in real-world scenarios with long audios, we used an internal test dataset with 84 minutes duration and verified ground truth. The test set contains 9 audios ranging from 3 minutes to 13 minutes and covers various audio types.
System | Baseline | Torch Compile | Torch Compile + HQQ |
---|---|---|---|
Speed | 15.8x | 57.6x | 72.8x |
WER (%) | 11.45 | 11.47 | 11.49 |
VRAM usage | 3576 MiB | 3576 MiB | 2454 MiB |
The speed-up in the actual processing time depends on the audio length. For datasets with larger audio lengths, the speed-up is higher as there are more tokens to decode at a higher pace, and the impact of optimization and quantization is more profound.
For example in AMI corpus, which is a corpus with very short audio clips with average length of 2 seconds, the speed up is around 2x compared to baseline versus the long form dataset where the speed up is 4.6x for the HQQ based approach and 3.6x for pytorch compiled one. This speed up can be achieved on top of the fastest sequential /batched execution pytorch based models, resulting in ultra-fast ASR inference.
Tokens per Second
The significant speed-up and impact of optimization and quantization happens in the decoder layers. The following table shows the tokens per second processed by the decoder.
Baseline | 49.5 it/sec |
Torch Compile | 243.4 it/sec |
Torch Compile + HQQ | 316.4 it/sec |
Whisper in Extreme Low-Bit Configurations
Additionally, we tested Whisper in extremely low-bit configurations (i.e., 3, 2, and 1.58 (ternary) bits) with different quantization group sizes. Quantization parameters (namely scale and zero point) are set and computed for a group of entries in a matrix. A lower group size requires more metadata regarding quantization parameters to be stored. Therefore, finding an optimal balance between a large group size and preserving quality is crucial. However, for lower bit sizes, a smaller group size is required (refer to this blog).
The default group size for 4-bit models is 64, so we report results for lower bit sizes at 32 and 16 group sizes. To our surprise, Whisper's performance only degrades slightly and remains quite useful even in extremely low-bit configurations. For 2-bit and 1.58-bit configurations, a group size of 16 is required to be useful. For larger bit sizes, a group size of 64 exhibits comparable accuracy to the non-quantized model.
Ternary weights (1.58 bit) are quite promising since matrix multiplications can be reformulated as additions, potentially resulting in a 70x improvement in efficiency (refer to this paper). We also conducted experiments with pure 1-bit quantization. However, even with a low group size, the Word Error Rate (WER) was very high, and the output was not meaningful.
Note that the PyTorch acceleration kernels we use are optimized for 4 bits, so the speed decreases when the bit size is less than 4 bits. Below, we report the Word Error Rate (WER) along with speed and quantization bitrate and group sizes for these configurations.
The interactive graph below summarizes the performance with different datasets into a scatter plot. Hover or click on a bubble to display the details.
Dataset | #bits | Group Size | WER | Speed |
---|---|---|---|---|
TEDLIUM | 4 bit | 64 | 4.06 | 37.3x |
3 bit | 16 | 4.94 | 17.9x | |
3 bit | 32 | 4.14 | 18.0x | |
3 bit | 64 | 4.72 | 15.9x | |
2 bit | 16 | 4.61 | 26.2x | |
2 bit | 32 | 4.77 | 23.2x | |
2 bit | 64 | 7.26 | 24.4x | |
1.58 bit | 16 | 5.30 | 26.3x | |
1.58 bit | 32 | 23.94 | 25.7x | |
Voxpopuli | 4 bit | 64 | 7.84 | 47.9x |
3 bit | 16 | 9.03 | 20.7x | |
3 bit | 32 | 8.40 | 21.3x | |
3 bit | 64 | 7.82 | 21.8x | |
2 bit | 16 | 8.05 | 29.6x | |
2 bit | 32 | 8.74 | 30.0x | |
2 bit | 64 | 8.97 | 29.4x | |
1.58 bit | 16 | 9.64 | 27.8x | |
1.58 bit | 32 | 30.42 | 25.3x | |
Earnings22 | 4 bit | 64 | 12.15 | 42.9x |
3 bit | 16 | 12.3 | 17.7x | |
3 bit | 32 | 12.41 | 16.3x | |
3 bit | 64 | 12.26 | 17.5x | |
2 bit | 16 | 13.60 | 25.9x | |
2 bit | 32 | 13.54 | 25.9x | |
2 bit | 64 | 16.03 | 23.0x | |
1.58 bit | 16 | 12.40 | 20.5x | |
1.58 bit | 32 | 38.50 | 23x | |
AMI | 4 bit | 64 | 16.71 | 20.4x |
3 bit | 16 | 16.63 | 9.0x | |
3 bit | 32 | 16.64 | 8.6x | |
3 bit | 64 | 16.72 | 8.7x | |
2 bit | 16 | 17.30 | 13.4x | |
2 bit | 32 | 17.80 | 13.7x | |
2 bit | 64 | 19.52 | 12.9x | |
1.58 bit | 16 | 19.15 | 13.3x | |
1.58 bit | 32 | 31.22 | 12.7x | |
Long-form Audio | 4 bit | 64 | 11.49 | 72.8x |
3 bit | 64 | 11.56 | 34.6x | |
3 bit | 32 | 11.40 | 33.9x | |
3 bit | 16 | 11.56 | 34.6x | |
2 bit | 64 | 12.74 | 46.0x | |
2 bit | 32 | 11.75 | 44.7x | |
2 bit | 16 | 11.58 | 45.3x | |
1.58 bit | 64 | 96.66 | - (higly erroneous) | |
1.58 bit | 32 | 41.8 | 38.7x (high hallucinations) | |
1.58 bit | 16 | 12.48 | 42.9x |
This warrants two future directions we will be pursuing: the development of more optimized kernels for lower bits and the use of post-quantization training, proposed as one of the methods in our 1-bit blog, to recover from low-bit performance degradation.
Caveats
A few caveats to note
- The TorchAO kernels required modern GPUs.
- Benchmarking of ASR systems is not fully mature yet. Therefore, these results need to be further battle-tested. However, our internal benchmark with long-form videos is representative of a production scenario, especially for English.
Citation
@misc{sebastian2024whisper1,
title = {Faster and Smaller Whisper: A Deep Dive into Quantization and Torch Compilation},
url = {https://mobiusml.github.io/whisper-static-cache-blog/},
author = {Jilt Sebastian, Husein Zolkepli, Hicham Badri, and Appu Shaji},
month = {May},
year = {2024}
}
Please feel free to contact us.