[Feat] Add CUDA torch fallbacks for fp8_mqa_logits/fp8_paged_mqa_logits_torch function (#35271)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -9,8 +9,13 @@ from vllm.forward_context import get_forward_context
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
|
from vllm.utils.deep_gemm import (
|
||||||
from vllm.utils.import_utils import has_deep_gemm
|
fp8_mqa_logits,
|
||||||
|
fp8_mqa_logits_torch,
|
||||||
|
fp8_paged_mqa_logits,
|
||||||
|
fp8_paged_mqa_logits_torch,
|
||||||
|
is_deep_gemm_supported,
|
||||||
|
)
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
from vllm.v1.attention.backends.mla.indexer import (
|
from vllm.v1.attention.backends.mla.indexer import (
|
||||||
DeepseekV32IndexerMetadata,
|
DeepseekV32IndexerMetadata,
|
||||||
@@ -102,15 +107,23 @@ def sparse_attn_indexer(
|
|||||||
chunk.block_table,
|
chunk.block_table,
|
||||||
chunk.cu_seq_lens,
|
chunk.cu_seq_lens,
|
||||||
)
|
)
|
||||||
|
if is_deep_gemm_supported():
|
||||||
logits = fp8_mqa_logits(
|
logits = fp8_mqa_logits(
|
||||||
q_fp8[chunk.token_start : chunk.token_end],
|
q_fp8[chunk.token_start : chunk.token_end],
|
||||||
(k_fp8, k_scale.view(torch.float32).flatten()),
|
(k_fp8, k_scale.view(torch.float32).flatten()),
|
||||||
weights[chunk.token_start : chunk.token_end],
|
weights[chunk.token_start : chunk.token_end],
|
||||||
chunk.cu_seqlen_ks,
|
chunk.cu_seqlen_ks,
|
||||||
chunk.cu_seqlen_ke,
|
chunk.cu_seqlen_ke,
|
||||||
clean_logits=False,
|
clean_logits=False,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logits = fp8_mqa_logits_torch(
|
||||||
|
q_fp8[chunk.token_start : chunk.token_end],
|
||||||
|
(k_fp8, k_scale.view(torch.float32).flatten()),
|
||||||
|
weights[chunk.token_start : chunk.token_end],
|
||||||
|
chunk.cu_seqlen_ks,
|
||||||
|
chunk.cu_seqlen_ke,
|
||||||
|
)
|
||||||
num_rows = logits.shape[0]
|
num_rows = logits.shape[0]
|
||||||
|
|
||||||
topk_indices = topk_indices_buffer[
|
topk_indices = topk_indices_buffer[
|
||||||
@@ -159,18 +172,26 @@ def sparse_attn_indexer(
|
|||||||
next_n = padded_q_fp8_decode_tokens.shape[1]
|
next_n = padded_q_fp8_decode_tokens.shape[1]
|
||||||
assert batch_size == decode_metadata.seq_lens.shape[0]
|
assert batch_size == decode_metadata.seq_lens.shape[0]
|
||||||
num_padded_tokens = batch_size * next_n
|
num_padded_tokens = batch_size * next_n
|
||||||
|
if is_deep_gemm_supported():
|
||||||
logits = fp8_paged_mqa_logits(
|
logits = fp8_paged_mqa_logits(
|
||||||
padded_q_fp8_decode_tokens,
|
padded_q_fp8_decode_tokens,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
weights[:num_padded_tokens],
|
weights[:num_padded_tokens],
|
||||||
decode_metadata.seq_lens,
|
decode_metadata.seq_lens,
|
||||||
decode_metadata.block_table,
|
decode_metadata.block_table,
|
||||||
decode_metadata.schedule_metadata,
|
decode_metadata.schedule_metadata,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
clean_logits=False,
|
clean_logits=False,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logits = fp8_paged_mqa_logits_torch(
|
||||||
|
padded_q_fp8_decode_tokens,
|
||||||
|
kv_cache,
|
||||||
|
weights[:num_padded_tokens],
|
||||||
|
decode_metadata.seq_lens,
|
||||||
|
decode_metadata.block_table,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
)
|
||||||
num_rows = logits.shape[0]
|
num_rows = logits.shape[0]
|
||||||
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
|
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
|
||||||
|
|
||||||
@@ -278,9 +299,12 @@ class SparseAttnIndexer(CustomOp):
|
|||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
self.max_total_seq_len = max_total_seq_len
|
self.max_total_seq_len = max_total_seq_len
|
||||||
self.topk_indices_buffer = topk_indices_buffer
|
self.topk_indices_buffer = topk_indices_buffer
|
||||||
if current_platform.is_cuda() and not has_deep_gemm():
|
if current_platform.is_cuda() and not is_deep_gemm_supported():
|
||||||
raise RuntimeError(
|
logger.warning_once(
|
||||||
"Sparse Attention Indexer CUDA op requires DeepGEMM to be installed."
|
"DeepGEMM is not supported or available. SparseAttnIndexer will use a "
|
||||||
|
"less efficient PyTorch implementation. "
|
||||||
|
"Please make sure you have the required hardware and software setup "
|
||||||
|
"for DeepGEMM to achieve optimal performance."
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
|
|||||||
@@ -418,6 +418,125 @@ def should_use_deepgemm_for_fp8_linear(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_mqa_logits_torch(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv: tuple[torch.Tensor, torch.Tensor],
|
||||||
|
weights: torch.Tensor,
|
||||||
|
cu_seqlen_ks: torch.Tensor,
|
||||||
|
cu_seqlen_ke: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute FP8 MQA logits for a single sequence without KV paging (CUDA fallback).
|
||||||
|
|
||||||
|
This is a pure PyTorch fallback for CUDA when DeepGEMM is not available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor of shape [M, H, D]. Casted to
|
||||||
|
`torch.float8_e4m3fn` by caller.
|
||||||
|
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||||
|
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||||
|
[N, 1]) with dtype `torch.float32`.
|
||||||
|
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||||
|
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||||
|
shape [M], dtype int32.
|
||||||
|
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||||
|
shape [M], dtype int32.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||||
|
"""
|
||||||
|
kv_fp8, scale = kv
|
||||||
|
seq_len_kv = kv_fp8.shape[0]
|
||||||
|
k = kv_fp8.to(torch.bfloat16)
|
||||||
|
q = q.to(torch.bfloat16)
|
||||||
|
|
||||||
|
mask_lo = (
|
||||||
|
torch.arange(0, seq_len_kv, device=q.device)[None, :] >= cu_seqlen_ks[:, None]
|
||||||
|
)
|
||||||
|
mask_hi = (
|
||||||
|
torch.arange(0, seq_len_kv, device=q.device)[None, :] < cu_seqlen_ke[:, None]
|
||||||
|
)
|
||||||
|
mask = mask_lo & mask_hi
|
||||||
|
|
||||||
|
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
|
||||||
|
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||||
|
logits = logits.masked_fill(~mask, float("-inf"))
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_paged_mqa_logits_torch(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
weights: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
max_model_len: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute FP8 MQA logits using paged KV-cache (CUDA fallback).
|
||||||
|
|
||||||
|
This is a pure PyTorch fallback for CUDA when DeepGEMM is not available.
|
||||||
|
Handles head_dim = 132 (128 + 4 for RoPE).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor of shape [B, next_n, H, D].
|
||||||
|
kv_cache: Paged KV-cache in packed FP8+scale layout with shape
|
||||||
|
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
||||||
|
4 bytes per (block,pos) store the `float` dequant scale.
|
||||||
|
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
||||||
|
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||||
|
for each batch element.
|
||||||
|
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
||||||
|
block indices to physical blocks in the paged cache.
|
||||||
|
max_model_len: Maximum sequence length used to size the logits output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logits tensor of shape [B * next_n, max_model_len], dtype
|
||||||
|
`torch.float32`.
|
||||||
|
"""
|
||||||
|
fp8_dtype = current_platform.fp8_dtype()
|
||||||
|
batch_size, next_n, heads, dim = q.size()
|
||||||
|
kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
|
||||||
|
scale = scale.contiguous().view(torch.float)
|
||||||
|
q = q.float()
|
||||||
|
kv_cache = kv_cache.view(fp8_dtype).float() * scale
|
||||||
|
num_blocks, block_size, _, dim = kv_cache.size()
|
||||||
|
logits = torch.full(
|
||||||
|
[batch_size * next_n, max_model_len],
|
||||||
|
float("-inf"),
|
||||||
|
device=q.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
for i in range(batch_size):
|
||||||
|
context_len = context_lens[i].item()
|
||||||
|
q_offsets = torch.arange(context_len - next_n, context_len, device=q.device)
|
||||||
|
weight_slice = (
|
||||||
|
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
|
||||||
|
)
|
||||||
|
for block_idx in range(cdiv(context_len, block_size)):
|
||||||
|
block_id = block_tables[i][block_idx]
|
||||||
|
qx, kx = q[i], kv_cache[block_id]
|
||||||
|
k_offsets = torch.arange(
|
||||||
|
block_idx * block_size, (block_idx + 1) * block_size, device=q.device
|
||||||
|
)
|
||||||
|
mask = (k_offsets[None, :] < context_len) & (
|
||||||
|
k_offsets[None, :] <= q_offsets[:, None]
|
||||||
|
)
|
||||||
|
s = torch.where(
|
||||||
|
mask[None, :, :],
|
||||||
|
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
||||||
|
logits.dtype
|
||||||
|
),
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
s = torch.relu(s) * weight_slice[..., None]
|
||||||
|
s = s.sum(dim=0)
|
||||||
|
logits[
|
||||||
|
i * next_n : (i + 1) * next_n,
|
||||||
|
block_idx * block_size : (block_idx + 1) * block_size,
|
||||||
|
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"calc_diff",
|
"calc_diff",
|
||||||
"DeepGemmQuantScaleFMT",
|
"DeepGemmQuantScaleFMT",
|
||||||
@@ -425,7 +544,9 @@ __all__ = [
|
|||||||
"m_grouped_fp8_gemm_nt_contiguous",
|
"m_grouped_fp8_gemm_nt_contiguous",
|
||||||
"fp8_m_grouped_gemm_nt_masked",
|
"fp8_m_grouped_gemm_nt_masked",
|
||||||
"fp8_mqa_logits",
|
"fp8_mqa_logits",
|
||||||
|
"fp8_mqa_logits_torch",
|
||||||
"fp8_paged_mqa_logits",
|
"fp8_paged_mqa_logits",
|
||||||
|
"fp8_paged_mqa_logits_torch",
|
||||||
"get_paged_mqa_logits_metadata",
|
"get_paged_mqa_logits_metadata",
|
||||||
"per_block_cast_to_fp8",
|
"per_block_cast_to_fp8",
|
||||||
"is_deep_gemm_e8m0_used",
|
"is_deep_gemm_e8m0_used",
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ import torch
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, has_deep_gemm
|
from vllm.utils.deep_gemm import (
|
||||||
|
get_paged_mqa_logits_metadata,
|
||||||
|
is_deep_gemm_supported,
|
||||||
|
)
|
||||||
from vllm.utils.platform_utils import num_compute_units
|
from vllm.utils.platform_utils import num_compute_units
|
||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
@@ -344,7 +347,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
|||||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||||
|
|
||||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||||
if current_platform.is_cuda() and has_deep_gemm():
|
if current_platform.is_cuda() and is_deep_gemm_supported():
|
||||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||||
seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user