[Bugfix] Fix the fp8_mqa_logits dim mismatch (#32652)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -686,7 +686,7 @@ def sparse_attn_indexer(
|
||||
fp8_mqa_logits_func = rocm_fp8_mqa_logits
|
||||
logits = fp8_mqa_logits_func(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32)),
|
||||
(k_fp8, k_scale.view(torch.float32).flatten()),
|
||||
weights[chunk.token_start : chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
|
||||
@@ -249,8 +249,8 @@ def fp8_mqa_logits(
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N])
|
||||
with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Reference in New Issue
Block a user