[GDN] Eliminate GPU->CPU sync in prepare_chunk_indices during prefill (#38361)
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bf8b022e60
commit
cb10b7e80b
@@ -16,7 +16,7 @@ from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
|||||||
from .cumsum import chunk_local_cumsum
|
from .cumsum import chunk_local_cumsum
|
||||||
from .l2norm import l2norm_fwd
|
from .l2norm import l2norm_fwd
|
||||||
from .solve_tril import solve_tril
|
from .solve_tril import solve_tril
|
||||||
from .utils import SUPPRESS_LEVEL, input_guard
|
from .utils import FLA_CHUNK_SIZE, SUPPRESS_LEVEL, input_guard
|
||||||
from .wy_fast import recompute_w_u_fwd
|
from .wy_fast import recompute_w_u_fwd
|
||||||
|
|
||||||
|
|
||||||
@@ -30,13 +30,24 @@ def chunk_gated_delta_rule_fwd(
|
|||||||
initial_state: torch.Tensor,
|
initial_state: torch.Tensor,
|
||||||
output_final_state: bool,
|
output_final_state: bool,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
chunk_indices: torch.Tensor | None = None,
|
||||||
|
chunk_offsets: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
g = chunk_local_cumsum(
|
||||||
|
g, chunk_size=FLA_CHUNK_SIZE, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices
|
||||||
|
)
|
||||||
# obtain WY representation. u is actually the new v.
|
# obtain WY representation. u is actually the new v.
|
||||||
A = chunk_scaled_dot_kkt_fwd(
|
A = chunk_scaled_dot_kkt_fwd(
|
||||||
k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
|
k=k,
|
||||||
|
beta=beta,
|
||||||
|
g=g,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
output_dtype=torch.float32,
|
||||||
|
)
|
||||||
|
A = solve_tril(
|
||||||
|
A=A, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, output_dtype=k.dtype
|
||||||
)
|
)
|
||||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
|
||||||
w, u = recompute_w_u_fwd(
|
w, u = recompute_w_u_fwd(
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
@@ -44,6 +55,7 @@ def chunk_gated_delta_rule_fwd(
|
|||||||
A=A,
|
A=A,
|
||||||
g_cumsum=g,
|
g_cumsum=g,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
)
|
)
|
||||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||||
k=k,
|
k=k,
|
||||||
@@ -53,6 +65,8 @@ def chunk_gated_delta_rule_fwd(
|
|||||||
initial_state=initial_state,
|
initial_state=initial_state,
|
||||||
output_final_state=output_final_state,
|
output_final_state=output_final_state,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
chunk_offsets=chunk_offsets,
|
||||||
)
|
)
|
||||||
o = chunk_fwd_o(
|
o = chunk_fwd_o(
|
||||||
q=q,
|
q=q,
|
||||||
@@ -62,6 +76,7 @@ def chunk_gated_delta_rule_fwd(
|
|||||||
g=g,
|
g=g,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
)
|
)
|
||||||
if SUPPRESS_LEVEL < 3:
|
if SUPPRESS_LEVEL < 3:
|
||||||
return g, o, A, final_state, None, None, None
|
return g, o, A, final_state, None, None, None
|
||||||
@@ -84,6 +99,8 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
|||||||
initial_state: torch.Tensor,
|
initial_state: torch.Tensor,
|
||||||
output_final_state: bool,
|
output_final_state: bool,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
chunk_indices: torch.Tensor | None = None,
|
||||||
|
chunk_offsets: torch.Tensor | None = None,
|
||||||
use_qk_l2norm_in_kernel: bool = False,
|
use_qk_l2norm_in_kernel: bool = False,
|
||||||
):
|
):
|
||||||
if use_qk_l2norm_in_kernel:
|
if use_qk_l2norm_in_kernel:
|
||||||
@@ -100,6 +117,8 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
|||||||
initial_state=initial_state,
|
initial_state=initial_state,
|
||||||
output_final_state=output_final_state,
|
output_final_state=output_final_state,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
chunk_offsets=chunk_offsets,
|
||||||
)
|
)
|
||||||
ctx.scale = scale
|
ctx.scale = scale
|
||||||
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
|
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
|
||||||
@@ -117,6 +136,8 @@ def chunk_gated_delta_rule(
|
|||||||
initial_state: torch.Tensor = None,
|
initial_state: torch.Tensor = None,
|
||||||
output_final_state: bool = False,
|
output_final_state: bool = False,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
chunk_indices: torch.Tensor | None = None,
|
||||||
|
chunk_offsets: torch.Tensor | None = None,
|
||||||
use_qk_l2norm_in_kernel: bool = False,
|
use_qk_l2norm_in_kernel: bool = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -206,6 +227,8 @@ def chunk_gated_delta_rule(
|
|||||||
initial_state,
|
initial_state,
|
||||||
output_final_state,
|
output_final_state,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
chunk_offsets,
|
||||||
use_qk_l2norm_in_kernel,
|
use_qk_l2norm_in_kernel,
|
||||||
)
|
)
|
||||||
return o, final_state
|
return o, final_state
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from vllm.triton_utils import tl, triton
|
|||||||
|
|
||||||
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
||||||
from .op import exp
|
from .op import exp
|
||||||
from .utils import use_cuda_graph
|
from .utils import FLA_CHUNK_SIZE, use_cuda_graph
|
||||||
|
|
||||||
NUM_WARPS = [2, 4, 8, 16]
|
NUM_WARPS = [2, 4, 8, 16]
|
||||||
|
|
||||||
@@ -286,9 +286,11 @@ def chunk_gated_delta_rule_fwd_h(
|
|||||||
gk: torch.Tensor | None = None,
|
gk: torch.Tensor | None = None,
|
||||||
initial_state: torch.Tensor | None = None,
|
initial_state: torch.Tensor | None = None,
|
||||||
output_final_state: bool = False,
|
output_final_state: bool = False,
|
||||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
chunk_size: int = FLA_CHUNK_SIZE,
|
||||||
save_new_value: bool = True,
|
save_new_value: bool = True,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
chunk_indices: torch.Tensor | None = None,
|
||||||
|
chunk_offsets: torch.Tensor | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||||
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||||
@@ -296,20 +298,15 @@ def chunk_gated_delta_rule_fwd_h(
|
|||||||
H = u.shape[-2]
|
H = u.shape[-2]
|
||||||
BT = chunk_size
|
BT = chunk_size
|
||||||
|
|
||||||
chunk_indices = (
|
if chunk_indices is None and cu_seqlens is not None:
|
||||||
prepare_chunk_indices(cu_seqlens, chunk_size)
|
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||||
if cu_seqlens is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||||
if cu_seqlens is None:
|
if cu_seqlens is None:
|
||||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||||
else:
|
else:
|
||||||
N, NT, chunk_offsets = (
|
N, NT = len(cu_seqlens) - 1, len(chunk_indices)
|
||||||
len(cu_seqlens) - 1,
|
if chunk_offsets is None:
|
||||||
len(chunk_indices),
|
chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT)
|
||||||
prepare_chunk_offsets(cu_seqlens, BT),
|
|
||||||
)
|
|
||||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||||
|
|
||||||
h = k.new_empty(B, NT, H, V, K)
|
h = k.new_empty(B, NT, H, V, K)
|
||||||
|
|||||||
@@ -146,14 +146,14 @@ def chunk_fwd_o(
|
|||||||
g: torch.Tensor | None = None, # cumsum of log decay
|
g: torch.Tensor | None = None, # cumsum of log decay
|
||||||
scale: float | None = None,
|
scale: float | None = None,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
chunk_indices: torch.Tensor | None = None,
|
||||||
chunk_size: int = FLA_CHUNK_SIZE,
|
chunk_size: int = FLA_CHUNK_SIZE,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||||
H = v.shape[-2]
|
H = v.shape[-2]
|
||||||
BT = chunk_size
|
BT = chunk_size
|
||||||
chunk_indices = (
|
if chunk_indices is None and cu_seqlens is not None:
|
||||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
|
||||||
)
|
|
||||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||||
if scale is None:
|
if scale is None:
|
||||||
scale = k.shape[-1] ** -0.5
|
scale = k.shape[-1] ** -0.5
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from vllm.triton_utils import tl, triton
|
|||||||
|
|
||||||
from .index import prepare_chunk_indices
|
from .index import prepare_chunk_indices
|
||||||
from .op import exp
|
from .op import exp
|
||||||
|
from .utils import FLA_CHUNK_SIZE
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics(
|
@triton.heuristics(
|
||||||
@@ -103,7 +104,8 @@ def chunk_scaled_dot_kkt_fwd(
|
|||||||
g: torch.Tensor | None = None,
|
g: torch.Tensor | None = None,
|
||||||
beta: torch.Tensor | None = None,
|
beta: torch.Tensor | None = None,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
chunk_size: int = 64,
|
chunk_indices: torch.Tensor | None = None,
|
||||||
|
chunk_size: int = FLA_CHUNK_SIZE,
|
||||||
output_dtype: torch.dtype = torch.float32,
|
output_dtype: torch.dtype = torch.float32,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
@@ -119,6 +121,9 @@ def chunk_scaled_dot_kkt_fwd(
|
|||||||
cu_seqlens (torch.Tensor):
|
cu_seqlens (torch.Tensor):
|
||||||
The cumulative sequence lengths of the input tensor.
|
The cumulative sequence lengths of the input tensor.
|
||||||
Default: None
|
Default: None
|
||||||
|
chunk_indices (torch.Tensor):
|
||||||
|
Pre-computed chunk indices. If None and cu_seqlens is provided,
|
||||||
|
computed internally. Default: None
|
||||||
chunk_size (int):
|
chunk_size (int):
|
||||||
The chunk size. Default: 64.
|
The chunk size. Default: 64.
|
||||||
output_dtype (torch.dtype):
|
output_dtype (torch.dtype):
|
||||||
@@ -132,9 +137,8 @@ def chunk_scaled_dot_kkt_fwd(
|
|||||||
B, T, Hg, K = k.shape
|
B, T, Hg, K = k.shape
|
||||||
H = beta.shape[-1]
|
H = beta.shape[-1]
|
||||||
BT = chunk_size
|
BT = chunk_size
|
||||||
chunk_indices = (
|
if chunk_indices is None and cu_seqlens is not None:
|
||||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
|
||||||
)
|
|
||||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||||
|
|
||||||
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||||
|
|||||||
@@ -162,6 +162,7 @@ def chunk_local_cumsum_scalar(
|
|||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
reverse: bool = False,
|
reverse: bool = False,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
chunk_indices: torch.Tensor | None = None,
|
||||||
head_first: bool = False,
|
head_first: bool = False,
|
||||||
output_dtype: torch.dtype | None = torch.float,
|
output_dtype: torch.dtype | None = torch.float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -172,10 +173,9 @@ def chunk_local_cumsum_scalar(
|
|||||||
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
|
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
|
||||||
"chunk_size must be a power of 2"
|
"chunk_size must be a power of 2"
|
||||||
)
|
)
|
||||||
|
if chunk_indices is None and cu_seqlens is not None:
|
||||||
|
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||||
BT = chunk_size
|
BT = chunk_size
|
||||||
chunk_indices = (
|
|
||||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
|
||||||
)
|
|
||||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||||
grid = (NT, B * H)
|
grid = (NT, B * H)
|
||||||
@@ -199,6 +199,7 @@ def chunk_local_cumsum_vector(
|
|||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
reverse: bool = False,
|
reverse: bool = False,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
chunk_indices: torch.Tensor | None = None,
|
||||||
head_first: bool = False,
|
head_first: bool = False,
|
||||||
output_dtype: torch.dtype | None = torch.float,
|
output_dtype: torch.dtype | None = torch.float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -206,16 +207,13 @@ def chunk_local_cumsum_vector(
|
|||||||
B, H, T, S = g.shape
|
B, H, T, S = g.shape
|
||||||
else:
|
else:
|
||||||
B, T, H, S = g.shape
|
B, T, H, S = g.shape
|
||||||
BT = chunk_size
|
|
||||||
chunk_indices = (
|
|
||||||
prepare_chunk_indices(cu_seqlens, chunk_size)
|
|
||||||
if cu_seqlens is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
|
||||||
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
|
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
|
||||||
"chunk_size must be a power of 2"
|
"chunk_size must be a power of 2"
|
||||||
)
|
)
|
||||||
|
if chunk_indices is None and cu_seqlens is not None:
|
||||||
|
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||||
|
BT = chunk_size
|
||||||
|
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||||
|
|
||||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||||
|
|
||||||
@@ -247,6 +245,7 @@ def chunk_local_cumsum(
|
|||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
reverse: bool = False,
|
reverse: bool = False,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
chunk_indices: torch.Tensor | None = None,
|
||||||
head_first: bool = False,
|
head_first: bool = False,
|
||||||
output_dtype: torch.dtype | None = torch.float,
|
output_dtype: torch.dtype | None = torch.float,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -257,11 +256,23 @@ def chunk_local_cumsum(
|
|||||||
)
|
)
|
||||||
if len(g.shape) == 3:
|
if len(g.shape) == 3:
|
||||||
return chunk_local_cumsum_scalar(
|
return chunk_local_cumsum_scalar(
|
||||||
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
|
g,
|
||||||
|
chunk_size,
|
||||||
|
reverse,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
head_first,
|
||||||
|
output_dtype,
|
||||||
)
|
)
|
||||||
elif len(g.shape) == 4:
|
elif len(g.shape) == 4:
|
||||||
return chunk_local_cumsum_vector(
|
return chunk_local_cumsum_vector(
|
||||||
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
|
g,
|
||||||
|
chunk_size,
|
||||||
|
reverse,
|
||||||
|
cu_seqlens,
|
||||||
|
chunk_indices,
|
||||||
|
head_first,
|
||||||
|
output_dtype,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from .index import prepare_chunk_indices
|
|||||||
from .l2norm import l2norm_fwd
|
from .l2norm import l2norm_fwd
|
||||||
from .op import exp, log
|
from .op import exp, log
|
||||||
from .solve_tril import solve_tril
|
from .solve_tril import solve_tril
|
||||||
from .utils import is_amd
|
from .utils import FLA_CHUNK_SIZE, is_amd
|
||||||
|
|
||||||
BT_LIST_AUTOTUNE = [32, 64, 128]
|
BT_LIST_AUTOTUNE = [32, 64, 128]
|
||||||
NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32]
|
NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32]
|
||||||
@@ -721,7 +721,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
|
|||||||
beta: torch.Tensor | None = None,
|
beta: torch.Tensor | None = None,
|
||||||
scale: float | None = None,
|
scale: float | None = None,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
chunk_size: int = 64,
|
chunk_size: int = FLA_CHUNK_SIZE,
|
||||||
output_dtype: torch.dtype = torch.float32,
|
output_dtype: torch.dtype = torch.float32,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1178,7 +1178,7 @@ def chunk_kda_fwd(
|
|||||||
output_final_state: bool,
|
output_final_state: bool,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
chunk_size = 64
|
chunk_size = FLA_CHUNK_SIZE
|
||||||
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
|
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
|
||||||
# the intra Aqk is kept in fp32
|
# the intra Aqk is kept in fp32
|
||||||
# the computation has very marginal effect on the entire throughput
|
# the computation has very marginal effect on the entire throughput
|
||||||
@@ -1189,6 +1189,7 @@ def chunk_kda_fwd(
|
|||||||
beta=beta,
|
beta=beta,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_size=chunk_size,
|
||||||
output_dtype=torch.float32,
|
output_dtype=torch.float32,
|
||||||
)
|
)
|
||||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||||
|
|||||||
@@ -507,6 +507,7 @@ def merge_16x16_to_64x64_inverse_kernel(
|
|||||||
def solve_tril(
|
def solve_tril(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
chunk_indices: torch.Tensor | None = None,
|
||||||
output_dtype: torch.dtype = torch.float,
|
output_dtype: torch.dtype = torch.float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -518,6 +519,8 @@ def solve_tril(
|
|||||||
[B, T, H, BT], where BT should only be 16, 32, or 64.
|
[B, T, H, BT], where BT should only be 16, 32, or 64.
|
||||||
cu_seqlens (torch.Tensor):
|
cu_seqlens (torch.Tensor):
|
||||||
The cumulative sequence lengths of the input tensor. Default: `None`.
|
The cumulative sequence lengths of the input tensor. Default: `None`.
|
||||||
|
chunk_indices (torch.Tensor):
|
||||||
|
Pre-computed chunk indices. Default: `None`.
|
||||||
output_dtype (torch.dtype):
|
output_dtype (torch.dtype):
|
||||||
The dtype of the output tensor. Default: `torch.float`.
|
The dtype of the output tensor. Default: `torch.float`.
|
||||||
If `None`, the output dtype will be the same as the input dtype.
|
If `None`, the output dtype will be the same as the input dtype.
|
||||||
@@ -529,9 +532,8 @@ def solve_tril(
|
|||||||
output_dtype = A.dtype if output_dtype is None else output_dtype
|
output_dtype = A.dtype if output_dtype is None else output_dtype
|
||||||
|
|
||||||
B, T, H, BT = A.shape
|
B, T, H, BT = A.shape
|
||||||
chunk_indices = (
|
if chunk_indices is None and cu_seqlens is not None:
|
||||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
|
||||||
)
|
|
||||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
||||||
|
|
||||||
Ai = torch.zeros_like(A, dtype=output_dtype)
|
Ai = torch.zeros_like(A, dtype=output_dtype)
|
||||||
|
|||||||
@@ -123,14 +123,14 @@ def recompute_w_u_fwd(
|
|||||||
g_cumsum: torch.Tensor,
|
g_cumsum: torch.Tensor,
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor | None,
|
cu_seqlens: torch.Tensor | None,
|
||||||
|
chunk_indices: torch.Tensor | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||||
H = v.shape[-2]
|
H = v.shape[-2]
|
||||||
BT = A.shape[-1]
|
BT = A.shape[-1]
|
||||||
|
|
||||||
chunk_indices = (
|
if chunk_indices is None and cu_seqlens is not None:
|
||||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
|
||||||
)
|
|
||||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||||
BK = 64
|
BK = 64
|
||||||
BV = 64
|
BV = 64
|
||||||
|
|||||||
@@ -163,6 +163,8 @@ class ChunkGatedDeltaRule(CustomOp):
|
|||||||
initial_state: torch.Tensor,
|
initial_state: torch.Tensor,
|
||||||
output_final_state: bool,
|
output_final_state: bool,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
chunk_indices: torch.Tensor | None = None,
|
||||||
|
chunk_offsets: torch.Tensor | None = None,
|
||||||
use_qk_l2norm_in_kernel: bool = True,
|
use_qk_l2norm_in_kernel: bool = True,
|
||||||
):
|
):
|
||||||
return fi_chunk_gated_delta_rule(
|
return fi_chunk_gated_delta_rule(
|
||||||
@@ -187,6 +189,8 @@ class ChunkGatedDeltaRule(CustomOp):
|
|||||||
initial_state: torch.Tensor,
|
initial_state: torch.Tensor,
|
||||||
output_final_state: bool,
|
output_final_state: bool,
|
||||||
cu_seqlens: torch.Tensor | None = None,
|
cu_seqlens: torch.Tensor | None = None,
|
||||||
|
chunk_indices: torch.Tensor | None = None,
|
||||||
|
chunk_offsets: torch.Tensor | None = None,
|
||||||
use_qk_l2norm_in_kernel: bool = True,
|
use_qk_l2norm_in_kernel: bool = True,
|
||||||
):
|
):
|
||||||
return fla_chunk_gated_delta_rule(
|
return fla_chunk_gated_delta_rule(
|
||||||
@@ -198,6 +202,8 @@ class ChunkGatedDeltaRule(CustomOp):
|
|||||||
initial_state=initial_state,
|
initial_state=initial_state,
|
||||||
output_final_state=output_final_state,
|
output_final_state=output_final_state,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
chunk_offsets=chunk_offsets,
|
||||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -959,6 +965,8 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
|||||||
initial_state=initial_state,
|
initial_state=initial_state,
|
||||||
output_final_state=True,
|
output_final_state=True,
|
||||||
cu_seqlens=non_spec_query_start_loc,
|
cu_seqlens=non_spec_query_start_loc,
|
||||||
|
chunk_indices=attn_metadata.chunk_indices,
|
||||||
|
chunk_offsets=attn_metadata.chunk_offsets,
|
||||||
use_qk_l2norm_in_kernel=False,
|
use_qk_l2norm_in_kernel=False,
|
||||||
)
|
)
|
||||||
# Init cache
|
# Init cache
|
||||||
|
|||||||
@@ -63,6 +63,10 @@ class GDNAttentionMetadata:
|
|||||||
|
|
||||||
num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]
|
num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]
|
||||||
|
|
||||||
|
# Pre-computed FLA chunk metadata (avoids GPU->CPU sync in prepare_chunk_indices)
|
||||||
|
chunk_indices: torch.Tensor | None = None
|
||||||
|
chunk_offsets: torch.Tensor | None = None
|
||||||
|
|
||||||
# The following attributes are for triton implementation of causal_conv1d
|
# The following attributes are for triton implementation of causal_conv1d
|
||||||
nums_dict: dict | None = None
|
nums_dict: dict | None = None
|
||||||
batch_ptr: torch.Tensor | None = None
|
batch_ptr: torch.Tensor | None = None
|
||||||
@@ -305,6 +309,26 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
|||||||
assert num_accepted_tokens is not None
|
assert num_accepted_tokens is not None
|
||||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||||
|
|
||||||
|
chunk_indices: torch.Tensor | None = None
|
||||||
|
chunk_offsets: torch.Tensor | None = None
|
||||||
|
if num_prefills > 0:
|
||||||
|
# Only prefill batches use FLA chunk ops.
|
||||||
|
# Pre-compute on CPU and async-copy to GPU to avoid
|
||||||
|
# GPU→CPU sync (.tolist()) in prepare_chunk_indices.
|
||||||
|
from vllm.model_executor.layers.fla.ops.index import (
|
||||||
|
prepare_chunk_indices,
|
||||||
|
prepare_chunk_offsets,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fla.ops.utils import FLA_CHUNK_SIZE
|
||||||
|
|
||||||
|
gpu_device = query_start_loc.device
|
||||||
|
chunk_indices = prepare_chunk_indices(
|
||||||
|
non_spec_query_start_loc_cpu, FLA_CHUNK_SIZE
|
||||||
|
).to(device=gpu_device, non_blocking=True)
|
||||||
|
chunk_offsets = prepare_chunk_offsets(
|
||||||
|
non_spec_query_start_loc_cpu, FLA_CHUNK_SIZE
|
||||||
|
).to(device=gpu_device, non_blocking=True)
|
||||||
|
|
||||||
if num_prefills > 0:
|
if num_prefills > 0:
|
||||||
has_initial_state = context_lens_tensor > 0
|
has_initial_state = context_lens_tensor > 0
|
||||||
if spec_sequence_masks is not None:
|
if spec_sequence_masks is not None:
|
||||||
@@ -405,6 +429,8 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
|||||||
num_spec_decode_tokens=num_spec_decode_tokens,
|
num_spec_decode_tokens=num_spec_decode_tokens,
|
||||||
num_actual_tokens=m.num_actual_tokens,
|
num_actual_tokens=m.num_actual_tokens,
|
||||||
has_initial_state=has_initial_state,
|
has_initial_state=has_initial_state,
|
||||||
|
chunk_indices=chunk_indices,
|
||||||
|
chunk_offsets=chunk_offsets,
|
||||||
spec_query_start_loc=spec_query_start_loc,
|
spec_query_start_loc=spec_query_start_loc,
|
||||||
non_spec_query_start_loc=non_spec_query_start_loc,
|
non_spec_query_start_loc=non_spec_query_start_loc,
|
||||||
spec_state_indices_tensor=spec_state_indices_tensor,
|
spec_state_indices_tensor=spec_state_indices_tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user