[PERF] Change GDN Attention State Layout from [N, HV, K, V] to [N, HV, V, K] (#33291)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson
2026-02-04 15:20:52 +04:00
committed by GitHub
parent 8e32690869
commit 824058076c
6 changed files with 61 additions and 61 deletions

View File

@@ -138,11 +138,11 @@ def chunk_gated_delta_rule(
Scale factor for the RetNet attention scores. Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`. If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]): initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences. Initial state of shape `[N, H, V, K]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`. For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`. Default: `None`.
output_final_state (Optional[bool]): output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
cu_seqlens (torch.LongTensor): cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training, Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API. consistent with the FlashAttention API.
@@ -154,7 +154,7 @@ def chunk_gated_delta_rule(
o (torch.Tensor): o (torch.Tensor):
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
final_state (torch.Tensor): final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. Final state of shape `[N, H, V, K]` if `output_final_state=True` else `None`.
Examples:: Examples::
>>> import torch >>> import torch
@@ -168,7 +168,7 @@ def chunk_gated_delta_rule(
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') >>> h0 = torch.randn(B, H, V, K, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule( >>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta, q, k, v, g, beta,
initial_state=h0, initial_state=h0,

View File

@@ -81,70 +81,70 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
NT = tl.cdiv(T, BT) NT = tl.cdiv(T, BT)
boh = i_n * NT boh = i_n * NT
# [BK, BV] # [BV, BK]
b_h1 = tl.zeros([64, BV], dtype=tl.float32) b_h1 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 64: if K > 64:
b_h2 = tl.zeros([64, BV], dtype=tl.float32) b_h2 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 128: if K > 128:
b_h3 = tl.zeros([64, BV], dtype=tl.float32) b_h3 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 192: if K > 192:
b_h4 = tl.zeros([64, BV], dtype=tl.float32) b_h4 = tl.zeros([BV, 64], dtype=tl.float32)
# calculate offset # calculate offset
h += ((boh * H + i_h) * K * V).to(tl.int64) h += ((boh * H + i_h) * V * K).to(tl.int64)
v += ((bos * H + i_h) * V).to(tl.int64) v += ((bos * H + i_h) * V).to(tl.int64)
k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64) k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
w += ((bos * H + i_h) * K).to(tl.int64) w += ((bos * H + i_h) * K).to(tl.int64)
if SAVE_NEW_VALUE: if SAVE_NEW_VALUE:
v_new += ((bos * H + i_h) * V).to(tl.int64) v_new += ((bos * H + i_h) * V).to(tl.int64)
stride_v = H * V stride_v = H * V
stride_h = H * K * V stride_h = H * V * K
stride_k = Hg * K stride_k = Hg * K
stride_w = H * K stride_w = H * K
if USE_INITIAL_STATE: if USE_INITIAL_STATE:
h0 = h0 + i_nh * K * V h0 = h0 + i_nh * V * K
if STORE_FINAL_STATE: if STORE_FINAL_STATE:
ht = ht + i_nh * K * V ht = ht + i_nh * V * K
# load initial state # load initial state
if USE_INITIAL_STATE: if USE_INITIAL_STATE:
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) p_h0_1 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
if K > 64: if K > 64:
p_h0_2 = tl.make_block_ptr( p_h0_2 = tl.make_block_ptr(
h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) h0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
) )
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
if K > 128: if K > 128:
p_h0_3 = tl.make_block_ptr( p_h0_3 = tl.make_block_ptr(
h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) h0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
) )
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
if K > 192: if K > 192:
p_h0_4 = tl.make_block_ptr( p_h0_4 = tl.make_block_ptr(
h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) h0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
) )
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
# main recurrence # main recurrence
for i_t in range(NT): for i_t in range(NT):
p_h1 = tl.make_block_ptr( p_h1 = tl.make_block_ptr(
h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0) h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)
) )
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
if K > 64: if K > 64:
p_h2 = tl.make_block_ptr( p_h2 = tl.make_block_ptr(
h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
) )
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
if K > 128: if K > 128:
p_h3 = tl.make_block_ptr( p_h3 = tl.make_block_ptr(
h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
) )
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
if K > 192: if K > 192:
p_h4 = tl.make_block_ptr( p_h4 = tl.make_block_ptr(
h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
) )
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
@@ -152,25 +152,25 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0) w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
) )
b_w = tl.load(p_w, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1))
b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) b_v = tl.dot(b_w, tl.trans(b_h1).to(b_w.dtype))
if K > 64: if K > 64:
p_w = tl.make_block_ptr( p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0) w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
) )
b_w = tl.load(p_w, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) b_v += tl.dot(b_w, tl.trans(b_h2).to(b_w.dtype))
if K > 128: if K > 128:
p_w = tl.make_block_ptr( p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0) w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
) )
b_w = tl.load(p_w, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) b_v += tl.dot(b_w, tl.trans(b_h3).to(b_w.dtype))
if K > 192: if K > 192:
p_w = tl.make_block_ptr( p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0) w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
) )
b_w = tl.load(p_w, boundary_check=(0, 1)) b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) b_v += tl.dot(b_w, tl.trans(b_h4).to(b_w.dtype))
p_v = tl.make_block_ptr( p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
) )
@@ -207,7 +207,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
mask=(o_k1 < K), mask=(o_k1 < K),
other=0.0, other=0.0,
) )
b_h1 *= exp(b_gk_last1)[:, None] b_h1 *= exp(b_gk_last1)[None, :]
if K > 64: if K > 64:
o_k2 = 64 + o_k1 o_k2 = 64 + o_k1
b_gk_last2 = tl.load( b_gk_last2 = tl.load(
@@ -215,7 +215,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
mask=(o_k2 < K), mask=(o_k2 < K),
other=0.0, other=0.0,
) )
b_h2 *= exp(b_gk_last2)[:, None] b_h2 *= exp(b_gk_last2)[None, :]
if K > 128: if K > 128:
o_k3 = 128 + o_k1 o_k3 = 128 + o_k1
b_gk_last3 = tl.load( b_gk_last3 = tl.load(
@@ -223,7 +223,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
mask=(o_k3 < K), mask=(o_k3 < K),
other=0.0, other=0.0,
) )
b_h3 *= exp(b_gk_last3)[:, None] b_h3 *= exp(b_gk_last3)[None, :]
if K > 192: if K > 192:
o_k4 = 192 + o_k1 o_k4 = 192 + o_k1
b_gk_last4 = tl.load( b_gk_last4 = tl.load(
@@ -231,49 +231,49 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
mask=(o_k4 < K), mask=(o_k4 < K),
other=0.0, other=0.0,
) )
b_h4 *= exp(b_gk_last4)[:, None] b_h4 *= exp(b_gk_last4)[None, :]
b_v = b_v.to(k.dtype.element_ty) b_v = b_v.to(k.dtype.element_ty)
p_k = tl.make_block_ptr( p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1) k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
) )
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1))
b_h1 += tl.dot(b_k, b_v) b_h1 += tl.trans(tl.dot(b_k, b_v))
if K > 64: if K > 64:
p_k = tl.make_block_ptr( p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1) k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
) )
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1))
b_h2 += tl.dot(b_k, b_v) b_h2 += tl.trans(tl.dot(b_k, b_v))
if K > 128: if K > 128:
p_k = tl.make_block_ptr( p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1) k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
) )
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1))
b_h3 += tl.dot(b_k, b_v) b_h3 += tl.trans(tl.dot(b_k, b_v))
if K > 192: if K > 192:
p_k = tl.make_block_ptr( p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1) k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
) )
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1))
b_h4 += tl.dot(b_k, b_v) b_h4 += tl.trans(tl.dot(b_k, b_v))
# epilogue # epilogue
if STORE_FINAL_STATE: if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 64: if K > 64:
p_ht = tl.make_block_ptr( p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) ht, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
) )
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 128: if K > 128:
p_ht = tl.make_block_ptr( p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) ht, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
) )
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 192: if K > 192:
p_ht = tl.make_block_ptr( p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) ht, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
) )
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
@@ -312,9 +312,9 @@ def chunk_gated_delta_rule_fwd_h(
) )
assert K <= 256, "current kernel does not support head dimension larger than 256." assert K <= 256, "current kernel does not support head dimension larger than 256."
h = k.new_empty(B, NT, H, K, V) h = k.new_empty(B, NT, H, V, K)
final_state = ( final_state = (
k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None k.new_empty(N, H, V, K, dtype=torch.float32) if output_final_state else None
) )
v_new = torch.empty_like(u) if save_new_value else None v_new = torch.empty_like(u) if save_new_value else None

View File

@@ -85,7 +85,7 @@ def chunk_fwd_kernel_o(
k += (bos * Hg + i_h // (H // Hg)) * K k += (bos * Hg + i_h // (H // Hg)) * K
v += (bos * H + i_h) * V v += (bos * H + i_h) * V
o += (bos * H + i_h) * V o += (bos * H + i_h) * V
h += (i_tg * H + i_h).to(tl.int64) * K * V h += (i_tg * H + i_h).to(tl.int64) * V * K
b_o = tl.zeros([BT, BV], dtype=tl.float32) b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_A = tl.zeros([BT, BT], dtype=tl.float32) b_A = tl.zeros([BT, BT], dtype=tl.float32)
@@ -98,17 +98,17 @@ def chunk_fwd_kernel_o(
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1) k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
) )
p_h = tl.make_block_ptr( p_h = tl.make_block_ptr(
h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0) h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)
) )
# [BT, BK] # [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT] # [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1))
# [BK, BV] # [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1)) b_h = tl.load(p_h, boundary_check=(0, 1))
# [BT, BK] @ [BK, BV] -> [BT, BV] # [BT, BK] @ [BK, BV] -> [BT, BV]
b_o += tl.dot(b_q, b_h) b_o += tl.dot(b_q, tl.trans(b_h))
# [BT, BK] @ [BK, BT] -> [BT, BT] # [BT, BK] @ [BK, BT] -> [BT, BT]
b_A += tl.dot(b_q, b_k) b_A += tl.dot(b_q, b_k)

View File

@@ -97,9 +97,9 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
mask_k = o_k < K mask_k = o_k < K
mask_v = o_v < V mask_v = o_v < V
mask_h = mask_k[:, None] & mask_v[None, :] mask_h = mask_v[:, None] & mask_k[None, :]
b_h = tl.zeros([BK, BV], dtype=tl.float32) b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE: if USE_INITIAL_STATE:
if IS_CONTINUOUS_BATCHING: if IS_CONTINUOUS_BATCHING:
if IS_SPEC_DECODING: if IS_SPEC_DECODING:
@@ -115,8 +115,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
return return
p_h0 = h0 + state_idx * stride_init_state_token p_h0 = h0 + state_idx * stride_init_state_token
else: else:
p_h0 = h0 + bos * HV * K * V p_h0 = h0 + bos * HV * V * K
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
for i_t in range(0, T): for i_t in range(0, T):
@@ -128,24 +128,24 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale b_q = b_q * scale
# [BK, BV] # [BV, BK]
if not IS_KDA: if not IS_KDA:
b_g = tl.load(p_g).to(tl.float32) b_g = tl.load(p_g).to(tl.float32)
b_h *= exp(b_g) b_h *= exp(b_g)
else: else:
b_gk = tl.load(p_gk).to(tl.float32) b_gk = tl.load(p_gk).to(tl.float32)
b_h *= exp(b_gk[:, None]) b_h *= exp(b_gk[None, :])
# [BV] # [BV]
b_v -= tl.sum(b_h * b_k[:, None], 0) b_v -= tl.sum(b_h * b_k[None, :], 1)
if IS_BETA_HEADWISE: if IS_BETA_HEADWISE:
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
else: else:
b_beta = tl.load(p_beta).to(tl.float32) b_beta = tl.load(p_beta).to(tl.float32)
b_v *= b_beta b_v *= b_beta
# [BK, BV] # [BV, BK]
b_h += b_k[:, None] * b_v[None, :] b_h += b_v[:, None] * b_k[None, :]
# [BV] # [BV]
b_o = tl.sum(b_h * b_q[:, None], 0) b_o = tl.sum(b_h * b_q[None, :], 1)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
# keep the states for multi-query tokens # keep the states for multi-query tokens
@@ -157,11 +157,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
# Only store if state index is valid (not PAD_SLOT_ID) # Only store if state index is valid (not PAD_SLOT_ID)
if final_state_idx >= 0: if final_state_idx >= 0:
p_ht = ht + final_state_idx * stride_final_state_token p_ht = ht + final_state_idx * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
else: else:
p_ht = ht + (bos + i_t) * stride_final_state_token p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
p_q += H * K p_q += H * K
@@ -202,7 +202,7 @@ def fused_recurrent_gated_delta_rule_fwd(
if inplace_final_state: if inplace_final_state:
final_state = initial_state final_state = initial_state
else: else:
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
stride_init_state_token = initial_state.stride(0) stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0) stride_final_state_token = final_state.stride(0)
@@ -318,7 +318,7 @@ def fused_recurrent_gated_delta_rule(
Scale factor for the RetNet attention scores. Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`. If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]): initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, HV, K, V]` for `N` input sequences. Initial state of shape `[N, HV, V, K]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`. For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`. Default: `None`.
inplace_final_state: bool: inplace_final_state: bool:
@@ -336,7 +336,7 @@ def fused_recurrent_gated_delta_rule(
o (torch.Tensor): o (torch.Tensor):
Outputs of shape `[B, T, HV, V]`. Outputs of shape `[B, T, HV, V]`.
final_state (torch.Tensor): final_state (torch.Tensor):
Final state of shape `[N, HV, K, V]`. Final state of shape `[N, HV, V, K]`.
Examples:: Examples::
>>> import torch >>> import torch
@@ -350,7 +350,7 @@ def fused_recurrent_gated_delta_rule(
>>> v = torch.randn(B, T, HV, V, device='cuda') >>> v = torch.randn(B, T, HV, V, device='cuda')
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
>>> h0 = torch.randn(B, HV, K, V, device='cuda') >>> h0 = torch.randn(B, HV, V, K, device='cuda')
>>> o, ht = fused_gated_recurrent_delta_rule( >>> o, ht = fused_gated_recurrent_delta_rule(
q, k, v, g, beta, q, k, v, g, beta,
initial_state=h0, initial_state=h0,

View File

@@ -55,7 +55,7 @@ def fused_recurrent_kda_fwd(
if inplace_final_state: if inplace_final_state:
final_state = initial_state final_state = initial_state
else: else:
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
stride_init_state_token = initial_state.stride(0) stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0) stride_final_state_token = final_state.stride(0)

View File

@@ -191,8 +191,8 @@ class MambaStateShapeCalculator:
temporal_state_shape = ( temporal_state_shape = (
divide(num_v_heads, tp_world_size), divide(num_v_heads, tp_world_size),
head_k_dim,
head_v_dim, head_v_dim,
head_k_dim,
) )
return conv_state_shape, temporal_state_shape return conv_state_shape, temporal_state_shape