[FlashMLA] Update FlashMLA to expose new arguments (#32810)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-01-21 22:02:39 -07:00
committed by GitHub
parent 49d9653852
commit 889722f3bf
8 changed files with 132 additions and 216 deletions

3
.gitignore vendored
View File

@@ -7,6 +7,9 @@ vllm/vllm_flash_attn/*
# OpenAI triton kernels copied from source # OpenAI triton kernels copied from source
vllm/third_party/triton_kernels/* vllm/third_party/triton_kernels/*
# FlashMLA interface copied from source
vllm/third_party/flashmla/flash_mla_interface.py
# triton jit # triton jit
.triton .triton

View File

@@ -19,7 +19,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
flashmla flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG 526781394b33d9888e4c41952e692266267dd8bf GIT_TAG c2afa9cb93e674d5a9120a170a6da57b89267208
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "" BUILD_COMMAND ""
@@ -30,6 +30,24 @@ endif()
FetchContent_MakeAvailable(flashmla) FetchContent_MakeAvailable(flashmla)
message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
# Vendor FlashMLA interface into vLLM with torch-ops shim.
set(FLASHMLA_VENDOR_DIR "${CMAKE_SOURCE_DIR}/vllm/third_party/flashmla")
file(MAKE_DIRECTORY "${FLASHMLA_VENDOR_DIR}")
file(READ "${flashmla_SOURCE_DIR}/flash_mla/flash_mla_interface.py"
FLASHMLA_INTERFACE_CONTENT)
string(REPLACE "import flash_mla.cuda as flash_mla_cuda"
"import vllm._flashmla_C\nflash_mla_cuda = torch.ops._flashmla_C"
FLASHMLA_INTERFACE_CONTENT
"${FLASHMLA_INTERFACE_CONTENT}")
file(WRITE "${FLASHMLA_VENDOR_DIR}/flash_mla_interface.py"
"${FLASHMLA_INTERFACE_CONTENT}")
# Install the generated flash_mla_interface.py to the wheel
# Use COMPONENT _flashmla_C to ensure it's installed with the C extension
install(FILES "${FLASHMLA_VENDOR_DIR}/flash_mla_interface.py"
DESTINATION vllm/third_party/flashmla/
COMPONENT _flashmla_C)
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later. # The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
# Only build FlashMLA kernels if we are building for something compatible with # Only build FlashMLA kernels if we are building for something compatible with
# sm90a # sm90a
@@ -79,7 +97,6 @@ if(FLASH_MLA_ARCHS)
# sm100 dense prefill & backward # sm100 dense prefill & backward
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
# sm100 sparse prefill # sm100 sparse prefill
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu

View File

@@ -646,6 +646,9 @@ class precompiled_wheel_utils:
triton_kernels_regex = re.compile( triton_kernels_regex = re.compile(
r"vllm/third_party/triton_kernels/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py" r"vllm/third_party/triton_kernels/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
) )
flashmla_regex = re.compile(
r"vllm/third_party/flashmla/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
)
file_members = list( file_members = list(
filter(lambda x: x.filename in files_to_copy, wheel.filelist) filter(lambda x: x.filename in files_to_copy, wheel.filelist)
) )
@@ -657,6 +660,9 @@ class precompiled_wheel_utils:
lambda x: triton_kernels_regex.match(x.filename), wheel.filelist lambda x: triton_kernels_regex.match(x.filename), wheel.filelist
) )
) )
file_members += list(
filter(lambda x: flashmla_regex.match(x.filename), wheel.filelist)
)
for file in file_members: for file in file_members:
print(f"[extract] {file.filename}") print(f"[extract] {file.filename}")
@@ -925,6 +931,10 @@ if _is_cuda():
): ):
# FA3 requires CUDA 12.3 or later # FA3 requires CUDA 12.3 or later
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
if envs.VLLM_USE_PRECOMPILED or (
CUDA_HOME and get_nvcc_cuda_version() >= Version("12.9")
):
# FlashMLA requires CUDA 12.9 or later
# Optional since this doesn't get built (produce an .so file) when # Optional since this doesn't get built (produce an .so file) when
# not targeting a hopper system # not targeting a hopper system
ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True)) ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True))

View File

@@ -53,7 +53,6 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
def _float_to_e8m0_truncate(f: float) -> float: def _float_to_e8m0_truncate(f: float) -> float:
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion. """Simulate SM100's float -> e8m0 -> bf16 scale conversion.
e8m0 format only stores the exponent (power of 2). e8m0 format only stores the exponent (power of 2).
cudaRoundZero truncates toward zero, meaning we round down to the cudaRoundZero truncates toward zero, meaning we round down to the
nearest power of 2. nearest power of 2.

1
vllm/third_party/flashmla/__init__.py vendored Normal file
View File

@@ -0,0 +1 @@
# Sources copied from FlashMLA

View File

@@ -32,8 +32,11 @@ from vllm.v1.attention.backends.utils import (
reshape_query_for_spec_decode, reshape_query_for_spec_decode,
) )
from vllm.v1.attention.ops.flashmla import ( from vllm.v1.attention.ops.flashmla import (
FlashMLASchedMeta,
flash_mla_with_kvcache, flash_mla_with_kvcache,
flash_mla_with_kvcache_fp8,
get_mla_metadata, get_mla_metadata,
get_mla_metadata_dense_fp8,
is_flashmla_dense_supported, is_flashmla_dense_supported,
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
@@ -93,8 +96,7 @@ class FlashMLABackend(MLACommonBackend):
@dataclass @dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata): class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
tile_scheduler_metadata: torch.Tensor scheduler_metadata: FlashMLASchedMeta
num_splits: torch.Tensor
@dataclass @dataclass
@@ -158,46 +160,25 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
# we use the max but all should be the same due to uniform length requirement # we use the max but all should be the same due to uniform length requirement
max_query_len = query_lens_cpu.max().item() max_query_len = query_lens_cpu.max().item()
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1 num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
tile_scheduler_metadata, num_splits = get_mla_metadata( scheduler_metadata, _ = get_mla_metadata(
seq_lens_device, seq_lens_device,
num_q_tokens_per_head_k, num_q_tokens_per_head_k,
1, # MQA for the decode path 1, # MQA for the decode path
is_fp8_kvcache=self.is_fp8_kvcache, is_fp8_kvcache=self.is_fp8_kvcache,
) )
if self.is_fp8_kvcache:
# TODO: we can disambiguate between decode and mixed-prefill decode here tile_scheduler_metadata, num_splits = get_mla_metadata_dense_fp8(
# so we can only use the persistent buffer if a cudagraph is actually seq_lens_device,
# being used. num_q_tokens_per_head_k,
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): 1, # MQA for the decode path
assert self.cg_buf_tile_scheduler_metadata is not None )
assert self.cg_buf_num_splits is not None scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
scheduler_metadata.num_splits = num_splits
sm_parts = tile_scheduler_metadata.size(0)
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[
:sm_parts
]
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
tile_scheduler_metadata = tile_scheduler_metadata_view
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
# Num splits needs to monotonically increasing
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
# it needs to monotonically increasing by 1)
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
num_splits = num_splits_view
return FlashMLADecodeMetadata( return FlashMLADecodeMetadata(
block_table=block_table_tensor, block_table=block_table_tensor,
seq_lens=seq_lens_device, seq_lens=seq_lens_device,
tile_scheduler_metadata=tile_scheduler_metadata, scheduler_metadata=scheduler_metadata,
num_splits=num_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device, dcp_tot_seq_lens=dcp_tot_seq_lens_device,
) )
@@ -272,9 +253,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_decodes = attn_metadata.num_decodes num_decodes = attn_metadata.num_decodes
q = reshape_query_for_spec_decode(q, num_decodes) q = reshape_query_for_spec_decode(q, num_decodes)
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata scheduler_metadata = attn_metadata.decode.scheduler_metadata
num_splits = attn_metadata.decode.num_splits if vllm_is_batch_invariant() and not self.kv_cache_dtype.startswith("fp8"):
if vllm_is_batch_invariant():
device = q.device device = q.device
dtype = torch.int32 dtype = torch.int32
@@ -301,20 +281,35 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
# Non-split path ignores num_splits, but the API requires it: # Non-split path ignores num_splits, but the API requires it:
# zeros of length B+1 # zeros of length B+1
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device) num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
scheduler_metadata.num_splits = num_splits
o, lse = flash_mla_with_kvcache( if self.kv_cache_dtype.startswith("fp8"):
q=q, o, lse = flash_mla_with_kvcache_fp8(
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 q=q,
block_table=attn_metadata.decode.block_table, k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
cache_seqlens=attn_metadata.decode.seq_lens, block_table=attn_metadata.decode.block_table,
head_dim_v=self.kv_lora_rank, cache_seqlens=attn_metadata.decode.seq_lens,
tile_scheduler_metadata=tile_scheduler_metadata, head_dim_v=self.kv_lora_rank,
num_splits=num_splits, tile_scheduler_metadata=scheduler_metadata.tile_scheduler_metadata,
softmax_scale=self.scale, num_splits=scheduler_metadata.num_splits,
causal=True, softmax_scale=self.scale,
descale_q=layer._q_scale.reshape(1), causal=True,
descale_k=layer._k_scale.reshape(1), descale_q=layer._q_scale.reshape(1),
) descale_k=layer._k_scale.reshape(1),
)
else:
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=scheduler_metadata,
softmax_scale=self.scale,
causal=True,
is_fp8_kvcache=False,
)
o = reshape_attn_output_for_spec_decode(o) o = reshape_attn_output_for_spec_decode(o)

View File

@@ -33,7 +33,8 @@ from vllm.v1.attention.backends.utils import (
split_prefill_chunks, split_prefill_chunks,
) )
from vllm.v1.attention.ops.flashmla import ( from vllm.v1.attention.ops.flashmla import (
flash_mla_sparse_prefill, FlashMLASchedMeta,
flash_mla_sparse_fwd,
flash_mla_with_kvcache, flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
) )
@@ -142,8 +143,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
@dataclass @dataclass
class FP8KernelMetadata: class FP8KernelMetadata:
scheduler_metadata: torch.Tensor | None scheduler_metadata: FlashMLASchedMeta
num_splits: torch.Tensor
dummy_block_table: torch.Tensor dummy_block_table: torch.Tensor
cache_lens: torch.Tensor cache_lens: torch.Tensor
@@ -468,7 +468,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
padded_heads = self.fp8_decode_padded_heads padded_heads = self.fp8_decode_padded_heads
# Build metadata for all tokens as a single batch # Build metadata for all tokens as a single batch
tile_scheduler_metadata, num_splits = get_mla_metadata( scheduler_metadata, _ = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor[:1], # Single batch cache_seqlens=self.topk_tokens_tensor[:1], # Single batch
num_q_tokens_per_head_k=num_tokens * padded_heads, num_q_tokens_per_head_k=num_tokens * padded_heads,
topk=self.topk_tokens, topk=self.topk_tokens,
@@ -477,17 +477,8 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
is_fp8_kvcache=True, is_fp8_kvcache=True,
) )
num_sm_parts = tile_scheduler_metadata.size(0)
tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[
:num_sm_parts
]
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
num_splits_view = self.num_splits_buffer[:2]
num_splits_view.copy_(num_splits)
fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata( fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=tile_scheduler_metadata_buffer, scheduler_metadata=scheduler_metadata,
num_splits=num_splits_view,
cache_lens=self.max_model_len_tensor[:1], cache_lens=self.max_model_len_tensor[:1],
dummy_block_table=self.dummy_block_table[:1], dummy_block_table=self.dummy_block_table[:1],
) )
@@ -620,7 +611,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
# Use padded head count since that's what the kernel will see # Use padded head count since that's what the kernel will see
padded_heads = self.fp8_decode_padded_heads padded_heads = self.fp8_decode_padded_heads
tile_scheduler_metadata, num_splits = get_mla_metadata( scheduler_metadata, _ = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor[:num_decodes], cache_seqlens=self.topk_tokens_tensor[:num_decodes],
num_q_tokens_per_head_k=decode_query_len * padded_heads, num_q_tokens_per_head_k=decode_query_len * padded_heads,
topk=self.topk_tokens, topk=self.topk_tokens,
@@ -629,19 +620,8 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
is_fp8_kvcache=True, is_fp8_kvcache=True,
) )
num_sm_parts = tile_scheduler_metadata.size(0)
# Copy to persistent buffer for full-CG support
tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[
:num_sm_parts
]
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
# num_splits has size [num_decodes + 1]
num_splits_view = self.num_splits_buffer[: num_decodes + 1]
num_splits_view.copy_(num_splits)
kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata( kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=tile_scheduler_metadata_buffer, scheduler_metadata=scheduler_metadata,
num_splits=num_splits_view,
dummy_block_table=self.dummy_block_table[:num_decodes], dummy_block_table=self.dummy_block_table[:num_decodes],
cache_lens=self.max_model_len_tensor[:num_decodes], cache_lens=self.max_model_len_tensor[:num_decodes],
) )
@@ -949,7 +929,6 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
head_dim_v=512, head_dim_v=512,
cache_seqlens=kernel_metadata.cache_lens, cache_seqlens=kernel_metadata.cache_lens,
tile_scheduler_metadata=kernel_metadata.scheduler_metadata, tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
num_splits=kernel_metadata.num_splits,
is_fp8_kvcache=True, is_fp8_kvcache=True,
indices=topk_indices, indices=topk_indices,
softmax_scale=self.softmax_scale, softmax_scale=self.softmax_scale,
@@ -985,7 +964,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
q = q_padded q = q_padded
topk_indices = topk_indices.view(num_tokens, 1, -1) topk_indices = topk_indices.view(num_tokens, 1, -1)
output = flash_mla_sparse_prefill( output = flash_mla_sparse_fwd(
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
)[0] )[0]
output = output[:, : self.num_heads, :] output = output[:, : self.num_heads, :]

View File

@@ -78,50 +78,49 @@ def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
return True, None return True, None
def get_mla_metadata( def _raise_flashmla_unavailable(*_args, **_kwargs):
_, reason = _is_flashmla_available()
raise RuntimeError(reason or "FlashMLA is not available")
if _is_flashmla_available()[0]:
from vllm.third_party.flashmla.flash_mla_interface import ( # noqa: F401
FlashMLASchedMeta,
flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_mla_sparse_fwd,
flash_mla_with_kvcache,
get_mla_metadata,
)
else:
class FlashMLASchedMeta: # type: ignore[no-redef]
pass
flash_attn_varlen_func = _raise_flashmla_unavailable # type: ignore[assignment]
flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment]
flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment]
flash_mla_sparse_fwd = _raise_flashmla_unavailable # type: ignore[assignment]
flash_mla_with_kvcache = _raise_flashmla_unavailable # type: ignore[assignment]
get_mla_metadata = _raise_flashmla_unavailable # type: ignore[assignment]
def get_mla_metadata_dense_fp8(
cache_seqlens: torch.Tensor, cache_seqlens: torch.Tensor,
num_q_tokens_per_head_k: int, num_q_tokens_per_head_k: int,
num_heads_k: int, num_heads_k: int,
num_heads_q: int | None = None,
is_fp8_kvcache: bool = False,
topk: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" if not _is_flashmla_available()[0]:
Arguments: _raise_flashmla_unavailable()
- cache_seqlens: (batch_size), dtype torch.int32. return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
- num_q_tokens_per_head_k:
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
- num_heads_k: The number of k heads.
- num_heads_q:
The number of q heads.
This argument is optional when sparse attention is not enabled
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
- topk: If not None, sparse attention will be enabled,
and only tokens in the `indices` array
passed to `flash_mla_with_kvcache_sm90` will be attended to.
Returns:
- tile_scheduler_metadata:
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
"""
if is_fp8_kvcache and topk is None:
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k,
)
return torch.ops._flashmla_C.get_mla_decoding_metadata(
cache_seqlens, cache_seqlens,
num_q_tokens_per_head_k, num_q_tokens_per_head_k,
num_heads_k, num_heads_k,
num_heads_q,
is_fp8_kvcache,
topk,
) )
def flash_mla_with_kvcache( def flash_mla_with_kvcache_fp8(
q: torch.Tensor, q: torch.Tensor,
k_cache: torch.Tensor, k_cache: torch.Tensor,
block_table: torch.Tensor, block_table: torch.Tensor,
@@ -133,114 +132,27 @@ def flash_mla_with_kvcache(
causal: bool = False, causal: bool = False,
descale_q: torch.Tensor | None = None, descale_q: torch.Tensor | None = None,
descale_k: torch.Tensor | None = None, descale_k: torch.Tensor | None = None,
is_fp8_kvcache: bool = False,
indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" if not _is_flashmla_available()[0]:
Arguments: _raise_flashmla_unavailable()
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
- cache_seqlens: (batch_size), torch.int32.
- head_dim_v: Head dimension of v.
- tile_scheduler_metadata:
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
returned by get_mla_metadata.
- num_splits:
(batch_size + 1), torch.int32, returned by get_mla_metadata.
- softmax_scale: float.
The scale of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
- causal: bool. Whether to apply causal attention mask.
- descale_q: (batch_size),
torch.float32. Descaling factors for Q, used for fp8 quantization.
- descale_k: (batch_size),
torch.float32. Descaling factors for K, used for fp8 quantization.
- is_fp8_kvcache: bool.
Whether the k_cache and v_cache are in fp8 format.
For the format of FP8 KV cache, please refer to README.md
- indices: (batch_size, seq_len_q, topk), torch.int32.
If not None, sparse attention will be enabled,
and only tokens in the `indices` array will be attended to.
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
For details about how to set up `indices`, please refer to README.md.
Returns:
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
if indices is not None: out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
# NOTE (zyongye): sparse attention is also causal q,
# since it only attend to the tokens before k_cache,
# but here `causal` should not be specified head_dim_v,
assert not causal, "causal must be `false` if sparse attention is enabled." cache_seqlens,
assert (descale_q is None) == (descale_k is None), ( block_table,
"descale_q and descale_k should be both None or both not None" softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
) )
if indices is None and q.element_size() == 1:
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q,
k_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q,
k_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
is_fp8_kvcache,
indices,
)
return out, softmax_lse return out, softmax_lse
def flash_mla_sparse_prefill(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
- q: [s_q, h_q, d_qk], bfloat16
- kv: [s_kv, h_kv, d_qk], bfloat16
- indices: [s_q, h_kv, topk], int32.
Invalid indices should be set to -1 or numbers >= s_kv
- sm_scale: float
- d_v: The dimension of value vectors. Can only be 512
Returns:
- (output, max_logits, lse)
About the definition of output,
max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, sm_scale, d_v)
return results
# #
# TODO: Add fake functions # TODO: Add fake functions
# #