Disable Cascade Attention for Batch Invariance (#32561)

Signed-off-by: frankwang28 <frank.wbb@hotmail.com>
Signed-off-by: Frank Wang <41319051+frankwang28@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Frank Wang
2026-01-30 07:00:46 -08:00
committed by GitHub
parent ae5b7aff2b
commit 8f5d51203b
6 changed files with 60 additions and 9 deletions

View File

@@ -188,7 +188,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
max_num_seqs=32,
max_num_seqs=128,
max_model_len=8192,
dtype="bfloat16", # not everything is supported
gpu_memory_utilization=0.9,
@@ -197,12 +197,20 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
)
# Use more realistic prompts for better token generation
prompts = [_random_prompt(10, 50) for i in range(32)]
prompts = [_random_prompt(10, 50) for _ in range(32)]
# TODO: Update prompts to have ragged lengths in order to test chunked prefill
# The above tests are not currently long enough to exercise chunking.
# prompts = (
# [_random_prompt(10, 50) for _ in range(28)]
# + [_random_prompt(256, 512) for _ in range(50)]
# + [_random_prompt(2048, 4096) for _ in range(50)]
# )
sp = SamplingParams(
temperature=0.6,
top_p=1.0,
max_tokens=8,
max_tokens=16,
seed=1234,
logprobs=5,
)

View File

@@ -7,7 +7,6 @@ import pytest
import torch
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
skip_unsupported = pytest.mark.skipif(
@@ -22,8 +21,10 @@ BACKENDS: list[str] = [
"TRITON_MLA",
]
if has_flashinfer():
BACKENDS.append("FLASHINFER")
# FlashInfer temporarily disabled due to invariant CTA sizes.
# See FlashInfer issue #2424
# if has_flashinfer():
# BACKENDS.append("FLASHINFER")
if flash_attn_supports_mla():
BACKENDS.append("FLASH_ATTN_MLA")
@@ -78,9 +79,10 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
# For longer prompts, repeat context
padding_text = (
" This is an interesting topic that deserves more explanation. "
# TODO: Update to * (target_words // 10) to better align with word ratio
* (target_words // 50)
)
base_prompt = base_prompt + padding_text
base_prompt = padding_text + base_prompt
return base_prompt