diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 13bb3cbd0..e11088cde 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -686,7 +686,7 @@ def sparse_attn_indexer( fp8_mqa_logits_func = rocm_fp8_mqa_logits logits = fp8_mqa_logits_func( q_fp8[chunk.token_start : chunk.token_end], - (k_fp8, k_scale.view(torch.float32)), + (k_fp8, k_scale.view(torch.float32).flatten()), weights[chunk.token_start : chunk.token_end], chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 84e0fbb44..e1eeaa131 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -249,8 +249,8 @@ def fp8_mqa_logits( q: Query tensor of shape [M, H, D]. Casted to `torch.float8_e4m3fn` by caller. kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with - dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or - [N, 1]) with dtype `torch.float32`. + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N]) + with dtype `torch.float32`. weights: weights of shape [M, H], dtype `torch.float32`. cu_seqlen_ks: Start indices (inclusive) for valid K per query position, shape [M], dtype int32.