Disable Cascade Attention for Batch Invariance (#32561)

Signed-off-by: frankwang28 <frank.wbb@hotmail.com>
Signed-off-by: Frank Wang <41319051+frankwang28@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Frank Wang
2026-01-30 07:00:46 -08:00
committed by GitHub
parent ae5b7aff2b
commit 8f5d51203b
6 changed files with 60 additions and 9 deletions

View File

@@ -1005,7 +1005,9 @@ def override_envs_for_invariance(
):
supported_backends = [
AttentionBackendEnum.FLASH_ATTN, # best supported backend
AttentionBackendEnum.FLASHINFER,
# FlashInfer temporarily disabled due to invariant CTA sizes.
# See FlashInfer issue #2424
# AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.TRITON_MLA,
# Not yet supported MLA backends

View File

@@ -18,11 +18,18 @@ from vllm.distributed import (
)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.batch_invariant import (
linear_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.model_executor.layers.utils import (
dispatch_unquantized_gemm,
is_layer_moe_router_gate,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
BlockQuantScaleParameter,
@@ -236,6 +243,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if (
vllm_is_batch_invariant()
and current_platform.is_cuda_alike()
and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
):
return linear_batch_invariant(x, layer.weight, bias)
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)

View File

@@ -16,6 +16,20 @@ from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
MOE_LAYER_ROUTER_GATE_SUFFIXES = {
"gate",
"router",
"router_gate",
"shared_expert_gate",
"expert_gate",
}
def is_layer_moe_router_gate(prefix: str) -> bool:
if not prefix:
return False
return prefix.rsplit(".", 1)[-1] in MOE_LAYER_ROUTER_GATE_SUFFIXES
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
# Shuffle weight along the last dimension so that