[Model] Add BNB quantization support for Mllama (#9720)
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
@@ -23,7 +24,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
bnb_4bit_use_double_quant: bool = False,
|
||||
llm_int8_enable_fp32_cpu_offload: bool = False,
|
||||
llm_int8_has_fp16_weight: bool = False,
|
||||
llm_int8_skip_modules: Optional[Any] = None,
|
||||
llm_int8_skip_modules: Optional[List[str]] = None,
|
||||
llm_int8_threshold: float = 0.0,
|
||||
) -> None:
|
||||
|
||||
@@ -34,11 +35,15 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
|
||||
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
|
||||
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
|
||||
self.llm_int8_skip_modules = llm_int8_skip_modules
|
||||
self.llm_int8_skip_modules = llm_int8_skip_modules or []
|
||||
self.llm_int8_threshold = llm_int8_threshold
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "BitsAndBytesConfig"
|
||||
return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
|
||||
f"load_in_4bit={self.load_in_4bit}, "
|
||||
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
|
||||
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
|
||||
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
@@ -102,8 +107,10 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
llm_int8_threshold=llm_int8_threshold)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
|
||||
prefix: str) -> Optional["LinearMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
|
||||
return UnquantizedLinearMethod()
|
||||
return BitsAndBytesLinearMethod(self)
|
||||
return None
|
||||
|
||||
@@ -111,6 +118,10 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
return []
|
||||
|
||||
|
||||
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
|
||||
return any(module_name in prefix for module_name in llm_int8_skip_modules)
|
||||
|
||||
|
||||
class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
"""Linear method for BitsAndBytes.
|
||||
|
||||
@@ -211,6 +222,11 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
from bitsandbytes import MatmulLtState, matmul
|
||||
|
||||
original_type = x.dtype
|
||||
original_shape = x.shape
|
||||
reshape_after_matmul = False
|
||||
if x.ndim > 2:
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
reshape_after_matmul = True
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.qweight
|
||||
@@ -265,6 +281,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
|
||||
out = out.to(original_type)
|
||||
|
||||
if reshape_after_matmul:
|
||||
out = out.view(*original_shape[:-1], out.size(-1))
|
||||
|
||||
if bias is not None:
|
||||
out += bias
|
||||
|
||||
@@ -282,6 +301,11 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
from bitsandbytes import matmul_4bit
|
||||
|
||||
original_type = x.dtype
|
||||
original_shape = x.shape
|
||||
reshape_after_matmul = False
|
||||
if x.ndim > 2:
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
reshape_after_matmul = True
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.qweight
|
||||
@@ -310,6 +334,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
|
||||
out = out.to(original_type)
|
||||
|
||||
if reshape_after_matmul:
|
||||
out = out.view(*original_shape[:-1], out.size(-1))
|
||||
|
||||
if bias is not None:
|
||||
out += bias
|
||||
|
||||
|
||||
Reference in New Issue
Block a user