|
|
|
|
@@ -2,6 +2,7 @@
|
|
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
import functools
|
|
|
|
|
import importlib
|
|
|
|
|
from importlib.util import find_spec
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
@@ -276,11 +277,9 @@ def fp8_paged_mqa_logits_torch(
|
|
|
|
|
@functools.lru_cache
|
|
|
|
|
def paged_mqa_logits_module():
|
|
|
|
|
paged_mqa_logits_module_path = None
|
|
|
|
|
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
|
|
|
|
|
if find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
|
|
|
|
|
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
|
|
|
|
|
elif (
|
|
|
|
|
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None
|
|
|
|
|
):
|
|
|
|
|
elif find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None:
|
|
|
|
|
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
|
|
|
|
|
|
|
|
|
|
if paged_mqa_logits_module_path is not None:
|
|
|
|
|
@@ -380,9 +379,9 @@ def fp8_mqa_logits_torch(
|
|
|
|
|
Returns:
|
|
|
|
|
Logits tensor of shape [M, N], dtype `torch.float32`.
|
|
|
|
|
"""
|
|
|
|
|
kv, scale = kv
|
|
|
|
|
seq_len_kv = kv.shape[0]
|
|
|
|
|
k = kv.to(torch.bfloat16)
|
|
|
|
|
k_fp8, scale = kv
|
|
|
|
|
seq_len_kv = k_fp8.shape[0]
|
|
|
|
|
k = k_fp8.to(torch.bfloat16)
|
|
|
|
|
q = q.to(torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
mask_lo = (
|
|
|
|
|
@@ -403,12 +402,9 @@ def fp8_mqa_logits_torch(
|
|
|
|
|
@functools.lru_cache
|
|
|
|
|
def mqa_logits_module():
|
|
|
|
|
mqa_logits_module_path = None
|
|
|
|
|
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
|
|
|
|
|
if find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
|
|
|
|
|
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
|
|
|
|
|
elif (
|
|
|
|
|
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
|
|
|
|
|
is not None
|
|
|
|
|
):
|
|
|
|
|
elif find_spec("aiter.ops.triton.attention.fp8_mqa_logits") is not None:
|
|
|
|
|
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
|
|
|
|
|
|
|
|
|
|
if mqa_logits_module_path is not None:
|
|
|
|
|
@@ -455,8 +451,8 @@ def rocm_fp8_mqa_logits(
|
|
|
|
|
|
|
|
|
|
if aiter_mqa_logits_module is not None:
|
|
|
|
|
fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits
|
|
|
|
|
kv, scale = kv
|
|
|
|
|
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
|
|
|
|
|
k_fp8, scale = kv
|
|
|
|
|
return fp8_mqa_logits(q, k_fp8, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
|
|
|
|
|
else:
|
|
|
|
|
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
|
|
|
|
|
|
|
|
|
@@ -523,12 +519,14 @@ def rocm_aiter_sparse_attn_indexer(
|
|
|
|
|
total_seq_lens,
|
|
|
|
|
topk_indices_buffer,
|
|
|
|
|
)
|
|
|
|
|
attn_metadata = attn_metadata[k_cache_prefix]
|
|
|
|
|
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
|
|
|
|
|
slot_mapping = attn_metadata.slot_mapping
|
|
|
|
|
has_decode = attn_metadata.num_decodes > 0
|
|
|
|
|
has_prefill = attn_metadata.num_prefills > 0
|
|
|
|
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
|
|
|
|
layer_attn_metadata = attn_metadata[k_cache_prefix]
|
|
|
|
|
assert isinstance(layer_attn_metadata, DeepseekV32IndexerMetadata)
|
|
|
|
|
assert topk_indices_buffer is not None
|
|
|
|
|
assert scale_fmt is not None
|
|
|
|
|
slot_mapping = layer_attn_metadata.slot_mapping
|
|
|
|
|
has_decode = layer_attn_metadata.num_decodes > 0
|
|
|
|
|
has_prefill = layer_attn_metadata.num_prefills > 0
|
|
|
|
|
num_decode_tokens = layer_attn_metadata.num_decode_tokens
|
|
|
|
|
|
|
|
|
|
ops.indexer_k_quant_and_cache(
|
|
|
|
|
k,
|
|
|
|
|
@@ -540,7 +538,8 @@ def rocm_aiter_sparse_attn_indexer(
|
|
|
|
|
|
|
|
|
|
topk_indices_buffer[: hidden_states.shape[0]] = -1
|
|
|
|
|
if has_prefill:
|
|
|
|
|
prefill_metadata = attn_metadata.prefill
|
|
|
|
|
prefill_metadata = layer_attn_metadata.prefill
|
|
|
|
|
assert prefill_metadata is not None
|
|
|
|
|
for chunk in prefill_metadata.chunks:
|
|
|
|
|
k_fp8 = torch.empty(
|
|
|
|
|
[chunk.total_seq_lens, head_dim],
|
|
|
|
|
@@ -585,7 +584,8 @@ def rocm_aiter_sparse_attn_indexer(
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if has_decode:
|
|
|
|
|
decode_metadata = attn_metadata.decode
|
|
|
|
|
decode_metadata = layer_attn_metadata.decode
|
|
|
|
|
assert decode_metadata is not None
|
|
|
|
|
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
|
|
|
|
# we only have [num_block, block_size, head_dim],
|
|
|
|
|
kv_cache = kv_cache.unsqueeze(-2)
|
|
|
|
|
|