[ROCm] Enable bitsandbytes quantization support on ROCm (#34688)

Signed-off-by: badaoui <abdennacerbadaoui0@gmail.com>
This commit is contained in:
BADAOUI Abdennacer
2026-02-21 09:34:55 +01:00
committed by GitHub
parent 2aab2bb543
commit 8dc8a99b56
8 changed files with 29 additions and 44 deletions

View File

@@ -28,6 +28,24 @@ from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
def _check_bitsandbytes_version():
min_version = "0.49.2" if current_platform.is_rocm() else "0.48.1"
try:
import bitsandbytes
if version.parse(bitsandbytes.__version__) < version.parse(min_version):
raise ImportError(
"bitsandbytes version is wrong. Please "
f"install bitsandbytes>={min_version}."
)
except ImportError as err:
raise ImportError(
f"Please install bitsandbytes>={min_version} via "
f"`pip install bitsandbytes>={min_version}` to use "
"bitsandbytes quantizer."
) from err
class BitsAndBytesConfig(QuantizationConfig):
"""Config class for BitsAndBytes Quantization.
@@ -183,21 +201,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
"""
def __init__(self, quant_config: BitsAndBytesConfig):
try:
import bitsandbytes
if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
raise ImportError(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.46.1."
)
except ImportError as err:
raise ImportError(
"Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer."
) from err
_check_bitsandbytes_version()
self.quant_config = quant_config
def create_weights(
@@ -442,20 +446,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
moe: FusedMoEConfig,
):
super().__init__(moe)
try:
import bitsandbytes
if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
raise ImportError(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.46.1."
)
except ImportError as err:
raise ImportError(
"Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer."
) from err
_check_bitsandbytes_version()
self.quant_config = quant_config
def create_weights(