[Feature] Integrate SM100 DeepGEMM support (#20087)

This commit is contained in:
Wentao Ye
2025-07-10 23:18:05 -04:00
committed by GitHub
parent 5b032352cc
commit e2de455c34
16 changed files with 397 additions and 114 deletions

View File

@@ -34,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
@@ -1171,9 +1172,15 @@ def fused_experts(
allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
# However, on B200, we use DeepGemm for all cases becuase they only support
# E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue.
N = w1.size(1)
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2)):
should_use_deep_gemm = ((N > 512
and _valid_deep_gemm(hidden_states, w1, w2))
or is_blackwell_deep_gemm_used())
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8(
hidden_states=hidden_states,
@@ -1363,7 +1370,6 @@ def fused_experts_impl(
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
A=curr_hidden_states,
A_scale=a1_scale,