diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 826caa5d3..f4ce6fca8 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -9,8 +9,13 @@ from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform -from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits -from vllm.utils.import_utils import has_deep_gemm +from vllm.utils.deep_gemm import ( + fp8_mqa_logits, + fp8_mqa_logits_torch, + fp8_paged_mqa_logits, + fp8_paged_mqa_logits_torch, + is_deep_gemm_supported, +) from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerMetadata, @@ -102,15 +107,23 @@ def sparse_attn_indexer( chunk.block_table, chunk.cu_seq_lens, ) - - logits = fp8_mqa_logits( - q_fp8[chunk.token_start : chunk.token_end], - (k_fp8, k_scale.view(torch.float32).flatten()), - weights[chunk.token_start : chunk.token_end], - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - clean_logits=False, - ) + if is_deep_gemm_supported(): + logits = fp8_mqa_logits( + q_fp8[chunk.token_start : chunk.token_end], + (k_fp8, k_scale.view(torch.float32).flatten()), + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + clean_logits=False, + ) + else: + logits = fp8_mqa_logits_torch( + q_fp8[chunk.token_start : chunk.token_end], + (k_fp8, k_scale.view(torch.float32).flatten()), + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) num_rows = logits.shape[0] topk_indices = topk_indices_buffer[ @@ -159,18 +172,26 @@ def sparse_attn_indexer( next_n = padded_q_fp8_decode_tokens.shape[1] assert batch_size == decode_metadata.seq_lens.shape[0] num_padded_tokens = batch_size * next_n - - logits = fp8_paged_mqa_logits( - padded_q_fp8_decode_tokens, - kv_cache, - weights[:num_padded_tokens], - decode_metadata.seq_lens, - decode_metadata.block_table, - decode_metadata.schedule_metadata, - max_model_len=max_model_len, - clean_logits=False, - ) - + if is_deep_gemm_supported(): + logits = fp8_paged_mqa_logits( + padded_q_fp8_decode_tokens, + kv_cache, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=max_model_len, + clean_logits=False, + ) + else: + logits = fp8_paged_mqa_logits_torch( + padded_q_fp8_decode_tokens, + kv_cache, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + max_model_len=max_model_len, + ) num_rows = logits.shape[0] topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] @@ -278,9 +299,12 @@ class SparseAttnIndexer(CustomOp): self.max_model_len = max_model_len self.max_total_seq_len = max_total_seq_len self.topk_indices_buffer = topk_indices_buffer - if current_platform.is_cuda() and not has_deep_gemm(): - raise RuntimeError( - "Sparse Attention Indexer CUDA op requires DeepGEMM to be installed." + if current_platform.is_cuda() and not is_deep_gemm_supported(): + logger.warning_once( + "DeepGEMM is not supported or available. SparseAttnIndexer will use a " + "less efficient PyTorch implementation. " + "Please make sure you have the required hardware and software setup " + "for DeepGEMM to achieve optimal performance." ) def forward_native( diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 8f664cc7d..ee104a6cc 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -418,6 +418,125 @@ def should_use_deepgemm_for_fp8_linear( ) +def fp8_mqa_logits_torch( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging (CUDA fallback). + + This is a pure PyTorch fallback for CUDA when DeepGEMM is not available. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + kv_fp8, scale = kv + seq_len_kv = kv_fp8.shape[0] + k = kv_fp8.to(torch.bfloat16) + q = q.to(torch.bfloat16) + + mask_lo = ( + torch.arange(0, seq_len_kv, device=q.device)[None, :] >= cu_seqlen_ks[:, None] + ) + mask_hi = ( + torch.arange(0, seq_len_kv, device=q.device)[None, :] < cu_seqlen_ke[:, None] + ) + mask = mask_lo & mask_hi + + score = torch.einsum("mhd,nd->hmn", q, k).float() * scale + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + return logits + + +def fp8_paged_mqa_logits_torch( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + """Compute FP8 MQA logits using paged KV-cache (CUDA fallback). + + This is a pure PyTorch fallback for CUDA when DeepGEMM is not available. + Handles head_dim = 132 (128 + 4 for RoPE). + + Args: + q: Query tensor of shape [B, next_n, H, D]. + kv_cache: Paged KV-cache in packed FP8+scale layout with shape + [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last + 4 bytes per (block,pos) store the `float` dequant scale. + weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. + context_lens: Tensor of shape [B], dtype int32; effective context length + for each batch element. + block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical + block indices to physical blocks in the paged cache. + max_model_len: Maximum sequence length used to size the logits output. + + Returns: + Logits tensor of shape [B * next_n, max_model_len], dtype + `torch.float32`. + """ + fp8_dtype = current_platform.fp8_dtype() + batch_size, next_n, heads, dim = q.size() + kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:] + scale = scale.contiguous().view(torch.float) + q = q.float() + kv_cache = kv_cache.view(fp8_dtype).float() * scale + num_blocks, block_size, _, dim = kv_cache.size() + logits = torch.full( + [batch_size * next_n, max_model_len], + float("-inf"), + device=q.device, + dtype=torch.float32, + ) + for i in range(batch_size): + context_len = context_lens[i].item() + q_offsets = torch.arange(context_len - next_n, context_len, device=q.device) + weight_slice = ( + weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() + ) + for block_idx in range(cdiv(context_len, block_size)): + block_id = block_tables[i][block_idx] + qx, kx = q[i], kv_cache[block_id] + k_offsets = torch.arange( + block_idx * block_size, (block_idx + 1) * block_size, device=q.device + ) + mask = (k_offsets[None, :] < context_len) & ( + k_offsets[None, :] <= q_offsets[:, None] + ) + s = torch.where( + mask[None, :, :], + (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( + logits.dtype + ), + float("-inf"), + ) + s = torch.relu(s) * weight_slice[..., None] + s = s.sum(dim=0) + logits[ + i * next_n : (i + 1) * next_n, + block_idx * block_size : (block_idx + 1) * block_size, + ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")) + return logits + + __all__ = [ "calc_diff", "DeepGemmQuantScaleFMT", @@ -425,7 +544,9 @@ __all__ = [ "m_grouped_fp8_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", "fp8_mqa_logits", + "fp8_mqa_logits_torch", "fp8_paged_mqa_logits", + "fp8_paged_mqa_logits_torch", "get_paged_mqa_logits_metadata", "per_block_cast_to_fp8", "is_deep_gemm_e8m0_used", diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 3c56f9fd0..7c81a4359 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -8,7 +8,10 @@ import torch from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm +from vllm.utils.deep_gemm import ( + get_paged_mqa_logits_metadata, + is_deep_gemm_supported, +) from vllm.utils.platform_utils import num_compute_units from vllm.v1.attention.backend import ( AttentionBackend, @@ -344,7 +347,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): seq_lens = common_attn_metadata.seq_lens[:num_decodes] # DeepGEMM is required for the paged MQA logits on CUDA devices - if current_platform.is_cuda() and has_deep_gemm(): + if current_platform.is_cuda() and is_deep_gemm_supported(): self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( seq_lens, self.kv_cache_spec.block_size, self.num_sms )