Deepseek-v3 Batch Invariant on 8xH100 (#26609)
Signed-off-by: Bram Wasti <bwasti@meta.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,9 @@ import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -837,6 +840,10 @@ def get_moe_configs(
|
||||
be picked and the associated configuration chosen to invoke the kernel.
|
||||
"""
|
||||
|
||||
# Avoid optimizing for the batch invariant case. Use default config
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
return None
|
||||
|
||||
# First look up if an optimized configuration is available in the configs
|
||||
# directory
|
||||
block_shape = [block_n, block_k] if block_n and block_k else None
|
||||
@@ -969,6 +976,15 @@ def get_default_config(
|
||||
dtype: str | None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
}
|
||||
return config
|
||||
|
||||
if dtype == "fp8_w8a8" and block_shape is not None:
|
||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
||||
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
||||
@@ -1118,7 +1134,10 @@ def fused_topk_bias(
|
||||
scores_for_choice = scores.view(
|
||||
-1, n_routed_experts
|
||||
) + e_score_correction_bias.unsqueeze(0)
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_kernel_override_batch_invariant()
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
@@ -1179,7 +1198,10 @@ def grouped_topk(
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_kernel_override_batch_invariant()
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
@@ -1192,11 +1214,13 @@ def grouped_topk(
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_scores.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
topk_weights, topk_ids = torch.topk(
|
||||
tmp_scores, k=topk, dim=-1, sorted=use_sorted
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
Reference in New Issue
Block a user