[ROCm] Enable bitsandbytes quantization support on ROCm (#34688)

Signed-off-by: badaoui <abdennacerbadaoui0@gmail.com>
This commit is contained in:
BADAOUI Abdennacer
2026-02-21 09:34:55 +01:00
committed by GitHub
parent 2aab2bb543
commit 8dc8a99b56
8 changed files with 29 additions and 44 deletions

View File

@@ -7,7 +7,7 @@ Compared to other quantization methods, BitsAndBytes eliminates the need for cal
Below are the steps to utilize BitsAndBytes with vLLM.
```bash
pip install bitsandbytes>=0.46.1
pip install bitsandbytes>=0.49.2
```
vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint.

View File

@@ -33,7 +33,7 @@ transformers==4.57.5
tokenizers==0.22.0
schemathesis>=3.39.15 # Required for openai schema test.
# quantization
bitsandbytes>=0.46.1
bitsandbytes>=0.49.2
buildkite-test-collector==0.1.9

View File

@@ -102,3 +102,5 @@ terratorch==1.2.2
segmentation-models-pytorch==0.5.0
# Required for Prithvi tests
imagehash==4.3.2
# Required for bitsandbytes quantization test
bitsandbytes==0.49.2

View File

@@ -41,7 +41,7 @@ transformers==4.57.5
tokenizers==0.22.0
schemathesis>=3.39.15 # Required for openai schema test.
# quantization
bitsandbytes==0.46.1
bitsandbytes==0.49.2
buildkite-test-collector==0.1.9

View File

@@ -66,7 +66,7 @@ backoff==2.2.1
# via
# -r requirements/test.in
# schemathesis
bitsandbytes==0.46.1
bitsandbytes==0.49.2
# via
# -r requirements/test.in
# lightning
@@ -653,6 +653,7 @@ orjson==3.11.5
packaging==24.2
# via
# accelerate
# bitsandbytes
# black
# datamodel-code-generator
# datasets

View File

@@ -6,8 +6,6 @@ from typing import Any
import pytest
from vllm.platforms import current_platform
from ..conftest import HfRunner, VllmRunner
from ..utils import multi_gpu_test, prep_prompts
from .registry import HF_EXAMPLE_MODELS
@@ -131,6 +129,7 @@ def test_distributed(
"quantization": "bitsandbytes",
},
),
("unsloth/tinyllama-bnb-4bit", {}),
],
)
@pytest.mark.parametrize("max_tokens", [32])
@@ -143,12 +142,6 @@ def test_quantization(
max_tokens: int,
num_logprobs: int,
) -> None:
if (
current_platform.is_rocm()
and quantization_kwargs.get("quantization", "") == "bitsandbytes"
):
pytest.skip("bitsandbytes quantization is currently not supported in rocm.")
with vllm_runner(
model,
model_impl="auto",

View File

@@ -28,6 +28,24 @@ from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
def _check_bitsandbytes_version():
min_version = "0.49.2" if current_platform.is_rocm() else "0.48.1"
try:
import bitsandbytes
if version.parse(bitsandbytes.__version__) < version.parse(min_version):
raise ImportError(
"bitsandbytes version is wrong. Please "
f"install bitsandbytes>={min_version}."
)
except ImportError as err:
raise ImportError(
f"Please install bitsandbytes>={min_version} via "
f"`pip install bitsandbytes>={min_version}` to use "
"bitsandbytes quantizer."
) from err
class BitsAndBytesConfig(QuantizationConfig):
"""Config class for BitsAndBytes Quantization.
@@ -183,21 +201,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
"""
def __init__(self, quant_config: BitsAndBytesConfig):
try:
import bitsandbytes
if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
raise ImportError(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.46.1."
)
except ImportError as err:
raise ImportError(
"Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer."
) from err
_check_bitsandbytes_version()
self.quant_config = quant_config
def create_weights(
@@ -442,20 +446,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
moe: FusedMoEConfig,
):
super().__init__(moe)
try:
import bitsandbytes
if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
raise ImportError(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.46.1."
)
except ImportError as err:
raise ImportError(
"Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer."
) from err
_check_bitsandbytes_version()
self.quant_config = quant_config
def create_weights(

View File

@@ -244,10 +244,8 @@ class RocmPlatform(Platform):
"mxfp4",
"petit_nvfp4",
"torchao",
"bitsandbytes",
]
# bitsandbytes not supported on gfx9 (warp size 64 limitation)
if not on_gfx9():
supported_quantization += ["bitsandbytes"]
@classmethod
def import_kernels(cls) -> None: