diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index e8cbba29f..065739cf9 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -12,6 +12,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update, ) +from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed from vllm.v1.attention.backends.utils import PAD_SLOT_ID @@ -429,6 +430,59 @@ def test_selective_state_update(dim, dstate, has_z, itype): assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) +@pytest.mark.parametrize("philox_rounds", [0, 4]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 64]) +@pytest.mark.parametrize("dim", [2048, 4096]) +@pytest.mark.skipif( + not ( + current_platform.is_cuda() and current_platform.is_device_capability_family(100) + ), + reason="Stochastic rounding in triton is only supported" + " on compute capability 10.0 CUDA devices.", +) +def test_selective_state_update_stochastic_rounding(dim, dstate, has_z, philox_rounds): + device = "cuda" + rtol, atol = 5e-3, 1e-1 + # set seed + set_random_seed(0) + batch_size = 1 + state = torch.randn(batch_size, dim, dstate, dtype=torch.float16, device=device) + x = torch.randn(batch_size, dim, device=device, dtype=torch.bfloat16) + out = torch.empty_like(x) + dt = torch.randn(batch_size, dim, device=device, dtype=torch.bfloat16) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(batch_size, dstate, device=device) + C = torch.randn(batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + # Reference uses fp32 state to get ground truth + state_ref = state.float() + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + out=out, + enable_stochastic_rounding=True, + cache_philox_rounds=philox_rounds, + ) + out_ref = selective_state_update_ref( + state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True + ) + + assert state.dtype == torch.float16 + assert torch.allclose(state, state_ref.to(torch.float16), rtol=rtol, atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + @pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("has_z", [False, True]) @pytest.mark.parametrize("dstate", [16, 64]) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index dcc93d987..1cadb4318 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -109,12 +109,20 @@ class CacheConfig: mamba_cache_mode: MambaCacheMode = "none" """The cache strategy for Mamba layers. - "none": set when prefix caching is disabled. - - "all": cache the mamba state of all tokens at position i * block_size. This is + - "all": cache the mamba state of all tokens at position i * block_size. This is the default behavior (for models that support it) when prefix caching is enabled. - "align": only cache the mamba state of the last token of each scheduler step and when the token is at position i * block_size. """ + enable_mamba_cache_stochastic_rounding: bool = False + """Enable stochastic rounding when writing SSM state to fp16 cache. + Uses random bits to unbias the rounding error, which can improve + numerical stability for long sequences.""" + mamba_cache_philox_rounds: int = 0 + """Number of Philox PRNG rounds for stochastic rounding random number + generation. 0 uses the Triton default. Higher values improve randomness + quality at the cost of compute.""" # Will be set after profiling. num_gpu_blocks: int | None = field(default=None, init=False) @@ -231,3 +239,29 @@ class CacheConfig: "scaling factor." ) return cache_dtype + + def __post_init__(self): + if self.enable_mamba_cache_stochastic_rounding: + from vllm.platforms import current_platform + + if not current_platform.is_cuda(): + raise ValueError( + "Stochastic rounding for Mamba cache is only supported " + "on NVIDIA CUDA platforms. Please do not specify " + "`--enable-mamba-cache-stochastic-rounding`." + ) + if not current_platform.is_device_capability_family(100): + raise ValueError( + "Stochastic rounding for Mamba cache requires compute " + "capability 10.0 (data center Blackwell). The `cvt.rs` PTX " + "instruction is not supported on your GPU. Please do not specify " + "`--enable-mamba-cache-stochastic-rounding`." + ) + if self.mamba_ssm_cache_dtype != "float16": + raise ValueError( + "Stochastic rounding for Mamba cache requires " + "the SSM cache to be float16. Please set it explicitly, " + "by specifying `--mamba-ssm-cache-dtype float16`, or disable " + "stochastic rounding by not specifying " + "`--enable-mamba-cache-stochastic-rounding`." + ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e1772ab1d..0c9cf2ae9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -604,6 +604,10 @@ class EngineArgs: mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size") mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode + enable_mamba_cache_stochastic_rounding: bool = ( + CacheConfig.enable_mamba_cache_stochastic_rounding + ) + mamba_cache_philox_rounds: int = CacheConfig.mamba_cache_philox_rounds additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") @@ -1024,6 +1028,13 @@ class EngineArgs: cache_group.add_argument( "--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"] ) + cache_group.add_argument( + "--enable-mamba-cache-stochastic-rounding", + **cache_kwargs["enable_mamba_cache_stochastic_rounding"], + ) + cache_group.add_argument( + "--mamba-cache-philox-rounds", **cache_kwargs["mamba_cache_philox_rounds"] + ) cache_group.add_argument( "--kv-offloading-size", **cache_kwargs["kv_offloading_size"] ) @@ -1590,6 +1601,8 @@ class EngineArgs: mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, mamba_block_size=self.mamba_block_size, mamba_cache_mode=self.mamba_cache_mode, + enable_mamba_cache_stochastic_rounding=self.enable_mamba_cache_stochastic_rounding, + mamba_cache_philox_rounds=self.mamba_cache_philox_rounds, kv_offloading_size=self.kv_offloading_size, kv_offloading_backend=self.kv_offloading_backend, ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 82ca367fb..d79af2e27 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -428,6 +428,8 @@ class MambaMixer(MambaBase, PluggableLayer): state_batch_indices=state_indices_tensor_d_input, dst_state_batch_indices=state_indices_tensor_d_output, out=scan_outputs_d, + enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding, + cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds, ) scan_outputs_d = scan_outputs_d.transpose(0, 1) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 9486e182e..041405b05 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -888,6 +888,8 @@ class MambaMixer2(MambaBase, PluggableLayer): num_accepted_tokens=num_accepted_tokens, cu_seqlens=query_start_loc_d, is_blackwell=self.is_blackwell, + enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding, + cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds, ) def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 1cd077758..793471fda 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -28,6 +28,21 @@ else: return dt +@triton.jit +def convert_rs_fp16x2(x: tl.tensor, rand: tl.tensor) -> tl.tensor: + y = tl.inline_asm_elementwise( + asm="""{ +cvt.rs.f16x2.f32 $0, $2, $1, $3; +}""", + constraints="=r,r,r,r,r", + args=(x, rand), + dtype=tl.float16, + is_pure=True, + pack=2, + ) + return y + + @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) @@ -48,6 +63,7 @@ else: def _selective_scan_update_kernel( # Pointers to matrices state_ptr, + rand_seed_ptr, x_ptr, dt_ptr, dt_bias_ptr, @@ -113,6 +129,8 @@ def _selective_scan_update_kernel( IS_SPEC_DECODING: tl.constexpr, IS_VARLEN: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, + USE_RS_ROUNDING: tl.constexpr, + PHILOX_ROUNDS: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_b = tl.program_id(axis=1) @@ -267,7 +285,35 @@ def _selective_scan_update_kernel( z_ptr += stride_z_batch if not IS_SPEC_DECODING: - tl.store(dst_state_ptrs, state.to(dst_state_ptrs.dtype.element_ty), mask=mask) + if USE_RS_ROUNDING: + # Load random seed + rand_seed = tl.load(rand_seed_ptr) + # Generate random offsets for each element in state + if HAS_STATE_BATCH_INDICES: + rand_offsets = ( + state_batch_idx * stride_state_batch + pid_h * stride_state_head + ) + else: + rand_offsets = pid_b * stride_state_batch + pid_h * stride_state_head + rand_offsets += ( + offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate + ) + # Generate random 32-bits for each element in state + if PHILOX_ROUNDS > 0: + rand = tl.randint(rand_seed, rand_offsets, PHILOX_ROUNDS) + else: + rand = tl.randint(rand_seed, rand_offsets) + # Convert state to fp16 with RS rounding + state = convert_rs_fp16x2(state, rand) + tl.static_assert(state.dtype == tl.float16, "state must be fp16") + tl.static_assert( + dst_state_ptrs.dtype.element_ty == tl.float16, + "dst_state_ptrs must be fp16", + ) + else: + state = state.to(dst_state_ptrs.dtype.element_ty) + tl.store(dst_state_ptrs, state, mask=mask) def selective_state_update( @@ -288,6 +334,8 @@ def selective_state_update( num_accepted_tokens=None, cu_seqlens=None, is_blackwell=False, + enable_stochastic_rounding=False, + cache_philox_rounds=0, ): """ Argument: @@ -419,9 +467,16 @@ def selective_state_update( and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0 ) + rand_seed = ( + torch.randint(0, 2**32, (1,), device=state.device) + if enable_stochastic_rounding + else None + ) + with torch.accelerator.device_index(x.device.index): _selective_scan_update_kernel[grid]( state, + rand_seed, x, dt, dt_bias, @@ -476,6 +531,8 @@ def selective_state_update( tie_hdim, BLOCK_SIZE_M, num_warps=num_warps, + USE_RS_ROUNDING=enable_stochastic_rounding, + PHILOX_ROUNDS=cache_philox_rounds, ) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index e38b7b166..44b120774 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -445,6 +445,8 @@ class Plamo2MambaMixer(MambaBase, PluggableLayer): dt_softplus=True, state_batch_indices=state_indices_tensor_d, out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), + enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding, + cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds, ) # 4. Final linear projection