[Feature][Hardware][Amd] Add fp8 Linear Layer for Rocm (#7210)
This commit is contained in:
@@ -20,10 +20,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||
create_per_tensor_scale_param, cutlass_fp8_supported,
|
||||
per_tensor_dequantize, requantize_with_max_scale)
|
||||
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
requantize_with_max_scale)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
from vllm.utils import is_hip, print_warning_once
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@@ -120,6 +121,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
# Disable marlin for rocm
|
||||
if is_hip():
|
||||
self.use_marlin = False
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -168,6 +172,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
layer.register_parameter("input_scale", scale)
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
@@ -202,9 +208,23 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# requantize the logical shards as a single weight.
|
||||
else:
|
||||
# Dequant -> Quant with max scale so we can run per tensor.
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# If rocm, use float8_e4m3fnuz.
|
||||
if is_hip():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=layer.input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
|
||||
weight_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
@@ -214,8 +234,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
@@ -346,10 +364,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If rocm, use float8_e4m3fnuz as dtype
|
||||
fp8_dtype = torch.float8_e4m3fnuz \
|
||||
if is_hip() else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
# Re-initialize w13_scale because we directly quantize
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
@@ -393,6 +413,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_input_scale.max(), requires_grad=False)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False)
|
||||
# If rocm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w13_weight, layer.w13_weight_scale,
|
||||
layer.w13_input_scale)
|
||||
w2_weight, w2_weight_scale, w2_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale,
|
||||
layer.w2_input_scale)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
w13_input_scale, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
|
||||
requires_grad=False)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
|
||||
Reference in New Issue
Block a user