[ROCm][Deepseekv3.2] Refactor Sparse Indexer as CustomOp (#29287)

Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
Pleaplusone
2026-01-21 23:16:30 +08:00
committed by GitHub
parent 85f55c943c
commit 6c20e89c02
8 changed files with 989 additions and 330 deletions

View File

@@ -63,6 +63,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor
cu_seq_lens: torch.Tensor
token_to_seq: torch.Tensor
total_seq_lens: int
token_start: int
token_end: int
@@ -234,6 +235,10 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
token_start = query_start_loc_cpu[reqs_start].item()
token_end = query_start_loc_cpu[reqs_end].item()
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32)
token_to_seq = torch.repeat_interleave(
seq_idx, seq_lens_cpu[reqs_start:reqs_end]
).to(self.device)
assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = (
torch.cat(
@@ -249,6 +254,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens,
token_to_seq=token_to_seq,
total_seq_lens=total_seq_lens,
block_table=block_table[reqs_start:reqs_end],
token_start=token_start,

View File

@@ -15,6 +15,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBaseImpl,
get_mla_dims,
)
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
@@ -33,6 +34,48 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
@triton.jit
def fetch_id_to_ragged_kernel(
in_tensor_ptr, # [num_seq, topk]
cumsum_ptr, # [num_seq + 1]
out_tensor_ptr, # [max_num_seq * topk]
in_tensor_ptr_stride,
TOPK: tl.constexpr,
TOKEN_NUM: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
seq_id = tl.program_id(0)
block_id = tl.program_id(1)
offset = tl.arange(0, BLOCK_SIZE)
token_start = tl.load(cumsum_ptr + seq_id)
token_end = tl.load(cumsum_ptr + seq_id + 1)
token_num = token_end - token_start
row_offset = block_id * BLOCK_SIZE
if row_offset >= token_num:
return
in_tensor_offset = seq_id * in_tensor_ptr_stride + row_offset + offset
in_tensor_mask = (row_offset + offset) < TOPK
in_tensor_val = tl.load(in_tensor_ptr + in_tensor_offset, mask=in_tensor_mask)
out_tensor_offset = token_start + row_offset + offset
out_tensor_mask = (out_tensor_offset < token_end) & in_tensor_mask
tl.store(out_tensor_ptr + out_tensor_offset, in_tensor_val, mask=out_tensor_mask)
def fetch_id_to_ragged_triton(
in_tensor: torch.Tensor, cumsum: torch.Tensor, out_tensor: torch.Tensor, topk
):
num_tokens = in_tensor.size(0)
block_size = 64
num_block_per_row = triton.cdiv(topk, block_size)
grid = (
num_tokens,
num_block_per_row,
)
fetch_id_to_ragged_kernel[grid](
in_tensor, cumsum, out_tensor, in_tensor.stride(0), topk, num_tokens, block_size
)
class ROCMAiterMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
@@ -83,6 +126,13 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
block_table: torch.Tensor
req_id_per_token: torch.Tensor
qo_indptr: torch.Tensor
paged_kv_last_page_len: torch.Tensor
paged_kv_indices: torch.Tensor
paged_kv_indptr: torch.Tensor
paged_kv_indptr_rest: torch.Tensor
block_size: int = 1
topk_tokens: int = 2048
@@ -91,7 +141,7 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
class ROCMAiterMLASparseMetadataBuilder(
AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]
):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
def __init__(
self,
@@ -104,6 +154,7 @@ class ROCMAiterMLASparseMetadataBuilder(
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.device = device
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
@@ -124,6 +175,23 @@ class ROCMAiterMLASparseMetadataBuilder(
dtype=torch.int32,
device=device,
)
self.qo_indptr = torch.arange(
0, max_num_batched_tokens + 1, dtype=torch.int32, device=device
)
self.paged_kv_last_page_len = torch.ones(
max_num_batched_tokens, dtype=torch.int32, device=device
)
# These two needs to be calculated in runtime,
# but we still needs to prepare the buffer
self.paged_kv_indices = torch.zeros(
[max_num_batched_tokens * self.topk_tokens],
dtype=torch.int32,
device=device,
)
self.paged_kv_indptr = torch.zeros(
[max_num_batched_tokens + 1], dtype=torch.int32, device=device
)
def build(
self,
@@ -142,7 +210,15 @@ class ROCMAiterMLASparseMetadataBuilder(
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
torch.from_numpy(req_id_per_token), non_blocking=True
)
self.paged_kv_indices.fill_(0)
self.paged_kv_indptr.fill_(0)
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
qo_indptr = self.qo_indptr[: num_tokens + 1]
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_tokens]
paged_kv_indices = self.paged_kv_indices[: num_tokens * self.topk_tokens]
paged_kv_indptr = self.paged_kv_indptr[: num_tokens + 1]
paged_kv_indptr_rest = self.paged_kv_indptr[num_tokens + 1 :]
metadata = ROCMAiterMLASparseMetadata(
num_reqs=common_attn_metadata.num_reqs,
@@ -155,6 +231,11 @@ class ROCMAiterMLASparseMetadataBuilder(
req_id_per_token=req_id_per_token,
block_size=self.kv_cache_spec.block_size,
topk_tokens=self.topk_tokens,
qo_indptr=qo_indptr,
paged_kv_last_page_len=paged_kv_last_page_len,
paged_kv_indices=paged_kv_indices,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indptr_rest=paged_kv_indptr_rest,
)
return metadata
@@ -226,20 +307,39 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
def _forward_bf16_kv(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
q: torch.Tensor, # [sq, heads, d_qk]
kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk]
topk_indices: torch.Tensor, # [sq, topk]
attn_metadata: ROCMAiterMLASparseMetadata,
) -> torch.Tensor:
num_tokens = q.shape[0]
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
-1, 1, kv_c_and_k_pe_cache.shape[-1]
output = torch.empty(
[num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype,
device=q.device,
)
seq_len = (topk_indices != -1).sum(dim=-1)
torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:])
attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1])
fetch_id_to_ragged_triton(
topk_indices,
attn_metadata.paged_kv_indptr,
attn_metadata.paged_kv_indices,
attn_metadata.topk_tokens,
)
rocm_aiter_ops.mla_decode_fwd(
q,
kv_c_and_k_pe_cache,
output,
self.scale,
attn_metadata.qo_indptr,
1,
attn_metadata.paged_kv_indptr,
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len,
)
topk_indices = topk_indices.view(num_tokens, 1, -1)
output = reference_mla_sparse_prefill(
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale, 512
)[0]
return output[:, : self.num_heads, :]
def forward(

View File

@@ -1,100 +1,220 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import importlib
from functools import lru_cache
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
logger = init_logger(__name__)
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
def fp8_mqa_logits_torch(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
k_fp8, scale = kv
seq_len_kv = k_fp8.shape[0]
k = k_fp8.to(torch.bfloat16)
q = q.to(torch.bfloat16)
mask_lo = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
)
mask_hi = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
)
mask = mask_lo & mask_hi
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float("-inf"))
return logits
def rocm_fp8_mqa_logits(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
# TODO(ganyi): Temporarily workaround, will remove the module check and reference
# path after aiter merge this kernel into main
@lru_cache
def has_mqa_logits_module():
return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None
if rocm_aiter_ops.is_enabled() and has_mqa_logits_module():
from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits
kv, scale = kv
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
@triton.jit
def _indexer_k_quant_and_cache_kernel(
k_ptr, # [num_tokens, head_dim]
kv_cache_ptr, # [n_blks, blk_size//tile_block, head_dim // 16B, tile_block, 16B]
# [n_blocks, blk_size, head_dim]
kv_cache_scale_ptr, # [n_blks, blk_size]
slot_mapping_ptr, # [num_tokens]
kv_cache_scale_stride,
kv_cache_value_stride,
block_size,
num_tokens,
head_dim: tl.constexpr,
LAYOUT: tl.constexpr,
BLOCK_TILE_SIZE: tl.constexpr,
HEAD_TILE_SIZE: tl.constexpr,
IS_FNUZ: tl.constexpr,
USE_UE8M0: tl.constexpr,
):
tid = tl.program_id(0)
offset = tl.arange(0, head_dim)
if LAYOUT == "SHUFFLE":
tile_offset = (
offset // HEAD_TILE_SIZE * BLOCK_TILE_SIZE * HEAD_TILE_SIZE
+ offset % HEAD_TILE_SIZE
)
else:
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
tile_offset = offset
tile_store_offset = tile_offset
# for idx in tl.range(tid, num_tokens, n_program):
src_ptr = k_ptr + tid * head_dim
slot_id = tl.load(slot_mapping_ptr + tid)
if slot_id < 0:
return
block_id = slot_id // block_size
block_offset = slot_id % block_size
tile_block_id = block_offset // BLOCK_TILE_SIZE
tile_block_offset = block_offset % BLOCK_TILE_SIZE
val = tl.load(src_ptr + offset)
amax = tl.max(val.abs(), axis=-1).to(tl.float32)
if IS_FNUZ:
scale = tl.maximum(1e-4, amax) / 224.0
else:
scale = tl.maximum(1e-4, amax) / 448.0
if USE_UE8M0:
scale = tl.exp2(tl.ceil(tl.log2(scale)))
fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty)
if LAYOUT == "SHUFFLE":
dst_ptr = (
kv_cache_ptr
+ block_id * kv_cache_value_stride
+ tile_block_id * BLOCK_TILE_SIZE * head_dim
+ tile_block_offset * HEAD_TILE_SIZE
)
else:
dst_ptr = (
kv_cache_ptr + block_id * kv_cache_value_stride + block_offset * head_dim
)
tl.store(dst_ptr + tile_store_offset, fp8_val)
dst_scale_ptr = kv_cache_scale_ptr + block_id * kv_cache_scale_stride + block_offset
tl.store(dst_scale_ptr, scale)
def indexer_k_quant_and_cache_triton(
k: torch.Tensor,
kv_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4]
slot_mapping: torch.Tensor,
quant_block_size,
scale_fmt,
block_tile_size=16,
head_tile_size=16,
):
num_blocks = kv_cache.shape[0]
head_dim = k.shape[-1]
num_tokens = slot_mapping.shape[0]
block_size = kv_cache.shape[1]
# In real layout, we store the first portion as kv cache value
# and second portion as kv cache scale
kv_cache = kv_cache.view(num_blocks, -1)
kv_cache_value = kv_cache[:, : block_size * head_dim]
kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32)
head_tile_size = head_tile_size // kv_cache.element_size()
grid = (num_tokens,)
_indexer_k_quant_and_cache_kernel[grid](
k,
kv_cache_value,
kv_cache_scale,
slot_mapping,
kv_cache_scale.stride(0),
kv_cache_value.stride(0),
block_size,
num_tokens,
head_dim,
"NHD",
block_tile_size,
head_tile_size,
IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz,
USE_UE8M0=scale_fmt == "ue8m0",
)
@triton.jit
def _cp_gather_indexer_quant_cache_kernel(
kv_cache_ptr, # [n_blks,blk_size//tile_blk,head_dim//16B,tile_blk,16B]
# [n_blks, blk_size, head_dim]
kv_cache_scale_ptr, # [n_blks, blk_size]
k_fp8_ptr, # [num_tokens, head_dim]
k_scale_ptr, # [num_tokens]
block_table_ptr, # [batch_size, block_table_stride]
cu_seqlen_ptr, # [batch_size + 1]
token_to_seq_ptr, # [num_tokens]
block_size,
block_table_stride,
kv_cache_stride,
kv_cache_scale_stride,
LAYOUT: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_TILE_SIZE: tl.constexpr,
HEAD_TILE_SIZE: tl.constexpr,
):
tid = tl.program_id(0)
offset = tl.arange(0, HEAD_DIM)
batch_id = tl.load(token_to_seq_ptr + tid)
batch_start = tl.load(cu_seqlen_ptr + batch_id)
batch_end = tl.load(cu_seqlen_ptr + batch_id + 1)
batch_offset = tid - batch_start
if tid >= batch_end:
return
block_table_id = batch_offset // block_size
block_offset = batch_offset % block_size
block_table_offset = batch_id * block_table_stride + block_table_id
block_id = tl.load(block_table_ptr + block_table_offset)
tiled_block_id = block_offset // BLOCK_TILE_SIZE
tiled_block_offset = block_offset % BLOCK_TILE_SIZE
if LAYOUT == "SHUFFLE":
src_cache_offset = (
block_id * kv_cache_stride
+ tiled_block_id * HEAD_DIM * BLOCK_TILE_SIZE
+ tiled_block_offset * HEAD_TILE_SIZE
)
else:
src_cache_offset = block_id * kv_cache_stride + block_offset * HEAD_DIM
src_scale_offset = block_id * kv_cache_scale_stride + block_offset
dst_offset = tid * HEAD_DIM
src_scale_ptr = kv_cache_scale_ptr + src_scale_offset
src_cache_ptr = kv_cache_ptr + src_cache_offset
dst_k_ptr = k_fp8_ptr + dst_offset
scale_val = tl.load(src_scale_ptr)
tl.store(k_scale_ptr + tid, scale_val)
if LAYOUT == "SHUFFLE":
tiled_src_offset = (
offset // HEAD_TILE_SIZE * HEAD_TILE_SIZE * BLOCK_TILE_SIZE
+ offset % HEAD_TILE_SIZE
)
else:
tiled_src_offset = offset
val = tl.load(src_cache_ptr + tiled_src_offset)
tl.store(dst_k_ptr + offset, val)
def cp_gather_indexer_k_quant_cache_triton(
k_cache: torch.Tensor, # [num_blocks, block_size, head_dim + 4]
k_fp8: torch.Tensor,
k_fp8_scale: torch.Tensor,
block_table: torch.Tensor,
cu_seqlen: torch.Tensor,
token_to_seq: torch.Tensor,
block_tile_size: int = 16,
head_tile_size: int = 16,
):
num_tokens = k_fp8.size(0)
block_size = k_cache.size(1)
block_table_stride = block_table.stride(0)
head_dim = k_fp8.shape[-1]
num_blocks = k_cache.shape[0]
# we assume the kv cache already been split to 2 portion
k_cache = k_cache.view(num_blocks, -1)
fp8_dtype = current_platform.fp8_dtype()
k_cache_value = k_cache[:, : block_size * head_dim].view(fp8_dtype)
k_cache_scale = k_cache[:, block_size * head_dim :].view(torch.float32)
grid = (num_tokens,)
k_fp8_scale = k_fp8_scale.view(torch.float32)
_cp_gather_indexer_quant_cache_kernel[grid](
k_cache_value,
k_cache_scale,
k_fp8,
k_fp8_scale,
block_table,
cu_seqlen,
token_to_seq,
block_size,
block_table_stride,
k_cache_value.stride(0),
k_cache_scale.stride(0),
"NHD",
head_dim,
block_tile_size,
head_tile_size,
)
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
@@ -183,10 +303,38 @@ def rocm_fp8_paged_mqa_logits(
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""
from vllm._aiter_ops import rocm_aiter_ops
@functools.lru_cache
def paged_mqa_logits_module():
paged_mqa_logits_module_path = None
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
elif (
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits")
is not None
):
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
if paged_mqa_logits_module_path is not None:
try:
module = importlib.import_module(paged_mqa_logits_module_path)
return module
except ImportError:
return None
return None
aiter_paged_mqa_logits_module = None
if rocm_aiter_ops.is_enabled():
from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
# FIXME(ganyi): Temporarily disable the aiter path until nightly docker
# update aiter to the fix PR.
aiter_paged_mqa_logits_module = None
if aiter_paged_mqa_logits_module is not None:
deepgemm_fp8_paged_mqa_logits_stage1 = (
aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1
)
batch_size, next_n, heads, _ = q_fp8.shape
out_qk = torch.full(
(heads, batch_size * next_n, max_model_len),
@@ -208,3 +356,293 @@ def rocm_fp8_paged_mqa_logits(
return fp8_paged_mqa_logits_torch(
q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len
)
# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
def fp8_mqa_logits_torch(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
kv, scale = kv
seq_len_kv = kv.shape[0]
k = kv.to(torch.bfloat16)
q = q.to(torch.bfloat16)
mask_lo = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
)
mask_hi = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
)
mask = mask_lo & mask_hi
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float("-inf"))
return logits
def rocm_fp8_mqa_logits(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
# TODO(ganyi): Temporarily workaround, will remove the module check and reference
# path after aiter merge this kernel into main
from vllm._aiter_ops import rocm_aiter_ops
@functools.lru_cache
def mqa_logits_module():
mqa_logits_module_path = None
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
elif (
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
is not None
):
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
if mqa_logits_module_path is not None:
try:
module = importlib.import_module(mqa_logits_module_path)
return module
except ImportError:
return None
return None
aiter_mqa_logits_module = None
if rocm_aiter_ops.is_enabled():
aiter_mqa_logits_module = mqa_logits_module()
if aiter_mqa_logits_module is not None:
fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits
kv, scale = kv
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
else:
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
def rocm_aiter_sparse_attn_indexer_fake(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: str | None,
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: torch.Tensor | None,
) -> torch.Tensor:
# profile run
# NOTE(Chen): create the max possible flattened_kv. So that
# profile_run can get correct memory usage.
_flattened_kv = torch.empty(
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
)
fp8_dtype = current_platform.fp8_dtype()
_k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous()
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
return topk_indices_buffer
def rocm_aiter_sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: str | None,
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: torch.Tensor | None,
) -> torch.Tensor:
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype()
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
return rocm_aiter_sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
kv_cache,
q_fp8,
k,
weights,
quant_block_size,
scale_fmt,
topk_tokens,
head_dim,
max_model_len,
total_seq_lens,
topk_indices_buffer,
)
attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
ops.indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
for chunk in prefill_metadata.chunks:
k_fp8 = torch.empty(
[chunk.total_seq_lens, head_dim],
device=k.device,
dtype=fp8_dtype,
)
k_scale = torch.empty(
[chunk.total_seq_lens, 4],
device=k.device,
dtype=torch.uint8,
)
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
logits = rocm_fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = topk_indices_buffer[
chunk.token_start : chunk.token_end, :topk_tokens
]
torch.ops._C.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if has_decode:
decode_metadata = attn_metadata.decode
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache = kv_cache.unsqueeze(-2)
decode_lens = decode_metadata.decode_lens
if decode_metadata.requires_padding:
# pad in edge case where we have short chunked prefill length <
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
padded_q_fp8_decode_tokens = pack_seq_triton(
q_fp8[:num_decode_tokens], decode_lens
)
else:
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
decode_lens.shape[0], -1, *q_fp8.shape[1:]
)
# TODO: move and optimize below logic with triton kernels
batch_size = padded_q_fp8_decode_tokens.shape[0]
next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
logits = rocm_fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
)
num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if decode_metadata.requires_padding:
# if padded, we need to unpack
# the topk indices removing padded tokens
topk_indices = unpack_seq_triton(
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
decode_lens,
)
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
topk_indices
)
return topk_indices_buffer