Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -23,22 +23,22 @@ from .utils import SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None):
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
# obtain WY representation. u is actually the new v.
A = chunk_scaled_dot_kkt_fwd(k=k,
beta=beta,
g_cumsum=g,
cu_seqlens=cu_seqlens,
output_dtype=torch.float32)
A = chunk_scaled_dot_kkt_fwd(
k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u = recompute_w_u_fwd(
k=k,
@@ -73,21 +73,22 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False):
@torch.amp.custom_fwd(device_type="cuda")
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
k = l2norm_fwd(k)
@@ -109,17 +110,19 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@torch.compiler.disable
def chunk_gated_delta_rule(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False):
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
Args:
q (torch.Tensor):
@@ -184,42 +187,55 @@ def chunk_gated_delta_rule(q: torch.Tensor,
)
"""
assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert len(
beta.shape
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
assert q.dtype != torch.float32, (
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
)
assert len(beta.shape) == 3, (
"beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
)
if head_first:
raise DeprecationWarning(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead.",
stacklevel=2)
stacklevel=2,
)
q, k, v, beta, g = map(
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
(q, k, v, beta, g))
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
)
if not head_first and q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2)
stacklevel=2,
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
if initial_state is not None and initial_state.shape[0] != len(
cu_seqlens) - 1:
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
use_qk_l2norm_in_kernel)
q,
k,
v,
g,
beta,
scale,
initial_state,
output_final_state,
cu_seqlens,
use_qk_l2norm_in_kernel,
)
if head_first:
o = rearrange(o, 'b t h ... -> b h t ...')
o = rearrange(o, "b t h ... -> b h t ...")
return o, final_state

View File

@@ -20,22 +20,26 @@ from .utils import is_nvidia_hopper, use_cuda_graph
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
@triton.heuristics({
'USE_G': lambda args: args['g'] is not None,
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
'SAVE_NEW_VALUE': lambda args: args['v_new'] is not None,
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=[
triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4] for num_stages in [2, 3, 4] for BV in [32, 64]
triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4]
for num_stages in [2, 3, 4]
for BV in [32, 64]
],
key=['H', 'K', 'V', 'BT', 'USE_G'],
key=["H", "K", "V", "BT", "USE_G"],
use_cuda_graph=use_cuda_graph,
)
@triton.jit(do_not_specialize=['T'])
@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
k,
v,
@@ -63,8 +67,10 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
@@ -100,87 +106,98 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
# load initial state
if USE_INITIAL_STATE:
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV),
(1, 0))
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
if K > 64:
p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV),
(64, BV), (1, 0))
p_h0_2 = tl.make_block_ptr(
h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
)
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
if K > 128:
p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV),
(64, BV), (1, 0))
p_h0_3 = tl.make_block_ptr(
h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
)
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
if K > 192:
p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV),
(64, BV), (1, 0))
p_h0_4 = tl.make_block_ptr(
h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
)
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
# main recurrence
for i_t in range(NT):
p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1),
(0, i_v * BV), (64, BV), (1, 0))
p_h1 = tl.make_block_ptr(
h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)
)
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1),
(64, i_v * BV), (64, BV), (1, 0))
tl.store(p_h2,
b_h2.to(p_h2.dtype.element_ty),
boundary_check=(0, 1))
p_h2 = tl.make_block_ptr(
h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
)
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1),
(128, i_v * BV), (64, BV), (1, 0))
tl.store(p_h3,
b_h3.to(p_h3.dtype.element_ty),
boundary_check=(0, 1))
p_h3 = tl.make_block_ptr(
h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
)
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1),
(192, i_v * BV), (64, BV), (1, 0))
tl.store(p_h4,
b_h4.to(p_h4.dtype.element_ty),
boundary_check=(0, 1))
p_h4 = tl.make_block_ptr(
h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
)
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV),
(BT, BV), (1, 0))
p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1),
(i_t * BT, i_v * BV), (BT, BV),
(1, 0)) if SAVE_NEW_VALUE else None
p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_v_new = (
tl.make_block_ptr(
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
if SAVE_NEW_VALUE
else None
)
b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0),
(BT, 64), (1, 0))
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
if K > 64:
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64),
(BT, 64), (1, 0))
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
if K > 128:
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128),
(BT, 64), (1, 0))
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
if K > 192:
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192),
(BT, 64), (1, 0))
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
if SAVE_NEW_VALUE:
p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1),
(i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_v_new,
b_v_new.to(p_v_new.dtype.element_ty),
boundary_check=(0, 1))
p_v_new = tl.make_block_ptr(
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
tl.store(
p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)
)
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)) < T
last_idx = min((i_t + 1) * BT, T) - 1
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ),
(i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
p_g = tl.make_block_ptr(
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
b_g_last = exp(b_g_last)
b_h1 = b_h1 * b_g_last
@@ -191,49 +208,49 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
if K > 192:
b_h4 = b_h4 * b_g_last
b_v_new = b_v_new.to(k.dtype.element_ty)
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT),
(64, BT), (0, 1))
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h1 += tl.dot(b_k, b_v_new)
if K > 64:
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT),
(64, BT), (0, 1))
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h2 += tl.dot(b_k, b_v_new)
if K > 128:
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT),
(64, BT), (0, 1))
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h3 += tl.dot(b_k, b_v_new)
if K > 192:
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT),
(64, BT), (0, 1))
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h4 += tl.dot(b_k, b_v_new)
# epilogue
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV),
(1, 0))
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV),
(64, BV), (1, 0))
tl.store(p_ht,
b_h2.to(p_ht.dtype.element_ty),
boundary_check=(0, 1))
p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
)
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV),
(64, BV), (1, 0))
tl.store(p_ht,
b_h3.to(p_ht.dtype.element_ty),
boundary_check=(0, 1))
p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
)
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV),
(64, BV), (1, 0))
tl.store(p_ht,
b_h4.to(p_ht.dtype.element_ty),
boundary_check=(0, 1))
p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
)
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
def chunk_gated_delta_rule_fwd_h(
@@ -251,24 +268,31 @@ def chunk_gated_delta_rule_fwd_h(
H = u.shape[-2]
BT = chunk_size
chunk_indices = prepare_chunk_indices(
cu_seqlens, chunk_size) if cu_seqlens is not None else None
chunk_indices = (
prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(
chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
N, NT, chunk_offsets = (
len(cu_seqlens) - 1,
len(chunk_indices),
prepare_chunk_offsets(cu_seqlens, BT),
)
assert K <= 256, "current kernel does not support head dimension larger than 256."
h = k.new_empty(B, NT, H, K, V)
final_state = k.new_empty(
N, H, K, V, dtype=torch.float32) if output_final_state else None
final_state = (
k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
)
v_new = torch.empty_like(u) if save_new_value else None
def grid(meta):
return (triton.cdiv(V, meta['BV']), N * H)
return (triton.cdiv(V, meta["BV"]), N * H)
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
k=k,
@@ -286,5 +310,6 @@ def chunk_gated_delta_rule_fwd_h(
Hg=Hg,
K=K,
V=V,
BT=BT)
BT=BT,
)
return h, v_new, final_state

View File

@@ -23,24 +23,23 @@ BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
@triton.heuristics({
'USE_G': lambda args: args['g'] is not None,
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
})
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=[
triton.Config({
'BK': BK,
'BV': BV
},
num_warps=num_warps,
num_stages=num_stages) for BK in BKV_LIST
for BV in BKV_LIST for num_warps in NUM_WARPS
triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages)
for BK in BKV_LIST
for BV in BKV_LIST
for num_warps in NUM_WARPS
for num_stages in [2, 3, 4]
],
key=['H', 'K', 'V', 'BT'],
key=["H", "K", "V", "BT"],
)
@triton.jit(do_not_specialize=['T'])
@triton.jit(do_not_specialize=["T"])
def chunk_fwd_kernel_o(
q,
k,
@@ -67,10 +66,14 @@ def chunk_fwd_kernel_o(
if IS_VARLEN:
i_tg = i_t
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
NT = tl.cdiv(T, BT)
else:
@@ -89,12 +92,15 @@ def chunk_fwd_kernel_o(
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
(BT, BK), (1, 0))
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT),
(BK, BT), (0, 1))
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV),
(BK, BV), (1, 0))
p_q = tl.make_block_ptr(
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_k = tl.make_block_ptr(
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
)
p_h = tl.make_block_ptr(
h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
)
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT]
@@ -109,8 +115,8 @@ def chunk_fwd_kernel_o(
if USE_G:
g += bos * H + i_h
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_o = b_o * exp(b_g)[:, None]
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
@@ -119,10 +125,12 @@ def chunk_fwd_kernel_o(
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0)
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
(BT, BV), (1, 0))
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
(BT, BV), (1, 0))
p_v = tl.make_block_ptr(
v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_o = tl.make_block_ptr(
o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1))
# to fix mma -> mma layout conversion
@@ -132,30 +140,32 @@ def chunk_fwd_kernel_o(
def chunk_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None, # cumsum of log decay
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64) -> torch.Tensor:
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None, # cumsum of log decay
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
if FLA_GDN_FIX_BT:
BT = 64
else:
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
o = torch.empty_like(v)
def grid(meta):
return (triton.cdiv(V, meta['BV']), NT, B * H)
return (triton.cdiv(V, meta["BV"]), NT, B * H)
chunk_fwd_kernel_o[grid](
q,

View File

@@ -17,19 +17,22 @@ from .index import prepare_chunk_indices
from .op import exp
@triton.heuristics({
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
'USE_G': lambda args: args['g_cumsum'] is not None
})
@triton.heuristics(
{
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
"USE_G": lambda args: args["g_cumsum"] is not None,
}
)
@triton.autotune(
configs=[
triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
for BK in [32, 64, 128] for num_warps in [2, 4, 8]
triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
for BK in [32, 64, 128]
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4]
],
key=['H', 'K', 'BT', 'IS_VARLEN'],
key=["H", "K", "BT", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=['T'])
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel(
k,
beta,
@@ -49,50 +52,63 @@ def chunk_scaled_dot_kkt_fwd_kernel(
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
(i_t * BT, ), (BT, ), (0, ))
b_beta = tl.load(p_beta, boundary_check=(0, ))
p_beta = tl.make_block_ptr(
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K),
(Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK),
(1, 0))
p_k = tl.make_block_ptr(
k + (bos * Hg + i_h // (H // Hg)) * K,
(T, K),
(Hg * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = b_k * b_beta[:, None]
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
if USE_G:
p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T, ), (H, ),
(i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
p_g = tl.make_block_ptr(
g_cumsum + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_g_diff = b_g[:, None] - b_g[None, :]
b_A = b_A * exp(b_g_diff)
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0)
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
(i_t * BT, 0), (BT, BT), (1, 0))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
beta: torch.Tensor,
g_cumsum: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32) -> torch.Tensor:
k: torch.Tensor,
beta: torch.Tensor,
g_cumsum: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
r"""
Compute beta * K * K^T.
@@ -120,8 +136,9 @@ def chunk_scaled_dot_kkt_fwd(
H = beta.shape[-1]
BT = chunk_size
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](

View File

@@ -20,12 +20,12 @@ from .utils import check_shared_mem, input_guard
BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
@triton.autotune(configs=[
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]
],
key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE'])
@triton.jit(do_not_specialize=['T'])
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=["B", "H", "BT", "IS_VARLEN", "REVERSE"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_local_cumsum_scalar_kernel(
s,
o,
@@ -42,40 +42,47 @@ def chunk_local_cumsum_scalar_kernel(
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
if HEAD_FIRST:
p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T, ), (1, ),
(i_t * BT, ), (BT, ), (0, ))
p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T, ), (1, ),
(i_t * BT, ), (BT, ), (0, ))
p_s = tl.make_block_ptr(
s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)
)
p_o = tl.make_block_ptr(
o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)
)
else:
p_s = tl.make_block_ptr(s + bos * H + i_h, (T, ), (H, ), (i_t * BT, ),
(BT, ), (0, ))
p_o = tl.make_block_ptr(o + bos * H + i_h, (T, ), (H, ), (i_t * BT, ),
(BT, ), (0, ))
p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
# [BT]
b_s = tl.load(p_s, boundary_check=(0, )).to(tl.float32)
b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
b_o = tl.cumsum(b_s, axis=0)
if REVERSE:
b_z = tl.sum(b_s, axis=0)
b_o = -b_o + b_z[None] + b_s
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, ))
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
@triton.autotune(configs=[
triton.Config({'BS': BS}, num_warps=num_warps) for BS in BS_LIST
for num_warps in [2, 4, 8]
],
key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE'])
@triton.jit(do_not_specialize=['T'])
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({"BS": BS}, num_warps=num_warps)
for BS in BS_LIST
for num_warps in [2, 4, 8]
],
key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_local_cumsum_vector_kernel(
s,
o,
@@ -94,30 +101,58 @@ def chunk_local_cumsum_vector_kernel(
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, BT)
if REVERSE:
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0)
else:
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0)
if HEAD_FIRST:
p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1),
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1),
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_s = tl.make_block_ptr(
s + (bos * H + i_h * T) * S,
(T, S),
(S, 1),
(i_t * BT, i_s * BS),
(BT, BS),
(1, 0),
)
p_o = tl.make_block_ptr(
o + (bos * H + i_h * T) * S,
(T, S),
(S, 1),
(i_t * BT, i_s * BS),
(BT, BS),
(1, 0),
)
else:
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1),
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1),
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_s = tl.make_block_ptr(
s + (bos * H + i_h) * S,
(T, S),
(H * S, 1),
(i_t * BT, i_s * BS),
(BT, BS),
(1, 0),
)
p_o = tl.make_block_ptr(
o + (bos * H + i_h) * S,
(T, S),
(H * S, 1),
(i_t * BT, i_s * BS),
(BT, BS),
(1, 0),
)
# [BT, BS]
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
b_o = tl.dot(m_s, b_s, allow_tf32=False)
@@ -125,102 +160,122 @@ def chunk_local_cumsum_vector_kernel(
def chunk_local_cumsum_scalar(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor:
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
) -> torch.Tensor:
if head_first:
B, H, T = g.shape
else:
B, T, H = g.shape
assert chunk_size == 2**(chunk_size.bit_length() -
1), "chunk_size must be a power of 2"
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
"chunk_size must be a power of 2"
)
BT = chunk_size
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
grid = (NT, B * H)
chunk_local_cumsum_scalar_kernel[grid](g_org,
g,
cu_seqlens,
chunk_indices,
T=T,
B=B,
H=H,
BT=BT,
HEAD_FIRST=head_first,
REVERSE=reverse)
chunk_local_cumsum_scalar_kernel[grid](
g_org,
g,
cu_seqlens,
chunk_indices,
T=T,
B=B,
H=H,
BT=BT,
HEAD_FIRST=head_first,
REVERSE=reverse,
)
return g
def chunk_local_cumsum_vector(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor:
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
) -> torch.Tensor:
if head_first:
B, H, T, S = g.shape
else:
B, T, H, S = g.shape
BT = chunk_size
chunk_indices = prepare_chunk_indices(
cu_seqlens, chunk_size) if cu_seqlens is not None else None
chunk_indices = (
prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
assert chunk_size == 2**(chunk_size.bit_length() -
1), "chunk_size must be a power of 2"
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
"chunk_size must be a power of 2"
)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
def grid(meta):
return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H)
# keep cumulative normalizer in fp32
# this kernel is equivalent to
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
chunk_local_cumsum_vector_kernel[grid](g_org,
g,
cu_seqlens,
chunk_indices,
T=T,
B=B,
H=H,
S=S,
BT=BT,
HEAD_FIRST=head_first,
REVERSE=reverse)
chunk_local_cumsum_vector_kernel[grid](
g_org,
g,
cu_seqlens,
chunk_indices,
T=T,
B=B,
H=H,
S=S,
BT=BT,
HEAD_FIRST=head_first,
REVERSE=reverse,
)
return g
@input_guard
def chunk_local_cumsum(g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
**kwargs) -> torch.Tensor:
def chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
**kwargs,
) -> torch.Tensor:
if not head_first and g.shape[1] < g.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2)
stacklevel=2,
)
if cu_seqlens is not None:
assert g.shape[
0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
assert g.shape[0] == 1, (
"Only batch size 1 is supported when cu_seqlens are provided"
)
if len(g.shape) == 3:
return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens,
head_first, output_dtype)
return chunk_local_cumsum_scalar(
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
)
elif len(g.shape) == 4:
return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens,
head_first, output_dtype)
return chunk_local_cumsum_vector(
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
)
else:
raise ValueError(f"Unsupported input shape {g.shape}. "
f"which should be (B, T, H, D) if `head_first=False` "
f"or (B, H, T, D) otherwise")
raise ValueError(
f"Unsupported input shape {g.shape}. "
f"which should be (B, T, H, D) if `head_first=False` "
f"or (B, H, T, D) otherwise"
)

View File

@@ -16,17 +16,15 @@ from vllm.triton_utils import tl, triton
from .op import exp
@triton.heuristics({
'USE_INITIAL_STATE':
lambda args: args['h0'] is not None,
'IS_VARLEN':
lambda args: args['cu_seqlens'] is not None,
"IS_CONTINUOUS_BATCHING":
lambda args: args['ssm_state_indices'] is not None,
"IS_SPEC_DECODING":
lambda args: args['num_accepted_tokens'] is not None,
})
@triton.jit(do_not_specialize=['N', 'T'])
@triton.heuristics(
{
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
"IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
}
)
@triton.jit(do_not_specialize=["N", "T"])
def fused_recurrent_gated_delta_rule_fwd_kernel(
q,
k,
@@ -55,8 +53,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
stride_indices_tok: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
IS_BETA_HEADWISE: tl.
constexpr, # whether beta is headwise vector or scalar,
IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
@@ -66,8 +63,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
T = eos - bos
else:
@@ -102,8 +101,13 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq +
i_t).to(tl.int64) * stride_init_state_token
p_h0 = (
h0
+ tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
tl.int64
)
* stride_init_state_token
)
else:
p_h0 = h0 + bos * HV * K * V
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
@@ -136,8 +140,13 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
# keep the states for multi-query tokens
if INPLACE_FINAL_STATE:
p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq +
i_t).to(tl.int64) * stride_final_state_token
p_ht = (
ht
+ tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
tl.int64
)
* stride_final_state_token
)
else:
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
@@ -228,21 +237,22 @@ def fused_recurrent_gated_delta_rule_fwd(
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False):
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
o, final_state = fused_recurrent_gated_delta_rule_fwd(
q=q.contiguous(),
k=k.contiguous(),
@@ -342,9 +352,10 @@ def fused_recurrent_gated_delta_rule(
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
f"Please flatten variable-length inputs before processing."
)
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
else:
assert scale > 0, "scale must be positive"
if beta is None:

View File

@@ -20,20 +20,22 @@ def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
@tensor_cache
def prepare_chunk_indices(cu_seqlens: torch.LongTensor,
chunk_size: int) -> torch.LongTensor:
indices = torch.cat([
torch.arange(n)
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
])
return torch.stack([indices.eq(0).cumsum(0) - 1, indices],
1).to(cu_seqlens)
def prepare_chunk_indices(
cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
indices = torch.cat(
[
torch.arange(n)
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
]
)
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
@tensor_cache
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor,
chunk_size: int) -> torch.LongTensor:
return torch.cat([
cu_seqlens.new_tensor([0]),
triton.cdiv(prepare_lens(cu_seqlens), chunk_size)
]).cumsum(-1)
def prepare_chunk_offsets(
cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
return torch.cat(
[cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
).cumsum(-1)

View File

@@ -19,11 +19,12 @@ BT_LIST = [8, 16, 32, 64, 128]
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
@triton.autotune(configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16, 32]
],
key=['D'])
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
],
key=["D"],
)
@triton.jit
def l2norm_fwd_kernel1(
x,
@@ -47,11 +48,14 @@ def l2norm_fwd_kernel1(
tl.store(y + cols, b_y, mask=mask)
@triton.autotune(configs=[
triton.Config({'BT': BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
],
key=['D'])
@triton.autotune(
configs=[
triton.Config({"BT": BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16]
for BT in BT_LIST
],
key=["D"],
)
@triton.jit(do_not_specialize=["NB"])
def l2norm_fwd_kernel(
x,
@@ -85,9 +89,9 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
def l2norm_fwd(x: torch.Tensor,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None):
def l2norm_fwd(
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
):
x_shape_og = x.shape
x = x.view(-1, x.shape[-1])
# allocate output
@@ -107,7 +111,7 @@ def l2norm_fwd(x: torch.Tensor,
if not USE_DEFAULT_FLA_NORM:
MBLOCK = 32
# M, N = x.shape
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK),)](
x,
y,
eps,
@@ -120,7 +124,7 @@ def l2norm_fwd(x: torch.Tensor,
NB = triton.cdiv(T, 2048)
def grid(meta):
return (triton.cdiv(T, meta['BT']), )
return (triton.cdiv(T, meta["BT"]),)
l2norm_fwd_kernel[grid](
x,
@@ -132,7 +136,7 @@ def l2norm_fwd(x: torch.Tensor,
BD=BD,
)
else:
l2norm_fwd_kernel1[(T, )](
l2norm_fwd_kernel1[(T,)](
x,
y,
eps=eps,

View File

@@ -25,14 +25,16 @@ from vllm.triton_utils import tl, triton
from .utils import input_guard
def rms_norm_ref(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
upcast=True):
def rms_norm_ref(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
upcast=True,
):
dtype = x.dtype
weight = weight.float()
bias = bias.float() if bias is not None else None
@@ -43,12 +45,10 @@ def rms_norm_ref(x,
x = x * F.silu(z)
if group_size is None:
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
weight)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
else:
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
eps)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
if bias is not None:
out = out + bias
@@ -57,10 +57,12 @@ def rms_norm_ref(x,
return out.to(dtype)
@triton.heuristics({
"HAS_BIAS": lambda args: args["B"] is not None,
"HAS_Z": lambda args: args["Z"] is not None,
})
@triton.heuristics(
{
"HAS_BIAS": lambda args: args["B"] is not None,
"HAS_Z": lambda args: args["Z"] is not None,
}
)
@triton.jit
def layer_norm_fwd_kernel(
X, # pointer to the input
@@ -97,17 +99,17 @@ def layer_norm_fwd_kernel(
B += group * N
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
x *= z * tl.sigmoid(z)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.)
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
@@ -145,64 +147,68 @@ def layer_norm_fwd(
if z is not None:
assert z.stride(-1) == 1
assert z.shape == (M, N)
assert weight.shape == (N, )
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N, )
assert bias.shape == (N,)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
device=x.device) if not is_rms_norm else None
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
mean = (
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
layer_norm_fwd_kernel[grid](x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps)
layer_norm_fwd_kernel[grid](
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
)
return out, mean, rstd
class LayerNormFn(torch.autograd.Function):
@input_guard
@staticmethod
def forward(ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
def forward(
ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
@@ -236,31 +242,30 @@ class LayerNormFn(torch.autograd.Function):
return y.reshape(x_shape_og)
def layernorm_fn(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False):
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
norm_before_gate, is_rms_norm)
def layernorm_fn(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
)
def rmsnorm_fn(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True):
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
norm_before_gate, True)
def rmsnorm_fn(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, True
)
class LayerNormGated(nn.Module):
def __init__(
self,
hidden_size,
@@ -288,19 +293,19 @@ class LayerNormGated(nn.Module):
torch.nn.init.zeros_(self.bias)
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return layernorm_fn(x,
self.weight,
self.bias,
z=z,
group_size=self.group_size,
eps=self.eps,
norm_before_gate=self.norm_before_gate)
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return layernorm_fn(
x,
self.weight,
self.bias,
z=z,
group_size=self.group_size,
eps=self.eps,
norm_before_gate=self.norm_before_gate,
)
class RMSNormGated(nn.Module):
def __init__(
self,
hidden_size,
@@ -326,12 +331,13 @@ class RMSNormGated(nn.Module):
torch.nn.init.ones_(self.weight)
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return rmsnorm_fn(x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate)
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return rmsnorm_fn(
x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
)

View File

@@ -11,7 +11,7 @@ import os
from vllm.triton_utils import tl, tldevice, triton
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
div = tldevice.fast_dividef
exp = tldevice.fast_expf
log = tldevice.fast_logf
@@ -28,7 +28,7 @@ else:
log2 = tl.log2
if not hasattr(tl, 'gather'):
if not hasattr(tl, "gather"):
@triton.jit
def gather(src, index, axis, _builder=None):

View File

@@ -17,15 +17,16 @@ from .index import prepare_chunk_indices
from .utils import input_guard
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4, 5]
for num_warps in [1, 2, 4, 8]
for num_stages in [2, 3, 4, 5]
],
key=['BT'],
key=["BT"],
)
@triton.jit(do_not_specialize=['T'])
@triton.jit(do_not_specialize=["T"])
def solve_tril_16x16_kernel(
A,
Ad,
@@ -39,10 +40,14 @@ def solve_tril_16x16_kernel(
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
@@ -51,13 +56,12 @@ def solve_tril_16x16_kernel(
Ad = Ad + (bos * H + i_h) * 16
offset = (i_t * 16) % BT
p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset),
(16, 16), (1, 0))
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16),
(1, 0))
p_A = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)
)
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0))
b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
b_A = -tl.where(
tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)
b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)
o_i = tl.arange(0, 16)
for i in range(1, min(16, T - i_t * 16)):
@@ -66,30 +70,45 @@ def solve_tril_16x16_kernel(
mask = o_i == i
b_A = tl.where(mask[:, None], b_a, b_A)
b_A += o_i[:, None] == o_i[None, :]
tl.store(p_Ai,
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(
p_Ai,
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4, 5]
for num_warps in [1, 2, 4, 8]
for num_stages in [2, 3, 4, 5]
],
key=['H', 'BT', 'IS_VARLEN'],
key=["H", "BT", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=['T'])
def merge_16x16_to_32x32_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices,
T, H: tl.constexpr, BT: tl.constexpr,
IS_VARLEN: tl.constexpr):
@triton.jit(do_not_specialize=["T"])
def merge_16x16_to_32x32_inverse_kernel(
A,
Ad,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
@@ -98,55 +117,80 @@ def merge_16x16_to_32x32_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices,
Ad += (bos * H + i_h) * 16
Ai += (bos * H + i_h) * 32
p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
(16, 16), (1, 0))
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0),
(16, 16), (1, 0))
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0),
(16, 16), (1, 0))
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0),
(16, 16), (1, 0))
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16),
(16, 16), (1, 0))
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
(16, 16), (1, 0))
p_A_21 = tl.make_block_ptr(
A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)
)
p_Ad_11 = tl.make_block_ptr(
Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0)
)
p_Ad_22 = tl.make_block_ptr(
Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)
)
p_Ai_11 = tl.make_block_ptr(
Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0)
)
p_Ai_22 = tl.make_block_ptr(
Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0)
)
p_Ai_21 = tl.make_block_ptr(
Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)
)
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'),
Ai_11,
input_precision='ieee')
tl.store(p_Ai_11,
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_22,
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_21,
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
Ai_21 = -tl.dot(
tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee"
)
tl.store(
p_Ai_11,
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_22,
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_21,
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4, 8] for num_stages in [2, 3, 4, 5]
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4, 5]
],
key=['H', 'BT', 'IS_VARLEN'],
key=["H", "BT", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=['T'])
def merge_16x16_to_64x64_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices,
T, H: tl.constexpr, BT: tl.constexpr,
IS_VARLEN: tl.constexpr):
@triton.jit(do_not_specialize=["T"])
def merge_16x16_to_64x64_inverse_kernel(
A,
Ad,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
@@ -155,26 +199,36 @@ def merge_16x16_to_64x64_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices,
Ad += (bos * H + i_h) * 16
Ai += (bos * H + i_h) * 64
p_A_21 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0),
(16, 16), (1, 0))
p_A_32 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16),
(16, 16), (1, 0))
p_A_31 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0),
(16, 16), (1, 0))
p_A_43 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32),
(16, 16), (1, 0))
p_A_42 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16),
(16, 16), (1, 0))
p_A_41 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0),
(16, 16), (1, 0))
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64, 0),
(16, 16), (1, 0))
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0),
(16, 16), (1, 0))
p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0),
(16, 16), (1, 0))
p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0),
(16, 16), (1, 0))
p_A_21 = tl.make_block_ptr(
A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)
)
p_A_32 = tl.make_block_ptr(
A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)
)
p_A_31 = tl.make_block_ptr(
A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)
)
p_A_43 = tl.make_block_ptr(
A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)
)
p_A_42 = tl.make_block_ptr(
A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)
)
p_A_41 = tl.make_block_ptr(
A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)
)
p_Ad_11 = tl.make_block_ptr(
Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), (16, 16), (1, 0)
)
p_Ad_22 = tl.make_block_ptr(
Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)
)
p_Ad_33 = tl.make_block_ptr(
Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)
)
p_Ad_44 = tl.make_block_ptr(
Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)
)
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
@@ -188,124 +242,174 @@ def merge_16x16_to_64x64_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices,
Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32)
Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32)
Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'),
Ai_11,
input_precision='ieee')
Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision='ieee'),
Ai_22,
input_precision='ieee')
Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision='ieee'),
Ai_33,
input_precision='ieee')
Ai_21 = -tl.dot(
tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee"
)
Ai_32 = -tl.dot(
tl.dot(Ai_33, A_32, input_precision="ieee"), Ai_22, input_precision="ieee"
)
Ai_43 = -tl.dot(
tl.dot(Ai_44, A_43, input_precision="ieee"), Ai_33, input_precision="ieee"
)
Ai_31 = -tl.dot(Ai_33,
tl.dot(A_31, Ai_11, input_precision='ieee') +
tl.dot(A_32, Ai_21, input_precision='ieee'),
input_precision='ieee')
Ai_42 = -tl.dot(Ai_44,
tl.dot(A_42, Ai_22, input_precision='ieee') +
tl.dot(A_43, Ai_32, input_precision='ieee'),
input_precision='ieee')
Ai_41 = -tl.dot(Ai_44,
tl.dot(A_41, Ai_11, input_precision='ieee') +
tl.dot(A_42, Ai_21, input_precision='ieee') +
tl.dot(A_43, Ai_31, input_precision='ieee'),
input_precision='ieee')
Ai_31 = -tl.dot(
Ai_33,
tl.dot(A_31, Ai_11, input_precision="ieee")
+ tl.dot(A_32, Ai_21, input_precision="ieee"),
input_precision="ieee",
)
Ai_42 = -tl.dot(
Ai_44,
tl.dot(A_42, Ai_22, input_precision="ieee")
+ tl.dot(A_43, Ai_32, input_precision="ieee"),
input_precision="ieee",
)
Ai_41 = -tl.dot(
Ai_44,
tl.dot(A_41, Ai_11, input_precision="ieee")
+ tl.dot(A_42, Ai_21, input_precision="ieee")
+ tl.dot(A_43, Ai_31, input_precision="ieee"),
input_precision="ieee",
)
p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0),
(16, 16), (1, 0))
p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16),
(16, 16), (1, 0))
p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32),
(16, 16), (1, 0))
p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48),
(16, 16), (1, 0))
p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0),
(16, 16), (1, 0))
p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0),
(16, 16), (1, 0))
p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16),
(16, 16), (1, 0))
p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0),
(16, 16), (1, 0))
p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16),
(16, 16), (1, 0))
p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32),
(16, 16), (1, 0))
tl.store(p_Ai_11,
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_22,
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_33,
Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_44,
Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_21,
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_31,
Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_32,
Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_41,
Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_42,
Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_43,
Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
p_Ai_11 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (16, 16), (1, 0)
)
p_Ai_22 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0)
)
p_Ai_33 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0)
)
p_Ai_44 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0)
)
p_Ai_21 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)
)
p_Ai_31 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)
)
p_Ai_32 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)
)
p_Ai_41 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)
)
p_Ai_42 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)
)
p_Ai_43 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)
)
tl.store(
p_Ai_11,
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_22,
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_33,
Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_44,
Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_21,
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_31,
Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_32,
Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_41,
Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_42,
Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_43,
Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
fill_zeros = tl.zeros((16, 16), dtype=tl.float32)
p_Ai_12 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 16),
(16, 16), (1, 0))
p_Ai_13 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 32),
(16, 16), (1, 0))
p_Ai_14 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 48),
(16, 16), (1, 0))
p_Ai_23 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32),
(16, 16), (1, 0))
p_Ai_24 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48),
(16, 16), (1, 0))
p_Ai_34 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48),
(16, 16), (1, 0))
tl.store(p_Ai_12,
fill_zeros.to(p_Ai_12.dtype.element_ty,
fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_13,
fill_zeros.to(p_Ai_13.dtype.element_ty,
fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_14,
fill_zeros.to(p_Ai_14.dtype.element_ty,
fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_23,
fill_zeros.to(p_Ai_23.dtype.element_ty,
fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_24,
fill_zeros.to(p_Ai_24.dtype.element_ty,
fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
tl.store(p_Ai_34,
fill_zeros.to(p_Ai_34.dtype.element_ty,
fp_downcast_rounding="rtne"),
boundary_check=(0, 1))
p_Ai_12 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), (16, 16), (1, 0)
)
p_Ai_13 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), (16, 16), (1, 0)
)
p_Ai_14 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), (16, 16), (1, 0)
)
p_Ai_23 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), (16, 16), (1, 0)
)
p_Ai_24 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), (16, 16), (1, 0)
)
p_Ai_34 = tl.make_block_ptr(
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), (16, 16), (1, 0)
)
tl.store(
p_Ai_12,
fill_zeros.to(p_Ai_12.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_13,
fill_zeros.to(p_Ai_13.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_14,
fill_zeros.to(p_Ai_14.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_23,
fill_zeros.to(p_Ai_23.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_24,
fill_zeros.to(p_Ai_24.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_34,
fill_zeros.to(p_Ai_34.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
@input_guard
def solve_tril(A: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
output_dtype: torch.dtype = torch.float) -> torch.Tensor:
def solve_tril(
A: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
output_dtype: torch.dtype = torch.float,
) -> torch.Tensor:
"""
Compute the inverse of the lower triangular matrix
A should be strictly lower triangular, i.e., A.triu() == 0.
@@ -325,15 +429,13 @@ def solve_tril(A: torch.Tensor,
assert A.shape[-1] in [16, 32, 64]
B, T, H, BT = A.shape
Ad = torch.empty(B,
T,
H,
16,
device=A.device,
dtype=torch.float if BT != 16 else output_dtype)
Ad = torch.empty(
B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype
)
chunk_indices = prepare_chunk_indices(
cu_seqlens, 16) if cu_seqlens is not None else None
chunk_indices = (
prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None
)
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16)
solve_tril_16x16_kernel[NT, B * H](
A=A,
@@ -348,9 +450,14 @@ def solve_tril(A: torch.Tensor,
return Ad
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
merge_fn = (
merge_16x16_to_32x32_inverse_kernel
if BT == 32
else merge_16x16_to_64x64_inverse_kernel
)
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
merge_fn[NT, B * H](
A=A,

View File

@@ -27,8 +27,7 @@ FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1"
SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
def tensor_cache(
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent results of a function with tensor inputs.
@@ -52,12 +51,19 @@ def tensor_cache(
nonlocal cache_entries, cache_size
for i, entry in enumerate(cache_entries):
last_args, last_kwargs, last_result = entry
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs) \
and all(a is b for a, b in zip(args, last_args)) \
and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
cache_entries = cache_entries[:i] + cache_entries[i + 1:] + [
(args, kwargs, last_result)
]
if (
len(args) == len(last_args)
and len(kwargs) == len(last_kwargs)
and all(a is b for a, b in zip(args, last_args))
and all(
k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
)
):
cache_entries = (
cache_entries[:i]
+ cache_entries[i + 1 :]
+ [(args, kwargs, last_result)]
)
return last_result
result = fn(*args, **kwargs)
@@ -70,16 +76,16 @@ def tensor_cache(
return wrapper
def input_guard(
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
contiguous_args = (i if not isinstance(i, torch.Tensor) else
i.contiguous() for i in args)
contiguous_args = (
i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
)
contiguous_kwargs = {
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
for k, v in kwargs.items()
@@ -112,11 +118,11 @@ def get_available_device() -> str:
try:
return triton.runtime.driver.active.get_current_target().backend
except BaseException:
return 'cpu'
return "cpu"
@functools.cache
def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']:
def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
device = get_available_device()
mapping = {
"cuda": "nvidia",
@@ -130,27 +136,28 @@ def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']:
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
device = get_available_device() if get_available_device() != 'hip' else 'cuda'
device = get_available_device() if get_available_device() != "hip" else "cuda"
device_torch_lib = getattr(torch, device)
device_platform = _check_platform()
is_amd = (device_platform == 'amd')
is_intel = (device_platform == 'intel')
is_nvidia = (device_platform == 'nvidia')
is_intel_alchemist = (is_intel
and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0))
is_nvidia_hopper = (is_nvidia
and ('NVIDIA H' in torch.cuda.get_device_name(0)
or torch.cuda.get_device_capability()[0] >= 9))
use_cuda_graph = (is_nvidia
and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')
is_amd = device_platform == "amd"
is_intel = device_platform == "intel"
is_nvidia = device_platform == "nvidia"
is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
is_nvidia_hopper = is_nvidia and (
"NVIDIA H" in torch.cuda.get_device_name(0)
or torch.cuda.get_device_capability()[0] >= 9
)
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
def get_all_max_shared_mem():
try:
return [
triton.runtime.driver.active.utils.get_device_properties(i)
['max_shared_mem'] for i in range(device_torch_lib.device_count())
triton.runtime.driver.active.utils.get_device_properties(i)[
"max_shared_mem"
]
for i in range(device_torch_lib.device_count())
]
except BaseException:
return [-1]

View File

@@ -17,56 +17,100 @@ from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4, 8] for num_stages in [2, 3, 4]
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4]
],
key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'],
key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=['T'])
def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices,
T, H: tl.constexpr, Hg: tl.constexpr,
K: tl.constexpr, V: tl.constexpr,
BT: tl.constexpr, BK: tl.constexpr,
BV: tl.constexpr, IS_VARLEN: tl.constexpr):
@triton.jit(do_not_specialize=["T"])
def recompute_w_u_fwd_kernel(
k,
v,
beta,
w,
u,
A,
g,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
(i_t * BT, ), (BT, ), (0, ))
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T, ), (H, ), (i_t * BT, ),
(BT, ), (0, ))
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1),
(i_t * BT, 0), (BT, BT), (1, 0))
b_beta = tl.load(p_beta, boundary_check=(0, ))
p_beta = tl.make_block_ptr(
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.load(p_A, boundary_check=(0, 1))
b_g = tl.exp(tl.load(p_g, boundary_check=(0, )))
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1),
(i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1),
(i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
p_u = tl.make_block_ptr(
u + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K),
(Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK),
(1, 0))
p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1),
(i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(
k + (bos * Hg + i_h // (H // Hg)) * K,
(T, K),
(Hg * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_w = tl.make_block_ptr(
w + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
b_w = tl.dot(b_A, b_kb)
@@ -85,8 +129,9 @@ def recompute_w_u_fwd(
H = v.shape[-2]
BT = A.shape[-1]
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 64
BV = 64