diff --git a/vllm/model_executor/layers/fla/ops/chunk_o.py b/vllm/model_executor/layers/fla/ops/chunk_o.py index 130781276..aab1ee006 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_o.py +++ b/vllm/model_executor/layers/fla/ops/chunk_o.py @@ -16,7 +16,7 @@ from vllm.triton_utils import tl, triton from .index import prepare_chunk_indices from .op import exp -from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper +from .utils import FLA_CHUNK_SIZE, check_shared_mem, is_nvidia_hopper BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] @@ -146,11 +146,11 @@ def chunk_fwd_o( g: torch.Tensor | None = None, # cumsum of log decay scale: float | None = None, cu_seqlens: torch.Tensor | None = None, - chunk_size: int = 64, + chunk_size: int = FLA_CHUNK_SIZE, ) -> torch.Tensor: B, T, Hg, K, V = *q.shape, v.shape[-1] H = v.shape[-2] - BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T))) + BT = chunk_size chunk_indices = ( prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None ) diff --git a/vllm/model_executor/layers/fla/ops/utils.py b/vllm/model_executor/layers/fla/ops/utils.py index f0ec1f7a6..83b75e685 100644 --- a/vllm/model_executor/layers/fla/ops/utils.py +++ b/vllm/model_executor/layers/fla/ops/utils.py @@ -24,10 +24,12 @@ logger = logging.getLogger(__name__) COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" -FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1" SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) +# Default chunk size used across FLA triton kernels (kda, chunk, chunk_o, etc.) +FLA_CHUNK_SIZE = 64 + def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: """ diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index 55cd17fe5..2b952e10e 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -28,6 +28,7 @@ from vllm.model_executor.layers.fla.ops import ( fused_sigmoid_gating_delta_rule_update, ) from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd +from vllm.model_executor.layers.fla.ops.utils import FLA_CHUNK_SIZE from vllm.model_executor.layers.layernorm import RMSNormGated from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -581,11 +582,9 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): results are cached globally, so only the first layer incurs actual benchmarking cost. - Most kernels use a fixed ``BT = chunk_size`` (64), but - ``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence - length: ``min(64, max(16, next_power_of_2(T)))``. Since ``BT`` - is part of its autotune key, we run warmup passes with T = 16, - 32, and 64 to cover all possible ``BT`` values. + All kernels including ``chunk_fwd_kernel_o`` now use a fixed + ``BT = chunk_size`` (64). A single warmup pass with T = 64 + is sufficient to populate the autotuner cache. The decode path uses ``fused_sigmoid_gating_delta_rule_update`` which has fixed kernel parameters (no autotuning), so only the @@ -601,66 +600,58 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase): num_v_heads = self.num_v_heads // self.tp_size _, state_dtype = self.get_state_dtype() - # Run warmup for each possible BT value of chunk_fwd_kernel_o: - # T=16 → BT=16, T=32 → BT=32, T=64 → BT=64. - # Other kernels always use BT=chunk_size(64), so their autotune - # cache is populated on the first pass and reused thereafter. - for T in (16, 32, 64): - q = torch.randn( - 1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype - ) - k = torch.randn( - 1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype - ) - v = torch.randn( - 1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype - ) - # NOTE: g and beta must have the same dtypes as during - # inference, so we construct them with the same function - # (fused_gdn_gating). dummy_a and dummy_b are throwaway - # inputs required by that function. - dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype) - dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype) - g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias) - state = torch.zeros( - 1, - num_v_heads, - self.head_v_dim, - self.head_k_dim, - device=device, - dtype=state_dtype, - ) - cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32) + # All kernels use BT = chunk_size (FLA_CHUNK_SIZE4), so a single pass with + # T = chunk_size is sufficient to populate every autotuner cache. + T = FLA_CHUNK_SIZE + q = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype) + k = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype) + v = torch.randn(1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype) + # NOTE: g and beta must have the same dtypes as during + # inference, so we construct them with the same function + # (fused_gdn_gating). dummy_a and dummy_b are throwaway + # inputs required by that function. + dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype) + dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype) + g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias) + state = torch.zeros( + 1, + num_v_heads, + self.head_v_dim, + self.head_k_dim, + device=device, + dtype=state_dtype, + ) + cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32) - try: - self.chunk_gated_delta_rule( - q=q, - k=k, - v=v, - g=g, - beta=beta, - initial_state=state, - output_final_state=True, - cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=True, - ) - except Exception: - logger.warning( - "GDN prefill kernel warmup (T=%d) failed for " - "layer %s. First inference may OOM due to " - "autotuner.", - T, - self.prefix, - exc_info=True, - ) - else: - logger.debug( - "GDN prefill kernel warmup (T=%d) completed for layer %s", - T, - self.prefix, - ) - finally: - del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens + try: + self.chunk_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=state, + output_final_state=True, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True, + ) + except Exception: + logger.warning( + "GDN prefill kernel warmup (T=%d) failed for " + "layer %s. First inference may OOM due to " + "autotuner.", + T, + self.prefix, + exc_info=True, + ) + else: + logger.debug( + "GDN prefill kernel warmup (T=%d) completed for layer %s", + T, + self.prefix, + ) + finally: + del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens torch.accelerator.empty_cache()