Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) (#30141)
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
@@ -1,162 +1,187 @@
|
||||
# Quantized KV Cache
|
||||
|
||||
## FP8 KV Cache
|
||||
## FP8 KV Cache Overview
|
||||
|
||||
Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache, improving throughput.
|
||||
Efficient memory usage is crucial for working with large language models. Quantizing the KV (Key-Value) cache to FP8 format can significantly reduce its memory footprint. This optimization enables you to store more tokens in memory, leading to improved throughput and support for longer context windows.
|
||||
|
||||
### FP8 Formats
|
||||
> **Note:** When using the Flash Attention 3 backend with FP8 KV cache, attention operations are also performed in the quantized (FP8) domain. In this configuration, queries are quantized to FP8 in addition to keys and values.
|
||||
|
||||
[OCP (Open Compute Project)](https://www.opencompute.org) specifies two common 8-bit floating point data formats:
|
||||
### Supported FP8 KV-Cache Quantization Schemes
|
||||
|
||||
- E5M2 (5 exponent bits and 2 mantissa bits)
|
||||
- E4M3FN (4 exponent bits and 3 mantissa bits, often shortened as E4M3)
|
||||
vLLM supports two main quantization strategies for the FP8 KV-cache:
|
||||
|
||||
The E4M3 format offers higher precision compared to E5M2. However, due to its small dynamic range (±240.0), E4M3 typically requires a higher-precision (FP32) scaling factor alongside each quantized tensor.
|
||||
- **Per-tensor quantization:**
|
||||
A single scale is applied for each Q, K, and V tensor individually. (`q/k/v_scale = [1]`)
|
||||
- **Per-attention-head quantization:**
|
||||
Each scale corresponds to an attention head: `q_scale = [num_heads]`, `k/v_scale = [num_kv_heads]`.
|
||||
|
||||
### Current Limitations
|
||||
> **Note:**
|
||||
> Per-attention-head quantization is currently available **only with the Flash Attention backend** and requires the calibration pathway provided by **llm-compressor**.
|
||||
|
||||
For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling factors of a finer granularity (e.g. per-channel).
|
||||
### Scale Calibration Approaches
|
||||
|
||||
### How FP8 KV Cache Works
|
||||
You can configure how the quantization scales are computed in vLLM using three different approaches:
|
||||
|
||||
The FP8 KV cache implementation follows this workflow:
|
||||
1. **No calibration (default scales):**
|
||||
All quantization scales are set to `1.0`.
|
||||
_Configure with:_
|
||||
```python
|
||||
kv_cache_dtype="fp8"
|
||||
calculate_kv_scales=False
|
||||
```
|
||||
|
||||
1. **Storage**: Key and Value tensors are quantized to FP8 format using scaling factors before being stored in the KV cache
|
||||
2. **Retrieval**: When needed for attention computation, cached KV tensors are dequantized back to higher precision (FP16/BF16)
|
||||
3. **Attention**: The attention-value multiplication (softmax output × V) is performed using the dequantized higher-precision V tensor
|
||||
2. **Random token calibration (on-the-fly):**
|
||||
Scales are automatically estimated from a single batch of random tokens during warmup and then fixed.
|
||||
_Configure with:_
|
||||
```python
|
||||
kv_cache_dtype="fp8"
|
||||
calculate_kv_scales=True
|
||||
```
|
||||
|
||||
This means the final attention computation operates on dequantized values, not FP8 tensors. The quantization reduces memory usage during storage but maintains computation accuracy by using higher precision during the actual attention operations.
|
||||
3. **[Recommended] Calibration with a dataset (via llm-compressor):**
|
||||
Scales are estimated using a curated calibration dataset for maximum accuracy.
|
||||
This requires the [llm-compressor](https://github.com/vllm-project/llm-compressor) library.
|
||||
_See example below!_
|
||||
|
||||
### Performance Impact
|
||||
#### Additional `kv_cache_dtype` Options
|
||||
|
||||
The current FP8 KV cache implementation primarily benefits throughput by allowing approximately double the amount of space for KV cache allocation. This enables either:
|
||||
- `kv_cache_dtype="auto"`: Use the model's default data type
|
||||
- `kv_cache_dtype="fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPUs)
|
||||
- `kv_cache_dtype="fp8_e5m2"`: Supported on CUDA 11.8+
|
||||
|
||||
- Processing longer context lengths for individual requests, or
|
||||
- Handling more concurrent request batches
|
||||
---
|
||||
|
||||
However, there are currently no latency improvements as the implementation does not yet include fused dequantization and attention operations. Future releases will support quantized attention with hardware acceleration, which should provide additional performance benefits. While the most recent silicon offerings (e.g. AMD MI300, NVIDIA Hopper or later) support native hardware conversion between FP8 and other formats (fp32, fp16, bf16), this benefit is not yet fully realized.
|
||||
## Examples
|
||||
|
||||
Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy, making it a practical choice for throughput optimization.
|
||||
### 1. No Calibration (`kv_cache_dtype="fp8"`, `calculate_kv_scales=False`)
|
||||
|
||||
## Usage Example
|
||||
|
||||
Here is an example of how to enable FP8 quantization:
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
# To calculate kv cache scales on the fly enable the calculate_kv_scales
|
||||
# parameter
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-2-7b-chat-hf",
|
||||
kv_cache_dtype="fp8",
|
||||
calculate_kv_scales=True,
|
||||
)
|
||||
prompt = "London is the capital of"
|
||||
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
|
||||
print(out)
|
||||
```
|
||||
|
||||
The `kv_cache_dtype` argument specifies the data type for KV cache storage:
|
||||
|
||||
- `"auto"`: Uses the model's default "unquantized" data type
|
||||
- `"fp8"` or `"fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPU)
|
||||
- `"fp8_e5m2"`: Supported on CUDA 11.8+
|
||||
|
||||
## Calibrated Scales for Better Accuracy
|
||||
|
||||
For optimal model quality when using FP8 KV Cache, we recommend using calibrated scales tuned to representative inference data. [LLM Compressor](https://github.com/vllm-project/llm-compressor/) is the recommended tool for this process.
|
||||
|
||||
### Installation
|
||||
|
||||
First, install the required dependencies:
|
||||
|
||||
```bash
|
||||
pip install llmcompressor
|
||||
```
|
||||
|
||||
### Example Usage
|
||||
|
||||
Here's a complete example using `meta-llama/Llama-3.1-8B-Instruct` (most models can use this same pattern):
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from llmcompressor import oneshot
|
||||
|
||||
# Select model and load it
|
||||
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
|
||||
# Select calibration dataset
|
||||
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
|
||||
DATASET_SPLIT = "train_sft"
|
||||
|
||||
# Configure calibration parameters
|
||||
NUM_CALIBRATION_SAMPLES = 512 # 512 samples is a good starting point
|
||||
MAX_SEQUENCE_LENGTH = 2048
|
||||
|
||||
# Load and preprocess dataset
|
||||
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
|
||||
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
|
||||
|
||||
def process_and_tokenize(example):
|
||||
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
|
||||
return tokenizer(
|
||||
text,
|
||||
padding=False,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
|
||||
ds = ds.map(process_and_tokenize, remove_columns=ds.column_names)
|
||||
|
||||
# Configure quantization settings
|
||||
recipe = """
|
||||
quant_stage:
|
||||
quant_modifiers:
|
||||
QuantizationModifier:
|
||||
kv_cache_scheme:
|
||||
num_bits: 8
|
||||
type: float
|
||||
strategy: tensor
|
||||
dynamic: false
|
||||
symmetric: true
|
||||
"""
|
||||
|
||||
# Apply quantization
|
||||
oneshot(
|
||||
model=model,
|
||||
dataset=ds,
|
||||
recipe=recipe,
|
||||
max_seq_length=MAX_SEQUENCE_LENGTH,
|
||||
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
|
||||
)
|
||||
|
||||
# Save quantized model: Llama-3.1-8B-Instruct-FP8-KV
|
||||
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-KV"
|
||||
model.save_pretrained(SAVE_DIR, save_compressed=True)
|
||||
tokenizer.save_pretrained(SAVE_DIR)
|
||||
```
|
||||
|
||||
The above script will create a folder in your current directory containing your quantized model (e.g., `Llama-3.1-8B-Instruct-FP8-KV`) with calibrated scales.
|
||||
|
||||
When running the model you must specify `kv_cache_dtype="fp8"` in order to enable the kv cache quantization and use the scales.
|
||||
All quantization scales are set to 1.0.
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
|
||||
llm = LLM(model="Llama-3.1-8B-Instruct-FP8-KV", kv_cache_dtype="fp8")
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-2-7b-chat-hf",
|
||||
kv_cache_dtype="fp8",
|
||||
calculate_kv_scales=False,
|
||||
)
|
||||
prompt = "London is the capital of"
|
||||
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
|
||||
print(out)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 2. Random Token Calibration (`kv_cache_dtype="fp8"`, `calculate_kv_scales=True`)
|
||||
|
||||
Scales are automatically estimated from a single batch of tokens during warmup.
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-2-7b-chat-hf",
|
||||
kv_cache_dtype="fp8",
|
||||
calculate_kv_scales=True,
|
||||
)
|
||||
prompt = "London is the capital of"
|
||||
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
|
||||
print(out)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 3. **[Recommended] Calibration Using a Dataset (with `llm-compressor`)**
|
||||
|
||||
For the highest-quality quantization, we recommend calibrating against a dataset using `llm-compressor`. This enables advanced strategies such as per-attention-head quantization.
|
||||
|
||||
#### Install the required package
|
||||
|
||||
```bash
|
||||
pip install llmcompressor
|
||||
```
|
||||
|
||||
#### Example: Quantize Llama Attention & KV Cache to FP8
|
||||
|
||||
```python
|
||||
"""
|
||||
Quantize Llama attention + KV cache to FP8 (choose either 'tensor' or 'attn_head' strategy)
|
||||
using llm-compressor one-shot calibration.
|
||||
"""
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from llmcompressor import oneshot
|
||||
from llmcompressor.modifiers.quantization import QuantizationModifier
|
||||
from compressed_tensors.quantization import QuantizationScheme, QuantizationArgs
|
||||
|
||||
# -----------------------------
|
||||
# Config
|
||||
# -----------------------------
|
||||
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
|
||||
DATASET_SPLIT = "train_sft"
|
||||
STRATEGY = "tensor" # or "attn_head"
|
||||
NUM_CALIB_SAMPLES = 512 # Good starting value
|
||||
MAX_SEQ_LEN = 2048
|
||||
|
||||
# -----------------------------
|
||||
# Helpers
|
||||
# -----------------------------
|
||||
def process_and_tokenize(example, tokenizer: AutoTokenizer):
|
||||
"""Convert chat messages to tokens."""
|
||||
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
|
||||
return tokenizer(
|
||||
text,
|
||||
padding=False,
|
||||
max_length=MAX_SEQ_LEN,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
|
||||
def build_recipe(strategy: str) -> QuantizationModifier:
|
||||
fp8_args = QuantizationArgs(num_bits=8, type="float", strategy=strategy)
|
||||
return QuantizationModifier(
|
||||
config_groups={
|
||||
"attention": QuantizationScheme(
|
||||
targets=["LlamaAttention"], # Quantize queries: q_scale
|
||||
input_activations=fp8_args,
|
||||
)
|
||||
},
|
||||
kv_cache_scheme=fp8_args, # Quantize KV cache: k/v_scale
|
||||
)
|
||||
|
||||
# -----------------------------
|
||||
# Main
|
||||
# -----------------------------
|
||||
def main():
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIB_SAMPLES}]")
|
||||
ds = ds.shuffle(seed=42)
|
||||
ds = ds.map(
|
||||
lambda ex: process_and_tokenize(ex, tokenizer),
|
||||
remove_columns=ds.column_names,
|
||||
)
|
||||
|
||||
recipe = build_recipe(STRATEGY)
|
||||
oneshot(
|
||||
model=model,
|
||||
dataset=ds,
|
||||
recipe=recipe,
|
||||
max_seq_length=MAX_SEQ_LEN,
|
||||
num_calibration_samples=NUM_CALIB_SAMPLES,
|
||||
)
|
||||
|
||||
save_dir = f"{MODEL_ID.rstrip('/').split('/')[-1]}-kvattn-fp8-{STRATEGY}"
|
||||
model.save_pretrained(save_dir, save_compressed=True)
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
For more detailed and up-to-date examples, see the [`llm-compressor` official examples](https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_kv_cache).
|
||||
|
||||
Reference in New Issue
Block a user