[FlashMLA] Update FlashMLA to expose new arguments (#32810)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -7,6 +7,9 @@ vllm/vllm_flash_attn/*
|
|||||||
# OpenAI triton kernels copied from source
|
# OpenAI triton kernels copied from source
|
||||||
vllm/third_party/triton_kernels/*
|
vllm/third_party/triton_kernels/*
|
||||||
|
|
||||||
|
# FlashMLA interface copied from source
|
||||||
|
vllm/third_party/flashmla/flash_mla_interface.py
|
||||||
|
|
||||||
# triton jit
|
# triton jit
|
||||||
.triton
|
.triton
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ else()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
flashmla
|
flashmla
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
|
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
|
||||||
GIT_TAG 526781394b33d9888e4c41952e692266267dd8bf
|
GIT_TAG c2afa9cb93e674d5a9120a170a6da57b89267208
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
CONFIGURE_COMMAND ""
|
CONFIGURE_COMMAND ""
|
||||||
BUILD_COMMAND ""
|
BUILD_COMMAND ""
|
||||||
@@ -30,6 +30,24 @@ endif()
|
|||||||
FetchContent_MakeAvailable(flashmla)
|
FetchContent_MakeAvailable(flashmla)
|
||||||
message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
|
message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
|
||||||
|
|
||||||
|
# Vendor FlashMLA interface into vLLM with torch-ops shim.
|
||||||
|
set(FLASHMLA_VENDOR_DIR "${CMAKE_SOURCE_DIR}/vllm/third_party/flashmla")
|
||||||
|
file(MAKE_DIRECTORY "${FLASHMLA_VENDOR_DIR}")
|
||||||
|
file(READ "${flashmla_SOURCE_DIR}/flash_mla/flash_mla_interface.py"
|
||||||
|
FLASHMLA_INTERFACE_CONTENT)
|
||||||
|
string(REPLACE "import flash_mla.cuda as flash_mla_cuda"
|
||||||
|
"import vllm._flashmla_C\nflash_mla_cuda = torch.ops._flashmla_C"
|
||||||
|
FLASHMLA_INTERFACE_CONTENT
|
||||||
|
"${FLASHMLA_INTERFACE_CONTENT}")
|
||||||
|
file(WRITE "${FLASHMLA_VENDOR_DIR}/flash_mla_interface.py"
|
||||||
|
"${FLASHMLA_INTERFACE_CONTENT}")
|
||||||
|
|
||||||
|
# Install the generated flash_mla_interface.py to the wheel
|
||||||
|
# Use COMPONENT _flashmla_C to ensure it's installed with the C extension
|
||||||
|
install(FILES "${FLASHMLA_VENDOR_DIR}/flash_mla_interface.py"
|
||||||
|
DESTINATION vllm/third_party/flashmla/
|
||||||
|
COMPONENT _flashmla_C)
|
||||||
|
|
||||||
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
|
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
|
||||||
# Only build FlashMLA kernels if we are building for something compatible with
|
# Only build FlashMLA kernels if we are building for something compatible with
|
||||||
# sm90a
|
# sm90a
|
||||||
@@ -79,7 +97,6 @@ if(FLASH_MLA_ARCHS)
|
|||||||
|
|
||||||
# sm100 dense prefill & backward
|
# sm100 dense prefill & backward
|
||||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
|
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
|
||||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
|
|
||||||
|
|
||||||
# sm100 sparse prefill
|
# sm100 sparse prefill
|
||||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu
|
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu
|
||||||
|
|||||||
10
setup.py
10
setup.py
@@ -646,6 +646,9 @@ class precompiled_wheel_utils:
|
|||||||
triton_kernels_regex = re.compile(
|
triton_kernels_regex = re.compile(
|
||||||
r"vllm/third_party/triton_kernels/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
|
r"vllm/third_party/triton_kernels/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
|
||||||
)
|
)
|
||||||
|
flashmla_regex = re.compile(
|
||||||
|
r"vllm/third_party/flashmla/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
|
||||||
|
)
|
||||||
file_members = list(
|
file_members = list(
|
||||||
filter(lambda x: x.filename in files_to_copy, wheel.filelist)
|
filter(lambda x: x.filename in files_to_copy, wheel.filelist)
|
||||||
)
|
)
|
||||||
@@ -657,6 +660,9 @@ class precompiled_wheel_utils:
|
|||||||
lambda x: triton_kernels_regex.match(x.filename), wheel.filelist
|
lambda x: triton_kernels_regex.match(x.filename), wheel.filelist
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
file_members += list(
|
||||||
|
filter(lambda x: flashmla_regex.match(x.filename), wheel.filelist)
|
||||||
|
)
|
||||||
|
|
||||||
for file in file_members:
|
for file in file_members:
|
||||||
print(f"[extract] {file.filename}")
|
print(f"[extract] {file.filename}")
|
||||||
@@ -925,6 +931,10 @@ if _is_cuda():
|
|||||||
):
|
):
|
||||||
# FA3 requires CUDA 12.3 or later
|
# FA3 requires CUDA 12.3 or later
|
||||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||||
|
if envs.VLLM_USE_PRECOMPILED or (
|
||||||
|
CUDA_HOME and get_nvcc_cuda_version() >= Version("12.9")
|
||||||
|
):
|
||||||
|
# FlashMLA requires CUDA 12.9 or later
|
||||||
# Optional since this doesn't get built (produce an .so file) when
|
# Optional since this doesn't get built (produce an .so file) when
|
||||||
# not targeting a hopper system
|
# not targeting a hopper system
|
||||||
ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True))
|
ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True))
|
||||||
|
|||||||
@@ -53,7 +53,6 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
|
|||||||
|
|
||||||
def _float_to_e8m0_truncate(f: float) -> float:
|
def _float_to_e8m0_truncate(f: float) -> float:
|
||||||
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion.
|
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion.
|
||||||
|
|
||||||
e8m0 format only stores the exponent (power of 2).
|
e8m0 format only stores the exponent (power of 2).
|
||||||
cudaRoundZero truncates toward zero, meaning we round down to the
|
cudaRoundZero truncates toward zero, meaning we round down to the
|
||||||
nearest power of 2.
|
nearest power of 2.
|
||||||
|
|||||||
1
vllm/third_party/flashmla/__init__.py
vendored
Normal file
1
vllm/third_party/flashmla/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Sources copied from FlashMLA
|
||||||
@@ -32,8 +32,11 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
reshape_query_for_spec_decode,
|
reshape_query_for_spec_decode,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.ops.flashmla import (
|
from vllm.v1.attention.ops.flashmla import (
|
||||||
|
FlashMLASchedMeta,
|
||||||
flash_mla_with_kvcache,
|
flash_mla_with_kvcache,
|
||||||
|
flash_mla_with_kvcache_fp8,
|
||||||
get_mla_metadata,
|
get_mla_metadata,
|
||||||
|
get_mla_metadata_dense_fp8,
|
||||||
is_flashmla_dense_supported,
|
is_flashmla_dense_supported,
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
@@ -93,8 +96,7 @@ class FlashMLABackend(MLACommonBackend):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||||
tile_scheduler_metadata: torch.Tensor
|
scheduler_metadata: FlashMLASchedMeta
|
||||||
num_splits: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -158,46 +160,25 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
# we use the max but all should be the same due to uniform length requirement
|
# we use the max but all should be the same due to uniform length requirement
|
||||||
max_query_len = query_lens_cpu.max().item()
|
max_query_len = query_lens_cpu.max().item()
|
||||||
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
|
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
|
||||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
scheduler_metadata, _ = get_mla_metadata(
|
||||||
seq_lens_device,
|
seq_lens_device,
|
||||||
num_q_tokens_per_head_k,
|
num_q_tokens_per_head_k,
|
||||||
1, # MQA for the decode path
|
1, # MQA for the decode path
|
||||||
is_fp8_kvcache=self.is_fp8_kvcache,
|
is_fp8_kvcache=self.is_fp8_kvcache,
|
||||||
)
|
)
|
||||||
|
if self.is_fp8_kvcache:
|
||||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
tile_scheduler_metadata, num_splits = get_mla_metadata_dense_fp8(
|
||||||
# so we can only use the persistent buffer if a cudagraph is actually
|
seq_lens_device,
|
||||||
# being used.
|
num_q_tokens_per_head_k,
|
||||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
1, # MQA for the decode path
|
||||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
)
|
||||||
assert self.cg_buf_num_splits is not None
|
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
|
||||||
|
scheduler_metadata.num_splits = num_splits
|
||||||
sm_parts = tile_scheduler_metadata.size(0)
|
|
||||||
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
|
|
||||||
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
|
|
||||||
tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[
|
|
||||||
:sm_parts
|
|
||||||
]
|
|
||||||
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
|
|
||||||
tile_scheduler_metadata = tile_scheduler_metadata_view
|
|
||||||
|
|
||||||
# Num splits is per-batch, varying size (batch_size,)
|
|
||||||
n = num_splits.size(0)
|
|
||||||
# make sure static buffer is large enough
|
|
||||||
assert n <= self.cg_buf_num_splits.size(0)
|
|
||||||
num_splits_view = self.cg_buf_num_splits[:n]
|
|
||||||
num_splits_view.copy_(num_splits)
|
|
||||||
# Num splits needs to monotonically increasing
|
|
||||||
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
|
|
||||||
# it needs to monotonically increasing by 1)
|
|
||||||
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
|
|
||||||
num_splits = num_splits_view
|
|
||||||
|
|
||||||
return FlashMLADecodeMetadata(
|
return FlashMLADecodeMetadata(
|
||||||
block_table=block_table_tensor,
|
block_table=block_table_tensor,
|
||||||
seq_lens=seq_lens_device,
|
seq_lens=seq_lens_device,
|
||||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
scheduler_metadata=scheduler_metadata,
|
||||||
num_splits=num_splits,
|
|
||||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -272,9 +253,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
num_decodes = attn_metadata.num_decodes
|
num_decodes = attn_metadata.num_decodes
|
||||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||||
|
|
||||||
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
|
scheduler_metadata = attn_metadata.decode.scheduler_metadata
|
||||||
num_splits = attn_metadata.decode.num_splits
|
if vllm_is_batch_invariant() and not self.kv_cache_dtype.startswith("fp8"):
|
||||||
if vllm_is_batch_invariant():
|
|
||||||
device = q.device
|
device = q.device
|
||||||
dtype = torch.int32
|
dtype = torch.int32
|
||||||
|
|
||||||
@@ -301,20 +281,35 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
# Non-split path ignores num_splits, but the API requires it:
|
# Non-split path ignores num_splits, but the API requires it:
|
||||||
# zeros of length B+1
|
# zeros of length B+1
|
||||||
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
|
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
|
||||||
|
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
|
||||||
|
scheduler_metadata.num_splits = num_splits
|
||||||
|
|
||||||
o, lse = flash_mla_with_kvcache(
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
q=q,
|
o, lse = flash_mla_with_kvcache_fp8(
|
||||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
q=q,
|
||||||
block_table=attn_metadata.decode.block_table,
|
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
block_table=attn_metadata.decode.block_table,
|
||||||
head_dim_v=self.kv_lora_rank,
|
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
head_dim_v=self.kv_lora_rank,
|
||||||
num_splits=num_splits,
|
tile_scheduler_metadata=scheduler_metadata.tile_scheduler_metadata,
|
||||||
softmax_scale=self.scale,
|
num_splits=scheduler_metadata.num_splits,
|
||||||
causal=True,
|
softmax_scale=self.scale,
|
||||||
descale_q=layer._q_scale.reshape(1),
|
causal=True,
|
||||||
descale_k=layer._k_scale.reshape(1),
|
descale_q=layer._q_scale.reshape(1),
|
||||||
)
|
descale_k=layer._k_scale.reshape(1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
o, lse = flash_mla_with_kvcache(
|
||||||
|
q=q,
|
||||||
|
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||||
|
block_table=attn_metadata.decode.block_table,
|
||||||
|
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||||
|
head_dim_v=self.kv_lora_rank,
|
||||||
|
tile_scheduler_metadata=scheduler_metadata,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
is_fp8_kvcache=False,
|
||||||
|
)
|
||||||
|
|
||||||
o = reshape_attn_output_for_spec_decode(o)
|
o = reshape_attn_output_for_spec_decode(o)
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ from vllm.v1.attention.backends.utils import (
|
|||||||
split_prefill_chunks,
|
split_prefill_chunks,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.ops.flashmla import (
|
from vllm.v1.attention.ops.flashmla import (
|
||||||
flash_mla_sparse_prefill,
|
FlashMLASchedMeta,
|
||||||
|
flash_mla_sparse_fwd,
|
||||||
flash_mla_with_kvcache,
|
flash_mla_with_kvcache,
|
||||||
get_mla_metadata,
|
get_mla_metadata,
|
||||||
)
|
)
|
||||||
@@ -142,8 +143,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FP8KernelMetadata:
|
class FP8KernelMetadata:
|
||||||
scheduler_metadata: torch.Tensor | None
|
scheduler_metadata: FlashMLASchedMeta
|
||||||
num_splits: torch.Tensor
|
|
||||||
dummy_block_table: torch.Tensor
|
dummy_block_table: torch.Tensor
|
||||||
cache_lens: torch.Tensor
|
cache_lens: torch.Tensor
|
||||||
|
|
||||||
@@ -468,7 +468,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
|||||||
padded_heads = self.fp8_decode_padded_heads
|
padded_heads = self.fp8_decode_padded_heads
|
||||||
|
|
||||||
# Build metadata for all tokens as a single batch
|
# Build metadata for all tokens as a single batch
|
||||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
scheduler_metadata, _ = get_mla_metadata(
|
||||||
cache_seqlens=self.topk_tokens_tensor[:1], # Single batch
|
cache_seqlens=self.topk_tokens_tensor[:1], # Single batch
|
||||||
num_q_tokens_per_head_k=num_tokens * padded_heads,
|
num_q_tokens_per_head_k=num_tokens * padded_heads,
|
||||||
topk=self.topk_tokens,
|
topk=self.topk_tokens,
|
||||||
@@ -477,17 +477,8 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
|||||||
is_fp8_kvcache=True,
|
is_fp8_kvcache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_sm_parts = tile_scheduler_metadata.size(0)
|
|
||||||
tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[
|
|
||||||
:num_sm_parts
|
|
||||||
]
|
|
||||||
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
|
|
||||||
num_splits_view = self.num_splits_buffer[:2]
|
|
||||||
num_splits_view.copy_(num_splits)
|
|
||||||
|
|
||||||
fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||||
scheduler_metadata=tile_scheduler_metadata_buffer,
|
scheduler_metadata=scheduler_metadata,
|
||||||
num_splits=num_splits_view,
|
|
||||||
cache_lens=self.max_model_len_tensor[:1],
|
cache_lens=self.max_model_len_tensor[:1],
|
||||||
dummy_block_table=self.dummy_block_table[:1],
|
dummy_block_table=self.dummy_block_table[:1],
|
||||||
)
|
)
|
||||||
@@ -620,7 +611,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
|||||||
|
|
||||||
# Use padded head count since that's what the kernel will see
|
# Use padded head count since that's what the kernel will see
|
||||||
padded_heads = self.fp8_decode_padded_heads
|
padded_heads = self.fp8_decode_padded_heads
|
||||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
scheduler_metadata, _ = get_mla_metadata(
|
||||||
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
|
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
|
||||||
num_q_tokens_per_head_k=decode_query_len * padded_heads,
|
num_q_tokens_per_head_k=decode_query_len * padded_heads,
|
||||||
topk=self.topk_tokens,
|
topk=self.topk_tokens,
|
||||||
@@ -629,19 +620,8 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
|||||||
is_fp8_kvcache=True,
|
is_fp8_kvcache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_sm_parts = tile_scheduler_metadata.size(0)
|
|
||||||
# Copy to persistent buffer for full-CG support
|
|
||||||
tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[
|
|
||||||
:num_sm_parts
|
|
||||||
]
|
|
||||||
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
|
|
||||||
# num_splits has size [num_decodes + 1]
|
|
||||||
num_splits_view = self.num_splits_buffer[: num_decodes + 1]
|
|
||||||
num_splits_view.copy_(num_splits)
|
|
||||||
|
|
||||||
kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
|
kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||||
scheduler_metadata=tile_scheduler_metadata_buffer,
|
scheduler_metadata=scheduler_metadata,
|
||||||
num_splits=num_splits_view,
|
|
||||||
dummy_block_table=self.dummy_block_table[:num_decodes],
|
dummy_block_table=self.dummy_block_table[:num_decodes],
|
||||||
cache_lens=self.max_model_len_tensor[:num_decodes],
|
cache_lens=self.max_model_len_tensor[:num_decodes],
|
||||||
)
|
)
|
||||||
@@ -949,7 +929,6 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
|||||||
head_dim_v=512,
|
head_dim_v=512,
|
||||||
cache_seqlens=kernel_metadata.cache_lens,
|
cache_seqlens=kernel_metadata.cache_lens,
|
||||||
tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
|
tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
|
||||||
num_splits=kernel_metadata.num_splits,
|
|
||||||
is_fp8_kvcache=True,
|
is_fp8_kvcache=True,
|
||||||
indices=topk_indices,
|
indices=topk_indices,
|
||||||
softmax_scale=self.softmax_scale,
|
softmax_scale=self.softmax_scale,
|
||||||
@@ -985,7 +964,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
|||||||
q = q_padded
|
q = q_padded
|
||||||
|
|
||||||
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
||||||
output = flash_mla_sparse_prefill(
|
output = flash_mla_sparse_fwd(
|
||||||
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
|
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
|
||||||
)[0]
|
)[0]
|
||||||
output = output[:, : self.num_heads, :]
|
output = output[:, : self.num_heads, :]
|
||||||
|
|||||||
@@ -78,50 +78,49 @@ def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
|
|||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
def get_mla_metadata(
|
def _raise_flashmla_unavailable(*_args, **_kwargs):
|
||||||
|
_, reason = _is_flashmla_available()
|
||||||
|
raise RuntimeError(reason or "FlashMLA is not available")
|
||||||
|
|
||||||
|
|
||||||
|
if _is_flashmla_available()[0]:
|
||||||
|
from vllm.third_party.flashmla.flash_mla_interface import ( # noqa: F401
|
||||||
|
FlashMLASchedMeta,
|
||||||
|
flash_attn_varlen_func,
|
||||||
|
flash_attn_varlen_kvpacked_func,
|
||||||
|
flash_attn_varlen_qkvpacked_func,
|
||||||
|
flash_mla_sparse_fwd,
|
||||||
|
flash_mla_with_kvcache,
|
||||||
|
get_mla_metadata,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
class FlashMLASchedMeta: # type: ignore[no-redef]
|
||||||
|
pass
|
||||||
|
|
||||||
|
flash_attn_varlen_func = _raise_flashmla_unavailable # type: ignore[assignment]
|
||||||
|
flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment]
|
||||||
|
flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment]
|
||||||
|
flash_mla_sparse_fwd = _raise_flashmla_unavailable # type: ignore[assignment]
|
||||||
|
flash_mla_with_kvcache = _raise_flashmla_unavailable # type: ignore[assignment]
|
||||||
|
get_mla_metadata = _raise_flashmla_unavailable # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
|
def get_mla_metadata_dense_fp8(
|
||||||
cache_seqlens: torch.Tensor,
|
cache_seqlens: torch.Tensor,
|
||||||
num_q_tokens_per_head_k: int,
|
num_q_tokens_per_head_k: int,
|
||||||
num_heads_k: int,
|
num_heads_k: int,
|
||||||
num_heads_q: int | None = None,
|
|
||||||
is_fp8_kvcache: bool = False,
|
|
||||||
topk: int | None = None,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
if not _is_flashmla_available()[0]:
|
||||||
Arguments:
|
_raise_flashmla_unavailable()
|
||||||
- cache_seqlens: (batch_size), dtype torch.int32.
|
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
|
||||||
- num_q_tokens_per_head_k:
|
|
||||||
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
|
|
||||||
- num_heads_k: The number of k heads.
|
|
||||||
- num_heads_q:
|
|
||||||
The number of q heads.
|
|
||||||
This argument is optional when sparse attention is not enabled
|
|
||||||
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
|
|
||||||
- topk: If not None, sparse attention will be enabled,
|
|
||||||
and only tokens in the `indices` array
|
|
||||||
passed to `flash_mla_with_kvcache_sm90` will be attended to.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- tile_scheduler_metadata:
|
|
||||||
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
|
||||||
- num_splits: (batch_size + 1), dtype torch.int32.
|
|
||||||
"""
|
|
||||||
if is_fp8_kvcache and topk is None:
|
|
||||||
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
|
|
||||||
cache_seqlens,
|
|
||||||
num_q_tokens_per_head_k,
|
|
||||||
num_heads_k,
|
|
||||||
)
|
|
||||||
return torch.ops._flashmla_C.get_mla_decoding_metadata(
|
|
||||||
cache_seqlens,
|
cache_seqlens,
|
||||||
num_q_tokens_per_head_k,
|
num_q_tokens_per_head_k,
|
||||||
num_heads_k,
|
num_heads_k,
|
||||||
num_heads_q,
|
|
||||||
is_fp8_kvcache,
|
|
||||||
topk,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def flash_mla_with_kvcache(
|
def flash_mla_with_kvcache_fp8(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k_cache: torch.Tensor,
|
k_cache: torch.Tensor,
|
||||||
block_table: torch.Tensor,
|
block_table: torch.Tensor,
|
||||||
@@ -133,114 +132,27 @@ def flash_mla_with_kvcache(
|
|||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
descale_q: torch.Tensor | None = None,
|
descale_q: torch.Tensor | None = None,
|
||||||
descale_k: torch.Tensor | None = None,
|
descale_k: torch.Tensor | None = None,
|
||||||
is_fp8_kvcache: bool = False,
|
|
||||||
indices: torch.Tensor | None = None,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
if not _is_flashmla_available()[0]:
|
||||||
Arguments:
|
_raise_flashmla_unavailable()
|
||||||
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
|
||||||
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
|
||||||
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
|
||||||
- cache_seqlens: (batch_size), torch.int32.
|
|
||||||
- head_dim_v: Head dimension of v.
|
|
||||||
- tile_scheduler_metadata:
|
|
||||||
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
|
|
||||||
returned by get_mla_metadata.
|
|
||||||
- num_splits:
|
|
||||||
(batch_size + 1), torch.int32, returned by get_mla_metadata.
|
|
||||||
- softmax_scale: float.
|
|
||||||
The scale of QK^T before applying softmax.
|
|
||||||
Default to 1 / sqrt(head_dim).
|
|
||||||
- causal: bool. Whether to apply causal attention mask.
|
|
||||||
- descale_q: (batch_size),
|
|
||||||
torch.float32. Descaling factors for Q, used for fp8 quantization.
|
|
||||||
- descale_k: (batch_size),
|
|
||||||
torch.float32. Descaling factors for K, used for fp8 quantization.
|
|
||||||
- is_fp8_kvcache: bool.
|
|
||||||
Whether the k_cache and v_cache are in fp8 format.
|
|
||||||
For the format of FP8 KV cache, please refer to README.md
|
|
||||||
- indices: (batch_size, seq_len_q, topk), torch.int32.
|
|
||||||
If not None, sparse attention will be enabled,
|
|
||||||
and only tokens in the `indices` array will be attended to.
|
|
||||||
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
|
|
||||||
For details about how to set up `indices`, please refer to README.md.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
|
||||||
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
|
||||||
"""
|
|
||||||
if softmax_scale is None:
|
if softmax_scale is None:
|
||||||
softmax_scale = q.shape[-1] ** (-0.5)
|
softmax_scale = q.shape[-1] ** (-0.5)
|
||||||
if indices is not None:
|
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
|
||||||
# NOTE (zyongye): sparse attention is also causal
|
q,
|
||||||
# since it only attend to the tokens before
|
k_cache,
|
||||||
# but here `causal` should not be specified
|
head_dim_v,
|
||||||
assert not causal, "causal must be `false` if sparse attention is enabled."
|
cache_seqlens,
|
||||||
assert (descale_q is None) == (descale_k is None), (
|
block_table,
|
||||||
"descale_q and descale_k should be both None or both not None"
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
tile_scheduler_metadata,
|
||||||
|
num_splits,
|
||||||
|
descale_q,
|
||||||
|
descale_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
if indices is None and q.element_size() == 1:
|
|
||||||
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
|
|
||||||
q,
|
|
||||||
k_cache,
|
|
||||||
head_dim_v,
|
|
||||||
cache_seqlens,
|
|
||||||
block_table,
|
|
||||||
softmax_scale,
|
|
||||||
causal,
|
|
||||||
tile_scheduler_metadata,
|
|
||||||
num_splits,
|
|
||||||
descale_q,
|
|
||||||
descale_k,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
|
||||||
q,
|
|
||||||
k_cache,
|
|
||||||
head_dim_v,
|
|
||||||
cache_seqlens,
|
|
||||||
block_table,
|
|
||||||
softmax_scale,
|
|
||||||
causal,
|
|
||||||
tile_scheduler_metadata,
|
|
||||||
num_splits,
|
|
||||||
is_fp8_kvcache,
|
|
||||||
indices,
|
|
||||||
)
|
|
||||||
return out, softmax_lse
|
return out, softmax_lse
|
||||||
|
|
||||||
|
|
||||||
def flash_mla_sparse_prefill(
|
|
||||||
q: torch.Tensor,
|
|
||||||
kv: torch.Tensor,
|
|
||||||
indices: torch.Tensor,
|
|
||||||
sm_scale: float,
|
|
||||||
d_v: int = 512,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Sparse attention prefill kernel
|
|
||||||
|
|
||||||
Args:
|
|
||||||
- q: [s_q, h_q, d_qk], bfloat16
|
|
||||||
- kv: [s_kv, h_kv, d_qk], bfloat16
|
|
||||||
- indices: [s_q, h_kv, topk], int32.
|
|
||||||
Invalid indices should be set to -1 or numbers >= s_kv
|
|
||||||
- sm_scale: float
|
|
||||||
- d_v: The dimension of value vectors. Can only be 512
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- (output, max_logits, lse)
|
|
||||||
About the definition of output,
|
|
||||||
max_logits and lse, please refer to README.md
|
|
||||||
- output: [s_q, h_q, d_v], bfloat16
|
|
||||||
- max_logits: [s_q, h_q], float
|
|
||||||
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
|
||||||
"""
|
|
||||||
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, sm_scale, d_v)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# TODO: Add fake functions
|
# TODO: Add fake functions
|
||||||
#
|
#
|
||||||
|
|||||||
Reference in New Issue
Block a user