[Mamba] Add stochastic rounding support (#35753)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755
2026-03-30 19:33:49 +03:00
committed by GitHub
parent dbdd9ae067
commit 8e6293e838
7 changed files with 166 additions and 2 deletions

View File

@@ -12,6 +12,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn,
selective_state_update,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
@@ -429,6 +430,59 @@ def test_selective_state_update(dim, dstate, has_z, itype):
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("philox_rounds", [0, 4])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
@pytest.mark.skipif(
not (
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
),
reason="Stochastic rounding in triton is only supported"
" on compute capability 10.0 CUDA devices.",
)
def test_selective_state_update_stochastic_rounding(dim, dstate, has_z, philox_rounds):
device = "cuda"
rtol, atol = 5e-3, 1e-1
# set seed
set_random_seed(0)
batch_size = 1
state = torch.randn(batch_size, dim, dstate, dtype=torch.float16, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=torch.bfloat16)
out = torch.empty_like(x)
dt = torch.randn(batch_size, dim, device=device, dtype=torch.bfloat16)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
# Reference uses fp32 state to get ground truth
state_ref = state.float()
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
out=out,
enable_stochastic_rounding=True,
cache_philox_rounds=philox_rounds,
)
out_ref = selective_state_update_ref(
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
)
assert state.dtype == torch.float16
assert torch.allclose(state, state_ref.to(torch.float16), rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])