diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 8f33c7cfa..ff1d9528e 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -21,7 +21,7 @@ struct SSMParamsBase { int dim_ngroups_ratio; bool is_variable_B; bool is_variable_C; - int64_t pad_slot_id; + int64_t null_block_id; bool delta_softplus; bool cache_enabled; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index d852a0ed4..ba2f0cc61 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -118,9 +118,17 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); - const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; - // cache_index == params.pad_slot_id is defined as padding, so we exit early - if (cache_index == params.pad_slot_id){ + int cache_index; + if (cache_indices == nullptr) { + cache_index = batch_id; + } else if (params.cache_enabled) { + const int* initial_state_idx = reinterpret_cast(params.initial_state_idx_ptr); + cache_index = cache_indices[batch_id * params.cache_indices_stride + initial_state_idx[batch_id]]; + } else { + cache_index = cache_indices[batch_id]; + } + // Skip batch entries whose cache index maps to the null block (padding). + if (cache_indices != nullptr && cache_index == params.null_block_id){ return; } input_t *u = reinterpret_cast(params.u_ptr) + sequence_start_index * params.u_batch_stride @@ -527,7 +535,7 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const std::optional& cache_indices, const std::optional& has_initial_state, bool varlen, - int64_t pad_slot_id, + int64_t null_block_id, int64_t block_size, const std::optional &block_idx_first_scheduled_token, const std::optional &block_idx_last_scheduled_token, @@ -544,7 +552,7 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.dstate = dstate; params.n_groups = n_groups; params.dim_ngroups_ratio = dim / n_groups; - params.pad_slot_id = pad_slot_id; + params.null_block_id = null_block_id; params.delta_softplus = delta_softplus; @@ -658,7 +666,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, const torch::Tensor &ssm_states, // used to identify padding entries if cache_indices provided // in case of padding, the kernel will return early - int64_t pad_slot_id, + int64_t null_block_id, int64_t block_size, const std::optional &block_idx_first_scheduled_token, const std::optional &block_idx_last_scheduled_token, @@ -805,7 +813,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, cache_indices, has_initial_state, varlen, - pad_slot_id, + null_block_id, block_size, block_idx_first_scheduled_token, block_idx_last_scheduled_token, diff --git a/csrc/ops.h b/csrc/ops.h index 9194c8ff0..580fdfc6b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -298,7 +298,7 @@ void selective_scan_fwd( const std::optional& query_start_loc, const std::optional& cache_indices, const std::optional& has_initial_state, - const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size, + const torch::Tensor& ssm_states, int64_t null_block_id, int64_t block_size, const std::optional& block_idx_first_scheduled_token, const std::optional& block_idx_last_scheduled_token, const std::optional& initial_state_idx, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index ca69d175f..b7ab51c1d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -556,7 +556,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? cache_indices," "Tensor? has_initial_state," "Tensor! ssm_states," - "int pad_slot_id," + "int null_block_id," "int block_size," "Tensor? block_idx_first_scheduled_token," "Tensor? block_idx_last_scheduled_token," diff --git a/tests/kernels/mamba/test_causal_conv1d.py b/tests/kernels/mamba/test_causal_conv1d.py index 1d10bd297..0ebc527d5 100644 --- a/tests/kernels/mamba/test_causal_conv1d.py +++ b/tests/kernels/mamba/test_causal_conv1d.py @@ -12,7 +12,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_update, ) from vllm.utils.torch_utils import set_random_seed -from vllm.v1.attention.backends.utils import PAD_SLOT_ID +from vllm.v1.attention.backends.utils import NULL_BLOCK_ID def causal_conv1d_ref( @@ -122,7 +122,7 @@ def causal_conv1d_opcheck_fn( has_initial_state: torch.Tensor | None = None, conv_states: torch.Tensor | None = None, activation: str | None = "silu", - pad_slot_id: int = PAD_SLOT_ID, + null_block_id: int = NULL_BLOCK_ID, ): """ x: (batch, dim, seqlen) @@ -158,15 +158,16 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity batch = 2 x = torch.randn(batch, dim, seqlen, device=device, dtype=itype) x_ref = x.clone() - conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype) + # +1 entry to reserve index 0 as null block + conv_state = torch.randn(batch + 1, dim, width - 1, device=device, dtype=itype) weight = torch.randn(dim, width, device=device, dtype=itype) bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None - conv_state_ref = conv_state.detach().clone() + # Start indices from 1, skipping null block at index 0 + conv_state_indices = torch.arange(1, batch + 1, dtype=torch.int32, device=device) + conv_state_ref = conv_state[conv_state_indices].detach().clone() activation = None if not silu_activation else "silu" - conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device) - out = causal_conv1d_update( x, conv_state, @@ -179,7 +180,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity x_ref, conv_state_ref, weight, bias, activation=activation ) - assert torch.equal(conv_state, conv_state_ref) + assert torch.equal(conv_state[conv_state_indices], conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) @@ -215,7 +216,8 @@ def test_causal_conv1d_update_with_batch_gather( x_ref = x.clone() - conv_state_indices = torch.randperm(total_entries)[:batch_size].to( + # +1 to exclude index 0 (null block) + conv_state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to( dtype=torch.int32, device=device ) unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) @@ -223,7 +225,9 @@ def test_causal_conv1d_update_with_batch_gather( padded_state_indices = torch.concat( [ conv_state_indices, - torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + torch.as_tensor( + [NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device + ), ], dim=0, ) @@ -248,7 +252,6 @@ def test_causal_conv1d_update_with_batch_gather( bias, activation=activation, conv_state_indices=padded_state_indices, - pad_slot_id=PAD_SLOT_ID, ) out_ref = causal_conv1d_update_ref( x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation @@ -317,13 +320,19 @@ def test_causal_conv1d_varlen( has_initial_states = torch.randint( 0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device ) - state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[ - :batch_size - ] + # +1 to exclude index 0 (null block) + state_indices = ( + torch.randperm(total_entries - 1, dtype=torch.int32, device=x.device)[ + :batch_size + ] + + 1 + ) padded_state_indices = torch.concat( [ state_indices, - torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + torch.as_tensor( + [NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device + ), ], dim=-1, ) @@ -336,7 +345,6 @@ def test_causal_conv1d_varlen( cache_indices=padded_state_indices, has_initial_state=has_initial_states, activation=activation, - pad_slot_id=PAD_SLOT_ID, ) out_ref = [] @@ -345,7 +353,7 @@ def test_causal_conv1d_varlen( splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)] for i in range(len(seqlens[0])): x_s = [v[i].unsqueeze(0) for v in splits][0] - if padded_state_indices[i] == PAD_SLOT_ID: + if padded_state_indices[i] == NULL_BLOCK_ID: continue out_ref_b.append( causal_conv1d_ref( diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index 065739cf9..81715be98 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -14,7 +14,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( ) from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed -from vllm.v1.attention.backends.utils import PAD_SLOT_ID +from vllm.v1.attention.backends.utils import NULL_BLOCK_ID def selective_state_update_ref( @@ -179,7 +179,7 @@ def selective_scan_opcheck_fn( cache_indices=None, has_initial_state=None, ssm_states=None, - pad_slot_id=PAD_SLOT_ID, + null_block_id=NULL_BLOCK_ID, block_size=2048, block_idx_first_scheduled_token=None, block_idx_last_scheduled_token=None, @@ -229,7 +229,7 @@ def selective_scan_opcheck_fn( cache_indices, has_initial_state, ssm_states, - pad_slot_id, + null_block_id, block_size, block_idx_first_scheduled_token, block_idx_last_scheduled_token, @@ -351,7 +351,6 @@ def test_selective_scan( has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool) if c > 0 else None, - pad_slot_id=PAD_SLOT_ID, block_size=2048, block_idx_first_scheduled_token=None, block_idx_last_scheduled_token=None, @@ -652,15 +651,21 @@ def test_selective_scan_varlen( prev_state_shape, device=u.device, dtype=itype, requires_grad=False ) prev_state_ref = prev_state.clone() - state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[ - :batch_size - ] + # +1 to exclude index 0 (null block) + state_indices = ( + torch.randperm(total_entries - 1, dtype=torch.int32, device=u.device)[ + :batch_size + ] + + 1 + ) unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) unused_states_bool[state_indices] = False padded_state_indices = torch.concat( [ state_indices, - torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + torch.as_tensor( + [NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device + ), ], dim=-1, ) @@ -690,7 +695,7 @@ def test_selective_scan_varlen( ] for i in range(len(seqlens[0])): u_s, delta_s, B_s, C_s, z_s = (v[i].unsqueeze(0) for v in splits) - if padded_state_indices[i] == PAD_SLOT_ID: + if padded_state_indices[i] == NULL_BLOCK_ID: continue out_ref_s, _ = selective_scan_ref( u_s, @@ -758,7 +763,8 @@ def test_selective_state_update_with_batch_indices( padded_batch_size = batch_size + padding total_entries = 10 * batch_size state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) - state_indices = torch.randperm(total_entries)[:batch_size].to( + # +1 to exclude index 0 (null block) + state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to( dtype=torch.int32, device=device ) unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) @@ -766,7 +772,9 @@ def test_selective_state_update_with_batch_indices( padded_state_indices = torch.concat( [ state_indices, - torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + torch.as_tensor( + [NULL_BLOCK_ID] * padding, dtype=torch.int32, device=device + ), ], dim=0, ) @@ -793,7 +801,6 @@ def test_selective_state_update_with_batch_indices( dt_bias=dt_bias, dt_softplus=True, state_batch_indices=padded_state_indices, - pad_slot_id=PAD_SLOT_ID, out=out, ) out_ref = selective_state_update_ref( @@ -849,7 +856,8 @@ def test_selective_state_update_with_heads_with_batch_indices( state = torch.randn( total_entries, nheads, headdim, dstate, dtype=itype, device=device ) - state_indices = torch.randperm(total_entries)[:batch_size].to( + # +1 to exclude index 0 (null block) + state_indices = (torch.randperm(total_entries - 1)[:batch_size] + 1).to( dtype=torch.int32, device=device ) @@ -887,7 +895,6 @@ def test_selective_state_update_with_heads_with_batch_indices( dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices, - pad_slot_id=PAD_SLOT_ID, out=out, ) out_ref = selective_state_update_ref( @@ -935,17 +942,18 @@ def test_selective_state_update_with_num_accepted_tokens( state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device) state_batch_indices = torch.full( - (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device + (batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device ) + # Start from 1 to exclude null block at index 0 initial_state_slots = torch.randint( - 0, 15, (batch_size,), device=device, dtype=torch.int32 + 1, 15, (batch_size,), device=device, dtype=torch.int32 ) for seq_idx in range(batch_size): token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0) state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx] dst_state_batch_indices = torch.full( - (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device + (batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device ) slot_offset = 15 dst_slots_map = {} @@ -1013,7 +1021,6 @@ def test_selective_state_update_with_num_accepted_tokens( state_batch_indices=state_batch_indices, dst_state_batch_indices=dst_state_batch_indices, num_accepted_tokens=num_accepted_tokens, - pad_slot_id=PAD_SLOT_ID, ) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) @@ -1061,18 +1068,19 @@ def test_selective_state_update_varlen_with_num_accepted( state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device) state_batch_indices = torch.full( - (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device + (batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device ) + # Start from 1 to exclude null block at index 0 initial_state_slots = torch.randint( - 0, 15, (batch_size,), device=device, dtype=torch.int32 + 1, 15, (batch_size,), device=device, dtype=torch.int32 ) for seq_idx in range(batch_size): token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0) state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx] dst_state_batch_indices = torch.full( - (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device + (batch_size, max_seq_len), NULL_BLOCK_ID, dtype=torch.int32, device=device ) slot_offset = 15 @@ -1138,7 +1146,6 @@ def test_selective_state_update_varlen_with_num_accepted( state_batch_indices=state_batch_indices, dst_state_batch_indices=dst_state_batch_indices, num_accepted_tokens=num_accepted_tokens, - pad_slot_id=PAD_SLOT_ID, ) for seq_idx in range(batch_size): diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6c9ca07db..7fef4b71a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2062,7 +2062,7 @@ def selective_scan_fwd( cache_indices: torch.Tensor | None, has_initial_state: torch.Tensor | None, ssm_states: torch.Tensor, - pad_slot_id: int, + null_block_id: int, block_size: int = 1024, block_idx_first_scheduled_token: torch.Tensor | None = None, block_idx_last_scheduled_token: torch.Tensor | None = None, @@ -2084,7 +2084,7 @@ def selective_scan_fwd( cache_indices, has_initial_state, ssm_states, - pad_slot_id, + null_block_id, block_size, block_idx_first_scheduled_token, block_idx_last_scheduled_token, diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py index ffccdc122..ef7a2745a 100644 --- a/vllm/model_executor/layers/lightning_attn.py +++ b/vllm/model_executor/layers/lightning_attn.py @@ -5,6 +5,7 @@ import torch from einops import rearrange from vllm.triton_utils import tl, triton +from vllm.v1.attention.backends.utils import PAD_SLOT_ID @triton.jit @@ -602,6 +603,7 @@ def _linear_attn_decode_kernel( cache_h_stride, cache_d0_stride, cache_d1_stride, + pad_slot_id: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): """ @@ -616,8 +618,8 @@ def _linear_attn_decode_kernel( # Load slot index for the current batch slot_id = tl.load(slot_idx + pid_b).to(tl.int64) - # Skip if slot_id is -1 (padding) - if slot_id == -1: + # Skip if slot_id is PAD_SLOT_ID (padding) + if slot_id == pad_slot_id: return batch_id = pid_b @@ -727,6 +729,7 @@ def linear_decode_forward_triton( cache_h_stride, cache_d0_stride, cache_d1_stride, + pad_slot_id=PAD_SLOT_ID, BLOCK_SIZE=BLOCK_SIZE, ) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index b0c1ffb0d..a8efdc9f1 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -9,7 +9,7 @@ import numpy as np import torch from vllm.triton_utils import tl, triton -from vllm.v1.attention.backends.utils import PAD_SLOT_ID +from vllm.v1.attention.backends.utils import NULL_BLOCK_ID, PAD_SLOT_ID @triton.jit() @@ -49,12 +49,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M # others pad_slot_id: tl.constexpr, + null_block_id: tl.constexpr, # Meta-parameters HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, IS_APC_ENABLED: tl.constexpr, - USE_PAD_SLOT: tl.constexpr, + HAS_NULL_BLOCK: tl.constexpr, NP2_STATELEN: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, @@ -133,9 +134,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching conv_state_indices_ptr + idx_seq * stride_cache_indices + conv_state_init_index ).to(tl.int64) - if USE_PAD_SLOT: # noqa - if conv_states_input_coord == pad_slot_id: - # not processing as this is not the actual sequence + if HAS_NULL_BLOCK: # noqa + if conv_states_input_coord == null_block_id: + # not processing as this is a null block (padding) return conv_states_base = ( conv_states_ptr @@ -475,6 +476,7 @@ def causal_conv1d_fn( has_initial_state: torch.Tensor | None = None, activation: str | None = "silu", pad_slot_id: int = PAD_SLOT_ID, + null_block_id: int = NULL_BLOCK_ID, block_idx_first_scheduled_token: torch.Tensor | None = None, block_idx_last_scheduled_token: torch.Tensor | None = None, initial_state_idx: torch.Tensor | None = None, @@ -730,12 +732,13 @@ def causal_conv1d_fn( block_size_to_align // BLOCK_M, # others pad_slot_id, + null_block_id, # META HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_APC_ENABLED=block_idx_last_scheduled_token is not None, - USE_PAD_SLOT=pad_slot_id is not None, + HAS_NULL_BLOCK=null_block_id is not None, NP2_STATELEN=np2_statelen, # launch_cooperative_grid=True BLOCK_M=BLOCK_M, @@ -778,7 +781,7 @@ def _causal_conv1d_update_kernel( stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, # others - pad_slot_id: tl.constexpr, + null_block_id: tl.constexpr, # Meta-parameters HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, @@ -787,7 +790,7 @@ def _causal_conv1d_update_kernel( IS_APC_ENABLED: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, - USE_PAD_SLOT: tl.constexpr, + HAS_NULL_BLOCK: tl.constexpr, BLOCK_N: tl.constexpr, ): # ruff: noqa: E501 @@ -811,8 +814,8 @@ def _causal_conv1d_update_kernel( conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init ).to(tl.int64) - if USE_PAD_SLOT: # noqa - if conv_states_input_coord == pad_slot_id: + if HAS_NULL_BLOCK: # noqa + if conv_states_input_coord == null_block_id: # not processing as this is not the actual sequence return @@ -1076,7 +1079,7 @@ def causal_conv1d_update( num_accepted_tokens: torch.Tensor | None = None, query_start_loc: torch.Tensor | None = None, max_query_len: int = -1, - pad_slot_id: int = PAD_SLOT_ID, + null_block_id: int = NULL_BLOCK_ID, block_idx_last_scheduled_token: torch.Tensor | None = None, initial_state_idx: torch.Tensor | None = None, validate_data=False, @@ -1111,16 +1114,16 @@ def causal_conv1d_update( max_query_len: int If query_start_loc is not None, this indicates the maximum query length in the batch. - pad_slot_id: int - if conv_state_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + null_block_id: int + Block ID used to identify padded entries in + conv_state_indices. Block 0 is the null block. + for example: conv_state_indices = [null_block_id, 1, 20, null_block_id] in this case, the kernel will not process entries at indices 0 and 3 out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` """ if validate_data: - assert pad_slot_id is not None + assert null_block_id is not None assert x.stride(1) == 1 if isinstance(activation, bool): activation = "silu" if activation is True else None @@ -1225,7 +1228,7 @@ def causal_conv1d_update( stride_o_dim, stride_o_token, # others - pad_slot_id, + null_block_id, # META HAS_BIAS=bias is not None, KERNEL_WIDTH=width, @@ -1234,7 +1237,7 @@ def causal_conv1d_update( IS_APC_ENABLED=block_idx_last_scheduled_token is not None, IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, - USE_PAD_SLOT=pad_slot_id is not None, + HAS_NULL_BLOCK=null_block_id is not None, BLOCK_N=256, ) if unsqueeze: diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 793471fda..c4a0ef385 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -10,7 +10,7 @@ from packaging import version from vllm import _custom_ops as ops from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp from vllm.triton_utils import HAS_TRITON, tl, triton -from vllm.v1.attention.backends.utils import PAD_SLOT_ID +from vllm.v1.attention.backends.utils import NULL_BLOCK_ID TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0")) @@ -75,7 +75,7 @@ def _selective_scan_update_kernel( out_ptr, state_batch_indices_ptr, dst_state_batch_indices_ptr, - pad_slot_id, + null_block_id, num_accepted_tokens_ptr, cu_seqlens_ptr, # Matrix dimensions @@ -203,7 +203,7 @@ def _selective_scan_update_kernel( mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) if HAS_STATE_BATCH_INDICES: - mask &= state_batch_idx != pad_slot_id + mask &= state_batch_idx != null_block_id state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32) if HAS_DT_BIAS: @@ -257,7 +257,7 @@ def _selective_scan_update_kernel( if IS_SPEC_DECODING: dst_idx_ptr = dst_state_batch_indices_ptr + i_t * stride_dst_state_indices_T token_dst_idx = tl.load(dst_idx_ptr).to(tl.int64) - if token_dst_idx != pad_slot_id: + if token_dst_idx != null_block_id: token_dst_ptrs = ( state_ptr_base + token_dst_idx * stride_state_batch @@ -329,7 +329,7 @@ def selective_state_update( dt_softplus=False, state_batch_indices=None, dst_state_batch_indices=None, - pad_slot_id=PAD_SLOT_ID, + null_block_id=NULL_BLOCK_ID, out=None, num_accepted_tokens=None, cu_seqlens=None, @@ -348,12 +348,12 @@ def selective_state_update( D: (dim,) or (nheads, dim) z: (batch, dim) or (batch, nheads, dim) dt_bias: (dim,) or (nheads, dim) - pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 + null_block_id: int + if state_batch_indices is passed, lets the kernel identify + padded entries that will not be processed, + for example: state_batch_indices = [null_block_id, 1, 20, + null_block_id] in this case, the kernel will not process + entries at indices 0 and 3 out: Preallocated ssm output tensor. Assume same shape as x. In-place updated. num_accepted_tokens: (batch,) @@ -488,7 +488,7 @@ def selective_state_update( out, state_batch_indices, dst_state_batch_indices, - pad_slot_id, + null_block_id, num_accepted_tokens, cu_seqlens, N, @@ -550,7 +550,7 @@ def selective_scan_fn( query_start_loc=None, cache_indices=None, has_initial_state=None, - pad_slot_id=PAD_SLOT_ID, + null_block_id=NULL_BLOCK_ID, block_size=1024, block_idx_first_scheduled_token=None, block_idx_last_scheduled_token=None, @@ -588,10 +588,10 @@ def selective_scan_fn( indicate if the ssm_state at the corresponding index should be used as initial state. Not providing argument assumes there's no initial state - pad_slot_id: int + null_block_id: int if cache_indices is passed, lets the kernel identify padding entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + for example: cache_indices = [null_block_id, 1 ,20 ,null_block_id] in this case, the kernel will not process entries at indices 0 and 3 block_size: int The block size to align the cached states to @@ -643,7 +643,7 @@ def selective_scan_fn( cache_indices, has_initial_state, ssm_states, - pad_slot_id, + null_block_id, block_size, block_idx_first_scheduled_token, block_idx_last_scheduled_token, diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index f65d9a4b3..41c69deb4 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -14,7 +14,7 @@ from vllm.v1.attention.backend import ( CommonAttentionMetadata, ) from vllm.v1.attention.backends.utils import ( - PAD_SLOT_ID, + NULL_BLOCK_ID, compute_causal_conv1d_metadata, mamba_get_block_table_tensor, split_decodes_and_prefills, @@ -341,7 +341,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] spec_state_indices_tensor, non_blocking=True ) spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size] - spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID) + spec_state_indices_tensor[num_spec_decodes:].fill_(NULL_BLOCK_ID) self.spec_sequence_masks[:num_spec_decodes].copy_( spec_sequence_masks[:num_spec_decodes], non_blocking=True @@ -387,7 +387,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[ :batch_size ] - non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID) + non_spec_state_indices_tensor[num_decodes:].fill_(NULL_BLOCK_ID) self.non_spec_query_start_loc[: num_decodes + 1].copy_( non_spec_query_start_loc, non_blocking=True diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 59f2e7ca5..eec530322 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -15,7 +15,7 @@ from vllm.v1.attention.backend import ( CommonAttentionMetadata, ) from vllm.v1.attention.backends.utils import ( - PAD_SLOT_ID, + NULL_BLOCK_ID, compute_causal_conv1d_metadata, mamba_get_block_table_tensor, split_decodes_and_prefills, @@ -504,7 +504,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): state_indices_tensor_d, non_blocking=True ) state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs] - state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID + state_indices_tensor_d[metadata.num_decodes :] = NULL_BLOCK_ID if self.use_spec_decode and num_accepted_tokens is not None: assert query_start_loc_d is not None diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 2fa9fe851..5deac4d2f 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -366,12 +366,6 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): seq_lens = common_attn_metadata.seq_lens[:num_decodes] block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...] - # Padded CUDA graph requests have block_table entries of -1. - # Clamp to 0 to prevent OOB access in the DeepGEMM kernel. - # This is safe because padded requests have seq_lens=0, so the - # kernel produces no meaningful output for those rows. - block_table.clamp_(min=0) - max_decode_len = int(decode_lens_cpu.max().item()) next_n = 1 + self.num_speculative_tokens use_native = not self.use_flattening and max_decode_len == next_n diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 59f6ca9bf..0a36e6fd4 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -42,6 +42,7 @@ KVCacheLayoutType = Literal["NHD", "HND"] _KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None PAD_SLOT_ID = -1 +NULL_BLOCK_ID = 0 def is_valid_kv_cache_layout(value: str) -> bool: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0b67490ca..a95082787 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -122,6 +122,7 @@ from vllm.v1.attention.backend import ( from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( + NULL_BLOCK_ID, create_fast_prefill_custom_backend, get_dcp_local_seq_lens, reorder_batch_to_split_decodes_and_prefills, @@ -2135,9 +2136,9 @@ class GPUModelRunner( blk_table = self.input_batch.block_table[kv_cache_gid] blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded) - # Fill unused with -1. Needed for reshape_and_cache in full cuda - # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID - blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) + # Fill unused block table entries with NULL_BLOCK_ID (null block) + # for CUDAGraph padding. Block 0 is reserved for padding. + blk_table_tensor[num_reqs:num_reqs_padded].fill_(NULL_BLOCK_ID) return blk_table_tensor assert slot_mappings is not None