[Mamba] Add stochastic rounding support (#35753)
Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn,
|
||||
selective_state_update,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
|
||||
@@ -429,6 +430,59 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("philox_rounds", [0, 4])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("dstate", [16, 64])
|
||||
@pytest.mark.parametrize("dim", [2048, 4096])
|
||||
@pytest.mark.skipif(
|
||||
not (
|
||||
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
|
||||
),
|
||||
reason="Stochastic rounding in triton is only supported"
|
||||
" on compute capability 10.0 CUDA devices.",
|
||||
)
|
||||
def test_selective_state_update_stochastic_rounding(dim, dstate, has_z, philox_rounds):
|
||||
device = "cuda"
|
||||
rtol, atol = 5e-3, 1e-1
|
||||
# set seed
|
||||
set_random_seed(0)
|
||||
batch_size = 1
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=torch.float16, device=device)
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=torch.bfloat16)
|
||||
out = torch.empty_like(x)
|
||||
dt = torch.randn(batch_size, dim, device=device, dtype=torch.bfloat16)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
B = torch.randn(batch_size, dstate, device=device)
|
||||
C = torch.randn(batch_size, dstate, device=device)
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
# Reference uses fp32 state to get ground truth
|
||||
state_ref = state.float()
|
||||
selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
out=out,
|
||||
enable_stochastic_rounding=True,
|
||||
cache_philox_rounds=philox_rounds,
|
||||
)
|
||||
out_ref = selective_state_update_ref(
|
||||
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
|
||||
)
|
||||
|
||||
assert state.dtype == torch.float16
|
||||
assert torch.allclose(state, state_ref.to(torch.float16), rtol=rtol, atol=atol)
|
||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("has_z", [False, True])
|
||||
@pytest.mark.parametrize("dstate", [16, 64])
|
||||
|
||||
Reference in New Issue
Block a user