refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)
Signed-off-by: Liao, Wei <wei.liao@intel.com>
This commit is contained in:
@@ -13,6 +13,9 @@ from utils import skip_unsupported
|
||||
|
||||
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@@ -34,7 +37,7 @@ def test_rms_norm_batch_invariant_vs_standard(
|
||||
equivalent results to the standard CUDA implementation across various
|
||||
configurations.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(DEVICE_TYPE)
|
||||
|
||||
# Create test input and weight
|
||||
torch.manual_seed(42)
|
||||
@@ -81,7 +84,7 @@ def test_rms_norm_3d_input(
|
||||
Ensures that the batch-invariant RMS norm correctly handles multi-dimensional
|
||||
inputs that are common in transformer models.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(DEVICE_TYPE)
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
|
||||
@@ -120,7 +123,7 @@ def test_rms_norm_numerical_stability(default_vllm_config):
|
||||
Ensures that both implementations handle edge cases like very small or large
|
||||
values without producing NaN or Inf.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(DEVICE_TYPE)
|
||||
dtype = torch.float16
|
||||
eps = 1e-6
|
||||
hidden_size = 2048
|
||||
@@ -179,7 +182,7 @@ def test_rms_norm_formula(default_vllm_config):
|
||||
|
||||
Verifies: output = input / sqrt(mean(input^2) + eps) * weight
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(DEVICE_TYPE)
|
||||
dtype = torch.float32 # Use float32 for higher precision in formula check
|
||||
eps = 1e-6
|
||||
hidden_size = 1024
|
||||
@@ -214,7 +217,7 @@ def test_rms_norm_different_hidden_sizes(default_vllm_config, hidden_size: int):
|
||||
The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it
|
||||
correctly handles hidden sizes both smaller and larger than the block size.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(DEVICE_TYPE)
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
batch_size = 16
|
||||
@@ -251,7 +254,7 @@ def test_rms_norm_determinism(default_vllm_config):
|
||||
Runs the same input through the kernel multiple times and verifies
|
||||
identical outputs.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(DEVICE_TYPE)
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
hidden_size = 4096
|
||||
@@ -283,7 +286,7 @@ if __name__ == "__main__":
|
||||
# Run a quick smoke test
|
||||
print("Running quick smoke test of RMS norm implementations...")
|
||||
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(DEVICE_TYPE)
|
||||
batch_size = 8
|
||||
hidden_size = 4096
|
||||
dtype = torch.bfloat16
|
||||
|
||||
Reference in New Issue
Block a user