[AMD][FP8] Using MI300 FP8 format on ROCm for block_quant (#12134)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
54cacf008f
commit
b5b57e301e
@@ -247,6 +247,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
if current_platform.is_rocm():
|
||||
weight, weight_scale, _ = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale)
|
||||
layer.weight = Parameter(weight, requires_grad=False)
|
||||
layer.weight_scale_inv = Parameter(weight_scale,
|
||||
requires_grad=False)
|
||||
return
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
requires_grad=False)
|
||||
@@ -495,6 +504,30 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
if current_platform.is_rocm():
|
||||
w13_weight, w13_weight_scale_inv, w13_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w13_weight, layer.w13_weight_scale_inv,
|
||||
layer.w13_input_scale)
|
||||
w2_weight, w2_weight_scale_inv, w2_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale_inv,
|
||||
layer.w2_input_scale)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
||||
w13_weight_scale_inv, requires_grad=False)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
w13_input_scale, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
||||
w2_weight_scale_inv, requires_grad=False)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False)
|
||||
return
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
|
||||
Reference in New Issue
Block a user