[Feature] Support Decode Context Parallel (DCP) for MLA (#23734)

Signed-off-by: hongchao <hongchao@msh.team>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: hongchao <hongchao@msh.team>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
yzds
2025-09-06 13:24:05 +08:00
committed by GitHub
parent 3c529fc994
commit ac201a0eaf
27 changed files with 999 additions and 230 deletions

View File

@@ -201,10 +201,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import is_global_first_rank
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
@@ -323,6 +324,13 @@ class MLACommonPrefillMetadata:
seq_lens: torch.Tensor
workspace: torch.Tensor
# for mla DCP
cp_chunk_seq_lens: Optional[list[list[int]]] = None
origin_context_lens: Optional[list[int]] = None
cp_cu_seq_lens: Optional[torch.Tensor] = None
chunk_size: Optional[int] = None
cu_seq_lens_lst: Optional[list[list[int]]] = None
block_table: torch.Tensor
query_start_loc: torch.Tensor
max_query_len: int
@@ -444,6 +452,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.aot_schedule = current_platform.is_cuda()
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
# Dont try to access the runner on AMD
if self.aot_schedule:
@@ -465,12 +480,27 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * cache_config.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
if self.dcp_world_size > 1:
# Note(hc): The local kvcache is incomplete when DCP is triggered,
# an additional kvcache allgather across the DCP group is therefore
# required, so the workspace has to be enlarged by 1/DCP relative
# to the original TP allocation.
assert self.chunked_prefill_workspace_size % \
self.dcp_world_size == 0
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size +
self.chunked_prefill_workspace_size // self.dcp_world_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
else:
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
self._use_cudnn_prefill = use_cudnn_prefill()
self._use_fi_prefill = use_flashinfer_prefill()
@@ -631,6 +661,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
# Note(hc): update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
seq_lens[:num_decodes] = seq_lens[:num_decodes] \
// self.dcp_world_size + (self.dcp_rank <= \
(seq_lens[:num_decodes] - 1) % self.dcp_world_size)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
@@ -639,6 +675,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
reqs_start = num_decodes # prefill_start
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
# Note(hc): The context lengths in the perspective of dcp rank0.
cp_context_lens_cpu = torch.ceil(context_lens_cpu.float() /
self.dcp_world_size).int()
origin_context_lens = context_lens_cpu.tolist()
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
prefill_query_start_loc = query_start_loc[
@@ -691,20 +731,66 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
if self.dcp_world_size > 1:
# Note(hc): The above max_context_chunk already enforces
# block_size alignment, DCP just need the block_size can
# be divisible by dcp_world_size, because DCP use
# cp_gather_cache which not require `cp_chunk_starts`
# aligned to page_size.
assert max_context_chunk % self.dcp_world_size == 0
cp_max_context_chunk = max_context_chunk // \
self.dcp_world_size
cp_chunk_starts = \
torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, num_prefills) \
* cp_max_context_chunk
cp_chunk_ends = torch.min(
cp_context_lens_cpu.unsqueeze(0),
cp_chunk_starts + cp_max_context_chunk)
cp_chunk_seq_lens = (cp_chunk_ends -
cp_chunk_starts).clamp(min=0)
cp_cu_seq_lens_cpu = torch.zeros(num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(cp_chunk_seq_lens,
dim=1,
out=cp_cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata_cls = \
CudnnPrefillMetadata.ChunkedContextMetadata \
if self._use_cudnn_prefill else \
MLACommonPrefillMetadata.ChunkedContextMetadata
chunked_context_metadata = \
chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
)
if self.dcp_world_size > 1:
chunked_context_metadata = \
chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu \
.to(device, non_blocking=True),
starts=cp_chunk_starts.to(device, non_blocking=True),
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(),
origin_context_lens=origin_context_lens,
cp_cu_seq_lens=cp_cu_seq_lens_cpu \
.to(device, non_blocking=True),
chunk_size=max_context_chunk,
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
)
else:
chunked_context_metadata = \
chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu \
.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
)
if self._use_cudnn_prefill:
chunked_context_metadata.seq_lens = chunk_seq_lens
@@ -757,6 +843,71 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
return attn_metadata
def reorg_kvcache(
allgatered_kv_c_normed: torch.Tensor,
allgatered_k_pe: torch.Tensor,
cp_chunk_seq_lens_lst: list[int],
origin_context_lens: list[int],
cp_world_size: int,
sum_seq_len: int,
max_seq_len: int,
chunk_size: int,
chunk_idx: int,
toks: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
reorg kvcache after cp local gather to tp layout for attn kernel.
Args:
cp_chunk_seq_lens_lst: chunk context lengths under CP.
origin_context_lens: origin full context lengths under CP.
cp_world_size: CP size.
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: equals to max_context_chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
"""
kv_c_segments = []
k_pe_segments = []
src_token_idx = 0
max_seq_len_check = 0
for cp_chunk_seq_len, origin_context_len in zip(cp_chunk_seq_lens_lst,
origin_context_lens):
chunk_context_len = chunk_size
if cp_chunk_seq_len != 0:
chunk_context_len = min(
chunk_context_len, origin_context_len - chunk_size * chunk_idx)
cp_target_rank = (chunk_context_len - 1) % cp_world_size
cur_seq_len = 0
for rank in range(cp_world_size):
if rank > cp_target_rank and cp_chunk_seq_len:
real_cp_chunk_seq_len = cp_chunk_seq_len - 1
else:
real_cp_chunk_seq_len = cp_chunk_seq_len
if real_cp_chunk_seq_len:
kv_c_segment = allgatered_kv_c_normed[rank * toks +
src_token_idx:rank *
toks + src_token_idx +
real_cp_chunk_seq_len]
k_pe_segment = allgatered_k_pe[rank * toks +
src_token_idx:rank * toks +
src_token_idx +
real_cp_chunk_seq_len]
kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment)
cur_seq_len += real_cp_chunk_seq_len
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
src_token_idx += cp_chunk_seq_len
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
assert reorganized_k_pe.shape[0] == sum_seq_len
assert max_seq_len_check == max_seq_len
return reorganized_kv_c_normed, reorganized_k_pe
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
@@ -836,6 +987,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
self.dcp_world_size: Optional[int] = None
def _flash_attn_varlen_diff_headdims(self,
q,
k,
@@ -1152,6 +1305,108 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return output, output_lse
def _context_parallel_compute_prefill_context(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
dcp_world_size: int,
):
assert k_scale is None, "DCP not support sacled kvcache now."
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
assert prefill_metadata.chunked_context is not None
assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None
assert prefill_metadata.chunked_context.origin_context_lens is not None
assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None
assert prefill_metadata.chunked_context.chunk_size is not None
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
ops.cp_gather_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
# workspace
# |------- N tokens --------|--------- N*dcp_size tokens ----------|
# |<- use for loca_gather ->|<--------- use for allgather -------->|
allgather_offset = workspace.shape[0] // (dcp_world_size + 1)
assert allgather_offset * (dcp_world_size +
1) == workspace.shape[0]
assert toks <= allgather_offset
local_gathered_kvcache = workspace[:toks]
cur_allgather_workspace = workspace[
allgather_offset:allgather_offset * (1 + dcp_world_size)]
assert toks * dcp_world_size <= cur_allgather_workspace.shape[0]
cur_allgather_kvcache = cur_allgather_workspace[:toks *
dcp_world_size]
cur_allgather_kvcache.copy_(get_dcp_group().all_gather(
local_gathered_kvcache, dim=0))
assert cur_allgather_kvcache.shape[
-1] == self.kv_lora_rank + self.qk_rope_head_dim
allgatered_kv_c_normed, allgatered_k_pe = \
cur_allgather_kvcache.unsqueeze(
1).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed, k_pe = reorg_kvcache(
allgatered_kv_c_normed,
allgatered_k_pe,
cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.
cp_chunk_seq_lens[i],
origin_context_lens=prefill_metadata.chunked_context.
origin_context_lens,
cp_world_size=dcp_world_size,
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i]
[-1],
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
chunk_size=prefill_metadata.chunked_context.chunk_size,
chunk_idx=i,
toks=toks)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
chunk_idx=i,
q=q,
k=k,
v=v,
)
if output is None:
output = attn_output
output_lse = attn_softmax_lse
else:
output_tmp = torch.empty_like(output)
output_lse_tmp = torch.empty_like(output_lse)
merge_attn_states(
output=output_tmp,
output_lse=output_lse_tmp,
prefix_output=output,
prefix_lse=output_lse,
suffix_output=attn_output,
suffix_lse=attn_softmax_lse,
)
output = output_tmp
output_lse = output_lse_tmp
return output, output_lse
def _forward_prefill(
self,
q: torch.Tensor,
@@ -1162,6 +1417,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_scale: torch.Tensor,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
assert self.dcp_world_size is not None
has_context = attn_metadata.prefill.chunked_context is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
@@ -1181,8 +1437,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_context:
suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
if self.dcp_world_size > 1:
context_output, context_lse = \
self._context_parallel_compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata,
k_scale=None, dcp_world_size=self.dcp_world_size)
else:
context_output, context_lse = \
self._compute_prefill_context(
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
output = torch.empty_like(suffix_output)
merge_attn_states(
@@ -1202,12 +1465,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
@abstractmethod
def _forward_decode(
self,
ql_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: M,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
raise NotImplementedError
def forward(
@@ -1235,6 +1497,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# same expert outputs.
return output.fill_(0)
if self.dcp_world_size is None:
self.dcp_world_size = get_dcp_group().world_size
fp8_attention = self.kv_cache_dtype.startswith("fp8")
num_actual_toks = attn_metadata.num_actual_tokens
@@ -1313,7 +1578,26 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
layer._q_scale)
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer)
decode_q = (decode_ql_nope, decode_q_pe)
if self.dcp_world_size > 1:
assert not fp8_attention, "DCP not support fp8 kvcache now."
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
decode_q = torch.cat(decode_q, dim=-1)
# decode_q do allgather in head dim.
decode_q = get_dcp_group().all_gather(decode_q, dim=1)
# call decode attn
attn_out, lse = self._forward_decode(decode_q, kv_cache,
attn_metadata, layer)
# recorect dcp attn_out with lse.
if self.dcp_world_size > 1:
assert lse is not None, (
"For a mla backend want to enable"
"DCP, it is mandatory that the corresponding decode attn"
"kernel return the softmax lse.")
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
# v_up projection
output[:num_decode_tokens] = self._v_up_proj(attn_out)
return output_padded

View File

@@ -232,7 +232,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
self._workspace.get_buf(),
self.scale, self._num_kv_splits)
return self._v_up_proj(o)
return o
# TODO: Currently we leave it here only for backup in case something is
# wrong with the new SM100 CUTLASS MLA kernel
@@ -265,21 +265,25 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table, self.scale)
return self._v_up_proj(o)
return o
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if self._use_old_cutlass_mla:
# TODO: Remove the old cutlass MLA kernel after more extensive
# testing
return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata)
attn_metadata), None
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata)
attn_metadata), None

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import ClassVar, Optional, Union
import torch
@@ -154,15 +154,20 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashAttnMLAMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if type(q) is tuple:
q_nope, q_pe = q
else:
q_nope, q_pe = torch.split(
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"FP8 FlashAttention MLA not yet supported")

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import ClassVar, Optional, Union
import torch
@@ -169,20 +169,20 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
if type(q) is tuple:
q = torch.cat(q, dim=-1)
o, _ = flash_mla_with_kvcache(
q=q,
assert isinstance(q, torch.Tensor)
o, lse = flash_mla_with_kvcache(
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
@@ -196,4 +196,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
descale_k=layer._k_scale.reshape(1),
)
return self._v_up_proj(o)
return o, lse

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import ClassVar, Optional, Union
import torch
@@ -220,18 +220,19 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
B = q_nope.shape[0]
if type(q) is tuple:
q = torch.cat(q, dim=-1)
q = torch.cat([q_nope, q_pe], dim=-1)
assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
@@ -249,4 +250,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len)
return self._v_up_proj(o)
return o, None

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Optional, Union
import torch
@@ -123,21 +123,22 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
B = q_nope.shape[0]
if type(q) is tuple:
q = torch.cat(q, dim=-1)
q = torch.cat([q_nope, q_pe], dim=-1)
assert isinstance(q, torch.Tensor)
B = q.shape[0]
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
@@ -171,4 +172,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata.decode.seq_lens, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE)
return self._v_up_proj(o)
return o, None