[Bugfix] Warm up Triton autotuner for GDN layers during V1 profiling (#36599)
Signed-off-by: AuYang <459461160@qq.com>
This commit is contained in:
@@ -645,6 +645,101 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
|
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
|
||||||
output[:num_tokens], _ = self.out_proj(core_attn_out)
|
output[:num_tokens], _ = self.out_proj(core_attn_out)
|
||||||
|
|
||||||
|
def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
|
||||||
|
"""Warm up GDN prefill kernels during V1 profiling.
|
||||||
|
|
||||||
|
During V1 profile runs, ``_forward_core`` returns early because
|
||||||
|
``attn_metadata`` is ``None``, so the autotuned kernels used by
|
||||||
|
``chunk_gated_delta_rule`` (e.g. ``solve_tril``,
|
||||||
|
``chunk_scaled_dot_kkt``) are never invoked. After profiling,
|
||||||
|
vLLM allocates KV cache using most of the remaining GPU memory.
|
||||||
|
When the first real inference triggers the autotuner it OOMs
|
||||||
|
because there is not enough memory left for benchmarking.
|
||||||
|
|
||||||
|
This method runs minimal forward passes through
|
||||||
|
``chunk_gated_delta_rule`` with small dummy tensors to force
|
||||||
|
autotuning while GPU memory is still plentiful. The autotuner
|
||||||
|
results are cached globally, so only the first layer incurs
|
||||||
|
actual benchmarking cost.
|
||||||
|
|
||||||
|
Most kernels use a fixed ``BT = chunk_size`` (64), but
|
||||||
|
``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence
|
||||||
|
length: ``min(64, max(16, next_power_of_2(T)))``. Since ``BT``
|
||||||
|
is part of its autotune key, we run warmup passes with T = 16,
|
||||||
|
32, and 64 to cover all possible ``BT`` values.
|
||||||
|
|
||||||
|
The decode path uses ``fused_sigmoid_gating_delta_rule_update``
|
||||||
|
which has fixed kernel parameters (no autotuning), so only the
|
||||||
|
prefill (chunked) path needs warming up.
|
||||||
|
"""
|
||||||
|
if hasattr(self, "_prefill_kernels_warmed_up"):
|
||||||
|
return
|
||||||
|
self._prefill_kernels_warmed_up = True
|
||||||
|
|
||||||
|
device = mixed_qkv.device
|
||||||
|
dtype = mixed_qkv.dtype
|
||||||
|
num_k_heads = self.num_k_heads // self.tp_size
|
||||||
|
num_v_heads = self.num_v_heads // self.tp_size
|
||||||
|
_, state_dtype = self.get_state_dtype()
|
||||||
|
|
||||||
|
# Run warmup for each possible BT value of chunk_fwd_kernel_o:
|
||||||
|
# T=16 → BT=16, T=32 → BT=32, T=64 → BT=64.
|
||||||
|
# Other kernels always use BT=chunk_size(64), so their autotune
|
||||||
|
# cache is populated on the first pass and reused thereafter.
|
||||||
|
for T in (16, 32, 64):
|
||||||
|
q = torch.randn(
|
||||||
|
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
k = torch.randn(
|
||||||
|
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
v = torch.randn(
|
||||||
|
1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
g = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
|
||||||
|
beta = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
|
||||||
|
state = torch.zeros(
|
||||||
|
1,
|
||||||
|
num_v_heads,
|
||||||
|
self.head_v_dim,
|
||||||
|
self.head_k_dim,
|
||||||
|
device=device,
|
||||||
|
dtype=state_dtype,
|
||||||
|
)
|
||||||
|
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.chunk_gated_delta_rule(
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
|
v=v,
|
||||||
|
g=g,
|
||||||
|
beta=beta,
|
||||||
|
initial_state=state,
|
||||||
|
output_final_state=False,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
use_qk_l2norm_in_kernel=True,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"GDN prefill kernel warmup (T=%d) failed for "
|
||||||
|
"layer %s. First inference may OOM due to "
|
||||||
|
"autotuner.",
|
||||||
|
T,
|
||||||
|
self.prefix,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"GDN prefill kernel warmup (T=%d) completed for layer %s",
|
||||||
|
T,
|
||||||
|
self.prefix,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
del q, k, v, g, beta, state, cu_seqlens
|
||||||
|
|
||||||
|
torch.accelerator.empty_cache()
|
||||||
|
|
||||||
def _forward_core(
|
def _forward_core(
|
||||||
self,
|
self,
|
||||||
mixed_qkv: torch.Tensor,
|
mixed_qkv: torch.Tensor,
|
||||||
@@ -659,7 +754,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# V1 profile run
|
# V1 profile run — warm up prefill kernels so that
|
||||||
|
# autotuning completes before KV cache allocation.
|
||||||
|
self._warmup_prefill_kernels(mixed_qkv)
|
||||||
return
|
return
|
||||||
|
|
||||||
assert isinstance(attn_metadata, dict)
|
assert isinstance(attn_metadata, dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user