[Bugfix][FP8] Fix dynamic FP8 Marlin quantization (#7219)

This commit is contained in:
Michael Goin
2024-08-07 14:23:12 -04:00
committed by GitHub
parent fde47d3bc2
commit 5223199e03
3 changed files with 33 additions and 5 deletions

View File

@@ -4,6 +4,7 @@ import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
@@ -118,7 +119,7 @@ class Fp8LinearMethod(LinearMethodBase):
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
def create_weights(
self,
@@ -174,6 +175,14 @@ class Fp8LinearMethod(LinearMethodBase):
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
assert weight_scale.numel() == 1
weight_scale = convert_to_channelwise(
weight_scale.expand(len(layer.logical_widths)),
layer.logical_widths)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)