[Performance][B200] Fix deepgemm prologue (#27897)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-11-12 16:13:03 -05:00
committed by GitHub
parent 478ee511de
commit 74a9a9faad
6 changed files with 163 additions and 48 deletions

View File

@@ -60,11 +60,10 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
expert_weight_is_col_major,
deepgemm_post_process_fp8_weight_block,
maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy,
requant_weight_ue8m0_inplace,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@@ -94,7 +93,6 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import (
get_col_major_tma_aligned_tensor,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
@@ -846,15 +844,31 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale_inv
if self.allow_deep_gemm:
dg_w13_weight, dg_w13_weight_scale_inv = (
deepgemm_post_process_fp8_weight_block(
wq=layer.w13_weight.data,
ws=layer.w13_weight_scale_inv.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv
)
dg_w2_weight, dg_w2_weight_scale_inv = (
deepgemm_post_process_fp8_weight_block(
wq=layer.w2_weight.data,
ws=layer.w2_weight_scale_inv.data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
)
layer.w13_weight = Parameter(dg_w13_weight, requires_grad=False)
layer.w13_weight_scale_inv = Parameter(
dg_w13_weight_scale_inv, requires_grad=False
)
layer.w2_weight = Parameter(dg_w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(
dg_w2_weight_scale_inv, requires_grad=False
)
# If checkpoint is fp16, quantize in place.
elif not self.quant_config.is_checkpoint_fp8_serialized:
@@ -990,31 +1004,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
if is_deep_gemm_e8m0_used() and self.block_quant:
assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
layer.w13_weight.data,
layer.w13_weight_scale_inv.data,
block_sz,
)
requant_weight_ue8m0_inplace(
layer.w2_weight.data,
layer.w2_weight_scale_inv.data,
block_sz,
)
# Ensure column-major TMA alignment expected by DeepGEMM.
if expert_weight_is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w13_weight_scale_inv
)
if expert_weight_is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
layer.w2_weight_scale_inv
)
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
if (
self.rocm_aiter_moe_enabled
@@ -1037,7 +1026,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
BatchedTritonOrDeepGemmExperts,
BatchedDeepGemmExperts,
BatchedTritonExperts,
TritonOrDeepGemmExperts,
)
@@ -1053,20 +1043,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
experts_impl = (
BatchedDeepGemmExperts if self.allow_deep_gemm else BatchedTritonExperts
)
logger.debug(
"BatchedTritonOrDeepGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
"%s(%s): max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
experts_impl.__name__,
self.__class__.__name__,
max_num_tokens_per_rank,
self.weight_block_size,
False,
)
return BatchedTritonOrDeepGemmExperts(
return experts_impl(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
experts = select_cutlass_fp8_gemm_impl(
self.moe,