[Feature] Batch-Invariant Support for FA2 and LoRA (#30018)
Signed-off-by: quanliu <18646313696@163.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -10,6 +10,7 @@ from utils import (
|
||||
BACKENDS,
|
||||
_extract_step_logprobs,
|
||||
_random_prompt,
|
||||
is_device_capability_below_90,
|
||||
resolve_model_name,
|
||||
skip_unsupported,
|
||||
)
|
||||
@@ -17,6 +18,8 @@ from utils import (
|
||||
import vllm.model_executor.layers.batch_invariant as batch_invariant
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.timeout(1000)
|
||||
@@ -190,6 +193,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16", # not everything is supported
|
||||
gpu_memory_utilization=0.9,
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
)
|
||||
|
||||
# Use more realistic prompts for better token generation
|
||||
@@ -393,6 +397,8 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=2048,
|
||||
dtype="bfloat16",
|
||||
enable_prefix_caching=False,
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
)
|
||||
|
||||
prompt = "the capital of france is"
|
||||
@@ -459,6 +465,7 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
)
|
||||
|
||||
# build ragged prompts to change shapes significantly across BS=1 vs BS=N
|
||||
@@ -682,6 +689,7 @@ def test_decode_logprobs_match_prefill_logprobs(
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
)
|
||||
|
||||
# Use a few test prompts
|
||||
@@ -925,6 +933,8 @@ def LLM_with_max_seqs(
|
||||
max_model_len=max_model_len,
|
||||
dtype="bfloat16",
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
enable_prefix_caching=False,
|
||||
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
|
||||
# Enable for MOE models
|
||||
# enable_expert_parallel=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user