[Perf] Improve MLA on V1 (#14540)
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
@@ -223,6 +223,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
scaled_quantize)
|
scaled_quantize)
|
||||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cdiv, round_down
|
from vllm.utils import cdiv, round_down
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -471,18 +472,23 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
common_prefix_len: int) -> M:
|
common_prefix_len: int) -> M:
|
||||||
assert self._num_decodes + self._num_prefills == num_reqs
|
assert self._num_decodes + self._num_prefills == num_reqs
|
||||||
|
|
||||||
|
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||||
|
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||||
|
# it blocks on all previous kernels.
|
||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
|
|
||||||
device, non_blocking=True)
|
|
||||||
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device,
|
|
||||||
non_blocking=True)
|
|
||||||
block_table = (
|
block_table = (
|
||||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||||
|
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
|
||||||
|
device, non_blocking=True)
|
||||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||||
device, non_blocking=True).long()
|
device, non_blocking=True).long()
|
||||||
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
||||||
device, non_blocking=True).long()
|
device, non_blocking=True).long()
|
||||||
|
|
||||||
|
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
||||||
|
seq_lens = seq_lens_cpu.to(device, non_blocking=True)
|
||||||
|
max_query_len = seq_lens_cpu.max().item()
|
||||||
|
|
||||||
prefill_metadata = None
|
prefill_metadata = None
|
||||||
if self._num_prefills > 0:
|
if self._num_prefills > 0:
|
||||||
reqs_start = self._num_decodes # prefill_start
|
reqs_start = self._num_decodes # prefill_start
|
||||||
@@ -490,24 +496,22 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
|
|
||||||
context_lens_cpu = self.runner.input_batch.\
|
context_lens_cpu = self.runner.input_batch.\
|
||||||
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
||||||
context_lens = context_lens_cpu.to(device, non_blocking=True)
|
max_context_len_cpu = context_lens_cpu.max().item()
|
||||||
|
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||||
|
|
||||||
chunked_context_metadata = None
|
chunked_context_metadata = None
|
||||||
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
||||||
and context_lens.max() > 0:
|
and max_context_len_cpu > 0:
|
||||||
# NOTE: it is recommend you read the `Chunked Prefill` section
|
# NOTE: it is recommend you read the `Chunked Prefill` section
|
||||||
# in the comment at the top of the file before trying to
|
# in the comment at the top of the file before trying to
|
||||||
# understand the following code
|
# understand the following code
|
||||||
|
|
||||||
num_prefills_with_context = (context_lens > 0).sum().item()
|
|
||||||
|
|
||||||
# currently we allocate an equal amount of workspace for each
|
# currently we allocate an equal amount of workspace for each
|
||||||
# prefill in the batch, we could probably use a more advanced
|
# prefill in the batch, we could probably use a more advanced
|
||||||
# algorithm here and allocate more workspace to prefills with
|
# algorithm here and allocate more workspace to prefills with
|
||||||
# longer context lengths
|
# longer context lengths
|
||||||
max_context_chunk = \
|
max_context_chunk = (self.chunked_prefill_workspace_size //
|
||||||
self.chunked_prefill_workspace_size \
|
num_prefills_with_context_cpu)
|
||||||
// num_prefills_with_context
|
|
||||||
|
|
||||||
# align max_context_chunk to page_size by rounding down,
|
# align max_context_chunk to page_size by rounding down,
|
||||||
# currently the `gather_cache` kernel cannot handle
|
# currently the `gather_cache` kernel cannot handle
|
||||||
@@ -516,30 +520,35 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
self.page_size)
|
self.page_size)
|
||||||
|
|
||||||
assert max_context_chunk > 0
|
assert max_context_chunk > 0
|
||||||
num_chunks = cdiv(context_lens.max(), max_context_chunk)
|
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
||||||
|
|
||||||
# if `max_context_chunk = 256`, `num_chunks = 3`, and
|
# if `max_context_chunk = 256`, `num_chunks = 3`, and
|
||||||
# `num_prefills_with_context = 4`, create a tensor that looks
|
# `num_prefills_with_context = 4`, create a tensor that looks
|
||||||
# like
|
# like
|
||||||
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
|
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
|
||||||
|
# Note(simon): this is done in CPU because of downstream's
|
||||||
|
# of `to_list`.
|
||||||
chunk_starts = \
|
chunk_starts = \
|
||||||
torch.arange(num_chunks, device=device, dtype=torch.int32) \
|
torch.arange(num_chunks, dtype=torch.int32) \
|
||||||
.unsqueeze(1).expand(-1, self._num_prefills) \
|
.unsqueeze(1).expand(-1, self._num_prefills) \
|
||||||
* max_context_chunk
|
* max_context_chunk
|
||||||
chunk_ends = torch.min(context_lens.unsqueeze(0),
|
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||||
chunk_starts + max_context_chunk)
|
chunk_starts + max_context_chunk)
|
||||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||||
_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(
|
|
||||||
torch.int32)
|
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||||
zero = torch.zeros(num_chunks,
|
self._num_prefills + 1,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device).unsqueeze(-1)
|
pin_memory=True)
|
||||||
|
torch.cumsum(chunk_seq_lens,
|
||||||
|
dim=1,
|
||||||
|
out=cu_seq_lens_cpu[:, 1:],
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
chunked_context_metadata = \
|
chunked_context_metadata = \
|
||||||
MLACommonPrefillMetadata.ChunkedContextMetadata(
|
MLACommonPrefillMetadata.ChunkedContextMetadata(
|
||||||
cu_seq_lens=torch.cat(
|
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||||
[zero, _chunk_cu_seq_lens], dim=1),
|
starts=chunk_starts.to(device, non_blocking=True),
|
||||||
starts=chunk_starts,
|
|
||||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||||
workspace=self.chunked_prefill_workspace,
|
workspace=self.chunked_prefill_workspace,
|
||||||
@@ -553,7 +562,7 @@ class MLACommonMetadataBuilder(Generic[M]):
|
|||||||
block_table=block_table[reqs_start:, ...],
|
block_table=block_table[reqs_start:, ...],
|
||||||
query_start_loc=query_start_loc[reqs_start:] -
|
query_start_loc=query_start_loc[reqs_start:] -
|
||||||
query_start_loc[reqs_start],
|
query_start_loc[reqs_start],
|
||||||
max_query_len=seq_lens[reqs_start:].max().item(),
|
max_query_len=max_query_len,
|
||||||
chunked_context=chunked_context_metadata,
|
chunked_context=chunked_context_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -629,7 +638,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
# already inside an attention custom op), pull out the forward
|
# already inside an attention custom op), pull out the forward
|
||||||
# method from the rotary embedding and call it directly
|
# method from the rotary embedding and call it directly
|
||||||
# TODO(lucas): we should probably find a cleaner way to do this
|
# TODO(lucas): we should probably find a cleaner way to do this
|
||||||
self.rotary_emb = rotary_emb._forward_method
|
self.rotary_emb = rotary_emb.forward_native
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
self.rotary_emb = rotary_emb.forward_cuda
|
||||||
|
|
||||||
self.q_proj = q_proj
|
self.q_proj = q_proj
|
||||||
self.kv_b_proj = kv_b_proj
|
self.kv_b_proj = kv_b_proj
|
||||||
@@ -1043,17 +1054,20 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||||
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\
|
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\
|
||||||
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
||||||
|
|
||||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||||
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
|
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
|
||||||
|
decode_k_pe)
|
||||||
|
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
||||||
.view(-1, self.num_heads, self.qk_head_dim)
|
.view(-1, self.num_heads, self.qk_head_dim)
|
||||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||||
|
|
||||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||||
attn_metadata.prefill.input_positions, prefill_q_pe,
|
attn_metadata.prefill.input_positions,
|
||||||
prefill_k_pe)
|
prefill_q_pe.contiguous(), prefill_k_pe)
|
||||||
|
|
||||||
# write the latent and rope to kv cache
|
# write the latent and rope to kv cache
|
||||||
if kv_cache.numel() > 0:
|
if kv_cache.numel() > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user