[Misc] Removed force_fp8_e4m3fnuz from FP8LinearOp (#23725)
Signed-off-by: Julien Lin <jullin@nvidia.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -92,13 +92,13 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: PTPCFp8Config):
|
||||
assert current_platform.is_rocm(), \
|
||||
"PTPCFp8LinearMethod is only supported on ROCm."
|
||||
super().__init__(quant_config=quant_config)
|
||||
# Force weight quantization
|
||||
self.quant_config.is_checkpoint_fp8_serialized = False
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False,
|
||||
act_quant_group_shape=GroupShape.PER_TOKEN,
|
||||
force_fp8_e4m3fnuz=True)
|
||||
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
|
||||
@@ -355,12 +355,10 @@ class Fp8LinearOp:
|
||||
def __init__(self,
|
||||
act_quant_static: bool,
|
||||
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
|
||||
pad_output: Optional[bool] = None,
|
||||
force_fp8_e4m3fnuz: bool = False):
|
||||
pad_output: Optional[bool] = None):
|
||||
if current_platform.is_rocm():
|
||||
self.preferred_backend = "rocm"
|
||||
elif current_platform.is_cuda(
|
||||
) and not force_fp8_e4m3fnuz and cutlass_fp8_supported():
|
||||
elif current_platform.is_cuda() and cutlass_fp8_supported():
|
||||
if has_flashinfer() and current_platform.has_device_capability(
|
||||
100):
|
||||
self.preferred_backend = "flashinfer"
|
||||
|
||||
Reference in New Issue
Block a user