[small][batch invariance] Rename the env and internal flags to simplify usage (#26855)

Signed-off-by: Bram Wasti <bwasti@meta.com>
This commit is contained in:
Bram Wasti
2025-10-16 14:40:25 -07:00
committed by GitHub
parent 23583ee28c
commit b2f78cbad4
20 changed files with 61 additions and 61 deletions

View File

@@ -19,14 +19,14 @@ hopper_only = pytest.mark.skipif(
@pytest.fixture(autouse=True)
def enable_batch_invariant_mode():
"""Automatically enable batch invariant kernel overrides for all tests."""
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "1"
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
os.environ["VLLM_BATCH_INVARIANT"] = "1"
yield
# Restore original value after test
if old_value is None:
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
os.environ.pop("VLLM_BATCH_INVARIANT", None)
else:
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
os.environ["VLLM_BATCH_INVARIANT"] = old_value
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
@@ -231,10 +231,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
# For batch invariance, disable custom all-reduce to ensure deterministic
# all-reduce operations (custom all-reduce may not be deterministic)
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
disable_custom_ar = vllm_kernel_override_batch_invariant()
disable_custom_ar = vllm_is_batch_invariant()
if disable_custom_ar:
print(f"\n{'=' * 80}")
@@ -494,8 +494,8 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
os.environ["VLLM_ATTENTION_BACKEND"] = backend
# CRITICAL: Disable batch invariance for this test
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "0"
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
os.environ["VLLM_BATCH_INVARIANT"] = "0"
try:
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
@@ -687,9 +687,9 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
finally:
# Restore original value
if old_value is None:
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
os.environ.pop("VLLM_BATCH_INVARIANT", None)
else:
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
os.environ["VLLM_BATCH_INVARIANT"] = old_value
@hopper_only
@@ -718,10 +718,10 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
disable_custom_ar = vllm_kernel_override_batch_invariant()
disable_custom_ar = vllm_is_batch_invariant()
if disable_custom_ar:
print(f"\n{'=' * 80}")