[Bugfix] Fix DeepGemm E8M0 accuracy degradation for Qwen3.5 FP8 on Blackwell (#38083)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -129,6 +129,7 @@ class Fp8Config(QuantizationConfig):
|
||||
f"{activation_scheme} activation scheme."
|
||||
)
|
||||
self.weight_block_size = weight_block_size
|
||||
self.use_deep_gemm: bool | None = None
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
@@ -276,7 +277,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self.marlin_input_dtype = None
|
||||
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||
self.use_deep_gemm = is_deep_gemm_supported()
|
||||
if self.quant_config.use_deep_gemm is not None:
|
||||
self.use_deep_gemm = self.quant_config.use_deep_gemm
|
||||
else:
|
||||
self.use_deep_gemm = is_deep_gemm_supported()
|
||||
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant = self.weight_block_size is not None
|
||||
@@ -311,6 +315,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
use_deep_gemm=self.use_deep_gemm,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
@@ -428,7 +433,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
if self.block_quant:
|
||||
if self.block_quant and self.use_deep_gemm:
|
||||
maybe_post_process_fp8_weight_block(layer)
|
||||
|
||||
def apply(
|
||||
|
||||
@@ -91,6 +91,7 @@ class QuantFP8(CustomOp):
|
||||
|
||||
if (
|
||||
self.is_group_quant
|
||||
and self.use_ue8m0
|
||||
and self.use_deep_gemm_supported
|
||||
and (DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0)
|
||||
):
|
||||
|
||||
@@ -359,10 +359,14 @@ class W8A8BlockFp8LinearOp:
|
||||
act_quant_group_shape: GroupShape,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
use_deep_gemm: bool | None = None,
|
||||
):
|
||||
self.weight_group_shape = weight_group_shape
|
||||
self.act_quant_group_shape = act_quant_group_shape
|
||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||
if use_deep_gemm is not None:
|
||||
self.is_deep_gemm_supported = use_deep_gemm
|
||||
else:
|
||||
self.is_deep_gemm_supported = is_deep_gemm_supported()
|
||||
self.is_hopper = current_platform.is_device_capability(90)
|
||||
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
|
||||
self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported()
|
||||
|
||||
Reference in New Issue
Block a user