[Quantization][V1] BitsAndBytes support V1 (#15611)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-03-28 10:12:47 +08:00
committed by GitHub
parent bd45912b99
commit 726efc6a32
7 changed files with 52 additions and 24 deletions

View File

@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import direct_register_custom_op
class BitsAndBytesConfig(QuantizationConfig):
@@ -321,9 +322,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# only load the bitsandbytes module when needed
from bitsandbytes import matmul_4bit
original_type = x.dtype
original_shape = x.shape
reshape_after_matmul = False
@@ -343,19 +341,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out_dim_1,
dtype=torch.bfloat16,
device=x.device)
current_index = 0
for i in range(len(quant_states)):
output_size = quant_states[i].shape[0]
# It is more efficient to use out kwarg like
# matmul_4bit(..., out = ...). Infeasible now due to the bug
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
# Need to change after the bug is fixed.
out[:, current_index:current_index + output_size] = matmul_4bit(
bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
current_index += output_size
apply_bnb_4bit(bf_x, qweight, offsets, out)
out = out.to(original_type)
if reshape_after_matmul:
@@ -365,3 +351,46 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out += bias
return out
def _apply_bnb_4bit(
x: torch.Tensor,
weight: torch.Tensor,
offsets: torch.Tensor,
out: torch.Tensor,
) -> None:
# only load the bitsandbytes module when needed
from bitsandbytes import matmul_4bit
quant_states = weight.bnb_quant_state
current_index = 0
for i in range(len(quant_states)):
output_size = quant_states[i].shape[0]
# It is more efficient to use out kwarg like
# matmul_4bit(..., out = ...). Infeasible now due to the bug
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
# Need to change after the bug is fixed.
out[:, current_index:current_index + output_size] = matmul_4bit(
x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
current_index += output_size
def _apply_bnb_4bit_fake(
x: torch.Tensor,
weight: torch.Tensor,
offsets: torch.Tensor,
out: torch.Tensor,
) -> None:
return
try:
direct_register_custom_op(
op_name="apply_bnb_4bit",
op_func=_apply_bnb_4bit,
mutates_args=["out"],
fake_impl=_apply_bnb_4bit_fake,
)
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
except AttributeError as error:
raise error