Disable Cascade Attention for Batch Invariance (#32561)
Signed-off-by: frankwang28 <frank.wbb@hotmail.com> Signed-off-by: Frank Wang <41319051+frankwang28@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -188,7 +188,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
max_num_seqs=32,
|
||||
max_num_seqs=128,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16", # not everything is supported
|
||||
gpu_memory_utilization=0.9,
|
||||
@@ -197,12 +197,20 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
)
|
||||
|
||||
# Use more realistic prompts for better token generation
|
||||
prompts = [_random_prompt(10, 50) for i in range(32)]
|
||||
prompts = [_random_prompt(10, 50) for _ in range(32)]
|
||||
|
||||
# TODO: Update prompts to have ragged lengths in order to test chunked prefill
|
||||
# The above tests are not currently long enough to exercise chunking.
|
||||
# prompts = (
|
||||
# [_random_prompt(10, 50) for _ in range(28)]
|
||||
# + [_random_prompt(256, 512) for _ in range(50)]
|
||||
# + [_random_prompt(2048, 4096) for _ in range(50)]
|
||||
# )
|
||||
|
||||
sp = SamplingParams(
|
||||
temperature=0.6,
|
||||
top_p=1.0,
|
||||
max_tokens=8,
|
||||
max_tokens=16,
|
||||
seed=1234,
|
||||
logprobs=5,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
|
||||
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
@@ -22,8 +21,10 @@ BACKENDS: list[str] = [
|
||||
"TRITON_MLA",
|
||||
]
|
||||
|
||||
if has_flashinfer():
|
||||
BACKENDS.append("FLASHINFER")
|
||||
# FlashInfer temporarily disabled due to invariant CTA sizes.
|
||||
# See FlashInfer issue #2424
|
||||
# if has_flashinfer():
|
||||
# BACKENDS.append("FLASHINFER")
|
||||
|
||||
if flash_attn_supports_mla():
|
||||
BACKENDS.append("FLASH_ATTN_MLA")
|
||||
@@ -78,9 +79,10 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
# For longer prompts, repeat context
|
||||
padding_text = (
|
||||
" This is an interesting topic that deserves more explanation. "
|
||||
# TODO: Update to * (target_words // 10) to better align with word ratio
|
||||
* (target_words // 50)
|
||||
)
|
||||
base_prompt = base_prompt + padding_text
|
||||
base_prompt = padding_text + base_prompt
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
@@ -959,6 +959,18 @@ class VllmConfig:
|
||||
"when cudagraph_mode piecewise cudagraphs is used, "
|
||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||
)
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
|
||||
if (
|
||||
self.model_config
|
||||
and vllm_is_batch_invariant()
|
||||
and not self.model_config.disable_cascade_attn
|
||||
):
|
||||
self.model_config.disable_cascade_attn = True
|
||||
logger.warning_once(
|
||||
"Disabling cascade attention when VLLM_BATCH_INVARIANT is enabled.",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
if self.parallel_config.use_ubatching:
|
||||
a2a_backend = self.parallel_config.all2all_backend
|
||||
|
||||
@@ -1005,7 +1005,9 @@ def override_envs_for_invariance(
|
||||
):
|
||||
supported_backends = [
|
||||
AttentionBackendEnum.FLASH_ATTN, # best supported backend
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
# FlashInfer temporarily disabled due to invariant CTA sizes.
|
||||
# See FlashInfer issue #2424
|
||||
# AttentionBackendEnum.FLASHINFER,
|
||||
AttentionBackendEnum.FLASH_ATTN_MLA,
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
# Not yet supported MLA backends
|
||||
|
||||
@@ -18,11 +18,18 @@ from vllm.distributed import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
linear_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
from vllm.model_executor.layers.utils import (
|
||||
dispatch_unquantized_gemm,
|
||||
is_layer_moe_router_gate,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
BlockQuantScaleParameter,
|
||||
@@ -236,6 +243,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
vllm_is_batch_invariant()
|
||||
and current_platform.is_cuda_alike()
|
||||
and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
|
||||
):
|
||||
return linear_batch_invariant(x, layer.weight, bias)
|
||||
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,20 @@ from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MOE_LAYER_ROUTER_GATE_SUFFIXES = {
|
||||
"gate",
|
||||
"router",
|
||||
"router_gate",
|
||||
"shared_expert_gate",
|
||||
"expert_gate",
|
||||
}
|
||||
|
||||
|
||||
def is_layer_moe_router_gate(prefix: str) -> bool:
|
||||
if not prefix:
|
||||
return False
|
||||
return prefix.rsplit(".", 1)[-1] in MOE_LAYER_ROUTER_GATE_SUFFIXES
|
||||
|
||||
|
||||
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
|
||||
# Shuffle weight along the last dimension so that
|
||||
|
||||
Reference in New Issue
Block a user