[Feature] Batch-Invariant Support for FA2 and LoRA (#30018)

Signed-off-by: quanliu <18646313696@163.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
quanliu
2025-12-09 23:01:38 +08:00
committed by GitHub
parent 5c213d2899
commit 5dcd593baf
3 changed files with 23 additions and 3 deletions

View File

@@ -10,6 +10,7 @@ from utils import (
BACKENDS,
_extract_step_logprobs,
_random_prompt,
is_device_capability_below_90,
resolve_model_name,
skip_unsupported,
)
@@ -17,6 +18,8 @@ from utils import (
import vllm.model_executor.layers.batch_invariant as batch_invariant
from vllm import LLM, SamplingParams
IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
@skip_unsupported
@pytest.mark.timeout(1000)
@@ -190,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
max_model_len=8192,
dtype="bfloat16", # not everything is supported
gpu_memory_utilization=0.9,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
)
# Use more realistic prompts for better token generation
@@ -393,6 +397,8 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
gpu_memory_utilization=0.9,
max_model_len=2048,
dtype="bfloat16",
enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
)
prompt = "the capital of france is"
@@ -459,6 +465,7 @@ def test_logprobs_without_batch_invariance_should_fail(
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
)
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
@@ -682,6 +689,7 @@ def test_decode_logprobs_match_prefill_logprobs(
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
)
# Use a few test prompts
@@ -925,6 +933,8 @@ def LLM_with_max_seqs(
max_model_len=max_model_len,
dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
# Enable for MOE models
# enable_expert_parallel=True,
)