[Perf] [Bugfix] Fix Triton autotuning in inference for Qwen3.5 (#37338)
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
54ab804e87
commit
a16133a0f1
@@ -30,7 +30,7 @@ def chunk_gated_delta_rule_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
@@ -84,7 +84,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
@@ -117,7 +117,7 @@ def chunk_gated_delta_rule(
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
r"""
|
||||
@@ -141,7 +141,7 @@ def chunk_gated_delta_rule(
|
||||
Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
Returns:
|
||||
@@ -171,7 +171,7 @@ def chunk_gated_delta_rule(
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
|
||||
>>> o_var, ht_var = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
|
||||
@@ -288,7 +288,7 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||
|
||||
@@ -145,7 +145,7 @@ def chunk_fwd_o(
|
||||
h: torch.Tensor,
|
||||
g: torch.Tensor | None = None, # cumsum of log decay
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
|
||||
@@ -102,7 +102,7 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
k: torch.Tensor,
|
||||
g: torch.Tensor | None = None,
|
||||
beta: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
@@ -116,7 +116,7 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
g (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
|
||||
@@ -184,7 +184,7 @@ def fused_recurrent_gated_delta_rule_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -489,7 +489,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -521,7 +521,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -549,7 +549,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
inplace_final_state: bool:
|
||||
Whether to store the final state in-place to save memory.
|
||||
Default: `True`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
ssm_state_indices (Optional[torch.Tensor]):
|
||||
@@ -583,7 +583,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
|
||||
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
|
||||
@@ -191,7 +191,7 @@ def fused_sigmoid_gating_delta_rule_update(
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
|
||||
@@ -15,14 +15,12 @@ from .utils import tensor_cache
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
def prepare_lens(cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_indices(
|
||||
cu_seqlens: torch.LongTensor, chunk_size: int
|
||||
) -> torch.LongTensor:
|
||||
def prepare_chunk_indices(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
|
||||
indices = torch.cat(
|
||||
[
|
||||
torch.arange(n)
|
||||
@@ -33,9 +31,7 @@ def prepare_chunk_indices(
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_offsets(
|
||||
cu_seqlens: torch.LongTensor, chunk_size: int
|
||||
) -> torch.LongTensor:
|
||||
def prepare_chunk_offsets(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
|
||||
return torch.cat(
|
||||
[cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
|
||||
).cumsum(-1)
|
||||
|
||||
@@ -38,7 +38,7 @@ def fused_recurrent_kda_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -116,7 +116,7 @@ def fused_recurrent_kda(
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.LongTensor | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -720,7 +720,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
|
||||
gk: torch.Tensor | None = None,
|
||||
beta: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -734,7 +734,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
gk (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
@@ -964,7 +964,7 @@ def recompute_w_u_fwd(
|
||||
A: torch.Tensor,
|
||||
q: torch.Tensor | None = None,
|
||||
gk: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
BT = A.shape[-1]
|
||||
@@ -1132,7 +1132,7 @@ def chunk_gla_fwd_o_gk(
|
||||
h: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
scale: float,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
B, T, H, K, V = *q.shape, v.shape[-1]
|
||||
@@ -1176,7 +1176,7 @@ def chunk_kda_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
):
|
||||
chunk_size = 64
|
||||
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
|
||||
@@ -1236,7 +1236,7 @@ def chunk_kda(
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if scale is None:
|
||||
|
||||
@@ -122,7 +122,7 @@ def recompute_w_u_fwd(
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: torch.LongTensor | None,
|
||||
cu_seqlens: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
|
||||
Reference in New Issue
Block a user