Deepseek-v3 Batch Invariant on 8xH100 (#26609)
Signed-off-by: Bram Wasti <bwasti@meta.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
346
tests/v1/generation/test_rms_norm_batch_invariant.py
Normal file
346
tests/v1/generation/test_rms_norm_batch_invariant.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test batch-invariant RMS normalization against standard implementations.
|
||||
|
||||
This test compares the Triton-based batch-invariant RMS norm implementation
|
||||
with the standard CUDA-based implementation to ensure numerical accuracy.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [1, 4, 16, 64])
|
||||
@pytest.mark.parametrize("hidden_size", [512, 2048, 4096, 8192])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("eps", [1e-6, 1e-5])
|
||||
def test_rms_norm_batch_invariant_vs_standard(
|
||||
batch_size: int, hidden_size: int, dtype: torch.dtype, eps: float
|
||||
):
|
||||
"""
|
||||
Compare batch-invariant Triton RMS norm against standard CUDA implementation.
|
||||
|
||||
Tests that the Triton-based batch-invariant RMS norm produces numerically
|
||||
equivalent results to the standard CUDA implementation across various
|
||||
configurations.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Create test input and weight
|
||||
torch.manual_seed(42)
|
||||
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
# Standard implementation (CUDA ops)
|
||||
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||
rms_norm_layer.weight.data = weight.clone()
|
||||
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation (Triton)
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
|
||||
# Compare outputs
|
||||
# Use looser tolerance for bfloat16 due to its lower precision
|
||||
if dtype == torch.bfloat16:
|
||||
rtol, atol = 1e-1, 1e-1 # 10% relative tolerance for bfloat16
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2 # 1% for float16/float32
|
||||
|
||||
torch.testing.assert_close(
|
||||
triton_output,
|
||||
standard_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"RMS norm mismatch for batch_size={batch_size}, "
|
||||
f"hidden_size={hidden_size}, "
|
||||
f"dtype={dtype}, eps={eps}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [1, 16, 128])
|
||||
@pytest.mark.parametrize("seq_len", [1, 32, 512])
|
||||
@pytest.mark.parametrize("hidden_size", [2048, 4096])
|
||||
def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
|
||||
"""
|
||||
Test RMS norm with 3D input tensors (batch, seq_len, hidden_size).
|
||||
|
||||
Ensures that the batch-invariant RMS norm correctly handles multi-dimensional
|
||||
inputs that are common in transformer models.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
|
||||
torch.manual_seed(42)
|
||||
input_tensor = torch.randn(
|
||||
batch_size, seq_len, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
# Standard implementation
|
||||
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||
rms_norm_layer.weight.data = weight.clone()
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
|
||||
# Use looser tolerance for bfloat16
|
||||
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
|
||||
|
||||
torch.testing.assert_close(
|
||||
triton_output,
|
||||
standard_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"RMS norm mismatch for 3D input with batch_size={batch_size}, "
|
||||
f"seq_len={seq_len}, hidden_size={hidden_size}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
def test_rms_norm_numerical_stability():
|
||||
"""
|
||||
Test RMS norm numerical stability with extreme values.
|
||||
|
||||
Ensures that both implementations handle edge cases like very small or large
|
||||
values without producing NaN or Inf.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16
|
||||
eps = 1e-6
|
||||
hidden_size = 2048
|
||||
|
||||
# Test cases with extreme values
|
||||
test_cases = [
|
||||
# Very small values
|
||||
torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e-5,
|
||||
# Very large values
|
||||
torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e4,
|
||||
# Mixed small and large
|
||||
torch.randn(4, hidden_size, dtype=dtype, device=device) * 100,
|
||||
# Values near zero
|
||||
torch.randn(4, hidden_size, dtype=dtype, device=device) * 1e-6,
|
||||
]
|
||||
|
||||
weight = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
for idx, input_tensor in enumerate(test_cases):
|
||||
# Standard implementation
|
||||
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||
rms_norm_layer.weight.data = weight.clone()
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
|
||||
# Check for NaN or Inf
|
||||
assert not torch.isnan(standard_output).any(), (
|
||||
f"Standard RMS norm produced NaN for test case {idx}"
|
||||
)
|
||||
assert not torch.isinf(standard_output).any(), (
|
||||
f"Standard RMS norm produced Inf for test case {idx}"
|
||||
)
|
||||
assert not torch.isnan(triton_output).any(), (
|
||||
f"Triton RMS norm produced NaN for test case {idx}"
|
||||
)
|
||||
assert not torch.isinf(triton_output).any(), (
|
||||
f"Triton RMS norm produced Inf for test case {idx}"
|
||||
)
|
||||
|
||||
# Compare outputs - very lenient for extreme values with float16
|
||||
torch.testing.assert_close(
|
||||
triton_output,
|
||||
standard_output,
|
||||
rtol=2e-1, # 20% tolerance for extreme values
|
||||
atol=2e-1,
|
||||
msg=f"RMS norm mismatch for extreme value test case {idx}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
def test_rms_norm_formula():
|
||||
"""
|
||||
Test that RMS norm follows the correct mathematical formula.
|
||||
|
||||
Verifies: output = input / sqrt(mean(input^2) + eps) * weight
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float32 # Use float32 for higher precision in formula check
|
||||
eps = 1e-6
|
||||
hidden_size = 1024
|
||||
|
||||
torch.manual_seed(42)
|
||||
input_tensor = torch.randn(8, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
# Compute expected output using the formula
|
||||
variance = (input_tensor.pow(2).mean(dim=-1, keepdim=True)).to(dtype)
|
||||
expected_output = input_tensor * torch.rsqrt(variance + eps) * weight
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
|
||||
# Compare against formula
|
||||
torch.testing.assert_close(
|
||||
triton_output,
|
||||
expected_output,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
msg="Triton RMS norm doesn't match expected formula",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
|
||||
def test_rms_norm_different_hidden_sizes(hidden_size: int):
|
||||
"""
|
||||
Test RMS norm with various hidden sizes to ensure block size handling.
|
||||
|
||||
The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it
|
||||
correctly handles hidden sizes both smaller and larger than the block size.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
batch_size = 16
|
||||
|
||||
torch.manual_seed(42)
|
||||
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
# Standard implementation
|
||||
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||
rms_norm_layer.weight.data = weight.clone()
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
|
||||
# Use looser tolerance for bfloat16
|
||||
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
|
||||
|
||||
torch.testing.assert_close(
|
||||
triton_output,
|
||||
standard_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"RMS norm mismatch for hidden_size={hidden_size}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
def test_rms_norm_determinism():
|
||||
"""
|
||||
Test that batch-invariant RMS norm produces deterministic results.
|
||||
|
||||
Runs the same input through the kernel multiple times and verifies
|
||||
identical outputs.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
hidden_size = 4096
|
||||
batch_size = 32
|
||||
|
||||
torch.manual_seed(42)
|
||||
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
# Run multiple times
|
||||
outputs = []
|
||||
for _ in range(5):
|
||||
output = triton_rms_norm(input_tensor.clone(), weight, eps=eps)
|
||||
outputs.append(output)
|
||||
|
||||
# All outputs should be identical
|
||||
reference = outputs[0]
|
||||
for idx, output in enumerate(outputs[1:], start=1):
|
||||
torch.testing.assert_close(
|
||||
output,
|
||||
reference,
|
||||
rtol=0.0,
|
||||
atol=0.0,
|
||||
msg=f"RMS norm not deterministic: run {idx} differs from reference",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run a quick smoke test
|
||||
print("Running quick smoke test of RMS norm implementations...")
|
||||
|
||||
device = torch.device("cuda")
|
||||
batch_size = 8
|
||||
hidden_size = 4096
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
|
||||
torch.manual_seed(42)
|
||||
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
# Standard implementation
|
||||
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||
rms_norm_layer.weight.data = weight.clone()
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
|
||||
# Compare
|
||||
max_diff = (triton_output - standard_output).abs().max().item()
|
||||
mean_diff = (triton_output - standard_output).abs().mean().item()
|
||||
|
||||
print(f"Max difference: {max_diff:.6e}")
|
||||
print(f"Mean difference: {mean_diff:.6e}")
|
||||
print(f"Standard output sample: {standard_output[0, :5].tolist()}")
|
||||
print(f"Triton output sample: {triton_output[0, :5].tolist()}")
|
||||
|
||||
if max_diff < 1e-3:
|
||||
print("✓ Smoke test passed!")
|
||||
else:
|
||||
print("✗ Smoke test failed - differences too large")
|
||||
Reference in New Issue
Block a user