[Mamba] Add stochastic rounding support (#35753)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755
2026-03-30 19:33:49 +03:00
committed by GitHub
parent dbdd9ae067
commit 8e6293e838
7 changed files with 166 additions and 2 deletions

View File

@@ -12,6 +12,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn,
selective_state_update,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
@@ -429,6 +430,59 @@ def test_selective_state_update(dim, dstate, has_z, itype):
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("philox_rounds", [0, 4])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
@pytest.mark.skipif(
not (
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
),
reason="Stochastic rounding in triton is only supported"
" on compute capability 10.0 CUDA devices.",
)
def test_selective_state_update_stochastic_rounding(dim, dstate, has_z, philox_rounds):
device = "cuda"
rtol, atol = 5e-3, 1e-1
# set seed
set_random_seed(0)
batch_size = 1
state = torch.randn(batch_size, dim, dstate, dtype=torch.float16, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=torch.bfloat16)
out = torch.empty_like(x)
dt = torch.randn(batch_size, dim, device=device, dtype=torch.bfloat16)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
# Reference uses fp32 state to get ground truth
state_ref = state.float()
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
out=out,
enable_stochastic_rounding=True,
cache_philox_rounds=philox_rounds,
)
out_ref = selective_state_update_ref(
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
)
assert state.dtype == torch.float16
assert torch.allclose(state, state_ref.to(torch.float16), rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])

View File

@@ -109,12 +109,20 @@ class CacheConfig:
mamba_cache_mode: MambaCacheMode = "none"
"""The cache strategy for Mamba layers.
- "none": set when prefix caching is disabled.
- "all": cache the mamba state of all tokens at position i * block_size. This is
- "all": cache the mamba state of all tokens at position i * block_size. This is
the default behavior (for models that support it) when prefix caching is
enabled.
- "align": only cache the mamba state of the last token of each scheduler step and
when the token is at position i * block_size.
"""
enable_mamba_cache_stochastic_rounding: bool = False
"""Enable stochastic rounding when writing SSM state to fp16 cache.
Uses random bits to unbias the rounding error, which can improve
numerical stability for long sequences."""
mamba_cache_philox_rounds: int = 0
"""Number of Philox PRNG rounds for stochastic rounding random number
generation. 0 uses the Triton default. Higher values improve randomness
quality at the cost of compute."""
# Will be set after profiling.
num_gpu_blocks: int | None = field(default=None, init=False)
@@ -231,3 +239,29 @@ class CacheConfig:
"scaling factor."
)
return cache_dtype
def __post_init__(self):
if self.enable_mamba_cache_stochastic_rounding:
from vllm.platforms import current_platform
if not current_platform.is_cuda():
raise ValueError(
"Stochastic rounding for Mamba cache is only supported "
"on NVIDIA CUDA platforms. Please do not specify "
"`--enable-mamba-cache-stochastic-rounding`."
)
if not current_platform.is_device_capability_family(100):
raise ValueError(
"Stochastic rounding for Mamba cache requires compute "
"capability 10.0 (data center Blackwell). The `cvt.rs` PTX "
"instruction is not supported on your GPU. Please do not specify "
"`--enable-mamba-cache-stochastic-rounding`."
)
if self.mamba_ssm_cache_dtype != "float16":
raise ValueError(
"Stochastic rounding for Mamba cache requires "
"the SSM cache to be float16. Please set it explicitly, "
"by specifying `--mamba-ssm-cache-dtype float16`, or disable "
"stochastic rounding by not specifying "
"`--enable-mamba-cache-stochastic-rounding`."
)

View File

@@ -604,6 +604,10 @@ class EngineArgs:
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
enable_mamba_cache_stochastic_rounding: bool = (
CacheConfig.enable_mamba_cache_stochastic_rounding
)
mamba_cache_philox_rounds: int = CacheConfig.mamba_cache_philox_rounds
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
@@ -1024,6 +1028,13 @@ class EngineArgs:
cache_group.add_argument(
"--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
)
cache_group.add_argument(
"--enable-mamba-cache-stochastic-rounding",
**cache_kwargs["enable_mamba_cache_stochastic_rounding"],
)
cache_group.add_argument(
"--mamba-cache-philox-rounds", **cache_kwargs["mamba_cache_philox_rounds"]
)
cache_group.add_argument(
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
)
@@ -1590,6 +1601,8 @@ class EngineArgs:
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
mamba_block_size=self.mamba_block_size,
mamba_cache_mode=self.mamba_cache_mode,
enable_mamba_cache_stochastic_rounding=self.enable_mamba_cache_stochastic_rounding,
mamba_cache_philox_rounds=self.mamba_cache_philox_rounds,
kv_offloading_size=self.kv_offloading_size,
kv_offloading_backend=self.kv_offloading_backend,
)

View File

@@ -428,6 +428,8 @@ class MambaMixer(MambaBase, PluggableLayer):
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=scan_outputs_d,
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
)
scan_outputs_d = scan_outputs_d.transpose(0, 1)

View File

@@ -888,6 +888,8 @@ class MambaMixer2(MambaBase, PluggableLayer):
num_accepted_tokens=num_accepted_tokens,
cu_seqlens=query_start_loc_d,
is_blackwell=self.is_blackwell,
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:

View File

@@ -28,6 +28,21 @@ else:
return dt
@triton.jit
def convert_rs_fp16x2(x: tl.tensor, rand: tl.tensor) -> tl.tensor:
y = tl.inline_asm_elementwise(
asm="""{
cvt.rs.f16x2.f32 $0, $2, $1, $3;
}""",
constraints="=r,r,r,r,r",
args=(x, rand),
dtype=tl.float16,
is_pure=True,
pack=2,
)
return y
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@@ -48,6 +63,7 @@ else:
def _selective_scan_update_kernel(
# Pointers to matrices
state_ptr,
rand_seed_ptr,
x_ptr,
dt_ptr,
dt_bias_ptr,
@@ -113,6 +129,8 @@ def _selective_scan_update_kernel(
IS_SPEC_DECODING: tl.constexpr,
IS_VARLEN: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
USE_RS_ROUNDING: tl.constexpr,
PHILOX_ROUNDS: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
@@ -267,7 +285,35 @@ def _selective_scan_update_kernel(
z_ptr += stride_z_batch
if not IS_SPEC_DECODING:
tl.store(dst_state_ptrs, state.to(dst_state_ptrs.dtype.element_ty), mask=mask)
if USE_RS_ROUNDING:
# Load random seed
rand_seed = tl.load(rand_seed_ptr)
# Generate random offsets for each element in state
if HAS_STATE_BATCH_INDICES:
rand_offsets = (
state_batch_idx * stride_state_batch + pid_h * stride_state_head
)
else:
rand_offsets = pid_b * stride_state_batch + pid_h * stride_state_head
rand_offsets += (
offs_m[:, None] * stride_state_dim
+ offs_n[None, :] * stride_state_dstate
)
# Generate random 32-bits for each element in state
if PHILOX_ROUNDS > 0:
rand = tl.randint(rand_seed, rand_offsets, PHILOX_ROUNDS)
else:
rand = tl.randint(rand_seed, rand_offsets)
# Convert state to fp16 with RS rounding
state = convert_rs_fp16x2(state, rand)
tl.static_assert(state.dtype == tl.float16, "state must be fp16")
tl.static_assert(
dst_state_ptrs.dtype.element_ty == tl.float16,
"dst_state_ptrs must be fp16",
)
else:
state = state.to(dst_state_ptrs.dtype.element_ty)
tl.store(dst_state_ptrs, state, mask=mask)
def selective_state_update(
@@ -288,6 +334,8 @@ def selective_state_update(
num_accepted_tokens=None,
cu_seqlens=None,
is_blackwell=False,
enable_stochastic_rounding=False,
cache_philox_rounds=0,
):
"""
Argument:
@@ -419,9 +467,16 @@ def selective_state_update(
and dt.stride(-1) == 0
and dt_bias.stride(-1) == 0
)
rand_seed = (
torch.randint(0, 2**32, (1,), device=state.device)
if enable_stochastic_rounding
else None
)
with torch.accelerator.device_index(x.device.index):
_selective_scan_update_kernel[grid](
state,
rand_seed,
x,
dt,
dt_bias,
@@ -476,6 +531,8 @@ def selective_state_update(
tie_hdim,
BLOCK_SIZE_M,
num_warps=num_warps,
USE_RS_ROUNDING=enable_stochastic_rounding,
PHILOX_ROUNDS=cache_philox_rounds,
)

View File

@@ -445,6 +445,8 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer):
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding,
cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds,
)
# 4. Final linear projection