[Model] Add BNB quantization support for Mllama (#9720)

This commit is contained in:
Isotr0py
2024-10-29 20:20:02 +08:00
committed by GitHub
parent ef7865b4f9
commit 09500f7dde
3 changed files with 84 additions and 12 deletions

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
import torch
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
@@ -23,7 +24,7 @@ class BitsAndBytesConfig(QuantizationConfig):
bnb_4bit_use_double_quant: bool = False,
llm_int8_enable_fp32_cpu_offload: bool = False,
llm_int8_has_fp16_weight: bool = False,
llm_int8_skip_modules: Optional[Any] = None,
llm_int8_skip_modules: Optional[List[str]] = None,
llm_int8_threshold: float = 0.0,
) -> None:
@@ -34,11 +35,15 @@ class BitsAndBytesConfig(QuantizationConfig):
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
self.llm_int8_skip_modules = llm_int8_skip_modules
self.llm_int8_skip_modules = llm_int8_skip_modules or []
self.llm_int8_threshold = llm_int8_threshold
def __repr__(self) -> str:
return "BitsAndBytesConfig"
return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
f"load_in_4bit={self.load_in_4bit}, "
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
@classmethod
def get_name(self) -> str:
@@ -102,8 +107,10 @@ class BitsAndBytesConfig(QuantizationConfig):
llm_int8_threshold=llm_int8_threshold)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
return UnquantizedLinearMethod()
return BitsAndBytesLinearMethod(self)
return None
@@ -111,6 +118,10 @@ class BitsAndBytesConfig(QuantizationConfig):
return []
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
return any(module_name in prefix for module_name in llm_int8_skip_modules)
class BitsAndBytesLinearMethod(LinearMethodBase):
"""Linear method for BitsAndBytes.
@@ -211,6 +222,11 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
from bitsandbytes import MatmulLtState, matmul
original_type = x.dtype
original_shape = x.shape
reshape_after_matmul = False
if x.ndim > 2:
x = x.reshape(-1, x.size(-1))
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)
qweight = layer.qweight
@@ -265,6 +281,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out = out.to(original_type)
if reshape_after_matmul:
out = out.view(*original_shape[:-1], out.size(-1))
if bias is not None:
out += bias
@@ -282,6 +301,11 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
from bitsandbytes import matmul_4bit
original_type = x.dtype
original_shape = x.shape
reshape_after_matmul = False
if x.ndim > 2:
x = x.reshape(-1, x.size(-1))
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)
qweight = layer.qweight
@@ -310,6 +334,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out = out.to(original_type)
if reshape_after_matmul:
out = out.view(*original_shape[:-1], out.size(-1))
if bias is not None:
out += bias