[small][batch invariance] Rename the env and internal flags to simplify usage (#26855)
Signed-off-by: Bram Wasti <bwasti@meta.com>
This commit is contained in:
@@ -10,7 +10,7 @@ import vllm.envs as envs
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
rms_norm_batch_invariant,
|
||||
vllm_kernel_override_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
@@ -25,7 +25,7 @@ def rms_norm(
|
||||
) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
return rms_norm_batch_invariant(x, weight, variance_epsilon)
|
||||
out = torch.empty_like(x)
|
||||
ops.rms_norm(
|
||||
@@ -45,7 +45,7 @@ def fused_add_rms_norm(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
if vllm_is_batch_invariant():
|
||||
return rms_norm_batch_invariant(
|
||||
x + residual, weight, variance_epsilon
|
||||
), x + residual
|
||||
|
||||
Reference in New Issue
Block a user