[Feature] Batch invariant: Enable TRITON_MLA without prefix-caching (#29125)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-12-08 19:31:57 -05:00
committed by GitHub
parent 9d6235ca9a
commit d9417096d1
5 changed files with 43 additions and 7 deletions

View File

@@ -185,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
# enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16", # not everything is supported
@@ -393,7 +393,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
gpu_memory_utilization=0.9,
max_model_len=2048,
dtype="bfloat16",
enable_prefix_caching=False,
)
prompt = "the capital of france is"
@@ -457,7 +456,6 @@ def test_logprobs_without_batch_invariance_should_fail(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
@@ -681,7 +679,6 @@ def test_decode_logprobs_match_prefill_logprobs(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
@@ -928,7 +925,6 @@ def LLM_with_max_seqs(
max_model_len=max_model_len,
dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
# Enable for MOE models
# enable_expert_parallel=True,
)