[ Kernel ] Enable fp8-marlin for fbgemm-fp8 models (#6606)
This commit is contained in:
@@ -9,9 +9,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear, create_per_channel_scale_param)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -31,6 +34,12 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
self.ignore_list = ignore_list
|
||||
self.input_scale_ub = input_scale_ub
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
self.use_marlin = capability < 89
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "fbgemm_fp8"
|
||||
@@ -41,7 +50,7 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 89
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@@ -143,11 +152,26 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
weight = layer.weight
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
|
||||
if self.quant_config.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale_ub
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.quant_config.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
return apply_fp8_linear(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
|
||||
Reference in New Issue
Block a user