[Quantization][V1] BitsAndBytes support V1 (#15611)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
class BitsAndBytesConfig(QuantizationConfig):
|
||||
@@ -321,9 +322,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
# only load the bitsandbytes module when needed
|
||||
from bitsandbytes import matmul_4bit
|
||||
|
||||
original_type = x.dtype
|
||||
original_shape = x.shape
|
||||
reshape_after_matmul = False
|
||||
@@ -343,19 +341,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
out_dim_1,
|
||||
dtype=torch.bfloat16,
|
||||
device=x.device)
|
||||
|
||||
current_index = 0
|
||||
for i in range(len(quant_states)):
|
||||
output_size = quant_states[i].shape[0]
|
||||
# It is more efficient to use out kwarg like
|
||||
# matmul_4bit(..., out = ...). Infeasible now due to the bug
|
||||
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
|
||||
# Need to change after the bug is fixed.
|
||||
out[:, current_index:current_index + output_size] = matmul_4bit(
|
||||
bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
|
||||
|
||||
current_index += output_size
|
||||
|
||||
apply_bnb_4bit(bf_x, qweight, offsets, out)
|
||||
out = out.to(original_type)
|
||||
|
||||
if reshape_after_matmul:
|
||||
@@ -365,3 +351,46 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
out += bias
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _apply_bnb_4bit(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
# only load the bitsandbytes module when needed
|
||||
from bitsandbytes import matmul_4bit
|
||||
quant_states = weight.bnb_quant_state
|
||||
current_index = 0
|
||||
for i in range(len(quant_states)):
|
||||
output_size = quant_states[i].shape[0]
|
||||
# It is more efficient to use out kwarg like
|
||||
# matmul_4bit(..., out = ...). Infeasible now due to the bug
|
||||
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
|
||||
# Need to change after the bug is fixed.
|
||||
out[:, current_index:current_index + output_size] = matmul_4bit(
|
||||
x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
|
||||
current_index += output_size
|
||||
|
||||
|
||||
def _apply_bnb_4bit_fake(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="apply_bnb_4bit",
|
||||
op_func=_apply_bnb_4bit,
|
||||
mutates_args=["out"],
|
||||
fake_impl=_apply_bnb_4bit_fake,
|
||||
)
|
||||
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
Reference in New Issue
Block a user