[Misc] Removed force_fp8_e4m3fnuz from FP8LinearOp (#23725)

Signed-off-by: Julien Lin <jullin@nvidia.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
nvjullin
2025-09-04 21:25:40 +08:00
committed by GitHub
parent c9f7081f9c
commit 37241077d5
5 changed files with 45 additions and 30 deletions

View File

@@ -92,13 +92,13 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
"""
def __init__(self, quant_config: PTPCFp8Config):
assert current_platform.is_rocm(), \
"PTPCFp8LinearMethod is only supported on ROCm."
super().__init__(quant_config=quant_config)
# Force weight quantization
self.quant_config.is_checkpoint_fp8_serialized = False
self.fp8_linear = Fp8LinearOp(
act_quant_static=False,
act_quant_group_shape=GroupShape.PER_TOKEN,
force_fp8_e4m3fnuz=True)
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,

View File

@@ -355,12 +355,10 @@ class Fp8LinearOp:
def __init__(self,
act_quant_static: bool,
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
pad_output: Optional[bool] = None,
force_fp8_e4m3fnuz: bool = False):
pad_output: Optional[bool] = None):
if current_platform.is_rocm():
self.preferred_backend = "rocm"
elif current_platform.is_cuda(
) and not force_fp8_e4m3fnuz and cutlass_fp8_supported():
elif current_platform.is_cuda() and cutlass_fp8_supported():
if has_flashinfer() and current_platform.has_device_capability(
100):
self.preferred_backend = "flashinfer"