[ROCm][Deepseekv3.2] Refactor Sparse Indexer as CustomOp (#29287)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
@@ -9,6 +9,10 @@ from torch._ops import OpOverload
|
||||
import vllm.envs as envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
|
||||
rocm_aiter_sparse_attn_indexer,
|
||||
rocm_aiter_sparse_attn_indexer_fake,
|
||||
)
|
||||
|
||||
_FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
@@ -1091,6 +1095,14 @@ class rocm_aiter_ops:
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_sparse_attn_indexer",
|
||||
op_func=rocm_aiter_sparse_attn_indexer,
|
||||
mutates_args=["topk_indices_buffer"],
|
||||
fake_impl=rocm_aiter_sparse_attn_indexer_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
_OPS_REGISTERED = True
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -611,6 +611,7 @@ class CompilationConfig:
|
||||
"vllm::gdn_attention_core",
|
||||
"vllm::kda_attention",
|
||||
"vllm::sparse_attn_indexer",
|
||||
"vllm::rocm_aiter_sparse_attn_indexer",
|
||||
]
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
|
||||
318
vllm/model_executor/layers/sparse_attn_indexer.py
Normal file
318
vllm/model_executor/layers/sparse_attn_indexer.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Custom Sparse Attention Indexer layers."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerMetadata,
|
||||
)
|
||||
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def sparse_attn_indexer(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: str | None,
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# careful! this will be None in dummy run
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
# assert isinstance(attn_metadata, dict)
|
||||
if not isinstance(attn_metadata, dict):
|
||||
# Reserve workspace for indexer during profiling run
|
||||
current_workspace_manager().get_simultaneous(
|
||||
((total_seq_lens, head_dim), torch.float8_e4m3fn),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
return sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
kv_cache,
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
topk_tokens,
|
||||
head_dim,
|
||||
max_model_len,
|
||||
total_seq_lens,
|
||||
topk_indices_buffer,
|
||||
)
|
||||
attn_metadata = attn_metadata[k_cache_prefix]
|
||||
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
ops.indexer_k_quant_and_cache(
|
||||
k,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
)
|
||||
|
||||
topk_indices_buffer[: hidden_states.shape[0]] = -1
|
||||
if has_prefill:
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
|
||||
# Get the full shared workspace buffers once (will allocate on first use)
|
||||
workspace_manager = current_workspace_manager()
|
||||
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
|
||||
((total_seq_lens, head_dim), fp8_dtype),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
|
||||
k_scale = k_scale_full[: chunk.total_seq_lens]
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
)
|
||||
|
||||
logits = fp8_mqa_logits(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32).flatten()),
|
||||
weights[chunk.token_start : chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
)
|
||||
num_rows = logits.shape[0]
|
||||
|
||||
topk_indices = topk_indices_buffer[
|
||||
chunk.token_start : chunk.token_end, :topk_tokens
|
||||
]
|
||||
torch.ops._C.top_k_per_row_prefill(
|
||||
logits,
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
if has_decode:
|
||||
decode_metadata = attn_metadata.decode
|
||||
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
||||
# we only have [num_block, block_size, head_dim],
|
||||
kv_cache = kv_cache.unsqueeze(-2)
|
||||
decode_lens = decode_metadata.decode_lens
|
||||
if decode_metadata.requires_padding:
|
||||
# pad in edge case where we have short chunked prefill length <
|
||||
# decode_threshold since we unstrictly split
|
||||
# prefill and decode by decode_threshold
|
||||
# (currently set to 1 + speculative tokens)
|
||||
padded_q_fp8_decode_tokens = pack_seq_triton(
|
||||
q_fp8[:num_decode_tokens], decode_lens
|
||||
)
|
||||
else:
|
||||
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
|
||||
decode_lens.shape[0], -1, *q_fp8.shape[1:]
|
||||
)
|
||||
# TODO: move and optimize below logic with triton kernels
|
||||
batch_size = padded_q_fp8_decode_tokens.shape[0]
|
||||
next_n = padded_q_fp8_decode_tokens.shape[1]
|
||||
assert batch_size == decode_metadata.seq_lens.shape[0]
|
||||
num_padded_tokens = batch_size * next_n
|
||||
|
||||
logits = fp8_paged_mqa_logits(
|
||||
padded_q_fp8_decode_tokens,
|
||||
kv_cache,
|
||||
weights[:num_padded_tokens],
|
||||
decode_metadata.seq_lens,
|
||||
decode_metadata.block_table,
|
||||
decode_metadata.schedule_metadata,
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
|
||||
num_rows = logits.shape[0]
|
||||
|
||||
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits,
|
||||
next_n,
|
||||
decode_metadata.seq_lens,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
if decode_metadata.requires_padding:
|
||||
# if padded, we need to unpack
|
||||
# the topk indices removing padded tokens
|
||||
topk_indices = unpack_seq_triton(
|
||||
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
||||
decode_lens,
|
||||
)
|
||||
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
|
||||
topk_indices
|
||||
)
|
||||
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
def sparse_attn_indexer_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: str | None,
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sparse_attn_indexer",
|
||||
op_func=sparse_attn_indexer,
|
||||
mutates_args=["topk_indices_buffer"],
|
||||
fake_impl=sparse_attn_indexer_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
@CustomOp.register("sparse_attn_indexer")
|
||||
class SparseAttnIndexer(CustomOp):
|
||||
"""Sparse Attention Indexer Custom Op Layer. This layer is extracted as a
|
||||
separate custom op since it involves heavy custom kernels like `mqa_logits`,
|
||||
`paged_mqa_logits` and `top_k_per_row`, etc. Those kernels maybe requires
|
||||
specific memory layout or implementation for different hardware backends to
|
||||
achieve optimal performance.
|
||||
|
||||
For now, the default native path will use CUDA backend path. Other platform
|
||||
may requires add the corresponding Custom Op name `sparse_attn_indexer` to
|
||||
`custom_ops` in `CompilationConfig` to enable the platform specific path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
k_cache,
|
||||
quant_block_size: int,
|
||||
scale_fmt: str,
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
max_total_seq_len: int,
|
||||
topk_indices_buffer: torch.Tensor,
|
||||
):
|
||||
super().__init__()
|
||||
self.k_cache = k_cache
|
||||
self.quant_block_size = quant_block_size
|
||||
self.scale_fmt = scale_fmt
|
||||
self.topk_tokens = topk_tokens
|
||||
self.head_dim = head_dim
|
||||
self.max_model_len = max_model_len
|
||||
self.max_total_seq_len = max_total_seq_len
|
||||
self.topk_indices_buffer = topk_indices_buffer
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
):
|
||||
if current_platform.is_cuda():
|
||||
return self.forward_cuda(hidden_states, q_fp8, k, weights)
|
||||
elif current_platform.is_rocm():
|
||||
return self.forward_hip(hidden_states, q_fp8, k, weights)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"SparseAttnIndexer native forward is only implemented for "
|
||||
"CUDA and ROCm platform."
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
):
|
||||
return torch.ops.vllm.sparse_attn_indexer(
|
||||
hidden_states,
|
||||
self.k_cache.prefix,
|
||||
self.k_cache.kv_cache[0],
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
self.quant_block_size,
|
||||
self.scale_fmt,
|
||||
self.topk_tokens,
|
||||
self.head_dim,
|
||||
self.max_model_len,
|
||||
self.max_total_seq_len,
|
||||
self.topk_indices_buffer,
|
||||
)
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
):
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(
|
||||
hidden_states,
|
||||
self.k_cache.prefix,
|
||||
self.k_cache.kv_cache[0],
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
self.quant_block_size,
|
||||
self.scale_fmt,
|
||||
self.topk_tokens,
|
||||
self.head_dim,
|
||||
self.max_model_len,
|
||||
self.max_total_seq_len,
|
||||
self.topk_indices_buffer,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Sparse attention indexer ROCm custom op requires ROCm "
|
||||
"Aiter ops to be enabled."
|
||||
)
|
||||
@@ -43,7 +43,6 @@ from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
@@ -63,6 +62,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
@@ -74,16 +74,11 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerBackend,
|
||||
DeepseekV32IndexerMetadata,
|
||||
)
|
||||
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
@@ -94,11 +89,6 @@ from .utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -599,213 +589,6 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
|
||||
return DeepseekV32IndexerBackend
|
||||
|
||||
|
||||
def sparse_attn_indexer(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: str | None,
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
# careful! this will be None in dummy run
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
# assert isinstance(attn_metadata, dict)
|
||||
if not isinstance(attn_metadata, dict):
|
||||
# Reserve workspace for indexer during profiling run
|
||||
current_workspace_manager().get_simultaneous(
|
||||
((total_seq_lens, head_dim), torch.float8_e4m3fn),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
|
||||
return sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
kv_cache,
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
topk_tokens,
|
||||
head_dim,
|
||||
max_model_len,
|
||||
total_seq_lens,
|
||||
topk_indices_buffer,
|
||||
)
|
||||
attn_metadata = attn_metadata[k_cache_prefix]
|
||||
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
ops.indexer_k_quant_and_cache(
|
||||
k,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
)
|
||||
|
||||
topk_indices_buffer[: hidden_states.shape[0]] = -1
|
||||
if has_prefill:
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
|
||||
# Get the full shared workspace buffers once (will allocate on first use)
|
||||
workspace_manager = current_workspace_manager()
|
||||
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
|
||||
((total_seq_lens, head_dim), fp8_dtype),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
|
||||
k_scale = k_scale_full[: chunk.total_seq_lens]
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
)
|
||||
fp8_mqa_logits_func = fp8_mqa_logits
|
||||
if current_platform.is_rocm():
|
||||
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
|
||||
rocm_fp8_mqa_logits,
|
||||
)
|
||||
|
||||
fp8_mqa_logits_func = rocm_fp8_mqa_logits
|
||||
logits = fp8_mqa_logits_func(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32).flatten()),
|
||||
weights[chunk.token_start : chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
)
|
||||
num_rows = logits.shape[0]
|
||||
topk_indices = topk_indices_buffer[
|
||||
chunk.token_start : chunk.token_end, :topk_tokens
|
||||
]
|
||||
torch.ops._C.top_k_per_row_prefill(
|
||||
logits,
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
if has_decode:
|
||||
decode_metadata = attn_metadata.decode
|
||||
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
||||
# we only have [num_block, block_size, head_dim],
|
||||
kv_cache = kv_cache.unsqueeze(-2)
|
||||
decode_lens = decode_metadata.decode_lens
|
||||
if decode_metadata.requires_padding:
|
||||
# pad in edge case where we have short chunked prefill length <
|
||||
# decode_threshold since we unstrictly split
|
||||
# prefill and decode by decode_threshold
|
||||
# (currently set to 1 + speculative tokens)
|
||||
|
||||
# [num_decode_tokens, n_head, head_dim] -> [bs, 1+next_n, n_head, head_dim]
|
||||
padded_q_fp8_decode_tokens = pack_seq_triton(
|
||||
q_fp8[:num_decode_tokens], decode_lens
|
||||
)
|
||||
# [num_decode_tokens, n_head] -> [bs, 1+next_n, n_head]
|
||||
padded_weights = pack_seq_triton(weights[:num_decode_tokens], decode_lens)
|
||||
# [bs, 1+next_n, n_head] -> [bs * next_n, n_head]
|
||||
padded_weights = padded_weights.flatten(0, 1)
|
||||
else:
|
||||
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
|
||||
decode_lens.shape[0], -1, *q_fp8.shape[1:]
|
||||
)
|
||||
padded_weights = weights
|
||||
# TODO: move and optimize below logic with triton kernels
|
||||
batch_size = padded_q_fp8_decode_tokens.shape[0]
|
||||
next_n = padded_q_fp8_decode_tokens.shape[1]
|
||||
assert batch_size == decode_metadata.seq_lens.shape[0]
|
||||
num_padded_tokens = batch_size * next_n
|
||||
fp8_paged_mqa_logits_func = fp8_paged_mqa_logits
|
||||
if current_platform.is_rocm():
|
||||
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
|
||||
rocm_fp8_paged_mqa_logits,
|
||||
)
|
||||
|
||||
fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits
|
||||
logits = fp8_paged_mqa_logits_func(
|
||||
padded_q_fp8_decode_tokens,
|
||||
kv_cache,
|
||||
padded_weights[:num_padded_tokens],
|
||||
decode_metadata.seq_lens,
|
||||
decode_metadata.block_table,
|
||||
decode_metadata.schedule_metadata,
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
num_rows = logits.shape[0]
|
||||
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
|
||||
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits,
|
||||
next_n,
|
||||
decode_metadata.seq_lens,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
if decode_metadata.requires_padding:
|
||||
# if padded, we need to unpack
|
||||
# the topk indices removing padded tokens
|
||||
topk_indices = unpack_seq_triton(
|
||||
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
||||
decode_lens,
|
||||
)
|
||||
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
|
||||
topk_indices
|
||||
)
|
||||
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
def sparse_attn_indexer_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: str | None,
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sparse_attn_indexer",
|
||||
op_func=sparse_attn_indexer,
|
||||
mutates_args=["topk_indices_buffer"],
|
||||
fake_impl=sparse_attn_indexer_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
class Indexer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -870,6 +653,16 @@ class Indexer(nn.Module):
|
||||
from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size
|
||||
|
||||
self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)
|
||||
self.indexer_op = SparseAttnIndexer(
|
||||
self.k_cache,
|
||||
self.quant_block_size,
|
||||
self.scale_fmt,
|
||||
self.topk_tokens,
|
||||
self.head_dim,
|
||||
self.max_model_len,
|
||||
self.max_total_seq_len,
|
||||
self.topk_indices_buffer,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb
|
||||
@@ -892,6 +685,8 @@ class Indexer(nn.Module):
|
||||
q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim)
|
||||
k_pe = k_pe.reshape(-1, 1, self.rope_dim)
|
||||
|
||||
# `rotary_emb` is shape-preserving; `q_pe` is already
|
||||
# [num_tokens, n_head, rope_dim].
|
||||
q = torch.cat([q_pe, q_nope], dim=-1)
|
||||
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
|
||||
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
|
||||
@@ -913,21 +708,7 @@ class Indexer(nn.Module):
|
||||
)
|
||||
weights = weights.squeeze(-1)
|
||||
|
||||
return torch.ops.vllm.sparse_attn_indexer(
|
||||
hidden_states,
|
||||
self.k_cache.prefix,
|
||||
self.k_cache.kv_cache[0],
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
self.quant_block_size,
|
||||
self.scale_fmt,
|
||||
self.topk_tokens,
|
||||
self.head_dim,
|
||||
self.max_model_len,
|
||||
self.max_total_seq_len,
|
||||
self.topk_indices_buffer,
|
||||
)
|
||||
return self.indexer_op(hidden_states, q_fp8, k, weights)
|
||||
|
||||
|
||||
class DeepseekV2MLAAttention(nn.Module):
|
||||
|
||||
@@ -480,6 +480,9 @@ class RocmPlatform(Platform):
|
||||
):
|
||||
compilation_config.custom_ops.append("+grouped_topk")
|
||||
|
||||
# Default dispatch to rocm's sparse_attn_indexer implementation
|
||||
compilation_config.custom_ops.append("+sparse_attn_indexer")
|
||||
|
||||
@classmethod
|
||||
def verify_model_arch(cls, model_arch: str) -> None:
|
||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
|
||||
@@ -63,6 +63,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
cu_seqlen_ks: torch.Tensor
|
||||
cu_seqlen_ke: torch.Tensor
|
||||
cu_seq_lens: torch.Tensor
|
||||
token_to_seq: torch.Tensor
|
||||
total_seq_lens: int
|
||||
token_start: int
|
||||
token_end: int
|
||||
@@ -234,6 +235,10 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
token_start = query_start_loc_cpu[reqs_start].item()
|
||||
token_end = query_start_loc_cpu[reqs_end].item()
|
||||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||
seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32)
|
||||
token_to_seq = torch.repeat_interleave(
|
||||
seq_idx, seq_lens_cpu[reqs_start:reqs_end]
|
||||
).to(self.device)
|
||||
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||
cu_seq_lens = (
|
||||
torch.cat(
|
||||
@@ -249,6 +254,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
token_to_seq=token_to_seq,
|
||||
total_seq_lens=total_seq_lens,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBaseImpl,
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
@@ -33,6 +34,48 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fetch_id_to_ragged_kernel(
|
||||
in_tensor_ptr, # [num_seq, topk]
|
||||
cumsum_ptr, # [num_seq + 1]
|
||||
out_tensor_ptr, # [max_num_seq * topk]
|
||||
in_tensor_ptr_stride,
|
||||
TOPK: tl.constexpr,
|
||||
TOKEN_NUM: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
seq_id = tl.program_id(0)
|
||||
block_id = tl.program_id(1)
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
token_start = tl.load(cumsum_ptr + seq_id)
|
||||
token_end = tl.load(cumsum_ptr + seq_id + 1)
|
||||
token_num = token_end - token_start
|
||||
row_offset = block_id * BLOCK_SIZE
|
||||
if row_offset >= token_num:
|
||||
return
|
||||
in_tensor_offset = seq_id * in_tensor_ptr_stride + row_offset + offset
|
||||
in_tensor_mask = (row_offset + offset) < TOPK
|
||||
in_tensor_val = tl.load(in_tensor_ptr + in_tensor_offset, mask=in_tensor_mask)
|
||||
out_tensor_offset = token_start + row_offset + offset
|
||||
out_tensor_mask = (out_tensor_offset < token_end) & in_tensor_mask
|
||||
tl.store(out_tensor_ptr + out_tensor_offset, in_tensor_val, mask=out_tensor_mask)
|
||||
|
||||
|
||||
def fetch_id_to_ragged_triton(
|
||||
in_tensor: torch.Tensor, cumsum: torch.Tensor, out_tensor: torch.Tensor, topk
|
||||
):
|
||||
num_tokens = in_tensor.size(0)
|
||||
block_size = 64
|
||||
num_block_per_row = triton.cdiv(topk, block_size)
|
||||
grid = (
|
||||
num_tokens,
|
||||
num_block_per_row,
|
||||
)
|
||||
fetch_id_to_ragged_kernel[grid](
|
||||
in_tensor, cumsum, out_tensor, in_tensor.stride(0), topk, num_tokens, block_size
|
||||
)
|
||||
|
||||
|
||||
class ROCMAiterMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@@ -83,6 +126,13 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
|
||||
|
||||
block_table: torch.Tensor
|
||||
req_id_per_token: torch.Tensor
|
||||
|
||||
qo_indptr: torch.Tensor
|
||||
paged_kv_last_page_len: torch.Tensor
|
||||
paged_kv_indices: torch.Tensor
|
||||
paged_kv_indptr: torch.Tensor
|
||||
paged_kv_indptr_rest: torch.Tensor
|
||||
|
||||
block_size: int = 1
|
||||
topk_tokens: int = 2048
|
||||
|
||||
@@ -91,7 +141,7 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
|
||||
class ROCMAiterMLASparseMetadataBuilder(
|
||||
AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]
|
||||
):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -104,6 +154,7 @@ class ROCMAiterMLASparseMetadataBuilder(
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.device = device
|
||||
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
@@ -124,6 +175,23 @@ class ROCMAiterMLASparseMetadataBuilder(
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.qo_indptr = torch.arange(
|
||||
0, max_num_batched_tokens + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
self.paged_kv_last_page_len = torch.ones(
|
||||
max_num_batched_tokens, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# These two needs to be calculated in runtime,
|
||||
# but we still needs to prepare the buffer
|
||||
self.paged_kv_indices = torch.zeros(
|
||||
[max_num_batched_tokens * self.topk_tokens],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.paged_kv_indptr = torch.zeros(
|
||||
[max_num_batched_tokens + 1], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
@@ -142,7 +210,15 @@ class ROCMAiterMLASparseMetadataBuilder(
|
||||
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
|
||||
torch.from_numpy(req_id_per_token), non_blocking=True
|
||||
)
|
||||
self.paged_kv_indices.fill_(0)
|
||||
self.paged_kv_indptr.fill_(0)
|
||||
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
qo_indptr = self.qo_indptr[: num_tokens + 1]
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens]
|
||||
paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens]
|
||||
paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1]
|
||||
paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :]
|
||||
|
||||
metadata = ROCMAiterMLASparseMetadata(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
@@ -155,6 +231,11 @@ class ROCMAiterMLASparseMetadataBuilder(
|
||||
req_id_per_token=req_id_per_token,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
qo_indptr=qo_indptr,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indptr_rest=paged_kv_indptr_rest,
|
||||
)
|
||||
return metadata
|
||||
|
||||
@@ -226,20 +307,39 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
q: torch.Tensor, # [sq, heads, d_qk]
|
||||
kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk]
|
||||
topk_indices: torch.Tensor, # [sq, topk]
|
||||
attn_metadata: ROCMAiterMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = q.shape[0]
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||
-1, 1, kv_c_and_k_pe_cache.shape[-1]
|
||||
output = torch.empty(
|
||||
[num_tokens, self.num_heads, self.kv_lora_rank],
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
seq_len = (topk_indices != -1).sum(dim=-1)
|
||||
torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:])
|
||||
attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1])
|
||||
fetch_id_to_ragged_triton(
|
||||
topk_indices,
|
||||
attn_metadata.paged_kv_indptr,
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.topk_tokens,
|
||||
)
|
||||
|
||||
rocm_aiter_ops.mla_decode_fwd(
|
||||
q,
|
||||
kv_c_and_k_pe_cache,
|
||||
output,
|
||||
self.scale,
|
||||
attn_metadata.qo_indptr,
|
||||
1,
|
||||
attn_metadata.paged_kv_indptr,
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_len,
|
||||
)
|
||||
|
||||
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
||||
output = reference_mla_sparse_prefill(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale, 512
|
||||
)[0]
|
||||
return output[:, : self.num_heads, :]
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -1,100 +1,220 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import importlib
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
|
||||
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
|
||||
logger = init_logger(__name__)
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
|
||||
def fp8_mqa_logits_torch(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
k_fp8, scale = kv
|
||||
seq_len_kv = k_fp8.shape[0]
|
||||
k = k_fp8.to(torch.bfloat16)
|
||||
q = q.to(torch.bfloat16)
|
||||
|
||||
mask_lo = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
|
||||
)
|
||||
mask_hi = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
||||
)
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def rocm_fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
|
||||
# TODO(ganyi): Temporarily workaround, will remove the module check and reference
|
||||
# path after aiter merge this kernel into main
|
||||
@lru_cache
|
||||
def has_mqa_logits_module():
|
||||
return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None
|
||||
|
||||
if rocm_aiter_ops.is_enabled() and has_mqa_logits_module():
|
||||
from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits
|
||||
|
||||
kv, scale = kv
|
||||
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||
@triton.jit
|
||||
def _indexer_k_quant_and_cache_kernel(
|
||||
k_ptr, # [num_tokens, head_dim]
|
||||
kv_cache_ptr, # [n_blks, blk_size//tile_block, head_dim // 16B, tile_block, 16B]
|
||||
# [n_blocks, blk_size, head_dim]
|
||||
kv_cache_scale_ptr, # [n_blks, blk_size]
|
||||
slot_mapping_ptr, # [num_tokens]
|
||||
kv_cache_scale_stride,
|
||||
kv_cache_value_stride,
|
||||
block_size,
|
||||
num_tokens,
|
||||
head_dim: tl.constexpr,
|
||||
LAYOUT: tl.constexpr,
|
||||
BLOCK_TILE_SIZE: tl.constexpr,
|
||||
HEAD_TILE_SIZE: tl.constexpr,
|
||||
IS_FNUZ: tl.constexpr,
|
||||
USE_UE8M0: tl.constexpr,
|
||||
):
|
||||
tid = tl.program_id(0)
|
||||
offset = tl.arange(0, head_dim)
|
||||
if LAYOUT == "SHUFFLE":
|
||||
tile_offset = (
|
||||
offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE
|
||||
+ offset % HEAD_TILE_SIZE
|
||||
)
|
||||
else:
|
||||
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||
tile_offset = offset
|
||||
tile_store_offset = tile_offset
|
||||
# for idx in tl.range(tid, num_tokens, n_program):
|
||||
src_ptr = k_ptr + tid * head_dim
|
||||
slot_id = tl.load(slot_mapping_ptr + tid)
|
||||
if slot_id < 0:
|
||||
return
|
||||
block_id = slot_id // block_size
|
||||
block_offset = slot_id % block_size
|
||||
tile_block_id = block_offset // BLOCK_TILE_SIZE
|
||||
tile_block_offset = block_offset % BLOCK_TILE_SIZE
|
||||
val = tl.load(src_ptr + offset)
|
||||
amax = tl.max(val.abs(), axis=-1).to(tl.float32)
|
||||
if IS_FNUZ:
|
||||
scale = tl.maximum(1e-4, amax) / 224.0
|
||||
else:
|
||||
scale = tl.maximum(1e-4, amax) / 448.0
|
||||
|
||||
if USE_UE8M0:
|
||||
scale = tl.exp2(tl.ceil(tl.log2(scale)))
|
||||
|
||||
fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty)
|
||||
if LAYOUT == "SHUFFLE":
|
||||
dst_ptr = (
|
||||
kv_cache_ptr
|
||||
+ block_id * kv_cache_value_stride
|
||||
+ tile_block_id * BLOCK_TILE_SIZE * head_dim
|
||||
+ tile_block_offset * HEAD_TILE_SIZE
|
||||
)
|
||||
else:
|
||||
dst_ptr = (
|
||||
kv_cache_ptr + block_id * kv_cache_value_stride + block_offset * head_dim
|
||||
)
|
||||
tl.store(dst_ptr + tile_store_offset, fp8_val)
|
||||
dst_scale_ptr = kv_cache_scale_ptr + block_id * kv_cache_scale_stride + block_offset
|
||||
tl.store(dst_scale_ptr, scale)
|
||||
|
||||
|
||||
def indexer_k_quant_and_cache_triton(
|
||||
k: torch.Tensor,
|
||||
kv_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4]
|
||||
slot_mapping: torch.Tensor,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
block_tile_size=16,
|
||||
head_tile_size=16,
|
||||
):
|
||||
num_blocks = kv_cache.shape[0]
|
||||
head_dim = k.shape[-1]
|
||||
num_tokens = slot_mapping.shape[0]
|
||||
block_size = kv_cache.shape[1]
|
||||
# In real layout, we store the first portion as kv cache value
|
||||
# and second portion as kv cache scale
|
||||
kv_cache = kv_cache.view(num_blocks, -1)
|
||||
kv_cache_value = kv_cache[:, : block_size * head_dim]
|
||||
kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32)
|
||||
head_tile_size = head_tile_size // kv_cache.element_size()
|
||||
grid = (num_tokens,)
|
||||
_indexer_k_quant_and_cache_kernel[grid](
|
||||
k,
|
||||
kv_cache_value,
|
||||
kv_cache_scale,
|
||||
slot_mapping,
|
||||
kv_cache_scale.stride(0),
|
||||
kv_cache_value.stride(0),
|
||||
block_size,
|
||||
num_tokens,
|
||||
head_dim,
|
||||
"NHD",
|
||||
block_tile_size,
|
||||
head_tile_size,
|
||||
IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz,
|
||||
USE_UE8M0=scale_fmt == "ue8m0",
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _cp_gather_indexer_quant_cache_kernel(
|
||||
kv_cache_ptr, # [n_blks,blk_size//tile_blk,head_dim//16B,tile_blk,16B]
|
||||
# [n_blks, blk_size, head_dim]
|
||||
kv_cache_scale_ptr, # [n_blks, blk_size]
|
||||
k_fp8_ptr, # [num_tokens, head_dim]
|
||||
k_scale_ptr, # [num_tokens]
|
||||
block_table_ptr, # [batch_size, block_table_stride]
|
||||
cu_seqlen_ptr, # [batch_size + 1]
|
||||
token_to_seq_ptr, # [num_tokens]
|
||||
block_size,
|
||||
block_table_stride,
|
||||
kv_cache_stride,
|
||||
kv_cache_scale_stride,
|
||||
LAYOUT: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_TILE_SIZE: tl.constexpr,
|
||||
HEAD_TILE_SIZE: tl.constexpr,
|
||||
):
|
||||
tid = tl.program_id(0)
|
||||
offset = tl.arange(0, HEAD_DIM)
|
||||
batch_id = tl.load(token_to_seq_ptr + tid)
|
||||
batch_start = tl.load(cu_seqlen_ptr + batch_id)
|
||||
batch_end = tl.load(cu_seqlen_ptr + batch_id + 1)
|
||||
batch_offset = tid - batch_start
|
||||
if tid >= batch_end:
|
||||
return
|
||||
block_table_id = batch_offset // block_size
|
||||
block_offset = batch_offset % block_size
|
||||
block_table_offset = batch_id * block_table_stride + block_table_id
|
||||
block_id = tl.load(block_table_ptr + block_table_offset)
|
||||
tiled_block_id = block_offset // BLOCK_TILE_SIZE
|
||||
tiled_block_offset = block_offset % BLOCK_TILE_SIZE
|
||||
if LAYOUT == "SHUFFLE":
|
||||
src_cache_offset = (
|
||||
block_id * kv_cache_stride
|
||||
+ tiled_block_id * HEAD_DIM * BLOCK_TILE_SIZE
|
||||
+ tiled_block_offset * HEAD_TILE_SIZE
|
||||
)
|
||||
else:
|
||||
src_cache_offset = block_id * kv_cache_stride + block_offset * HEAD_DIM
|
||||
src_scale_offset = block_id * kv_cache_scale_stride + block_offset
|
||||
dst_offset = tid * HEAD_DIM
|
||||
src_scale_ptr = kv_cache_scale_ptr + src_scale_offset
|
||||
src_cache_ptr = kv_cache_ptr + src_cache_offset
|
||||
dst_k_ptr = k_fp8_ptr + dst_offset
|
||||
scale_val = tl.load(src_scale_ptr)
|
||||
tl.store(k_scale_ptr + tid, scale_val)
|
||||
if LAYOUT == "SHUFFLE":
|
||||
tiled_src_offset = (
|
||||
offset // HEAD_TILE_SIZE * HEAD_TILE_SIZE * BLOCK_TILE_SIZE
|
||||
+ offset % HEAD_TILE_SIZE
|
||||
)
|
||||
else:
|
||||
tiled_src_offset = offset
|
||||
val = tl.load(src_cache_ptr + tiled_src_offset)
|
||||
tl.store(dst_k_ptr + offset, val)
|
||||
|
||||
|
||||
def cp_gather_indexer_k_quant_cache_triton(
|
||||
k_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4]
|
||||
k_fp8: torch.Tensor,
|
||||
k_fp8_scale: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
token_to_seq: torch.Tensor,
|
||||
block_tile_size: int = 16,
|
||||
head_tile_size: int = 16,
|
||||
):
|
||||
num_tokens = k_fp8.size(0)
|
||||
block_size = k_cache.size(1)
|
||||
block_table_stride = block_table.stride(0)
|
||||
head_dim = k_fp8.shape[-1]
|
||||
num_blocks = k_cache.shape[0]
|
||||
# we assume the kv cache already been split to 2 portion
|
||||
k_cache = k_cache.view(num_blocks, -1)
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
k_cache_value = k_cache[:, : block_size * head_dim].view(fp8_dtype)
|
||||
k_cache_scale = k_cache[:, block_size * head_dim :].view(torch.float32)
|
||||
grid = (num_tokens,)
|
||||
k_fp8_scale = k_fp8_scale.view(torch.float32)
|
||||
_cp_gather_indexer_quant_cache_kernel[grid](
|
||||
k_cache_value,
|
||||
k_cache_scale,
|
||||
k_fp8,
|
||||
k_fp8_scale,
|
||||
block_table,
|
||||
cu_seqlen,
|
||||
token_to_seq,
|
||||
block_size,
|
||||
block_table_stride,
|
||||
k_cache_value.stride(0),
|
||||
k_cache_scale.stride(0),
|
||||
"NHD",
|
||||
head_dim,
|
||||
block_tile_size,
|
||||
head_tile_size,
|
||||
)
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
|
||||
@@ -183,10 +303,38 @@ def rocm_fp8_paged_mqa_logits(
|
||||
Logits tensor of shape [B * next_n, max_model_len], dtype
|
||||
`torch.float32`.
|
||||
"""
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
@functools.lru_cache
|
||||
def paged_mqa_logits_module():
|
||||
paged_mqa_logits_module_path = None
|
||||
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
|
||||
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
|
||||
elif (
|
||||
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits")
|
||||
is not None
|
||||
):
|
||||
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
|
||||
|
||||
if paged_mqa_logits_module_path is not None:
|
||||
try:
|
||||
module = importlib.import_module(paged_mqa_logits_module_path)
|
||||
return module
|
||||
except ImportError:
|
||||
return None
|
||||
return None
|
||||
|
||||
aiter_paged_mqa_logits_module = None
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1
|
||||
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
|
||||
# FIXME(ganyi): Temporarily disable the aiter path until nightly docker
|
||||
# update aiter to the fix PR.
|
||||
aiter_paged_mqa_logits_module = None
|
||||
|
||||
if aiter_paged_mqa_logits_module is not None:
|
||||
deepgemm_fp8_paged_mqa_logits_stage1 = (
|
||||
aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1
|
||||
)
|
||||
batch_size, next_n, heads, _ = q_fp8.shape
|
||||
out_qk = torch.full(
|
||||
(heads, batch_size * next_n, max_model_len),
|
||||
@@ -208,3 +356,293 @@ def rocm_fp8_paged_mqa_logits(
|
||||
return fp8_paged_mqa_logits_torch(
|
||||
q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len
|
||||
)
|
||||
|
||||
|
||||
# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
|
||||
def fp8_mqa_logits_torch(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
kv, scale = kv
|
||||
seq_len_kv = kv.shape[0]
|
||||
k = kv.to(torch.bfloat16)
|
||||
q = q.to(torch.bfloat16)
|
||||
|
||||
mask_lo = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
|
||||
)
|
||||
mask_hi = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
||||
)
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def rocm_fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
|
||||
# TODO(ganyi): Temporarily workaround, will remove the module check and reference
|
||||
# path after aiter merge this kernel into main
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
@functools.lru_cache
|
||||
def mqa_logits_module():
|
||||
mqa_logits_module_path = None
|
||||
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
|
||||
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
|
||||
elif (
|
||||
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
|
||||
is not None
|
||||
):
|
||||
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
|
||||
|
||||
if mqa_logits_module_path is not None:
|
||||
try:
|
||||
module = importlib.import_module(mqa_logits_module_path)
|
||||
return module
|
||||
except ImportError:
|
||||
return None
|
||||
return None
|
||||
|
||||
aiter_mqa_logits_module = None
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
aiter_mqa_logits_module = mqa_logits_module()
|
||||
|
||||
if aiter_mqa_logits_module is not None:
|
||||
fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits
|
||||
kv, scale = kv
|
||||
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||
else:
|
||||
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||
|
||||
|
||||
def rocm_aiter_sparse_attn_indexer_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: str | None,
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
# profile run
|
||||
# NOTE(Chen): create the max possible flattened_kv. So that
|
||||
# profile_run can get correct memory usage.
|
||||
_flattened_kv = torch.empty(
|
||||
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
|
||||
)
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
_k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous()
|
||||
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
def rocm_aiter_sparse_attn_indexer(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: str | None,
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
# careful! this will be None in dummy run
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
# assert isinstance(attn_metadata, dict)
|
||||
if not isinstance(attn_metadata, dict):
|
||||
return rocm_aiter_sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
kv_cache,
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
topk_tokens,
|
||||
head_dim,
|
||||
max_model_len,
|
||||
total_seq_lens,
|
||||
topk_indices_buffer,
|
||||
)
|
||||
attn_metadata = attn_metadata[k_cache_prefix]
|
||||
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
ops.indexer_k_quant_and_cache(
|
||||
k,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
)
|
||||
|
||||
topk_indices_buffer[: hidden_states.shape[0]] = -1
|
||||
if has_prefill:
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8 = torch.empty(
|
||||
[chunk.total_seq_lens, head_dim],
|
||||
device=k.device,
|
||||
dtype=fp8_dtype,
|
||||
)
|
||||
k_scale = torch.empty(
|
||||
[chunk.total_seq_lens, 4],
|
||||
device=k.device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
)
|
||||
|
||||
logits = rocm_fp8_mqa_logits(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32)),
|
||||
weights[chunk.token_start : chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
)
|
||||
num_rows = logits.shape[0]
|
||||
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
||||
topk_indices = topk_indices_buffer[
|
||||
chunk.token_start : chunk.token_end, :topk_tokens
|
||||
]
|
||||
torch.ops._C.top_k_per_row_prefill(
|
||||
logits,
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
if has_decode:
|
||||
decode_metadata = attn_metadata.decode
|
||||
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
||||
# we only have [num_block, block_size, head_dim],
|
||||
kv_cache = kv_cache.unsqueeze(-2)
|
||||
decode_lens = decode_metadata.decode_lens
|
||||
if decode_metadata.requires_padding:
|
||||
# pad in edge case where we have short chunked prefill length <
|
||||
# decode_threshold since we unstrictly split
|
||||
# prefill and decode by decode_threshold
|
||||
# (currently set to 1 + speculative tokens)
|
||||
padded_q_fp8_decode_tokens = pack_seq_triton(
|
||||
q_fp8[:num_decode_tokens], decode_lens
|
||||
)
|
||||
else:
|
||||
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
|
||||
decode_lens.shape[0], -1, *q_fp8.shape[1:]
|
||||
)
|
||||
# TODO: move and optimize below logic with triton kernels
|
||||
batch_size = padded_q_fp8_decode_tokens.shape[0]
|
||||
next_n = padded_q_fp8_decode_tokens.shape[1]
|
||||
assert batch_size == decode_metadata.seq_lens.shape[0]
|
||||
num_padded_tokens = batch_size * next_n
|
||||
|
||||
logits = rocm_fp8_paged_mqa_logits(
|
||||
padded_q_fp8_decode_tokens,
|
||||
kv_cache,
|
||||
weights[:num_padded_tokens],
|
||||
decode_metadata.seq_lens,
|
||||
decode_metadata.block_table,
|
||||
decode_metadata.schedule_metadata,
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
|
||||
num_rows = logits.shape[0]
|
||||
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
||||
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits,
|
||||
next_n,
|
||||
decode_metadata.seq_lens,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
if decode_metadata.requires_padding:
|
||||
# if padded, we need to unpack
|
||||
# the topk indices removing padded tokens
|
||||
topk_indices = unpack_seq_triton(
|
||||
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
||||
decode_lens,
|
||||
)
|
||||
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
|
||||
topk_indices
|
||||
)
|
||||
|
||||
return topk_indices_buffer
|
||||
|
||||
Reference in New Issue
Block a user