[AMD][FP8] Using MI300 FP8 format on ROCm for block_quant (#12134)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-01-17 12:12:26 -05:00
committed by GitHub
parent 54cacf008f
commit b5b57e301e
2 changed files with 44 additions and 3 deletions

View File

@@ -247,6 +247,15 @@ class Fp8LinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
if current_platform.is_rocm():
weight, weight_scale, _ = \
normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale)
layer.weight = Parameter(weight, requires_grad=False)
layer.weight_scale_inv = Parameter(weight_scale,
requires_grad=False)
return
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
@@ -495,6 +504,30 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
if self.block_quant:
if current_platform.is_rocm():
w13_weight, w13_weight_scale_inv, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale_inv,
layer.w13_input_scale)
w2_weight, w2_weight_scale_inv, w2_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale_inv,
layer.w2_input_scale)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale_inv, requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale_inv, requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False)
return
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized: