[PERF] Change GDN Attention State Layout from [N, HV, K, V] to [N, HV, V, K] (#33291)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -138,11 +138,11 @@ def chunk_gated_delta_rule(
|
|||||||
Scale factor for the RetNet attention scores.
|
Scale factor for the RetNet attention scores.
|
||||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||||
initial_state (Optional[torch.Tensor]):
|
initial_state (Optional[torch.Tensor]):
|
||||||
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
Initial state of shape `[N, H, V, K]` for `N` input sequences.
|
||||||
For equal-length input sequences, `N` equals the batch size `B`.
|
For equal-length input sequences, `N` equals the batch size `B`.
|
||||||
Default: `None`.
|
Default: `None`.
|
||||||
output_final_state (Optional[bool]):
|
output_final_state (Optional[bool]):
|
||||||
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
|
||||||
cu_seqlens (torch.LongTensor):
|
cu_seqlens (torch.LongTensor):
|
||||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||||
consistent with the FlashAttention API.
|
consistent with the FlashAttention API.
|
||||||
@@ -154,7 +154,7 @@ def chunk_gated_delta_rule(
|
|||||||
o (torch.Tensor):
|
o (torch.Tensor):
|
||||||
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
||||||
final_state (torch.Tensor):
|
final_state (torch.Tensor):
|
||||||
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
Final state of shape `[N, H, V, K]` if `output_final_state=True` else `None`.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> import torch
|
>>> import torch
|
||||||
@@ -168,7 +168,7 @@ def chunk_gated_delta_rule(
|
|||||||
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
|
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
|
||||||
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
|
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
|
||||||
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
|
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
|
||||||
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
|
>>> h0 = torch.randn(B, H, V, K, dtype=torch.bfloat16, device='cuda')
|
||||||
>>> o, ht = chunk_gated_delta_rule(
|
>>> o, ht = chunk_gated_delta_rule(
|
||||||
q, k, v, g, beta,
|
q, k, v, g, beta,
|
||||||
initial_state=h0,
|
initial_state=h0,
|
||||||
|
|||||||
@@ -81,70 +81,70 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
|||||||
NT = tl.cdiv(T, BT)
|
NT = tl.cdiv(T, BT)
|
||||||
boh = i_n * NT
|
boh = i_n * NT
|
||||||
|
|
||||||
# [BK, BV]
|
# [BV, BK]
|
||||||
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
|
b_h1 = tl.zeros([BV, 64], dtype=tl.float32)
|
||||||
if K > 64:
|
if K > 64:
|
||||||
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
|
b_h2 = tl.zeros([BV, 64], dtype=tl.float32)
|
||||||
if K > 128:
|
if K > 128:
|
||||||
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
|
b_h3 = tl.zeros([BV, 64], dtype=tl.float32)
|
||||||
if K > 192:
|
if K > 192:
|
||||||
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
|
b_h4 = tl.zeros([BV, 64], dtype=tl.float32)
|
||||||
|
|
||||||
# calculate offset
|
# calculate offset
|
||||||
h += ((boh * H + i_h) * K * V).to(tl.int64)
|
h += ((boh * H + i_h) * V * K).to(tl.int64)
|
||||||
v += ((bos * H + i_h) * V).to(tl.int64)
|
v += ((bos * H + i_h) * V).to(tl.int64)
|
||||||
k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
|
k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
|
||||||
w += ((bos * H + i_h) * K).to(tl.int64)
|
w += ((bos * H + i_h) * K).to(tl.int64)
|
||||||
if SAVE_NEW_VALUE:
|
if SAVE_NEW_VALUE:
|
||||||
v_new += ((bos * H + i_h) * V).to(tl.int64)
|
v_new += ((bos * H + i_h) * V).to(tl.int64)
|
||||||
stride_v = H * V
|
stride_v = H * V
|
||||||
stride_h = H * K * V
|
stride_h = H * V * K
|
||||||
stride_k = Hg * K
|
stride_k = Hg * K
|
||||||
stride_w = H * K
|
stride_w = H * K
|
||||||
if USE_INITIAL_STATE:
|
if USE_INITIAL_STATE:
|
||||||
h0 = h0 + i_nh * K * V
|
h0 = h0 + i_nh * V * K
|
||||||
if STORE_FINAL_STATE:
|
if STORE_FINAL_STATE:
|
||||||
ht = ht + i_nh * K * V
|
ht = ht + i_nh * V * K
|
||||||
|
|
||||||
# load initial state
|
# load initial state
|
||||||
if USE_INITIAL_STATE:
|
if USE_INITIAL_STATE:
|
||||||
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
p_h0_1 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
|
||||||
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
|
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
|
||||||
if K > 64:
|
if K > 64:
|
||||||
p_h0_2 = tl.make_block_ptr(
|
p_h0_2 = tl.make_block_ptr(
|
||||||
h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
|
h0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
|
||||||
)
|
)
|
||||||
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
|
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
|
||||||
if K > 128:
|
if K > 128:
|
||||||
p_h0_3 = tl.make_block_ptr(
|
p_h0_3 = tl.make_block_ptr(
|
||||||
h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
|
h0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
|
||||||
)
|
)
|
||||||
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
|
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
|
||||||
if K > 192:
|
if K > 192:
|
||||||
p_h0_4 = tl.make_block_ptr(
|
p_h0_4 = tl.make_block_ptr(
|
||||||
h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
|
h0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
|
||||||
)
|
)
|
||||||
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
|
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
|
||||||
|
|
||||||
# main recurrence
|
# main recurrence
|
||||||
for i_t in range(NT):
|
for i_t in range(NT):
|
||||||
p_h1 = tl.make_block_ptr(
|
p_h1 = tl.make_block_ptr(
|
||||||
h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)
|
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)
|
||||||
)
|
)
|
||||||
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
|
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
|
||||||
if K > 64:
|
if K > 64:
|
||||||
p_h2 = tl.make_block_ptr(
|
p_h2 = tl.make_block_ptr(
|
||||||
h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
|
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
|
||||||
)
|
)
|
||||||
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
|
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
|
||||||
if K > 128:
|
if K > 128:
|
||||||
p_h3 = tl.make_block_ptr(
|
p_h3 = tl.make_block_ptr(
|
||||||
h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
|
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
|
||||||
)
|
)
|
||||||
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
|
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
|
||||||
if K > 192:
|
if K > 192:
|
||||||
p_h4 = tl.make_block_ptr(
|
p_h4 = tl.make_block_ptr(
|
||||||
h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
|
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
|
||||||
)
|
)
|
||||||
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
|
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
@@ -152,25 +152,25 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
|||||||
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
|
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
|
||||||
)
|
)
|
||||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
|
b_v = tl.dot(b_w, tl.trans(b_h1).to(b_w.dtype))
|
||||||
if K > 64:
|
if K > 64:
|
||||||
p_w = tl.make_block_ptr(
|
p_w = tl.make_block_ptr(
|
||||||
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
|
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
|
||||||
)
|
)
|
||||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
|
b_v += tl.dot(b_w, tl.trans(b_h2).to(b_w.dtype))
|
||||||
if K > 128:
|
if K > 128:
|
||||||
p_w = tl.make_block_ptr(
|
p_w = tl.make_block_ptr(
|
||||||
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
|
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
|
||||||
)
|
)
|
||||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
|
b_v += tl.dot(b_w, tl.trans(b_h3).to(b_w.dtype))
|
||||||
if K > 192:
|
if K > 192:
|
||||||
p_w = tl.make_block_ptr(
|
p_w = tl.make_block_ptr(
|
||||||
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
|
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
|
||||||
)
|
)
|
||||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||||
b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
|
b_v += tl.dot(b_w, tl.trans(b_h4).to(b_w.dtype))
|
||||||
p_v = tl.make_block_ptr(
|
p_v = tl.make_block_ptr(
|
||||||
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||||
)
|
)
|
||||||
@@ -207,7 +207,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
|||||||
mask=(o_k1 < K),
|
mask=(o_k1 < K),
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
b_h1 *= exp(b_gk_last1)[:, None]
|
b_h1 *= exp(b_gk_last1)[None, :]
|
||||||
if K > 64:
|
if K > 64:
|
||||||
o_k2 = 64 + o_k1
|
o_k2 = 64 + o_k1
|
||||||
b_gk_last2 = tl.load(
|
b_gk_last2 = tl.load(
|
||||||
@@ -215,7 +215,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
|||||||
mask=(o_k2 < K),
|
mask=(o_k2 < K),
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
b_h2 *= exp(b_gk_last2)[:, None]
|
b_h2 *= exp(b_gk_last2)[None, :]
|
||||||
if K > 128:
|
if K > 128:
|
||||||
o_k3 = 128 + o_k1
|
o_k3 = 128 + o_k1
|
||||||
b_gk_last3 = tl.load(
|
b_gk_last3 = tl.load(
|
||||||
@@ -223,7 +223,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
|||||||
mask=(o_k3 < K),
|
mask=(o_k3 < K),
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
b_h3 *= exp(b_gk_last3)[:, None]
|
b_h3 *= exp(b_gk_last3)[None, :]
|
||||||
if K > 192:
|
if K > 192:
|
||||||
o_k4 = 192 + o_k1
|
o_k4 = 192 + o_k1
|
||||||
b_gk_last4 = tl.load(
|
b_gk_last4 = tl.load(
|
||||||
@@ -231,49 +231,49 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
|||||||
mask=(o_k4 < K),
|
mask=(o_k4 < K),
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
b_h4 *= exp(b_gk_last4)[:, None]
|
b_h4 *= exp(b_gk_last4)[None, :]
|
||||||
b_v = b_v.to(k.dtype.element_ty)
|
b_v = b_v.to(k.dtype.element_ty)
|
||||||
|
|
||||||
p_k = tl.make_block_ptr(
|
p_k = tl.make_block_ptr(
|
||||||
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
|
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
|
||||||
)
|
)
|
||||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
b_h1 += tl.dot(b_k, b_v)
|
b_h1 += tl.trans(tl.dot(b_k, b_v))
|
||||||
if K > 64:
|
if K > 64:
|
||||||
p_k = tl.make_block_ptr(
|
p_k = tl.make_block_ptr(
|
||||||
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
|
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
|
||||||
)
|
)
|
||||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
b_h2 += tl.dot(b_k, b_v)
|
b_h2 += tl.trans(tl.dot(b_k, b_v))
|
||||||
if K > 128:
|
if K > 128:
|
||||||
p_k = tl.make_block_ptr(
|
p_k = tl.make_block_ptr(
|
||||||
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
|
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
|
||||||
)
|
)
|
||||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
b_h3 += tl.dot(b_k, b_v)
|
b_h3 += tl.trans(tl.dot(b_k, b_v))
|
||||||
if K > 192:
|
if K > 192:
|
||||||
p_k = tl.make_block_ptr(
|
p_k = tl.make_block_ptr(
|
||||||
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
|
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
|
||||||
)
|
)
|
||||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
b_h4 += tl.dot(b_k, b_v)
|
b_h4 += tl.trans(tl.dot(b_k, b_v))
|
||||||
# epilogue
|
# epilogue
|
||||||
if STORE_FINAL_STATE:
|
if STORE_FINAL_STATE:
|
||||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
|
||||||
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||||
if K > 64:
|
if K > 64:
|
||||||
p_ht = tl.make_block_ptr(
|
p_ht = tl.make_block_ptr(
|
||||||
ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
|
ht, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
|
||||||
)
|
)
|
||||||
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||||
if K > 128:
|
if K > 128:
|
||||||
p_ht = tl.make_block_ptr(
|
p_ht = tl.make_block_ptr(
|
||||||
ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
|
ht, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
|
||||||
)
|
)
|
||||||
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||||
if K > 192:
|
if K > 192:
|
||||||
p_ht = tl.make_block_ptr(
|
p_ht = tl.make_block_ptr(
|
||||||
ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
|
ht, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
|
||||||
)
|
)
|
||||||
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||||
|
|
||||||
@@ -312,9 +312,9 @@ def chunk_gated_delta_rule_fwd_h(
|
|||||||
)
|
)
|
||||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||||
|
|
||||||
h = k.new_empty(B, NT, H, K, V)
|
h = k.new_empty(B, NT, H, V, K)
|
||||||
final_state = (
|
final_state = (
|
||||||
k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
|
k.new_empty(N, H, V, K, dtype=torch.float32) if output_final_state else None
|
||||||
)
|
)
|
||||||
|
|
||||||
v_new = torch.empty_like(u) if save_new_value else None
|
v_new = torch.empty_like(u) if save_new_value else None
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ def chunk_fwd_kernel_o(
|
|||||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
k += (bos * Hg + i_h // (H // Hg)) * K
|
||||||
v += (bos * H + i_h) * V
|
v += (bos * H + i_h) * V
|
||||||
o += (bos * H + i_h) * V
|
o += (bos * H + i_h) * V
|
||||||
h += (i_tg * H + i_h).to(tl.int64) * K * V
|
h += (i_tg * H + i_h).to(tl.int64) * V * K
|
||||||
|
|
||||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||||
@@ -98,17 +98,17 @@ def chunk_fwd_kernel_o(
|
|||||||
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
|
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
|
||||||
)
|
)
|
||||||
p_h = tl.make_block_ptr(
|
p_h = tl.make_block_ptr(
|
||||||
h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
|
h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)
|
||||||
)
|
)
|
||||||
# [BT, BK]
|
# [BT, BK]
|
||||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||||
# [BK, BT]
|
# [BK, BT]
|
||||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||||
# [BK, BV]
|
# [BV, BK]
|
||||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||||
|
|
||||||
# [BT, BK] @ [BK, BV] -> [BT, BV]
|
# [BT, BK] @ [BK, BV] -> [BT, BV]
|
||||||
b_o += tl.dot(b_q, b_h)
|
b_o += tl.dot(b_q, tl.trans(b_h))
|
||||||
# [BT, BK] @ [BK, BT] -> [BT, BT]
|
# [BT, BK] @ [BK, BT] -> [BT, BT]
|
||||||
b_A += tl.dot(b_q, b_k)
|
b_A += tl.dot(b_q, b_k)
|
||||||
|
|
||||||
|
|||||||
@@ -97,9 +97,9 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|||||||
|
|
||||||
mask_k = o_k < K
|
mask_k = o_k < K
|
||||||
mask_v = o_v < V
|
mask_v = o_v < V
|
||||||
mask_h = mask_k[:, None] & mask_v[None, :]
|
mask_h = mask_v[:, None] & mask_k[None, :]
|
||||||
|
|
||||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||||
if USE_INITIAL_STATE:
|
if USE_INITIAL_STATE:
|
||||||
if IS_CONTINUOUS_BATCHING:
|
if IS_CONTINUOUS_BATCHING:
|
||||||
if IS_SPEC_DECODING:
|
if IS_SPEC_DECODING:
|
||||||
@@ -115,8 +115,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|||||||
return
|
return
|
||||||
p_h0 = h0 + state_idx * stride_init_state_token
|
p_h0 = h0 + state_idx * stride_init_state_token
|
||||||
else:
|
else:
|
||||||
p_h0 = h0 + bos * HV * K * V
|
p_h0 = h0 + bos * HV * V * K
|
||||||
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||||
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||||
|
|
||||||
for i_t in range(0, T):
|
for i_t in range(0, T):
|
||||||
@@ -128,24 +128,24 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|||||||
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
||||||
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
||||||
b_q = b_q * scale
|
b_q = b_q * scale
|
||||||
# [BK, BV]
|
# [BV, BK]
|
||||||
if not IS_KDA:
|
if not IS_KDA:
|
||||||
b_g = tl.load(p_g).to(tl.float32)
|
b_g = tl.load(p_g).to(tl.float32)
|
||||||
b_h *= exp(b_g)
|
b_h *= exp(b_g)
|
||||||
else:
|
else:
|
||||||
b_gk = tl.load(p_gk).to(tl.float32)
|
b_gk = tl.load(p_gk).to(tl.float32)
|
||||||
b_h *= exp(b_gk[:, None])
|
b_h *= exp(b_gk[None, :])
|
||||||
# [BV]
|
# [BV]
|
||||||
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
b_v -= tl.sum(b_h * b_k[None, :], 1)
|
||||||
if IS_BETA_HEADWISE:
|
if IS_BETA_HEADWISE:
|
||||||
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
||||||
else:
|
else:
|
||||||
b_beta = tl.load(p_beta).to(tl.float32)
|
b_beta = tl.load(p_beta).to(tl.float32)
|
||||||
b_v *= b_beta
|
b_v *= b_beta
|
||||||
# [BK, BV]
|
# [BV, BK]
|
||||||
b_h += b_k[:, None] * b_v[None, :]
|
b_h += b_v[:, None] * b_k[None, :]
|
||||||
# [BV]
|
# [BV]
|
||||||
b_o = tl.sum(b_h * b_q[:, None], 0)
|
b_o = tl.sum(b_h * b_q[None, :], 1)
|
||||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||||
|
|
||||||
# keep the states for multi-query tokens
|
# keep the states for multi-query tokens
|
||||||
@@ -157,11 +157,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|||||||
# Only store if state index is valid (not PAD_SLOT_ID)
|
# Only store if state index is valid (not PAD_SLOT_ID)
|
||||||
if final_state_idx >= 0:
|
if final_state_idx >= 0:
|
||||||
p_ht = ht + final_state_idx * stride_final_state_token
|
p_ht = ht + final_state_idx * stride_final_state_token
|
||||||
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||||
else:
|
else:
|
||||||
p_ht = ht + (bos + i_t) * stride_final_state_token
|
p_ht = ht + (bos + i_t) * stride_final_state_token
|
||||||
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||||
|
|
||||||
p_q += H * K
|
p_q += H * K
|
||||||
@@ -202,7 +202,7 @@ def fused_recurrent_gated_delta_rule_fwd(
|
|||||||
if inplace_final_state:
|
if inplace_final_state:
|
||||||
final_state = initial_state
|
final_state = initial_state
|
||||||
else:
|
else:
|
||||||
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
|
final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
|
||||||
|
|
||||||
stride_init_state_token = initial_state.stride(0)
|
stride_init_state_token = initial_state.stride(0)
|
||||||
stride_final_state_token = final_state.stride(0)
|
stride_final_state_token = final_state.stride(0)
|
||||||
@@ -318,7 +318,7 @@ def fused_recurrent_gated_delta_rule(
|
|||||||
Scale factor for the RetNet attention scores.
|
Scale factor for the RetNet attention scores.
|
||||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||||
initial_state (Optional[torch.Tensor]):
|
initial_state (Optional[torch.Tensor]):
|
||||||
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
|
Initial state of shape `[N, HV, V, K]` for `N` input sequences.
|
||||||
For equal-length input sequences, `N` equals the batch size `B`.
|
For equal-length input sequences, `N` equals the batch size `B`.
|
||||||
Default: `None`.
|
Default: `None`.
|
||||||
inplace_final_state: bool:
|
inplace_final_state: bool:
|
||||||
@@ -336,7 +336,7 @@ def fused_recurrent_gated_delta_rule(
|
|||||||
o (torch.Tensor):
|
o (torch.Tensor):
|
||||||
Outputs of shape `[B, T, HV, V]`.
|
Outputs of shape `[B, T, HV, V]`.
|
||||||
final_state (torch.Tensor):
|
final_state (torch.Tensor):
|
||||||
Final state of shape `[N, HV, K, V]`.
|
Final state of shape `[N, HV, V, K]`.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> import torch
|
>>> import torch
|
||||||
@@ -350,7 +350,7 @@ def fused_recurrent_gated_delta_rule(
|
|||||||
>>> v = torch.randn(B, T, HV, V, device='cuda')
|
>>> v = torch.randn(B, T, HV, V, device='cuda')
|
||||||
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
|
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
|
||||||
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
|
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
|
||||||
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
|
>>> h0 = torch.randn(B, HV, V, K, device='cuda')
|
||||||
>>> o, ht = fused_gated_recurrent_delta_rule(
|
>>> o, ht = fused_gated_recurrent_delta_rule(
|
||||||
q, k, v, g, beta,
|
q, k, v, g, beta,
|
||||||
initial_state=h0,
|
initial_state=h0,
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ def fused_recurrent_kda_fwd(
|
|||||||
if inplace_final_state:
|
if inplace_final_state:
|
||||||
final_state = initial_state
|
final_state = initial_state
|
||||||
else:
|
else:
|
||||||
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
|
final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
|
||||||
|
|
||||||
stride_init_state_token = initial_state.stride(0)
|
stride_init_state_token = initial_state.stride(0)
|
||||||
stride_final_state_token = final_state.stride(0)
|
stride_final_state_token = final_state.stride(0)
|
||||||
|
|||||||
@@ -191,8 +191,8 @@ class MambaStateShapeCalculator:
|
|||||||
|
|
||||||
temporal_state_shape = (
|
temporal_state_shape = (
|
||||||
divide(num_v_heads, tp_world_size),
|
divide(num_v_heads, tp_world_size),
|
||||||
head_k_dim,
|
|
||||||
head_v_dim,
|
head_v_dim,
|
||||||
|
head_k_dim,
|
||||||
)
|
)
|
||||||
return conv_state_shape, temporal_state_shape
|
return conv_state_shape, temporal_state_shape
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user