[Feature] Integrate SM100 DeepGEMM support (#20087)
This commit is contained in:
@@ -34,6 +34,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
|
||||
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
||||
|
||||
@@ -1171,9 +1172,15 @@ def fused_experts(
|
||||
allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor:
|
||||
# For now, disable DeepGemm for small N (<= 512) until better
|
||||
# permute/unpermute ops are available.
|
||||
# However, on B200, we use DeepGemm for all cases becuase they only support
|
||||
# E8M0 scale, which means we requantize the weight and input to the specific
|
||||
# scale. Fallen back to cutlass or triton for some cases would cause
|
||||
# accuracy issue.
|
||||
N = w1.size(1)
|
||||
if (allow_deep_gemm and use_fp8_w8a8 and N > 512
|
||||
and _valid_deep_gemm(hidden_states, w1, w2)):
|
||||
should_use_deep_gemm = ((N > 512
|
||||
and _valid_deep_gemm(hidden_states, w1, w2))
|
||||
or is_blackwell_deep_gemm_used())
|
||||
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
|
||||
assert apply_router_weight_on_input is False
|
||||
return deep_gemm_moe_fp8(
|
||||
hidden_states=hidden_states,
|
||||
@@ -1363,7 +1370,6 @@ def fused_experts_impl(
|
||||
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||
A=curr_hidden_states,
|
||||
A_scale=a1_scale,
|
||||
|
||||
Reference in New Issue
Block a user