[FlashMLA] Update FlashMLA to expose new arguments (#32810)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -7,6 +7,9 @@ vllm/vllm_flash_attn/*
|
||||
# OpenAI triton kernels copied from source
|
||||
vllm/third_party/triton_kernels/*
|
||||
|
||||
# FlashMLA interface copied from source
|
||||
vllm/third_party/flashmla/flash_mla_interface.py
|
||||
|
||||
# triton jit
|
||||
.triton
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
|
||||
GIT_TAG 526781394b33d9888e4c41952e692266267dd8bf
|
||||
GIT_TAG c2afa9cb93e674d5a9120a170a6da57b89267208
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
@@ -30,6 +30,24 @@ endif()
|
||||
FetchContent_MakeAvailable(flashmla)
|
||||
message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
|
||||
|
||||
# Vendor FlashMLA interface into vLLM with torch-ops shim.
|
||||
set(FLASHMLA_VENDOR_DIR "${CMAKE_SOURCE_DIR}/vllm/third_party/flashmla")
|
||||
file(MAKE_DIRECTORY "${FLASHMLA_VENDOR_DIR}")
|
||||
file(READ "${flashmla_SOURCE_DIR}/flash_mla/flash_mla_interface.py"
|
||||
FLASHMLA_INTERFACE_CONTENT)
|
||||
string(REPLACE "import flash_mla.cuda as flash_mla_cuda"
|
||||
"import vllm._flashmla_C\nflash_mla_cuda = torch.ops._flashmla_C"
|
||||
FLASHMLA_INTERFACE_CONTENT
|
||||
"${FLASHMLA_INTERFACE_CONTENT}")
|
||||
file(WRITE "${FLASHMLA_VENDOR_DIR}/flash_mla_interface.py"
|
||||
"${FLASHMLA_INTERFACE_CONTENT}")
|
||||
|
||||
# Install the generated flash_mla_interface.py to the wheel
|
||||
# Use COMPONENT _flashmla_C to ensure it's installed with the C extension
|
||||
install(FILES "${FLASHMLA_VENDOR_DIR}/flash_mla_interface.py"
|
||||
DESTINATION vllm/third_party/flashmla/
|
||||
COMPONENT _flashmla_C)
|
||||
|
||||
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
|
||||
# Only build FlashMLA kernels if we are building for something compatible with
|
||||
# sm90a
|
||||
@@ -79,7 +97,6 @@ if(FLASH_MLA_ARCHS)
|
||||
|
||||
# sm100 dense prefill & backward
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
|
||||
|
||||
# sm100 sparse prefill
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu
|
||||
|
||||
10
setup.py
10
setup.py
@@ -646,6 +646,9 @@ class precompiled_wheel_utils:
|
||||
triton_kernels_regex = re.compile(
|
||||
r"vllm/third_party/triton_kernels/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
|
||||
)
|
||||
flashmla_regex = re.compile(
|
||||
r"vllm/third_party/flashmla/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
|
||||
)
|
||||
file_members = list(
|
||||
filter(lambda x: x.filename in files_to_copy, wheel.filelist)
|
||||
)
|
||||
@@ -657,6 +660,9 @@ class precompiled_wheel_utils:
|
||||
lambda x: triton_kernels_regex.match(x.filename), wheel.filelist
|
||||
)
|
||||
)
|
||||
file_members += list(
|
||||
filter(lambda x: flashmla_regex.match(x.filename), wheel.filelist)
|
||||
)
|
||||
|
||||
for file in file_members:
|
||||
print(f"[extract] {file.filename}")
|
||||
@@ -925,6 +931,10 @@ if _is_cuda():
|
||||
):
|
||||
# FA3 requires CUDA 12.3 or later
|
||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||
if envs.VLLM_USE_PRECOMPILED or (
|
||||
CUDA_HOME and get_nvcc_cuda_version() >= Version("12.9")
|
||||
):
|
||||
# FlashMLA requires CUDA 12.9 or later
|
||||
# Optional since this doesn't get built (produce an .so file) when
|
||||
# not targeting a hopper system
|
||||
ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True))
|
||||
|
||||
@@ -53,7 +53,6 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
|
||||
|
||||
def _float_to_e8m0_truncate(f: float) -> float:
|
||||
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion.
|
||||
|
||||
e8m0 format only stores the exponent (power of 2).
|
||||
cudaRoundZero truncates toward zero, meaning we round down to the
|
||||
nearest power of 2.
|
||||
|
||||
1
vllm/third_party/flashmla/__init__.py
vendored
Normal file
1
vllm/third_party/flashmla/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
# Sources copied from FlashMLA
|
||||
@@ -32,8 +32,11 @@ from vllm.v1.attention.backends.utils import (
|
||||
reshape_query_for_spec_decode,
|
||||
)
|
||||
from vllm.v1.attention.ops.flashmla import (
|
||||
FlashMLASchedMeta,
|
||||
flash_mla_with_kvcache,
|
||||
flash_mla_with_kvcache_fp8,
|
||||
get_mla_metadata,
|
||||
get_mla_metadata_dense_fp8,
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
@@ -93,8 +96,7 @@ class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
tile_scheduler_metadata: torch.Tensor
|
||||
num_splits: torch.Tensor
|
||||
scheduler_metadata: FlashMLASchedMeta
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -158,46 +160,25 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
# we use the max but all should be the same due to uniform length requirement
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
scheduler_metadata, _ = get_mla_metadata(
|
||||
seq_lens_device,
|
||||
num_q_tokens_per_head_k,
|
||||
1, # MQA for the decode path
|
||||
is_fp8_kvcache=self.is_fp8_kvcache,
|
||||
)
|
||||
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
sm_parts = tile_scheduler_metadata.size(0)
|
||||
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
|
||||
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
|
||||
tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[
|
||||
:sm_parts
|
||||
]
|
||||
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = tile_scheduler_metadata_view
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
num_splits_view.copy_(num_splits)
|
||||
# Num splits needs to monotonically increasing
|
||||
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
|
||||
# it needs to monotonically increasing by 1)
|
||||
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
|
||||
num_splits = num_splits_view
|
||||
if self.is_fp8_kvcache:
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata_dense_fp8(
|
||||
seq_lens_device,
|
||||
num_q_tokens_per_head_k,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
|
||||
scheduler_metadata.num_splits = num_splits
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
|
||||
@@ -272,9 +253,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||
|
||||
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
|
||||
num_splits = attn_metadata.decode.num_splits
|
||||
if vllm_is_batch_invariant():
|
||||
scheduler_metadata = attn_metadata.decode.scheduler_metadata
|
||||
if vllm_is_batch_invariant() and not self.kv_cache_dtype.startswith("fp8"):
|
||||
device = q.device
|
||||
dtype = torch.int32
|
||||
|
||||
@@ -301,19 +281,34 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
# Non-split path ignores num_splits, but the API requires it:
|
||||
# zeros of length B+1
|
||||
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
|
||||
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
|
||||
scheduler_metadata.num_splits = num_splits
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
o, lse = flash_mla_with_kvcache_fp8(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=scheduler_metadata.tile_scheduler_metadata,
|
||||
num_splits=scheduler_metadata.num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
descale_q=layer._q_scale.reshape(1),
|
||||
descale_k=layer._k_scale.reshape(1),
|
||||
)
|
||||
else:
|
||||
o, lse = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
tile_scheduler_metadata=scheduler_metadata,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
descale_q=layer._q_scale.reshape(1),
|
||||
descale_k=layer._k_scale.reshape(1),
|
||||
is_fp8_kvcache=False,
|
||||
)
|
||||
|
||||
o = reshape_attn_output_for_spec_decode(o)
|
||||
|
||||
@@ -33,7 +33,8 @@ from vllm.v1.attention.backends.utils import (
|
||||
split_prefill_chunks,
|
||||
)
|
||||
from vllm.v1.attention.ops.flashmla import (
|
||||
flash_mla_sparse_prefill,
|
||||
FlashMLASchedMeta,
|
||||
flash_mla_sparse_fwd,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
@@ -142,8 +143,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
|
||||
|
||||
@dataclass
|
||||
class FP8KernelMetadata:
|
||||
scheduler_metadata: torch.Tensor | None
|
||||
num_splits: torch.Tensor
|
||||
scheduler_metadata: FlashMLASchedMeta
|
||||
dummy_block_table: torch.Tensor
|
||||
cache_lens: torch.Tensor
|
||||
|
||||
@@ -468,7 +468,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
padded_heads = self.fp8_decode_padded_heads
|
||||
|
||||
# Build metadata for all tokens as a single batch
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
scheduler_metadata, _ = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor[:1], # Single batch
|
||||
num_q_tokens_per_head_k=num_tokens * padded_heads,
|
||||
topk=self.topk_tokens,
|
||||
@@ -477,17 +477,8 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
|
||||
num_sm_parts = tile_scheduler_metadata.size(0)
|
||||
tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[
|
||||
:num_sm_parts
|
||||
]
|
||||
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
|
||||
num_splits_view = self.num_splits_buffer[:2]
|
||||
num_splits_view.copy_(num_splits)
|
||||
|
||||
fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
scheduler_metadata=tile_scheduler_metadata_buffer,
|
||||
num_splits=num_splits_view,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
cache_lens=self.max_model_len_tensor[:1],
|
||||
dummy_block_table=self.dummy_block_table[:1],
|
||||
)
|
||||
@@ -620,7 +611,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
|
||||
# Use padded head count since that's what the kernel will see
|
||||
padded_heads = self.fp8_decode_padded_heads
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
scheduler_metadata, _ = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
|
||||
num_q_tokens_per_head_k=decode_query_len * padded_heads,
|
||||
topk=self.topk_tokens,
|
||||
@@ -629,19 +620,8 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
|
||||
num_sm_parts = tile_scheduler_metadata.size(0)
|
||||
# Copy to persistent buffer for full-CG support
|
||||
tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[
|
||||
:num_sm_parts
|
||||
]
|
||||
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
|
||||
# num_splits has size [num_decodes + 1]
|
||||
num_splits_view = self.num_splits_buffer[: num_decodes + 1]
|
||||
num_splits_view.copy_(num_splits)
|
||||
|
||||
kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
scheduler_metadata=tile_scheduler_metadata_buffer,
|
||||
num_splits=num_splits_view,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
dummy_block_table=self.dummy_block_table[:num_decodes],
|
||||
cache_lens=self.max_model_len_tensor[:num_decodes],
|
||||
)
|
||||
@@ -949,7 +929,6 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
head_dim_v=512,
|
||||
cache_seqlens=kernel_metadata.cache_lens,
|
||||
tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
|
||||
num_splits=kernel_metadata.num_splits,
|
||||
is_fp8_kvcache=True,
|
||||
indices=topk_indices,
|
||||
softmax_scale=self.softmax_scale,
|
||||
@@ -985,7 +964,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
q = q_padded
|
||||
|
||||
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
||||
output = flash_mla_sparse_prefill(
|
||||
output = flash_mla_sparse_fwd(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
|
||||
)[0]
|
||||
output = output[:, : self.num_heads, :]
|
||||
|
||||
@@ -78,50 +78,49 @@ def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
def _raise_flashmla_unavailable(*_args, **_kwargs):
|
||||
_, reason = _is_flashmla_available()
|
||||
raise RuntimeError(reason or "FlashMLA is not available")
|
||||
|
||||
|
||||
if _is_flashmla_available()[0]:
|
||||
from vllm.third_party.flashmla.flash_mla_interface import ( # noqa: F401
|
||||
FlashMLASchedMeta,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
flash_mla_sparse_fwd,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
else:
|
||||
|
||||
class FlashMLASchedMeta: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
flash_attn_varlen_func = _raise_flashmla_unavailable # type: ignore[assignment]
|
||||
flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment]
|
||||
flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment]
|
||||
flash_mla_sparse_fwd = _raise_flashmla_unavailable # type: ignore[assignment]
|
||||
flash_mla_with_kvcache = _raise_flashmla_unavailable # type: ignore[assignment]
|
||||
get_mla_metadata = _raise_flashmla_unavailable # type: ignore[assignment]
|
||||
|
||||
|
||||
def get_mla_metadata_dense_fp8(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_q_tokens_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
num_heads_q: int | None = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
topk: int | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
- cache_seqlens: (batch_size), dtype torch.int32.
|
||||
- num_q_tokens_per_head_k:
|
||||
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
|
||||
- num_heads_k: The number of k heads.
|
||||
- num_heads_q:
|
||||
The number of q heads.
|
||||
This argument is optional when sparse attention is not enabled
|
||||
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
|
||||
- topk: If not None, sparse attention will be enabled,
|
||||
and only tokens in the `indices` array
|
||||
passed to `flash_mla_with_kvcache_sm90` will be attended to.
|
||||
|
||||
Returns:
|
||||
- tile_scheduler_metadata:
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
- num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
if is_fp8_kvcache and topk is None:
|
||||
if not _is_flashmla_available()[0]:
|
||||
_raise_flashmla_unavailable()
|
||||
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
|
||||
cache_seqlens,
|
||||
num_q_tokens_per_head_k,
|
||||
num_heads_k,
|
||||
)
|
||||
return torch.ops._flashmla_C.get_mla_decoding_metadata(
|
||||
cache_seqlens,
|
||||
num_q_tokens_per_head_k,
|
||||
num_heads_k,
|
||||
num_heads_q,
|
||||
is_fp8_kvcache,
|
||||
topk,
|
||||
)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
def flash_mla_with_kvcache_fp8(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
@@ -133,54 +132,11 @@ def flash_mla_with_kvcache(
|
||||
causal: bool = False,
|
||||
descale_q: torch.Tensor | None = None,
|
||||
descale_k: torch.Tensor | None = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
indices: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
- cache_seqlens: (batch_size), torch.int32.
|
||||
- head_dim_v: Head dimension of v.
|
||||
- tile_scheduler_metadata:
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
|
||||
returned by get_mla_metadata.
|
||||
- num_splits:
|
||||
(batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||
- softmax_scale: float.
|
||||
The scale of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(head_dim).
|
||||
- causal: bool. Whether to apply causal attention mask.
|
||||
- descale_q: (batch_size),
|
||||
torch.float32. Descaling factors for Q, used for fp8 quantization.
|
||||
- descale_k: (batch_size),
|
||||
torch.float32. Descaling factors for K, used for fp8 quantization.
|
||||
- is_fp8_kvcache: bool.
|
||||
Whether the k_cache and v_cache are in fp8 format.
|
||||
For the format of FP8 KV cache, please refer to README.md
|
||||
- indices: (batch_size, seq_len_q, topk), torch.int32.
|
||||
If not None, sparse attention will be enabled,
|
||||
and only tokens in the `indices` array will be attended to.
|
||||
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
|
||||
For details about how to set up `indices`, please refer to README.md.
|
||||
|
||||
Returns:
|
||||
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if not _is_flashmla_available()[0]:
|
||||
_raise_flashmla_unavailable()
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
if indices is not None:
|
||||
# NOTE (zyongye): sparse attention is also causal
|
||||
# since it only attend to the tokens before
|
||||
# but here `causal` should not be specified
|
||||
assert not causal, "causal must be `false` if sparse attention is enabled."
|
||||
assert (descale_q is None) == (descale_k is None), (
|
||||
"descale_q and descale_k should be both None or both not None"
|
||||
)
|
||||
|
||||
if indices is None and q.element_size() == 1:
|
||||
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
|
||||
q,
|
||||
k_cache,
|
||||
@@ -194,53 +150,9 @@ def flash_mla_with_kvcache(
|
||||
descale_q,
|
||||
descale_k,
|
||||
)
|
||||
else:
|
||||
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||
q,
|
||||
k_cache,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
is_fp8_kvcache,
|
||||
indices,
|
||||
)
|
||||
return out, softmax_lse
|
||||
|
||||
|
||||
def flash_mla_sparse_prefill(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
d_v: int = 512,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sparse attention prefill kernel
|
||||
|
||||
Args:
|
||||
- q: [s_q, h_q, d_qk], bfloat16
|
||||
- kv: [s_kv, h_kv, d_qk], bfloat16
|
||||
- indices: [s_q, h_kv, topk], int32.
|
||||
Invalid indices should be set to -1 or numbers >= s_kv
|
||||
- sm_scale: float
|
||||
- d_v: The dimension of value vectors. Can only be 512
|
||||
|
||||
Returns:
|
||||
- (output, max_logits, lse)
|
||||
About the definition of output,
|
||||
max_logits and lse, please refer to README.md
|
||||
- output: [s_q, h_q, d_v], bfloat16
|
||||
- max_logits: [s_q, h_q], float
|
||||
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
||||
"""
|
||||
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, sm_scale, d_v)
|
||||
return results
|
||||
|
||||
|
||||
#
|
||||
# TODO: Add fake functions
|
||||
#
|
||||
|
||||
Reference in New Issue
Block a user