[ROCm] Enable bitsandbytes quantization support on ROCm (#34688)
Signed-off-by: badaoui <abdennacerbadaoui0@gmail.com>
This commit is contained in:
committed by
GitHub
parent
2aab2bb543
commit
8dc8a99b56
@@ -7,7 +7,7 @@ Compared to other quantization methods, BitsAndBytes eliminates the need for cal
|
||||
Below are the steps to utilize BitsAndBytes with vLLM.
|
||||
|
||||
```bash
|
||||
pip install bitsandbytes>=0.46.1
|
||||
pip install bitsandbytes>=0.49.2
|
||||
```
|
||||
|
||||
vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint.
|
||||
|
||||
@@ -33,7 +33,7 @@ transformers==4.57.5
|
||||
tokenizers==0.22.0
|
||||
schemathesis>=3.39.15 # Required for openai schema test.
|
||||
# quantization
|
||||
bitsandbytes>=0.46.1
|
||||
bitsandbytes>=0.49.2
|
||||
buildkite-test-collector==0.1.9
|
||||
|
||||
|
||||
|
||||
@@ -102,3 +102,5 @@ terratorch==1.2.2
|
||||
segmentation-models-pytorch==0.5.0
|
||||
# Required for Prithvi tests
|
||||
imagehash==4.3.2
|
||||
# Required for bitsandbytes quantization test
|
||||
bitsandbytes==0.49.2
|
||||
|
||||
@@ -41,7 +41,7 @@ transformers==4.57.5
|
||||
tokenizers==0.22.0
|
||||
schemathesis>=3.39.15 # Required for openai schema test.
|
||||
# quantization
|
||||
bitsandbytes==0.46.1
|
||||
bitsandbytes==0.49.2
|
||||
buildkite-test-collector==0.1.9
|
||||
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ backoff==2.2.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# schemathesis
|
||||
bitsandbytes==0.46.1
|
||||
bitsandbytes==0.49.2
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# lightning
|
||||
@@ -653,6 +653,7 @@ orjson==3.11.5
|
||||
packaging==24.2
|
||||
# via
|
||||
# accelerate
|
||||
# bitsandbytes
|
||||
# black
|
||||
# datamodel-code-generator
|
||||
# datasets
|
||||
|
||||
@@ -6,8 +6,6 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..conftest import HfRunner, VllmRunner
|
||||
from ..utils import multi_gpu_test, prep_prompts
|
||||
from .registry import HF_EXAMPLE_MODELS
|
||||
@@ -131,6 +129,7 @@ def test_distributed(
|
||||
"quantization": "bitsandbytes",
|
||||
},
|
||||
),
|
||||
("unsloth/tinyllama-bnb-4bit", {}),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@@ -143,12 +142,6 @@ def test_quantization(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and quantization_kwargs.get("quantization", "") == "bitsandbytes"
|
||||
):
|
||||
pytest.skip("bitsandbytes quantization is currently not supported in rocm.")
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
model_impl="auto",
|
||||
|
||||
@@ -28,6 +28,24 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def _check_bitsandbytes_version():
|
||||
min_version = "0.49.2" if current_platform.is_rocm() else "0.48.1"
|
||||
try:
|
||||
import bitsandbytes
|
||||
|
||||
if version.parse(bitsandbytes.__version__) < version.parse(min_version):
|
||||
raise ImportError(
|
||||
"bitsandbytes version is wrong. Please "
|
||||
f"install bitsandbytes>={min_version}."
|
||||
)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
f"Please install bitsandbytes>={min_version} via "
|
||||
f"`pip install bitsandbytes>={min_version}` to use "
|
||||
"bitsandbytes quantizer."
|
||||
) from err
|
||||
|
||||
|
||||
class BitsAndBytesConfig(QuantizationConfig):
|
||||
"""Config class for BitsAndBytes Quantization.
|
||||
|
||||
@@ -183,21 +201,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: BitsAndBytesConfig):
|
||||
try:
|
||||
import bitsandbytes
|
||||
|
||||
if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
|
||||
raise ImportError(
|
||||
"bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.46.1."
|
||||
)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes>=0.46.1 via "
|
||||
"`pip install bitsandbytes>=0.46.1` to use "
|
||||
"bitsandbytes quantizer."
|
||||
) from err
|
||||
|
||||
_check_bitsandbytes_version()
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
@@ -442,20 +446,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
try:
|
||||
import bitsandbytes
|
||||
|
||||
if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
|
||||
raise ImportError(
|
||||
"bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.46.1."
|
||||
)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes>=0.46.1 via "
|
||||
"`pip install bitsandbytes>=0.46.1` to use "
|
||||
"bitsandbytes quantizer."
|
||||
) from err
|
||||
_check_bitsandbytes_version()
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
|
||||
@@ -244,10 +244,8 @@ class RocmPlatform(Platform):
|
||||
"mxfp4",
|
||||
"petit_nvfp4",
|
||||
"torchao",
|
||||
"bitsandbytes",
|
||||
]
|
||||
# bitsandbytes not supported on gfx9 (warp size 64 limitation)
|
||||
if not on_gfx9():
|
||||
supported_quantization += ["bitsandbytes"]
|
||||
|
||||
@classmethod
|
||||
def import_kernels(cls) -> None:
|
||||
|
||||
Reference in New Issue
Block a user