[Perf] [Bugfix] Fix Triton autotuning in inference for Qwen3.5 (#37338)

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
This commit is contained in:
Artem Perevedentsev
2026-03-23 09:37:58 +02:00
committed by GitHub
parent 54ab804e87
commit a16133a0f1
10 changed files with 40 additions and 39 deletions

View File

@@ -30,7 +30,7 @@ def chunk_gated_delta_rule_fwd(
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
# obtain WY representation. u is actually the new v.
@@ -84,7 +84,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
@@ -117,7 +117,7 @@ def chunk_gated_delta_rule(
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
@@ -141,7 +141,7 @@ def chunk_gated_delta_rule(
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
@@ -171,7 +171,7 @@ def chunk_gated_delta_rule(
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
>>> o_var, ht_var = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,

View File

@@ -288,7 +288,7 @@ def chunk_gated_delta_rule_fwd_h(
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.

View File

@@ -145,7 +145,7 @@ def chunk_fwd_o(
h: torch.Tensor,
g: torch.Tensor | None = None, # cumsum of log decay
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]

View File

@@ -102,7 +102,7 @@ def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
g: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
@@ -116,7 +116,7 @@ def chunk_scaled_dot_kkt_fwd(
The beta tensor of shape `[B, T, H]`.
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):

View File

@@ -184,7 +184,7 @@ def fused_recurrent_gated_delta_rule_fwd(
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -489,7 +489,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -521,7 +521,7 @@ def fused_recurrent_gated_delta_rule(
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -549,7 +549,7 @@ def fused_recurrent_gated_delta_rule(
inplace_final_state: bool:
Whether to store the final state in-place to save memory.
Default: `True`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
ssm_state_indices (Optional[torch.Tensor]):
@@ -583,7 +583,7 @@ def fused_recurrent_gated_delta_rule(
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,

View File

@@ -191,7 +191,7 @@ def fused_sigmoid_gating_delta_rule_update(
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,

View File

@@ -15,14 +15,12 @@ from .utils import tensor_cache
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
def prepare_lens(cu_seqlens: torch.Tensor) -> torch.Tensor:
return cu_seqlens[1:] - cu_seqlens[:-1]
@tensor_cache
def prepare_chunk_indices(
cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
def prepare_chunk_indices(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
indices = torch.cat(
[
torch.arange(n)
@@ -33,9 +31,7 @@ def prepare_chunk_indices(
@tensor_cache
def prepare_chunk_offsets(
cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
def prepare_chunk_offsets(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
return torch.cat(
[cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
).cumsum(-1)

View File

@@ -38,7 +38,7 @@ def fused_recurrent_kda_fwd(
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -116,7 +116,7 @@ def fused_recurrent_kda(
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
use_qk_l2norm_in_kernel: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.LongTensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -720,7 +720,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
gk: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -734,7 +734,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
The beta tensor of shape `[B, T, H]`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
@@ -964,7 +964,7 @@ def recompute_w_u_fwd(
A: torch.Tensor,
q: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
@@ -1132,7 +1132,7 @@ def chunk_gla_fwd_o_gk(
h: torch.Tensor,
o: torch.Tensor,
scale: float,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
):
B, T, H, K, V = *q.shape, v.shape[-1]
@@ -1176,7 +1176,7 @@ def chunk_kda_fwd(
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
):
chunk_size = 64
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
@@ -1236,7 +1236,7 @@ def chunk_kda(
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
**kwargs,
):
if scale is None:

View File

@@ -122,7 +122,7 @@ def recompute_w_u_fwd(
beta: torch.Tensor,
g_cumsum: torch.Tensor,
A: torch.Tensor,
cu_seqlens: torch.LongTensor | None,
cu_seqlens: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]

View File

@@ -119,7 +119,7 @@ def fi_chunk_gated_delta_rule(
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True,
):
from flashinfer.gdn_prefill import (
@@ -214,7 +214,7 @@ class ChunkGatedDeltaRule(CustomOp):
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True,
):
return fi_chunk_gated_delta_rule(
@@ -238,7 +238,7 @@ class ChunkGatedDeltaRule(CustomOp):
beta: torch.Tensor,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True,
):
return fla_chunk_gated_delta_rule(
@@ -755,8 +755,13 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
v = torch.randn(
1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype
)
g = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
beta = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
# NOTE: g and beta must have the same dtypes as during
# inference, so we construct them with the same function
# (fused_gdn_gating). dummy_a and dummy_b are throwaway
# inputs required by that function.
dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype)
dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype)
g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias)
state = torch.zeros(
1,
num_v_heads,
@@ -765,7 +770,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
device=device,
dtype=state_dtype,
)
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long)
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32)
try:
self.chunk_gated_delta_rule(
@@ -775,7 +780,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
g=g,
beta=beta,
initial_state=state,
output_final_state=False,
output_final_state=True,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
@@ -795,7 +800,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self.prefix,
)
finally:
del q, k, v, g, beta, state, cu_seqlens
del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens
torch.accelerator.empty_cache()