[Perf] [Bugfix] Fix Triton autotuning in inference for Qwen3.5 (#37338)
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
54ab804e87
commit
a16133a0f1
@@ -30,7 +30,7 @@ def chunk_gated_delta_rule_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
@@ -84,7 +84,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
@@ -117,7 +117,7 @@ def chunk_gated_delta_rule(
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
r"""
|
||||
@@ -141,7 +141,7 @@ def chunk_gated_delta_rule(
|
||||
Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
Returns:
|
||||
@@ -171,7 +171,7 @@ def chunk_gated_delta_rule(
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
|
||||
>>> o_var, ht_var = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
|
||||
@@ -288,7 +288,7 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||
|
||||
@@ -145,7 +145,7 @@ def chunk_fwd_o(
|
||||
h: torch.Tensor,
|
||||
g: torch.Tensor | None = None, # cumsum of log decay
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
|
||||
@@ -102,7 +102,7 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
k: torch.Tensor,
|
||||
g: torch.Tensor | None = None,
|
||||
beta: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
@@ -116,7 +116,7 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
g (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
|
||||
@@ -184,7 +184,7 @@ def fused_recurrent_gated_delta_rule_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -489,7 +489,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -521,7 +521,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -549,7 +549,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
inplace_final_state: bool:
|
||||
Whether to store the final state in-place to save memory.
|
||||
Default: `True`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
ssm_state_indices (Optional[torch.Tensor]):
|
||||
@@ -583,7 +583,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
|
||||
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
|
||||
@@ -191,7 +191,7 @@ def fused_sigmoid_gating_delta_rule_update(
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
|
||||
@@ -15,14 +15,12 @@ from .utils import tensor_cache
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
def prepare_lens(cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_indices(
|
||||
cu_seqlens: torch.LongTensor, chunk_size: int
|
||||
) -> torch.LongTensor:
|
||||
def prepare_chunk_indices(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
|
||||
indices = torch.cat(
|
||||
[
|
||||
torch.arange(n)
|
||||
@@ -33,9 +31,7 @@ def prepare_chunk_indices(
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_offsets(
|
||||
cu_seqlens: torch.LongTensor, chunk_size: int
|
||||
) -> torch.LongTensor:
|
||||
def prepare_chunk_offsets(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
|
||||
return torch.cat(
|
||||
[cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
|
||||
).cumsum(-1)
|
||||
|
||||
@@ -38,7 +38,7 @@ def fused_recurrent_kda_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -116,7 +116,7 @@ def fused_recurrent_kda(
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.LongTensor | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -720,7 +720,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
|
||||
gk: torch.Tensor | None = None,
|
||||
beta: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -734,7 +734,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
gk (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
@@ -964,7 +964,7 @@ def recompute_w_u_fwd(
|
||||
A: torch.Tensor,
|
||||
q: torch.Tensor | None = None,
|
||||
gk: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
BT = A.shape[-1]
|
||||
@@ -1132,7 +1132,7 @@ def chunk_gla_fwd_o_gk(
|
||||
h: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
scale: float,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
B, T, H, K, V = *q.shape, v.shape[-1]
|
||||
@@ -1176,7 +1176,7 @@ def chunk_kda_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
):
|
||||
chunk_size = 64
|
||||
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
|
||||
@@ -1236,7 +1236,7 @@ def chunk_kda(
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if scale is None:
|
||||
|
||||
@@ -122,7 +122,7 @@ def recompute_w_u_fwd(
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: torch.LongTensor | None,
|
||||
cu_seqlens: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
|
||||
@@ -119,7 +119,7 @@ def fi_chunk_gated_delta_rule(
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
):
|
||||
from flashinfer.gdn_prefill import (
|
||||
@@ -214,7 +214,7 @@ class ChunkGatedDeltaRule(CustomOp):
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
):
|
||||
return fi_chunk_gated_delta_rule(
|
||||
@@ -238,7 +238,7 @@ class ChunkGatedDeltaRule(CustomOp):
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
):
|
||||
return fla_chunk_gated_delta_rule(
|
||||
@@ -755,8 +755,13 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
v = torch.randn(
|
||||
1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype
|
||||
)
|
||||
g = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
|
||||
beta = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
|
||||
# NOTE: g and beta must have the same dtypes as during
|
||||
# inference, so we construct them with the same function
|
||||
# (fused_gdn_gating). dummy_a and dummy_b are throwaway
|
||||
# inputs required by that function.
|
||||
dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype)
|
||||
dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype)
|
||||
g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias)
|
||||
state = torch.zeros(
|
||||
1,
|
||||
num_v_heads,
|
||||
@@ -765,7 +770,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
device=device,
|
||||
dtype=state_dtype,
|
||||
)
|
||||
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long)
|
||||
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.int32)
|
||||
|
||||
try:
|
||||
self.chunk_gated_delta_rule(
|
||||
@@ -775,7 +780,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state=state,
|
||||
output_final_state=False,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
@@ -795,7 +800,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
self.prefix,
|
||||
)
|
||||
finally:
|
||||
del q, k, v, g, beta, state, cu_seqlens
|
||||
del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens
|
||||
|
||||
torch.accelerator.empty_cache()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user