[Feature] Support Decode Context Parallel (DCP) for MLA (#23734)
Signed-off-by: hongchao <hongchao@msh.team> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: hongchao <hongchao@msh.team> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -201,10 +201,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import is_global_first_rank
|
||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase,
|
||||
@@ -323,6 +324,13 @@ class MLACommonPrefillMetadata:
|
||||
seq_lens: torch.Tensor
|
||||
workspace: torch.Tensor
|
||||
|
||||
# for mla DCP
|
||||
cp_chunk_seq_lens: Optional[list[list[int]]] = None
|
||||
origin_context_lens: Optional[list[int]] = None
|
||||
cp_cu_seq_lens: Optional[torch.Tensor] = None
|
||||
chunk_size: Optional[int] = None
|
||||
cu_seq_lens_lst: Optional[list[list[int]]] = None
|
||||
|
||||
block_table: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
max_query_len: int
|
||||
@@ -444,6 +452,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.aot_schedule = current_platform.is_cuda()
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
|
||||
# Dont try to access the runner on AMD
|
||||
if self.aot_schedule:
|
||||
@@ -465,12 +480,27 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
128 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * cache_config.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
if self.dcp_world_size > 1:
|
||||
# Note(hc): The local kvcache is incomplete when DCP is triggered,
|
||||
# an additional kvcache allgather across the DCP group is therefore
|
||||
# required, so the workspace has to be enlarged by 1/DCP relative
|
||||
# to the original TP allocation.
|
||||
assert self.chunked_prefill_workspace_size % \
|
||||
self.dcp_world_size == 0
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size +
|
||||
self.chunked_prefill_workspace_size // self.dcp_world_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self._use_cudnn_prefill = use_cudnn_prefill()
|
||||
self._use_fi_prefill = use_flashinfer_prefill()
|
||||
@@ -631,6 +661,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
# Note(hc): update seq_lens of decode reqs under DCP.
|
||||
if self.dcp_world_size > 1:
|
||||
seq_lens[:num_decodes] = seq_lens[:num_decodes] \
|
||||
// self.dcp_world_size + (self.dcp_rank <= \
|
||||
(seq_lens[:num_decodes] - 1) % self.dcp_world_size)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
@@ -639,6 +675,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
reqs_start = num_decodes # prefill_start
|
||||
|
||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||
# Note(hc): The context lengths in the perspective of dcp rank0.
|
||||
cp_context_lens_cpu = torch.ceil(context_lens_cpu.float() /
|
||||
self.dcp_world_size).int()
|
||||
origin_context_lens = context_lens_cpu.tolist()
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
@@ -691,20 +731,66 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
out=cu_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
# Note(hc): The above max_context_chunk already enforces
|
||||
# block_size alignment, DCP just need the block_size can
|
||||
# be divisible by dcp_world_size, because DCP use
|
||||
# cp_gather_cache which not require `cp_chunk_starts`
|
||||
# aligned to page_size.
|
||||
assert max_context_chunk % self.dcp_world_size == 0
|
||||
cp_max_context_chunk = max_context_chunk // \
|
||||
self.dcp_world_size
|
||||
cp_chunk_starts = \
|
||||
torch.arange(num_chunks, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, num_prefills) \
|
||||
* cp_max_context_chunk
|
||||
cp_chunk_ends = torch.min(
|
||||
cp_context_lens_cpu.unsqueeze(0),
|
||||
cp_chunk_starts + cp_max_context_chunk)
|
||||
cp_chunk_seq_lens = (cp_chunk_ends -
|
||||
cp_chunk_starts).clamp(min=0)
|
||||
|
||||
cp_cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||
num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
torch.cumsum(cp_chunk_seq_lens,
|
||||
dim=1,
|
||||
out=cp_cu_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32)
|
||||
|
||||
chunked_context_metadata_cls = \
|
||||
CudnnPrefillMetadata.ChunkedContextMetadata \
|
||||
if self._use_cudnn_prefill else \
|
||||
MLACommonPrefillMetadata.ChunkedContextMetadata
|
||||
|
||||
chunked_context_metadata = \
|
||||
chunked_context_metadata_cls(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
if self.dcp_world_size > 1:
|
||||
chunked_context_metadata = \
|
||||
chunked_context_metadata_cls(
|
||||
cu_seq_lens=cu_seq_lens_cpu \
|
||||
.to(device, non_blocking=True),
|
||||
starts=cp_chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(),
|
||||
origin_context_lens=origin_context_lens,
|
||||
cp_cu_seq_lens=cp_cu_seq_lens_cpu \
|
||||
.to(device, non_blocking=True),
|
||||
chunk_size=max_context_chunk,
|
||||
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
|
||||
)
|
||||
else:
|
||||
chunked_context_metadata = \
|
||||
chunked_context_metadata_cls(
|
||||
cu_seq_lens=cu_seq_lens_cpu \
|
||||
.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
|
||||
if self._use_cudnn_prefill:
|
||||
chunked_context_metadata.seq_lens = chunk_seq_lens
|
||||
@@ -757,6 +843,71 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
return attn_metadata
|
||||
|
||||
|
||||
def reorg_kvcache(
|
||||
allgatered_kv_c_normed: torch.Tensor,
|
||||
allgatered_k_pe: torch.Tensor,
|
||||
cp_chunk_seq_lens_lst: list[int],
|
||||
origin_context_lens: list[int],
|
||||
cp_world_size: int,
|
||||
sum_seq_len: int,
|
||||
max_seq_len: int,
|
||||
chunk_size: int,
|
||||
chunk_idx: int,
|
||||
toks: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
reorg kvcache after cp local gather to tp layout for attn kernel.
|
||||
|
||||
Args:
|
||||
cp_chunk_seq_lens_lst: chunk context lengths under CP.
|
||||
origin_context_lens: origin full context lengths under CP.
|
||||
cp_world_size: CP size.
|
||||
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
|
||||
max_seq_len: the max value of cp_chunk_seq_lens_lst.
|
||||
chunk_size: equals to max_context_chunk from
|
||||
chunked_context_metadata building.
|
||||
chunk_idx: chunk idx of chunked_prefill.
|
||||
toks: the number of tokens for local gather cache.
|
||||
"""
|
||||
kv_c_segments = []
|
||||
k_pe_segments = []
|
||||
src_token_idx = 0
|
||||
max_seq_len_check = 0
|
||||
for cp_chunk_seq_len, origin_context_len in zip(cp_chunk_seq_lens_lst,
|
||||
origin_context_lens):
|
||||
chunk_context_len = chunk_size
|
||||
if cp_chunk_seq_len != 0:
|
||||
chunk_context_len = min(
|
||||
chunk_context_len, origin_context_len - chunk_size * chunk_idx)
|
||||
cp_target_rank = (chunk_context_len - 1) % cp_world_size
|
||||
cur_seq_len = 0
|
||||
for rank in range(cp_world_size):
|
||||
if rank > cp_target_rank and cp_chunk_seq_len:
|
||||
real_cp_chunk_seq_len = cp_chunk_seq_len - 1
|
||||
else:
|
||||
real_cp_chunk_seq_len = cp_chunk_seq_len
|
||||
if real_cp_chunk_seq_len:
|
||||
kv_c_segment = allgatered_kv_c_normed[rank * toks +
|
||||
src_token_idx:rank *
|
||||
toks + src_token_idx +
|
||||
real_cp_chunk_seq_len]
|
||||
k_pe_segment = allgatered_k_pe[rank * toks +
|
||||
src_token_idx:rank * toks +
|
||||
src_token_idx +
|
||||
real_cp_chunk_seq_len]
|
||||
kv_c_segments.append(kv_c_segment)
|
||||
k_pe_segments.append(k_pe_segment)
|
||||
cur_seq_len += real_cp_chunk_seq_len
|
||||
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
|
||||
src_token_idx += cp_chunk_seq_len
|
||||
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
|
||||
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
|
||||
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
|
||||
assert reorganized_k_pe.shape[0] == sum_seq_len
|
||||
assert max_seq_len_check == max_seq_len
|
||||
return reorganized_kv_c_normed, reorganized_k_pe
|
||||
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
@@ -836,6 +987,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
self.vllm_flash_attn_version == 3
|
||||
and current_platform.get_device_capability()[0] == 9)
|
||||
|
||||
self.dcp_world_size: Optional[int] = None
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(self,
|
||||
q,
|
||||
k,
|
||||
@@ -1152,6 +1305,108 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
|
||||
return output, output_lse
|
||||
|
||||
def _context_parallel_compute_prefill_context(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
dcp_world_size: int,
|
||||
):
|
||||
assert k_scale is None, "DCP not support sacled kvcache now."
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
assert prefill_metadata.chunked_context is not None
|
||||
assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None
|
||||
assert prefill_metadata.chunked_context.origin_context_lens is not None
|
||||
assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None
|
||||
assert prefill_metadata.chunked_context.chunk_size is not None
|
||||
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
|
||||
|
||||
output = None
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
workspace = prefill_metadata.chunked_context.workspace
|
||||
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
ops.cp_gather_cache(
|
||||
src_cache=kv_c_and_k_pe_cache,
|
||||
dst=workspace,
|
||||
block_table=prefill_metadata.block_table,
|
||||
cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i],
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
)
|
||||
# workspace
|
||||
# |------- N tokens --------|--------- N*dcp_size tokens ----------|
|
||||
# |<- use for loca_gather ->|<--------- use for allgather -------->|
|
||||
allgather_offset = workspace.shape[0] // (dcp_world_size + 1)
|
||||
assert allgather_offset * (dcp_world_size +
|
||||
1) == workspace.shape[0]
|
||||
assert toks <= allgather_offset
|
||||
local_gathered_kvcache = workspace[:toks]
|
||||
cur_allgather_workspace = workspace[
|
||||
allgather_offset:allgather_offset * (1 + dcp_world_size)]
|
||||
assert toks * dcp_world_size <= cur_allgather_workspace.shape[0]
|
||||
cur_allgather_kvcache = cur_allgather_workspace[:toks *
|
||||
dcp_world_size]
|
||||
cur_allgather_kvcache.copy_(get_dcp_group().all_gather(
|
||||
local_gathered_kvcache, dim=0))
|
||||
assert cur_allgather_kvcache.shape[
|
||||
-1] == self.kv_lora_rank + self.qk_rope_head_dim
|
||||
allgatered_kv_c_normed, allgatered_k_pe = \
|
||||
cur_allgather_kvcache.unsqueeze(
|
||||
1).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
kv_c_normed, k_pe = reorg_kvcache(
|
||||
allgatered_kv_c_normed,
|
||||
allgatered_k_pe,
|
||||
cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.
|
||||
cp_chunk_seq_lens[i],
|
||||
origin_context_lens=prefill_metadata.chunked_context.
|
||||
origin_context_lens,
|
||||
cp_world_size=dcp_world_size,
|
||||
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i]
|
||||
[-1],
|
||||
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
|
||||
chunk_size=prefill_metadata.chunked_context.chunk_size,
|
||||
chunk_idx=i,
|
||||
toks=toks)
|
||||
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
||||
dim=-1)
|
||||
|
||||
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
|
||||
prefill=prefill_metadata,
|
||||
chunk_idx=i,
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
)
|
||||
|
||||
if output is None:
|
||||
output = attn_output
|
||||
output_lse = attn_softmax_lse
|
||||
else:
|
||||
output_tmp = torch.empty_like(output)
|
||||
output_lse_tmp = torch.empty_like(output_lse)
|
||||
merge_attn_states(
|
||||
output=output_tmp,
|
||||
output_lse=output_lse_tmp,
|
||||
prefix_output=output,
|
||||
prefix_lse=output_lse,
|
||||
suffix_output=attn_output,
|
||||
suffix_lse=attn_softmax_lse,
|
||||
)
|
||||
output = output_tmp
|
||||
output_lse = output_lse_tmp
|
||||
|
||||
return output, output_lse
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@@ -1162,6 +1417,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
k_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata.prefill is not None
|
||||
assert self.dcp_world_size is not None
|
||||
|
||||
has_context = attn_metadata.prefill.chunked_context is not None
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
||||
@@ -1181,8 +1437,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
|
||||
if has_context:
|
||||
suffix_output, suffix_lse = output
|
||||
context_output, context_lse = self._compute_prefill_context( \
|
||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
|
||||
if self.dcp_world_size > 1:
|
||||
context_output, context_lse = \
|
||||
self._context_parallel_compute_prefill_context(
|
||||
q, kv_c_and_k_pe_cache, attn_metadata,
|
||||
k_scale=None, dcp_world_size=self.dcp_world_size)
|
||||
else:
|
||||
context_output, context_lse = \
|
||||
self._compute_prefill_context(
|
||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale)
|
||||
|
||||
output = torch.empty_like(suffix_output)
|
||||
merge_attn_states(
|
||||
@@ -1202,12 +1465,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
@abstractmethod
|
||||
def _forward_decode(
|
||||
self,
|
||||
ql_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: M,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
@@ -1235,6 +1497,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
if self.dcp_world_size is None:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
|
||||
fp8_attention = self.kv_cache_dtype.startswith("fp8")
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
@@ -1313,7 +1578,26 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
layer._q_scale)
|
||||
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
|
||||
|
||||
output[:num_decode_tokens] = self._forward_decode(
|
||||
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer)
|
||||
decode_q = (decode_ql_nope, decode_q_pe)
|
||||
if self.dcp_world_size > 1:
|
||||
assert not fp8_attention, "DCP not support fp8 kvcache now."
|
||||
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
|
||||
decode_q = torch.cat(decode_q, dim=-1)
|
||||
# decode_q do allgather in head dim.
|
||||
decode_q = get_dcp_group().all_gather(decode_q, dim=1)
|
||||
|
||||
# call decode attn
|
||||
attn_out, lse = self._forward_decode(decode_q, kv_cache,
|
||||
attn_metadata, layer)
|
||||
|
||||
# recorect dcp attn_out with lse.
|
||||
if self.dcp_world_size > 1:
|
||||
assert lse is not None, (
|
||||
"For a mla backend want to enable"
|
||||
"DCP, it is mandatory that the corresponding decode attn"
|
||||
"kernel return the softmax lse.")
|
||||
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
|
||||
|
||||
# v_up projection
|
||||
output[:num_decode_tokens] = self._v_up_proj(attn_out)
|
||||
return output_padded
|
||||
|
||||
@@ -232,7 +232,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
self._workspace.get_buf(),
|
||||
self.scale, self._num_kv_splits)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
return o
|
||||
|
||||
# TODO: Currently we leave it here only for backup in case something is
|
||||
# wrong with the new SM100 CUTLASS MLA kernel
|
||||
@@ -265,21 +265,25 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table, self.scale)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
return o
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
if self._use_old_cutlass_mla:
|
||||
# TODO: Remove the old cutlass MLA kernel after more extensive
|
||||
# testing
|
||||
return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata)
|
||||
attn_metadata), None
|
||||
|
||||
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata)
|
||||
attn_metadata), None
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -154,15 +154,20 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttnMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError(
|
||||
"FP8 FlashAttention MLA not yet supported")
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -169,20 +169,20 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
assert isinstance(q, torch.Tensor)
|
||||
o, lse = flash_mla_with_kvcache(
|
||||
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
@@ -196,4 +196,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
descale_k=layer._k_scale.reshape(1),
|
||||
)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
return o, lse
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -220,18 +220,19 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
B = q_nope.shape[0]
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
assert isinstance(q, torch.Tensor)
|
||||
B = q.shape[0]
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
@@ -249,4 +250,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
attn_metadata.decode.paged_kv_indices,
|
||||
attn_metadata.decode.paged_kv_last_page_len)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
return o, None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -123,21 +123,22 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
assert isinstance(q, torch.Tensor)
|
||||
B = q.shape[0]
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
@@ -171,4 +172,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
attn_metadata.decode.seq_lens, attn_logits,
|
||||
num_kv_splits, self.scale, PAGE_SIZE)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
return o, None
|
||||
|
||||
Reference in New Issue
Block a user