[small][batch invariance] Rename the env and internal flags to simplify usage (#26855)
Signed-off-by: Bram Wasti <bwasti@meta.com>
This commit is contained in:
@@ -19,14 +19,14 @@ hopper_only = pytest.mark.skipif(
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_batch_invariant_mode():
|
||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
|
||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "1"
|
||||
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
||||
yield
|
||||
# Restore original value after test
|
||||
if old_value is None:
|
||||
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
|
||||
os.environ.pop("VLLM_BATCH_INVARIANT", None)
|
||||
else:
|
||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
@@ -231,10 +231,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
||||
# For batch invariance, disable custom all-reduce to ensure deterministic
|
||||
# all-reduce operations (custom all-reduce may not be deterministic)
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
|
||||
disable_custom_ar = vllm_kernel_override_batch_invariant()
|
||||
disable_custom_ar = vllm_is_batch_invariant()
|
||||
|
||||
if disable_custom_ar:
|
||||
print(f"\n{'=' * 80}")
|
||||
@@ -494,8 +494,8 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
|
||||
# CRITICAL: Disable batch invariance for this test
|
||||
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
|
||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "0"
|
||||
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "0"
|
||||
|
||||
try:
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
@@ -687,9 +687,9 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
||||
finally:
|
||||
# Restore original value
|
||||
if old_value is None:
|
||||
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
|
||||
os.environ.pop("VLLM_BATCH_INVARIANT", None)
|
||||
else:
|
||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
||||
|
||||
|
||||
@hopper_only
|
||||
@@ -718,10 +718,10 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
|
||||
disable_custom_ar = vllm_kernel_override_batch_invariant()
|
||||
disable_custom_ar = vllm_is_batch_invariant()
|
||||
|
||||
if disable_custom_ar:
|
||||
print(f"\n{'=' * 80}")
|
||||
|
||||
Reference in New Issue
Block a user