[Bugfix] Fix DeepGemm E8M0 accuracy degradation for Qwen3.5 FP8 on Blackwell (#38083)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson
2026-03-26 12:21:47 +04:00
committed by GitHub
parent 71161e8b63
commit 52069012fe
10 changed files with 69 additions and 11 deletions

View File

@@ -129,6 +129,7 @@ class Fp8Config(QuantizationConfig):
f"{activation_scheme} activation scheme."
)
self.weight_block_size = weight_block_size
self.use_deep_gemm: bool | None = None
@classmethod
def get_name(cls) -> QuantizationMethods:
@@ -276,7 +277,10 @@ class Fp8LinearMethod(LinearMethodBase):
self.marlin_input_dtype = None
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
self.use_deep_gemm = is_deep_gemm_supported()
if self.quant_config.use_deep_gemm is not None:
self.use_deep_gemm = self.quant_config.use_deep_gemm
else:
self.use_deep_gemm = is_deep_gemm_supported()
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
@@ -311,6 +315,7 @@ class Fp8LinearMethod(LinearMethodBase):
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
use_deep_gemm=self.use_deep_gemm,
)
def create_weights(
@@ -428,7 +433,7 @@ class Fp8LinearMethod(LinearMethodBase):
else:
layer.input_scale = None
if self.block_quant:
if self.block_quant and self.use_deep_gemm:
maybe_post_process_fp8_weight_block(layer)
def apply(

View File

@@ -91,6 +91,7 @@ class QuantFP8(CustomOp):
if (
self.is_group_quant
and self.use_ue8m0
and self.use_deep_gemm_supported
and (DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0)
):

View File

@@ -359,10 +359,14 @@ class W8A8BlockFp8LinearOp:
act_quant_group_shape: GroupShape,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
use_deep_gemm: bool | None = None,
):
self.weight_group_shape = weight_group_shape
self.act_quant_group_shape = act_quant_group_shape
self.is_deep_gemm_supported = is_deep_gemm_supported()
if use_deep_gemm is not None:
self.is_deep_gemm_supported = use_deep_gemm
else:
self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported()