[PERF] Change GDN Attention State Layout from [N, HV, K, V] to [N, HV, V, K] (#33291)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -138,11 +138,11 @@ def chunk_gated_delta_rule(
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
||||
Initial state of shape `[N, H, V, K]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
||||
Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
@@ -154,7 +154,7 @@ def chunk_gated_delta_rule(
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
||||
Final state of shape `[N, H, V, K]` if `output_final_state=True` else `None`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
@@ -168,7 +168,7 @@ def chunk_gated_delta_rule(
|
||||
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
|
||||
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> h0 = torch.randn(B, H, V, K, dtype=torch.bfloat16, device='cuda')
|
||||
>>> o, ht = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
|
||||
@@ -81,70 +81,70 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = i_n * NT
|
||||
|
||||
# [BK, BV]
|
||||
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
# [BV, BK]
|
||||
b_h1 = tl.zeros([BV, 64], dtype=tl.float32)
|
||||
if K > 64:
|
||||
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
b_h2 = tl.zeros([BV, 64], dtype=tl.float32)
|
||||
if K > 128:
|
||||
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
b_h3 = tl.zeros([BV, 64], dtype=tl.float32)
|
||||
if K > 192:
|
||||
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
b_h4 = tl.zeros([BV, 64], dtype=tl.float32)
|
||||
|
||||
# calculate offset
|
||||
h += ((boh * H + i_h) * K * V).to(tl.int64)
|
||||
h += ((boh * H + i_h) * V * K).to(tl.int64)
|
||||
v += ((bos * H + i_h) * V).to(tl.int64)
|
||||
k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
|
||||
w += ((bos * H + i_h) * K).to(tl.int64)
|
||||
if SAVE_NEW_VALUE:
|
||||
v_new += ((bos * H + i_h) * V).to(tl.int64)
|
||||
stride_v = H * V
|
||||
stride_h = H * K * V
|
||||
stride_h = H * V * K
|
||||
stride_k = Hg * K
|
||||
stride_w = H * K
|
||||
if USE_INITIAL_STATE:
|
||||
h0 = h0 + i_nh * K * V
|
||||
h0 = h0 + i_nh * V * K
|
||||
if STORE_FINAL_STATE:
|
||||
ht = ht + i_nh * K * V
|
||||
ht = ht + i_nh * V * K
|
||||
|
||||
# load initial state
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
p_h0_1 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
|
||||
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 64:
|
||||
p_h0_2 = tl.make_block_ptr(
|
||||
h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
|
||||
h0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
|
||||
)
|
||||
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 128:
|
||||
p_h0_3 = tl.make_block_ptr(
|
||||
h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
|
||||
h0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
|
||||
)
|
||||
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 192:
|
||||
p_h0_4 = tl.make_block_ptr(
|
||||
h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
|
||||
h0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
|
||||
)
|
||||
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
# main recurrence
|
||||
for i_t in range(NT):
|
||||
p_h1 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)
|
||||
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)
|
||||
)
|
||||
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_h2 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
|
||||
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
|
||||
)
|
||||
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_h3 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
|
||||
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
|
||||
)
|
||||
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_h4 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
|
||||
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
|
||||
)
|
||||
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
@@ -152,25 +152,25 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
|
||||
)
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
|
||||
b_v = tl.dot(b_w, tl.trans(b_h1).to(b_w.dtype))
|
||||
if K > 64:
|
||||
p_w = tl.make_block_ptr(
|
||||
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
|
||||
)
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
|
||||
b_v += tl.dot(b_w, tl.trans(b_h2).to(b_w.dtype))
|
||||
if K > 128:
|
||||
p_w = tl.make_block_ptr(
|
||||
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
|
||||
)
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
|
||||
b_v += tl.dot(b_w, tl.trans(b_h3).to(b_w.dtype))
|
||||
if K > 192:
|
||||
p_w = tl.make_block_ptr(
|
||||
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
|
||||
)
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
|
||||
b_v += tl.dot(b_w, tl.trans(b_h4).to(b_w.dtype))
|
||||
p_v = tl.make_block_ptr(
|
||||
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
@@ -207,7 +207,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
mask=(o_k1 < K),
|
||||
other=0.0,
|
||||
)
|
||||
b_h1 *= exp(b_gk_last1)[:, None]
|
||||
b_h1 *= exp(b_gk_last1)[None, :]
|
||||
if K > 64:
|
||||
o_k2 = 64 + o_k1
|
||||
b_gk_last2 = tl.load(
|
||||
@@ -215,7 +215,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
mask=(o_k2 < K),
|
||||
other=0.0,
|
||||
)
|
||||
b_h2 *= exp(b_gk_last2)[:, None]
|
||||
b_h2 *= exp(b_gk_last2)[None, :]
|
||||
if K > 128:
|
||||
o_k3 = 128 + o_k1
|
||||
b_gk_last3 = tl.load(
|
||||
@@ -223,7 +223,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
mask=(o_k3 < K),
|
||||
other=0.0,
|
||||
)
|
||||
b_h3 *= exp(b_gk_last3)[:, None]
|
||||
b_h3 *= exp(b_gk_last3)[None, :]
|
||||
if K > 192:
|
||||
o_k4 = 192 + o_k1
|
||||
b_gk_last4 = tl.load(
|
||||
@@ -231,49 +231,49 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
mask=(o_k4 < K),
|
||||
other=0.0,
|
||||
)
|
||||
b_h4 *= exp(b_gk_last4)[:, None]
|
||||
b_h4 *= exp(b_gk_last4)[None, :]
|
||||
b_v = b_v.to(k.dtype.element_ty)
|
||||
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h1 += tl.dot(b_k, b_v)
|
||||
b_h1 += tl.trans(tl.dot(b_k, b_v))
|
||||
if K > 64:
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h2 += tl.dot(b_k, b_v)
|
||||
b_h2 += tl.trans(tl.dot(b_k, b_v))
|
||||
if K > 128:
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h3 += tl.dot(b_k, b_v)
|
||||
b_h3 += tl.trans(tl.dot(b_k, b_v))
|
||||
if K > 192:
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h4 += tl.dot(b_k, b_v)
|
||||
b_h4 += tl.trans(tl.dot(b_k, b_v))
|
||||
# epilogue
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
|
||||
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_ht = tl.make_block_ptr(
|
||||
ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
|
||||
ht, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
|
||||
)
|
||||
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_ht = tl.make_block_ptr(
|
||||
ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
|
||||
ht, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
|
||||
)
|
||||
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_ht = tl.make_block_ptr(
|
||||
ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
|
||||
ht, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
|
||||
)
|
||||
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
@@ -312,9 +312,9 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
)
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h = k.new_empty(B, NT, H, K, V)
|
||||
h = k.new_empty(B, NT, H, V, K)
|
||||
final_state = (
|
||||
k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
|
||||
k.new_empty(N, H, V, K, dtype=torch.float32) if output_final_state else None
|
||||
)
|
||||
|
||||
v_new = torch.empty_like(u) if save_new_value else None
|
||||
|
||||
@@ -85,7 +85,7 @@ def chunk_fwd_kernel_o(
|
||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
||||
v += (bos * H + i_h) * V
|
||||
o += (bos * H + i_h) * V
|
||||
h += (i_tg * H + i_h).to(tl.int64) * K * V
|
||||
h += (i_tg * H + i_h).to(tl.int64) * V * K
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
@@ -98,17 +98,17 @@ def chunk_fwd_kernel_o(
|
||||
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
|
||||
)
|
||||
p_h = tl.make_block_ptr(
|
||||
h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
|
||||
h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)
|
||||
)
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
# [BV, BK]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BK] @ [BK, BV] -> [BT, BV]
|
||||
b_o += tl.dot(b_q, b_h)
|
||||
b_o += tl.dot(b_q, tl.trans(b_h))
|
||||
# [BT, BK] @ [BK, BT] -> [BT, BT]
|
||||
b_A += tl.dot(b_q, b_k)
|
||||
|
||||
|
||||
@@ -97,9 +97,9 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_k[:, None] & mask_v[None, :]
|
||||
mask_h = mask_v[:, None] & mask_k[None, :]
|
||||
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
if IS_SPEC_DECODING:
|
||||
@@ -115,8 +115,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
return
|
||||
p_h0 = h0 + state_idx * stride_init_state_token
|
||||
else:
|
||||
p_h0 = h0 + bos * HV * K * V
|
||||
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
p_h0 = h0 + bos * HV * V * K
|
||||
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
for i_t in range(0, T):
|
||||
@@ -128,24 +128,24 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
||||
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
||||
b_q = b_q * scale
|
||||
# [BK, BV]
|
||||
# [BV, BK]
|
||||
if not IS_KDA:
|
||||
b_g = tl.load(p_g).to(tl.float32)
|
||||
b_h *= exp(b_g)
|
||||
else:
|
||||
b_gk = tl.load(p_gk).to(tl.float32)
|
||||
b_h *= exp(b_gk[:, None])
|
||||
b_h *= exp(b_gk[None, :])
|
||||
# [BV]
|
||||
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
||||
b_v -= tl.sum(b_h * b_k[None, :], 1)
|
||||
if IS_BETA_HEADWISE:
|
||||
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
||||
else:
|
||||
b_beta = tl.load(p_beta).to(tl.float32)
|
||||
b_v *= b_beta
|
||||
# [BK, BV]
|
||||
b_h += b_k[:, None] * b_v[None, :]
|
||||
# [BV, BK]
|
||||
b_h += b_v[:, None] * b_k[None, :]
|
||||
# [BV]
|
||||
b_o = tl.sum(b_h * b_q[:, None], 0)
|
||||
b_o = tl.sum(b_h * b_q[None, :], 1)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
# keep the states for multi-query tokens
|
||||
@@ -157,11 +157,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
# Only store if state index is valid (not PAD_SLOT_ID)
|
||||
if final_state_idx >= 0:
|
||||
p_ht = ht + final_state_idx * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
else:
|
||||
p_ht = ht + (bos + i_t) * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
|
||||
p_q += H * K
|
||||
@@ -202,7 +202,7 @@ def fused_recurrent_gated_delta_rule_fwd(
|
||||
if inplace_final_state:
|
||||
final_state = initial_state
|
||||
else:
|
||||
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
|
||||
final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
|
||||
|
||||
stride_init_state_token = initial_state.stride(0)
|
||||
stride_final_state_token = final_state.stride(0)
|
||||
@@ -318,7 +318,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
|
||||
Initial state of shape `[N, HV, V, K]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
inplace_final_state: bool:
|
||||
@@ -336,7 +336,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, HV, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, HV, K, V]`.
|
||||
Final state of shape `[N, HV, V, K]`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
@@ -350,7 +350,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
>>> v = torch.randn(B, T, HV, V, device='cuda')
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
|
||||
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
|
||||
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
|
||||
>>> h0 = torch.randn(B, HV, V, K, device='cuda')
|
||||
>>> o, ht = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
|
||||
@@ -55,7 +55,7 @@ def fused_recurrent_kda_fwd(
|
||||
if inplace_final_state:
|
||||
final_state = initial_state
|
||||
else:
|
||||
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
|
||||
final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
|
||||
|
||||
stride_init_state_token = initial_state.stride(0)
|
||||
stride_final_state_token = final_state.stride(0)
|
||||
|
||||
@@ -191,8 +191,8 @@ class MambaStateShapeCalculator:
|
||||
|
||||
temporal_state_shape = (
|
||||
divide(num_v_heads, tp_world_size),
|
||||
head_k_dim,
|
||||
head_v_dim,
|
||||
head_k_dim,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
|
||||
Reference in New Issue
Block a user