[FlashMLA] Update FlashMLA to expose new arguments (#32810)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-01-21 22:02:39 -07:00
committed by GitHub
parent 49d9653852
commit 889722f3bf
8 changed files with 132 additions and 216 deletions

3
.gitignore vendored
View File

@@ -7,6 +7,9 @@ vllm/vllm_flash_attn/*
# OpenAI triton kernels copied from source
vllm/third_party/triton_kernels/*
# FlashMLA interface copied from source
vllm/third_party/flashmla/flash_mla_interface.py
# triton jit
.triton

View File

@@ -19,7 +19,7 @@ else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG 526781394b33d9888e4c41952e692266267dd8bf
GIT_TAG c2afa9cb93e674d5a9120a170a6da57b89267208
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
@@ -30,6 +30,24 @@ endif()
FetchContent_MakeAvailable(flashmla)
message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
# Vendor FlashMLA interface into vLLM with torch-ops shim.
set(FLASHMLA_VENDOR_DIR "${CMAKE_SOURCE_DIR}/vllm/third_party/flashmla")
file(MAKE_DIRECTORY "${FLASHMLA_VENDOR_DIR}")
file(READ "${flashmla_SOURCE_DIR}/flash_mla/flash_mla_interface.py"
FLASHMLA_INTERFACE_CONTENT)
string(REPLACE "import flash_mla.cuda as flash_mla_cuda"
"import vllm._flashmla_C\nflash_mla_cuda = torch.ops._flashmla_C"
FLASHMLA_INTERFACE_CONTENT
"${FLASHMLA_INTERFACE_CONTENT}")
file(WRITE "${FLASHMLA_VENDOR_DIR}/flash_mla_interface.py"
"${FLASHMLA_INTERFACE_CONTENT}")
# Install the generated flash_mla_interface.py to the wheel
# Use COMPONENT _flashmla_C to ensure it's installed with the C extension
install(FILES "${FLASHMLA_VENDOR_DIR}/flash_mla_interface.py"
DESTINATION vllm/third_party/flashmla/
COMPONENT _flashmla_C)
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
# Only build FlashMLA kernels if we are building for something compatible with
# sm90a
@@ -79,7 +97,6 @@ if(FLASH_MLA_ARCHS)
# sm100 dense prefill & backward
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
# sm100 sparse prefill
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu

View File

@@ -646,6 +646,9 @@ class precompiled_wheel_utils:
triton_kernels_regex = re.compile(
r"vllm/third_party/triton_kernels/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
)
flashmla_regex = re.compile(
r"vllm/third_party/flashmla/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
)
file_members = list(
filter(lambda x: x.filename in files_to_copy, wheel.filelist)
)
@@ -657,6 +660,9 @@ class precompiled_wheel_utils:
lambda x: triton_kernels_regex.match(x.filename), wheel.filelist
)
)
file_members += list(
filter(lambda x: flashmla_regex.match(x.filename), wheel.filelist)
)
for file in file_members:
print(f"[extract] {file.filename}")
@@ -925,6 +931,10 @@ if _is_cuda():
):
# FA3 requires CUDA 12.3 or later
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
if envs.VLLM_USE_PRECOMPILED or (
CUDA_HOME and get_nvcc_cuda_version() >= Version("12.9")
):
# FlashMLA requires CUDA 12.9 or later
# Optional since this doesn't get built (produce an .so file) when
# not targeting a hopper system
ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True))

View File

@@ -53,7 +53,6 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
def _float_to_e8m0_truncate(f: float) -> float:
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion.
e8m0 format only stores the exponent (power of 2).
cudaRoundZero truncates toward zero, meaning we round down to the
nearest power of 2.

1
vllm/third_party/flashmla/__init__.py vendored Normal file
View File

@@ -0,0 +1 @@
# Sources copied from FlashMLA

View File

@@ -32,8 +32,11 @@ from vllm.v1.attention.backends.utils import (
reshape_query_for_spec_decode,
)
from vllm.v1.attention.ops.flashmla import (
FlashMLASchedMeta,
flash_mla_with_kvcache,
flash_mla_with_kvcache_fp8,
get_mla_metadata,
get_mla_metadata_dense_fp8,
is_flashmla_dense_supported,
)
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -93,8 +96,7 @@ class FlashMLABackend(MLACommonBackend):
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
tile_scheduler_metadata: torch.Tensor
num_splits: torch.Tensor
scheduler_metadata: FlashMLASchedMeta
@dataclass
@@ -158,46 +160,25 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
# we use the max but all should be the same due to uniform length requirement
max_query_len = query_lens_cpu.max().item()
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
tile_scheduler_metadata, num_splits = get_mla_metadata(
scheduler_metadata, _ = get_mla_metadata(
seq_lens_device,
num_q_tokens_per_head_k,
1, # MQA for the decode path
is_fp8_kvcache=self.is_fp8_kvcache,
)
# TODO: we can disambiguate between decode and mixed-prefill decode here
# so we can only use the persistent buffer if a cudagraph is actually
# being used.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None
sm_parts = tile_scheduler_metadata.size(0)
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[
:sm_parts
]
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
tile_scheduler_metadata = tile_scheduler_metadata_view
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
# Num splits needs to monotonically increasing
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
# it needs to monotonically increasing by 1)
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
num_splits = num_splits_view
if self.is_fp8_kvcache:
tile_scheduler_metadata, num_splits = get_mla_metadata_dense_fp8(
seq_lens_device,
num_q_tokens_per_head_k,
1, # MQA for the decode path
)
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
scheduler_metadata.num_splits = num_splits
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
scheduler_metadata=scheduler_metadata,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)
@@ -272,9 +253,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_decodes = attn_metadata.num_decodes
q = reshape_query_for_spec_decode(q, num_decodes)
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
num_splits = attn_metadata.decode.num_splits
if vllm_is_batch_invariant():
scheduler_metadata = attn_metadata.decode.scheduler_metadata
if vllm_is_batch_invariant() and not self.kv_cache_dtype.startswith("fp8"):
device = q.device
dtype = torch.int32
@@ -301,20 +281,35 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
# Non-split path ignores num_splits, but the API requires it:
# zeros of length B+1
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
scheduler_metadata.num_splits = num_splits
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1),
)
if self.kv_cache_dtype.startswith("fp8"):
o, lse = flash_mla_with_kvcache_fp8(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=scheduler_metadata.tile_scheduler_metadata,
num_splits=scheduler_metadata.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1),
)
else:
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=scheduler_metadata,
softmax_scale=self.scale,
causal=True,
is_fp8_kvcache=False,
)
o = reshape_attn_output_for_spec_decode(o)

View File

@@ -33,7 +33,8 @@ from vllm.v1.attention.backends.utils import (
split_prefill_chunks,
)
from vllm.v1.attention.ops.flashmla import (
flash_mla_sparse_prefill,
FlashMLASchedMeta,
flash_mla_sparse_fwd,
flash_mla_with_kvcache,
get_mla_metadata,
)
@@ -142,8 +143,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
@dataclass
class FP8KernelMetadata:
scheduler_metadata: torch.Tensor | None
num_splits: torch.Tensor
scheduler_metadata: FlashMLASchedMeta
dummy_block_table: torch.Tensor
cache_lens: torch.Tensor
@@ -468,7 +468,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
padded_heads = self.fp8_decode_padded_heads
# Build metadata for all tokens as a single batch
tile_scheduler_metadata, num_splits = get_mla_metadata(
scheduler_metadata, _ = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor[:1], # Single batch
num_q_tokens_per_head_k=num_tokens * padded_heads,
topk=self.topk_tokens,
@@ -477,17 +477,8 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
is_fp8_kvcache=True,
)
num_sm_parts = tile_scheduler_metadata.size(0)
tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[
:num_sm_parts
]
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
num_splits_view = self.num_splits_buffer[:2]
num_splits_view.copy_(num_splits)
fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=tile_scheduler_metadata_buffer,
num_splits=num_splits_view,
scheduler_metadata=scheduler_metadata,
cache_lens=self.max_model_len_tensor[:1],
dummy_block_table=self.dummy_block_table[:1],
)
@@ -620,7 +611,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
# Use padded head count since that's what the kernel will see
padded_heads = self.fp8_decode_padded_heads
tile_scheduler_metadata, num_splits = get_mla_metadata(
scheduler_metadata, _ = get_mla_metadata(
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
num_q_tokens_per_head_k=decode_query_len * padded_heads,
topk=self.topk_tokens,
@@ -629,19 +620,8 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
is_fp8_kvcache=True,
)
num_sm_parts = tile_scheduler_metadata.size(0)
# Copy to persistent buffer for full-CG support
tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[
:num_sm_parts
]
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
# num_splits has size [num_decodes + 1]
num_splits_view = self.num_splits_buffer[: num_decodes + 1]
num_splits_view.copy_(num_splits)
kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
scheduler_metadata=tile_scheduler_metadata_buffer,
num_splits=num_splits_view,
scheduler_metadata=scheduler_metadata,
dummy_block_table=self.dummy_block_table[:num_decodes],
cache_lens=self.max_model_len_tensor[:num_decodes],
)
@@ -949,7 +929,6 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
head_dim_v=512,
cache_seqlens=kernel_metadata.cache_lens,
tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
num_splits=kernel_metadata.num_splits,
is_fp8_kvcache=True,
indices=topk_indices,
softmax_scale=self.softmax_scale,
@@ -985,7 +964,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
q = q_padded
topk_indices = topk_indices.view(num_tokens, 1, -1)
output = flash_mla_sparse_prefill(
output = flash_mla_sparse_fwd(
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
)[0]
output = output[:, : self.num_heads, :]

View File

@@ -78,50 +78,49 @@ def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
return True, None
def get_mla_metadata(
def _raise_flashmla_unavailable(*_args, **_kwargs):
_, reason = _is_flashmla_available()
raise RuntimeError(reason or "FlashMLA is not available")
if _is_flashmla_available()[0]:
from vllm.third_party.flashmla.flash_mla_interface import ( # noqa: F401
FlashMLASchedMeta,
flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_mla_sparse_fwd,
flash_mla_with_kvcache,
get_mla_metadata,
)
else:
class FlashMLASchedMeta: # type: ignore[no-redef]
pass
flash_attn_varlen_func = _raise_flashmla_unavailable # type: ignore[assignment]
flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment]
flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment]
flash_mla_sparse_fwd = _raise_flashmla_unavailable # type: ignore[assignment]
flash_mla_with_kvcache = _raise_flashmla_unavailable # type: ignore[assignment]
get_mla_metadata = _raise_flashmla_unavailable # type: ignore[assignment]
def get_mla_metadata_dense_fp8(
cache_seqlens: torch.Tensor,
num_q_tokens_per_head_k: int,
num_heads_k: int,
num_heads_q: int | None = None,
is_fp8_kvcache: bool = False,
topk: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
- cache_seqlens: (batch_size), dtype torch.int32.
- num_q_tokens_per_head_k:
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
- num_heads_k: The number of k heads.
- num_heads_q:
The number of q heads.
This argument is optional when sparse attention is not enabled
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
- topk: If not None, sparse attention will be enabled,
and only tokens in the `indices` array
passed to `flash_mla_with_kvcache_sm90` will be attended to.
Returns:
- tile_scheduler_metadata:
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
"""
if is_fp8_kvcache and topk is None:
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k,
)
return torch.ops._flashmla_C.get_mla_decoding_metadata(
if not _is_flashmla_available()[0]:
_raise_flashmla_unavailable()
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k,
num_heads_q,
is_fp8_kvcache,
topk,
)
def flash_mla_with_kvcache(
def flash_mla_with_kvcache_fp8(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
@@ -133,114 +132,27 @@ def flash_mla_with_kvcache(
causal: bool = False,
descale_q: torch.Tensor | None = None,
descale_k: torch.Tensor | None = None,
is_fp8_kvcache: bool = False,
indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
- cache_seqlens: (batch_size), torch.int32.
- head_dim_v: Head dimension of v.
- tile_scheduler_metadata:
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
returned by get_mla_metadata.
- num_splits:
(batch_size + 1), torch.int32, returned by get_mla_metadata.
- softmax_scale: float.
The scale of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
- causal: bool. Whether to apply causal attention mask.
- descale_q: (batch_size),
torch.float32. Descaling factors for Q, used for fp8 quantization.
- descale_k: (batch_size),
torch.float32. Descaling factors for K, used for fp8 quantization.
- is_fp8_kvcache: bool.
Whether the k_cache and v_cache are in fp8 format.
For the format of FP8 KV cache, please refer to README.md
- indices: (batch_size, seq_len_q, topk), torch.int32.
If not None, sparse attention will be enabled,
and only tokens in the `indices` array will be attended to.
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
For details about how to set up `indices`, please refer to README.md.
Returns:
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if not _is_flashmla_available()[0]:
_raise_flashmla_unavailable()
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if indices is not None:
# NOTE (zyongye): sparse attention is also causal
# since it only attend to the tokens before
# but here `causal` should not be specified
assert not causal, "causal must be `false` if sparse attention is enabled."
assert (descale_q is None) == (descale_k is None), (
"descale_q and descale_k should be both None or both not None"
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q,
k_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
if indices is None and q.element_size() == 1:
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q,
k_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q,
k_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
is_fp8_kvcache,
indices,
)
return out, softmax_lse
def flash_mla_sparse_prefill(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
- q: [s_q, h_q, d_qk], bfloat16
- kv: [s_kv, h_kv, d_qk], bfloat16
- indices: [s_q, h_kv, topk], int32.
Invalid indices should be set to -1 or numbers >= s_kv
- sm_scale: float
- d_v: The dimension of value vectors. Can only be 512
Returns:
- (output, max_logits, lse)
About the definition of output,
max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, sm_scale, d_v)
return results
#
# TODO: Add fake functions
#