[small][batch invariance] Rename the env and internal flags to simplify usage (#26855)
Signed-off-by: Bram Wasti <bwasti@meta.com>
This commit is contained in:
@@ -16,7 +16,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
@@ -841,7 +841,7 @@ def get_moe_configs(
|
||||
"""
|
||||
|
||||
# Avoid optimizing for the batch invariant case. Use default config
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
return None
|
||||
|
||||
# First look up if an optimized configuration is available in the configs
|
||||
@@ -976,7 +976,7 @@ def get_default_config(
|
||||
dtype: str | None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
@@ -1136,7 +1136,7 @@ def fused_topk_bias(
|
||||
) + e_score_correction_bias.unsqueeze(0)
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_kernel_override_batch_invariant()
|
||||
use_sorted = vllm_is_batch_invariant()
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if renormalize:
|
||||
@@ -1200,7 +1200,7 @@ def grouped_topk(
|
||||
) # [n, n_group]
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_kernel_override_batch_invariant()
|
||||
use_sorted = vllm_is_batch_invariant()
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
|
||||
Reference in New Issue
Block a user