Deepseek-v3 Batch Invariant on 8xH100 (#26609)

Signed-off-by: Bram Wasti <bwasti@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Bram Wasti
2025-10-15 22:06:02 -07:00
committed by GitHub
parent 785d8b6410
commit 7d8975de84
21 changed files with 1567 additions and 102 deletions

View File

@@ -15,6 +15,9 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig,
@@ -837,6 +840,10 @@ def get_moe_configs(
be picked and the associated configuration chosen to invoke the kernel.
"""
# Avoid optimizing for the batch invariant case. Use default config
if vllm_kernel_override_batch_invariant():
return None
# First look up if an optimized configuration is available in the configs
# directory
block_shape = [block_n, block_k] if block_n and block_k else None
@@ -969,6 +976,15 @@ def get_default_config(
dtype: str | None,
block_shape: list[int] | None = None,
) -> dict[str, int]:
if vllm_kernel_override_batch_invariant():
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
return config
if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
# BLOCK_SIZE_K must be divisible by block_shape[1]
@@ -1118,7 +1134,10 @@ def fused_topk_bias(
scores_for_choice = scores.view(
-1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0)
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_kernel_override_batch_invariant()
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
@@ -1179,7 +1198,10 @@ def grouped_topk(
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_kernel_override_batch_invariant()
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
@@ -1192,11 +1214,13 @@ def grouped_topk(
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=use_sorted
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)