[Bugfix] Warm up Triton autotuner for GDN layers during V1 profiling (#36599)

Signed-off-by: AuYang <459461160@qq.com>
This commit is contained in:
Xu Jinyang
2026-03-12 15:51:09 +08:00
committed by GitHub
parent 8cb24d3aed
commit 3e64fe4a18

View File

@@ -645,6 +645,101 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out) output[:num_tokens], _ = self.out_proj(core_attn_out)
def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
"""Warm up GDN prefill kernels during V1 profiling.
During V1 profile runs, ``_forward_core`` returns early because
``attn_metadata`` is ``None``, so the autotuned kernels used by
``chunk_gated_delta_rule`` (e.g. ``solve_tril``,
``chunk_scaled_dot_kkt``) are never invoked. After profiling,
vLLM allocates KV cache using most of the remaining GPU memory.
When the first real inference triggers the autotuner it OOMs
because there is not enough memory left for benchmarking.
This method runs minimal forward passes through
``chunk_gated_delta_rule`` with small dummy tensors to force
autotuning while GPU memory is still plentiful. The autotuner
results are cached globally, so only the first layer incurs
actual benchmarking cost.
Most kernels use a fixed ``BT = chunk_size`` (64), but
``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence
length: ``min(64, max(16, next_power_of_2(T)))``. Since ``BT``
is part of its autotune key, we run warmup passes with T = 16,
32, and 64 to cover all possible ``BT`` values.
The decode path uses ``fused_sigmoid_gating_delta_rule_update``
which has fixed kernel parameters (no autotuning), so only the
prefill (chunked) path needs warming up.
"""
if hasattr(self, "_prefill_kernels_warmed_up"):
return
self._prefill_kernels_warmed_up = True
device = mixed_qkv.device
dtype = mixed_qkv.dtype
num_k_heads = self.num_k_heads // self.tp_size
num_v_heads = self.num_v_heads // self.tp_size
_, state_dtype = self.get_state_dtype()
# Run warmup for each possible BT value of chunk_fwd_kernel_o:
# T=16 → BT=16, T=32 → BT=32, T=64 → BT=64.
# Other kernels always use BT=chunk_size(64), so their autotune
# cache is populated on the first pass and reused thereafter.
for T in (16, 32, 64):
q = torch.randn(
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
)
k = torch.randn(
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
)
v = torch.randn(
1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype
)
g = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
beta = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
state = torch.zeros(
1,
num_v_heads,
self.head_v_dim,
self.head_k_dim,
device=device,
dtype=state_dtype,
)
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long)
try:
self.chunk_gated_delta_rule(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=state,
output_final_state=False,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
except Exception:
logger.warning(
"GDN prefill kernel warmup (T=%d) failed for "
"layer %s. First inference may OOM due to "
"autotuner.",
T,
self.prefix,
exc_info=True,
)
else:
logger.debug(
"GDN prefill kernel warmup (T=%d) completed for layer %s",
T,
self.prefix,
)
finally:
del q, k, v, g, beta, state, cu_seqlens
torch.accelerator.empty_cache()
def _forward_core( def _forward_core(
self, self,
mixed_qkv: torch.Tensor, mixed_qkv: torch.Tensor,
@@ -659,7 +754,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
attn_metadata: AttentionMetadata = forward_context.attn_metadata attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is None: if attn_metadata is None:
# V1 profile run # V1 profile run — warm up prefill kernels so that
# autotuning completes before KV cache allocation.
self._warmup_prefill_kernels(mixed_qkv)
return return
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)