This PR implements the Deepseek V3 support by performing matrix absorption the fp8 weights --------- Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: simon-mo <simon.mo@hey.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
746 lines
31 KiB
Python
746 lines
31 KiB
Python
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from itertools import accumulate
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
|
|
|
from vllm.multimodal import MultiModalPlaceholderMap
|
|
|
|
try:
|
|
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
|
|
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
|
except ImportError:
|
|
BatchDecodeMlaWithPagedKVCacheWrapper = None
|
|
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
|
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
|
AttentionMetadata,
|
|
AttentionMetadataBuilder,
|
|
AttentionState, AttentionType)
|
|
from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata
|
|
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
|
compute_slot_mapping_start_idx,
|
|
is_block_tables_empty)
|
|
from vllm.attention.ops.paged_attn import PagedAttention
|
|
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
|
ModelInputForGPUWithSamplingMetadata)
|
|
|
|
|
|
class TritonMLABackend(AttentionBackend):
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "TRITON_MLA"
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> Type["TritonMLAImpl"]:
|
|
return TritonMLAImpl
|
|
|
|
@staticmethod
|
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
|
return TritonMLAMetadata
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> Type["TritonMLAMetadataBuilder"]:
|
|
return TritonMLAMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_state_cls() -> Type["TritonMLAState"]:
|
|
return TritonMLAState
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int, # assumed to be 1 for MLA
|
|
head_size: int,
|
|
) -> Tuple[int, ...]:
|
|
return (num_blocks, block_size, head_size)
|
|
|
|
@staticmethod
|
|
def swap_blocks(
|
|
src_kv_cache: torch.Tensor,
|
|
dst_kv_cache: torch.Tensor,
|
|
src_to_dst: torch.Tensor,
|
|
) -> None:
|
|
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
|
|
|
@staticmethod
|
|
def copy_blocks(
|
|
kv_caches: List[torch.Tensor],
|
|
src_to_dists: torch.Tensor,
|
|
) -> None:
|
|
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
|
|
|
@staticmethod
|
|
def get_supported_head_sizes() -> List[int]:
|
|
return [576]
|
|
|
|
|
|
class TritonMLAState(AttentionState):
|
|
|
|
def __init__(self, runner):
|
|
self.runner = runner
|
|
self._is_graph_capturing = False
|
|
|
|
@contextmanager
|
|
def graph_capture(self, max_batch_size: int):
|
|
self._is_graph_capturing = True
|
|
|
|
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
|
PAD_SLOT_ID,
|
|
dtype=torch.long,
|
|
device=self.runner.device)
|
|
self._graph_seq_lens = torch.ones(max_batch_size,
|
|
dtype=torch.int32,
|
|
device=self.runner.device)
|
|
self._graph_block_tables = torch.from_numpy(
|
|
self.runner.graph_block_tables).to(device=self.runner.device)
|
|
|
|
self._positions = torch.zeros((max_batch_size, ),
|
|
dtype=torch.long,
|
|
device=self.runner.device)
|
|
|
|
yield
|
|
|
|
self._is_graph_capturing = False
|
|
del self._graph_slot_mapping
|
|
del self._graph_seq_lens
|
|
del self._graph_block_tables
|
|
del self._positions
|
|
|
|
def graph_clone(self, batch_size: int):
|
|
assert self._is_graph_capturing
|
|
return self.__class__(self.runner)
|
|
|
|
def graph_capture_get_metadata_for_batch(
|
|
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
|
assert self._is_graph_capturing
|
|
|
|
attn_metadata = self.runner.attn_backend.make_metadata(
|
|
num_prefills=0,
|
|
num_prefill_tokens=0,
|
|
num_decode_tokens=batch_size,
|
|
slot_mapping=self._graph_slot_mapping[:batch_size],
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=True,
|
|
seq_lens=None,
|
|
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
|
max_query_len=1,
|
|
max_decode_query_len=1,
|
|
max_prefill_seq_len=0,
|
|
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
|
query_start_loc=None,
|
|
seq_start_loc=None,
|
|
context_lens_tensor=None,
|
|
block_tables=self._graph_block_tables[:batch_size],
|
|
use_cuda_graph=True,
|
|
input_positions=self._positions[:batch_size],
|
|
head_dim=self.runner.model_config.get_head_size())
|
|
|
|
if is_encoder_decoder_model:
|
|
raise NotImplementedError(
|
|
"TritonMLAState does not support encoder/decoder yet")
|
|
|
|
return attn_metadata
|
|
|
|
def get_graph_input_buffers(self,
|
|
attn_metadata,
|
|
is_encoder_decoder_model: bool = False):
|
|
input_buffers = {
|
|
"slot_mapping": attn_metadata.slot_mapping,
|
|
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
|
"block_tables": attn_metadata.decode_metadata.block_tables,
|
|
"input_positions": attn_metadata.decode_metadata.input_positions,
|
|
}
|
|
if is_encoder_decoder_model:
|
|
raise NotImplementedError(
|
|
"TritonMLAState does not support encoder/decoder yet")
|
|
|
|
return input_buffers
|
|
|
|
def prepare_graph_input_buffers(self,
|
|
input_buffers,
|
|
attn_metadata,
|
|
is_encoder_decoder_model: bool = False):
|
|
input_positions = attn_metadata.input_positions
|
|
num_positions = input_positions.shape[0]
|
|
input_buffers["seq_lens_tensor"].copy_(
|
|
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
|
input_buffers["block_tables"].copy_(
|
|
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
|
# CUDA graph buffer is padded so only perform a partial copy based on
|
|
# num_positions
|
|
input_buffers["input_positions"][:num_positions].copy_(
|
|
input_positions, non_blocking=True)
|
|
if is_encoder_decoder_model:
|
|
raise NotImplementedError(
|
|
"TritonMLAState does not support encoder/decoder yet")
|
|
|
|
def begin_forward(self, model_input):
|
|
return
|
|
|
|
|
|
@dataclass
|
|
class TritonMLAMetadata(MLACommonMetadata):
|
|
"""Metadata for TritonMLAMetadata.
|
|
|
|
NOTE: Any python object stored here is not updated when it is
|
|
cuda-graph replayed. If you have values that need to be changed
|
|
dynamically, it should be stored in tensor. The tensor has to be
|
|
updated from `CUDAGraphRunner.forward` API.
|
|
"""
|
|
# (batch_size,). The sequence length per sequence. Sequence length means
|
|
# the computed tokens + new tokens None if it is a decoding.
|
|
seq_lens: Optional[List[int]]
|
|
# seq_lens stored as a tensor.
|
|
seq_lens_tensor: Optional[torch.Tensor]
|
|
|
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
|
# |---------- N-1 iteration --------|
|
|
# |---------------- N iteration ---------------------|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
# |---------- context_len ----------|
|
|
# |-------------------- seq_len ---------------------|
|
|
# |-- query_len ---|
|
|
|
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
|
# requests only.
|
|
max_prefill_seq_len: int
|
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
|
# requests only.
|
|
max_decode_seq_len: int
|
|
# (batch_size,) A tensor of context lengths (tokens that are computed
|
|
# so far).
|
|
context_lens_tensor: Optional[torch.Tensor]
|
|
|
|
# (batch_size, max_blocks_per_seq).
|
|
# Block addresses per sequence. (Seq id -> list of physical block)
|
|
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
|
# in the kv cache. Each block can contain up to block_size tokens.
|
|
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
|
# captured.
|
|
block_tables: Optional[torch.Tensor]
|
|
|
|
# Whether or not if cuda graph is enabled.
|
|
# Cuda-graph is currently enabled for decoding only.
|
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
|
|
|
use_cuda_graph: bool
|
|
|
|
# Maximum query length in the batch.
|
|
max_query_len: Optional[int] = None
|
|
|
|
# Max number of query tokens among request in the batch.
|
|
max_decode_query_len: Optional[int] = None
|
|
|
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
|
# the batch, used to index into subquery. E.g., if the subquery length
|
|
# is [4, 6], it is [0, 4, 10].
|
|
query_start_loc: Optional[torch.Tensor] = None
|
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
|
# [4, 6], it is [0, 4, 10].
|
|
seq_start_loc: Optional[torch.Tensor] = None
|
|
|
|
_cached_prefill_metadata: Optional["TritonMLAMetadata"] = None
|
|
_cached_decode_metadata: Optional["TritonMLAMetadata"] = None
|
|
|
|
num_prefill_tokens: int
|
|
|
|
num_kv_splits: int = 4 # TODO(lucas) add heuristic
|
|
attn_logits: Optional[torch.Tensor] = None
|
|
req_idx: Optional[torch.Tensor] = None
|
|
|
|
# The dimension of the attention heads
|
|
head_dim: Optional[int] = None
|
|
|
|
def __post_init__(self):
|
|
supported_head_sizes = TritonMLABackend.get_supported_head_sizes()
|
|
if self.head_dim is not None and self.head_dim \
|
|
not in supported_head_sizes:
|
|
raise ValueError(
|
|
f"Only {supported_head_sizes} are supported for head_dim,",
|
|
f"received {self.head_dim}.")
|
|
|
|
@property
|
|
def prefill_metadata(self) -> Optional["TritonMLAMetadata"]:
|
|
if self.num_prefills == 0:
|
|
return None
|
|
|
|
if self._cached_prefill_metadata is not None:
|
|
return self._cached_prefill_metadata
|
|
|
|
assert self.seq_lens is not None
|
|
assert self.seq_lens_tensor is not None
|
|
|
|
# Compute some attn_metadata fields which default to None
|
|
query_start_loc = (None if self.query_start_loc is None else
|
|
self.query_start_loc[:self.num_prefills + 1])
|
|
slot_mapping = (None if self.slot_mapping is None else
|
|
self.slot_mapping[:self.num_prefill_tokens])
|
|
seq_lens = (None if self.seq_lens is None else
|
|
self.seq_lens[:self.num_prefills])
|
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
self.seq_lens_tensor[:self.num_prefills])
|
|
seq_start_loc = (None if self.seq_start_loc is None else
|
|
self.seq_start_loc[:self.num_prefills + 1])
|
|
context_lens_tensor = (None if self.context_lens_tensor is None else
|
|
self.context_lens_tensor[:self.num_prefills])
|
|
block_tables = (None if self.block_tables is None else
|
|
self.block_tables[:self.num_prefills])
|
|
input_positions = (None if self.input_positions is None else
|
|
self.input_positions[:self.num_prefill_tokens])
|
|
|
|
self._cached_prefill_metadata = TritonMLAMetadata(
|
|
num_prefills=self.num_prefills,
|
|
num_prefill_tokens=self.num_prefill_tokens,
|
|
num_decode_tokens=0,
|
|
slot_mapping=slot_mapping,
|
|
multi_modal_placeholder_index_maps=self.
|
|
multi_modal_placeholder_index_maps,
|
|
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
|
input_positions=input_positions,
|
|
seq_lens=seq_lens,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_query_len=self.max_query_len,
|
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
|
max_decode_query_len=0,
|
|
max_decode_seq_len=0,
|
|
query_start_loc=query_start_loc,
|
|
seq_start_loc=seq_start_loc,
|
|
context_lens_tensor=context_lens_tensor,
|
|
block_tables=block_tables,
|
|
use_cuda_graph=False,
|
|
head_dim=self.head_dim)
|
|
return self._cached_prefill_metadata
|
|
|
|
@property
|
|
def decode_metadata(self) -> Optional["TritonMLAMetadata"]:
|
|
if self.num_decode_tokens == 0:
|
|
return None
|
|
|
|
if self._cached_decode_metadata is not None:
|
|
return self._cached_decode_metadata
|
|
assert self.seq_lens_tensor is not None
|
|
|
|
# Compute some attn_metadata fields which default to None
|
|
slot_mapping = (None if self.slot_mapping is None else
|
|
self.slot_mapping[self.num_prefill_tokens:])
|
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
self.seq_lens_tensor[self.num_prefills:])
|
|
block_tables = (None if self.block_tables is None else
|
|
self.block_tables[self.num_prefills:])
|
|
input_positions = (None if self.input_positions is None else
|
|
self.input_positions[self.num_prefill_tokens:])
|
|
|
|
self._cached_decode_metadata = TritonMLAMetadata(
|
|
num_prefills=0,
|
|
num_prefill_tokens=0,
|
|
num_decode_tokens=self.num_decode_tokens,
|
|
slot_mapping=slot_mapping,
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=True,
|
|
seq_lens=None,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_decode_query_len=self.max_decode_query_len,
|
|
max_query_len=self.max_query_len,
|
|
max_prefill_seq_len=0,
|
|
max_decode_seq_len=self.max_decode_seq_len,
|
|
# Batch may be composed of prefill|decodes, adjust query start
|
|
# indices to refer to the start of decodes. E.g.
|
|
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
|
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
|
self.query_start_loc[self.num_prefills])
|
|
if self.query_start_loc is not None else None,
|
|
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
|
if self.seq_start_loc is not None else None,
|
|
context_lens_tensor=None,
|
|
block_tables=block_tables,
|
|
use_cuda_graph=self.use_cuda_graph,
|
|
input_positions=input_positions,
|
|
head_dim=self.head_dim)
|
|
return self._cached_decode_metadata
|
|
|
|
def advance_step(self,
|
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
sampled_token_ids: Optional[torch.Tensor],
|
|
block_size: int,
|
|
num_seqs: int,
|
|
num_queries: int,
|
|
turn_prefills_into_decodes: bool = False):
|
|
"""
|
|
Update metadata in-place to advance one decode step.
|
|
"""
|
|
# When using cudagraph, the num_seqs is padded to the next captured
|
|
# batch sized, but num_queries tracks the actual number of requests in
|
|
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
|
if num_seqs != num_queries:
|
|
assert num_seqs > num_queries
|
|
assert self.use_cuda_graph
|
|
|
|
if turn_prefills_into_decodes:
|
|
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
|
# decodes are scheduled together. In the first step, all the
|
|
# prefills turn into decodes. This update reflects that
|
|
# conversion.
|
|
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
|
self.num_decode_tokens += self.num_prefills
|
|
self.num_prefills = 0
|
|
self.num_prefill_tokens = 0
|
|
self.max_prefill_seq_len = 0
|
|
self.max_query_len = 1
|
|
|
|
self.slot_mapping = self.slot_mapping[:num_seqs]
|
|
else:
|
|
assert self.seq_lens is not None
|
|
assert self.max_decode_seq_len == max(self.seq_lens)
|
|
|
|
assert self.num_prefills == 0
|
|
assert self.num_prefill_tokens == 0
|
|
assert self.num_decode_tokens == num_seqs
|
|
assert self.slot_mapping.shape == (num_seqs, )
|
|
|
|
assert self.seq_lens is not None
|
|
assert len(self.seq_lens) == num_seqs
|
|
assert self.seq_lens_tensor is not None
|
|
assert self.seq_lens_tensor.shape == (num_seqs, )
|
|
assert self.max_query_len == 1
|
|
assert self.max_prefill_seq_len == 0
|
|
|
|
assert self.query_start_loc is not None
|
|
assert self.query_start_loc.shape == (num_queries + 1, )
|
|
assert self.seq_start_loc is not None
|
|
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
|
|
|
assert self.context_lens_tensor is not None
|
|
assert self.context_lens_tensor.shape == (num_queries, )
|
|
|
|
assert self.block_tables is not None
|
|
assert self.block_tables.shape[0] == num_seqs
|
|
|
|
# Update query lengths. Note that we update only queries and not seqs,
|
|
# since tensors may be padded due to captured cuda graph batch size
|
|
for i in range(num_queries):
|
|
self.seq_lens[i] += 1
|
|
self.max_decode_seq_len = max(self.seq_lens)
|
|
|
|
ops.advance_step_flashattn(num_seqs=num_seqs,
|
|
num_queries=num_queries,
|
|
block_size=block_size,
|
|
input_tokens=model_input.input_tokens,
|
|
sampled_token_ids=sampled_token_ids,
|
|
input_positions=model_input.input_positions,
|
|
seq_lens=self.seq_lens_tensor,
|
|
slot_mapping=self.slot_mapping,
|
|
block_tables=self.block_tables)
|
|
|
|
|
|
class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
|
|
|
|
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
|
self.input_builder = input_builder
|
|
self.runner = input_builder.runner
|
|
self.sliding_window = input_builder.sliding_window
|
|
self.block_size = input_builder.block_size
|
|
|
|
def prepare(self):
|
|
self.slot_mapping: List[int] = []
|
|
self.prefill_seq_lens: List[int] = []
|
|
self.context_lens: List[int] = []
|
|
self.block_tables: List[List[int]] = []
|
|
self.curr_seq_lens: List[int] = []
|
|
self.input_positions: List[int] = []
|
|
self.multimodal_placeholder_maps: Dict[
|
|
str,
|
|
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
|
self.num_prefills = 0
|
|
self.num_prefill_tokens = 0
|
|
self.num_decode_tokens = 0
|
|
self.has_prefix_cache_hit = False
|
|
|
|
def _add_seq_group(
|
|
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
|
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
|
|
"""Add a sequence group to the metadata. Specifically update/append
|
|
1. context length.
|
|
2. block table.
|
|
3. slot mapping.
|
|
"""
|
|
is_prompt = inter_data.is_prompt
|
|
block_tables = inter_data.block_tables
|
|
|
|
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
|
curr_sliding_window_block, input_positions) in zip(
|
|
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
|
inter_data.orig_seq_lens, inter_data.seq_lens,
|
|
inter_data.query_lens, inter_data.context_lens,
|
|
inter_data.curr_sliding_window_blocks,
|
|
inter_data.input_positions):
|
|
self.input_positions.extend(input_positions)
|
|
self.context_lens.append(context_len)
|
|
if is_prompt:
|
|
mm_maps = inter_data.multi_modal_placeholder_maps
|
|
if mm_maps:
|
|
for modality, placeholders in mm_maps.items():
|
|
self.multimodal_placeholder_maps[modality].extend(
|
|
placeholders)
|
|
|
|
self.num_prefills += 1
|
|
self.num_prefill_tokens += token_len
|
|
self.prefill_seq_lens.append(seq_len)
|
|
else:
|
|
self.num_decode_tokens += query_len
|
|
self.curr_seq_lens.append(curr_seq_len)
|
|
|
|
# Compute block table.
|
|
# TODO(sang): Combine chunked prefill and prefix caching by
|
|
# only allowing multiple of block_size chunk size.
|
|
# NOTE: This only works for oooooooxxx style attention.
|
|
block_table = []
|
|
if prefix_cache_hit:
|
|
# NOTE(woosuk): For flash-attn, the block table should
|
|
# include the entries for the incoming prefill tokens.
|
|
block_table = block_tables[seq_id]
|
|
elif ((chunked_prefill_enabled or not is_prompt)
|
|
and block_tables is not None):
|
|
if curr_sliding_window_block == 0:
|
|
block_table = block_tables[seq_id]
|
|
else:
|
|
block_table = block_tables[seq_id][
|
|
-curr_sliding_window_block:]
|
|
self.block_tables.append(block_table)
|
|
|
|
# Compute slot mapping.
|
|
is_profile_run = is_block_tables_empty(block_tables)
|
|
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
|
context_len,
|
|
self.sliding_window)
|
|
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
|
seq_len, context_len, start_idx,
|
|
self.block_size, inter_data.block_tables)
|
|
|
|
def _get_graph_runner_block_tables(
|
|
self, num_seqs: int,
|
|
block_tables: List[List[int]]) -> torch.Tensor:
|
|
# The shape of graph_block_tables is
|
|
# [max batch size, max context len // block size].
|
|
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
|
assert max_batch_size >= num_seqs
|
|
|
|
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
|
|
for i, block_table in enumerate(block_tables):
|
|
if block_table:
|
|
num_blocks = len(block_table)
|
|
if num_blocks <= max_blocks:
|
|
graph_block_tables[i, :num_blocks] = block_table
|
|
else:
|
|
# It may be possible to have more blocks allocated due
|
|
# to lookahead slots of multi-step, however, they are
|
|
# not used anyway, so can be safely ignored.
|
|
graph_block_tables[
|
|
i, :max_blocks] = block_table[:max_blocks]
|
|
|
|
return torch.from_numpy(graph_block_tables).to(
|
|
device=self.runner.device, non_blocking=True)
|
|
|
|
def build(self, seq_lens: List[int], query_lens: List[int],
|
|
cuda_graph_pad_size: int, batch_size: int):
|
|
"""Build attention metadata with on-device tensors.
|
|
|
|
Args:
|
|
seq_lens: The maybe padded sequence lengths of the input sequences.
|
|
query_lens: The query lengths of the input sequences.
|
|
cuda_graph_pad_size: The padding size for cuda graph.
|
|
-1 if cuda graph is not used.
|
|
batch_size: The maybe padded batch size.
|
|
"""
|
|
prefix_cache_hit = any([
|
|
inter_data.prefix_cache_hit
|
|
for inter_data in self.input_builder.inter_data_list
|
|
])
|
|
for inter_data in self.input_builder.inter_data_list:
|
|
self._add_seq_group(inter_data,
|
|
self.input_builder.chunked_prefill_enabled,
|
|
prefix_cache_hit)
|
|
|
|
device = self.runner.device
|
|
use_captured_graph = cuda_graph_pad_size != -1
|
|
|
|
max_query_len = max(query_lens)
|
|
decode_query_lens = query_lens[self.num_prefills:]
|
|
if len(decode_query_lens) > 0:
|
|
max_decode_query_len = max(decode_query_lens)
|
|
else:
|
|
max_decode_query_len = 1
|
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
|
num_decode_tokens = self.num_decode_tokens
|
|
query_start_loc = list(accumulate(query_lens, initial=0))
|
|
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
|
|
|
num_seqs = len(seq_lens)
|
|
if use_captured_graph:
|
|
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
|
self.block_tables.extend([] * cuda_graph_pad_size)
|
|
num_decode_tokens = batch_size - self.num_prefill_tokens
|
|
block_tables = self._get_graph_runner_block_tables(
|
|
num_seqs, self.block_tables)
|
|
else:
|
|
block_tables = make_tensor_with_pad(
|
|
self.block_tables,
|
|
pad=0,
|
|
dtype=torch.int,
|
|
device=device,
|
|
)
|
|
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
|
|
|
assert device is not None
|
|
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
|
device, self.runner.pin_memory)
|
|
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
|
self.runner.pin_memory)
|
|
input_positions = async_tensor_h2d(self.input_positions, torch.long,
|
|
device, self.runner.pin_memory)
|
|
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
|
device, self.runner.pin_memory)
|
|
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
|
device,
|
|
self.runner.pin_memory)
|
|
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
|
device, self.runner.pin_memory)
|
|
placeholder_index_maps = {
|
|
modality: placeholder_map.index_map()
|
|
for modality, placeholder_map in
|
|
self.multimodal_placeholder_maps.items()
|
|
}
|
|
|
|
return TritonMLAMetadata(
|
|
num_prefills=self.num_prefills,
|
|
slot_mapping=slot_mapping_tensor,
|
|
num_prefill_tokens=self.num_prefill_tokens,
|
|
num_decode_tokens=num_decode_tokens,
|
|
seq_lens=seq_lens,
|
|
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
|
enable_kv_scales_calculation=True,
|
|
input_positions=input_positions,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_query_len=max_query_len,
|
|
max_decode_query_len=max_decode_query_len,
|
|
max_prefill_seq_len=max_prefill_seq_len,
|
|
max_decode_seq_len=max_decode_seq_len,
|
|
query_start_loc=query_start_loc_tensor,
|
|
seq_start_loc=seq_start_loc_tensor,
|
|
context_lens_tensor=context_lens_tensor,
|
|
block_tables=block_tables,
|
|
use_cuda_graph=use_captured_graph,
|
|
num_kv_splits=4, # TODO(lucas) add heuristic
|
|
head_dim=self.runner.model_config.get_head_size(),
|
|
)
|
|
|
|
|
|
class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: Optional[List[float]],
|
|
sliding_window: Optional[int],
|
|
kv_cache_dtype: str,
|
|
blocksparse_params: Optional[Dict[str, Any]],
|
|
logits_soft_cap: Optional[float],
|
|
attn_type: str,
|
|
# MLA Specific Arguments
|
|
**kwargs) -> None:
|
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
|
blocksparse_params, logits_soft_cap, attn_type,
|
|
**kwargs)
|
|
|
|
unsupported_features = [
|
|
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
|
]
|
|
if any(unsupported_features):
|
|
raise NotImplementedError(
|
|
"TritonMLAImpl does not support one of the following: "
|
|
"alibi_slopes, sliding_window, blocksparse_params, "
|
|
"logits_soft_cap")
|
|
|
|
if attn_type != AttentionType.DECODER:
|
|
raise NotImplementedError("Encoder self-attention and "
|
|
"encoder/decoder cross-attention "
|
|
"are not implemented for "
|
|
"TritonMLAImpl")
|
|
|
|
def _forward_prefill(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv_c_normed: torch.Tensor,
|
|
k_pe: torch.Tensor,
|
|
attn_metadata: TritonMLAMetadata,
|
|
) -> torch.Tensor:
|
|
assert isinstance(attn_metadata, TritonMLAMetadata)
|
|
return self._forward_prefill_flash(q, kv_c_normed, k_pe,
|
|
attn_metadata.seq_start_loc,
|
|
attn_metadata.max_prefill_seq_len)
|
|
|
|
def _forward_decode(
|
|
self,
|
|
q_nope: torch.Tensor,
|
|
q_pe: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: TritonMLAMetadata,
|
|
) -> torch.Tensor:
|
|
assert kv_c_and_k_pe_cache.numel() > 0
|
|
if self.kv_cache_dtype.startswith("fp8"):
|
|
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
|
|
|
decode_meta = attn_metadata.decode_metadata
|
|
assert decode_meta is not None
|
|
B = q_nope.shape[0]
|
|
|
|
q = torch.cat([q_nope, q_pe], dim=-1)
|
|
o = torch.zeros(B,
|
|
self.num_heads,
|
|
self.kv_lora_rank,
|
|
dtype=q.dtype,
|
|
device=q.device)
|
|
|
|
# TODO(lucas) Allocate ahead of time
|
|
attn_logits = torch.empty(
|
|
(
|
|
B,
|
|
self.num_heads,
|
|
attn_metadata.num_kv_splits,
|
|
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
|
# just mirror that
|
|
self.kv_lora_rank + 1,
|
|
),
|
|
dtype=torch.float32,
|
|
device=q.device,
|
|
)
|
|
|
|
# Add a head dim of 1
|
|
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
|
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
|
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
|
|
|
# Run MQA
|
|
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
|
decode_meta.block_tables,
|
|
decode_meta.seq_lens_tensor, attn_logits,
|
|
attn_metadata.num_kv_splits, self.scale,
|
|
PAGE_SIZE)
|
|
|
|
return self._v_up_proj_and_o_proj(o)
|