|
|
|
|
@@ -275,17 +275,47 @@ class MLACommonBackend(AttentionBackend):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class MLACommonMetadata:
|
|
|
|
|
class MLACommonPrefillMetadata:
|
|
|
|
|
""" Prefill Specific Metadata """
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class ChunkedContextMetadata:
|
|
|
|
|
# New for MLA (compared to FlashAttention)
|
|
|
|
|
# For handling chunked prefill
|
|
|
|
|
cu_seq_lens: torch.Tensor
|
|
|
|
|
starts: torch.Tensor
|
|
|
|
|
seq_tot: list[int]
|
|
|
|
|
max_seq_lens: list[int]
|
|
|
|
|
workspace: torch.Tensor
|
|
|
|
|
|
|
|
|
|
# Input positions for rotrary embeddings since for MLA the rotary
|
|
|
|
|
# position embeddings are applied inside the attention backend
|
|
|
|
|
input_positions: torch.Tensor
|
|
|
|
|
block_table: torch.Tensor
|
|
|
|
|
query_start_loc: torch.Tensor
|
|
|
|
|
max_query_len: int
|
|
|
|
|
chunked_context: Optional[ChunkedContextMetadata] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class MLACommonDecodeMetadata:
|
|
|
|
|
# Input positions for rotrary embeddings since for MLA the rotary
|
|
|
|
|
# position embeddings are applied inside the attention backend
|
|
|
|
|
input_positions: torch.Tensor
|
|
|
|
|
block_table: torch.Tensor
|
|
|
|
|
seq_lens: torch.Tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
D = TypeVar("D", bound=MLACommonDecodeMetadata)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class MLACommonMetadata(Generic[D]):
|
|
|
|
|
"""Metadata for MLACommon.
|
|
|
|
|
|
|
|
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
|
|
|
understand this class
|
|
|
|
|
"""
|
|
|
|
|
# New for MLA (compared to FlashAttention)
|
|
|
|
|
# Input positions for rotrary embeddings since for MLA the rotary
|
|
|
|
|
# position embeddings are applied inside the attention backend
|
|
|
|
|
input_positions: torch.Tensor
|
|
|
|
|
|
|
|
|
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
|
|
|
|
# |---------- N-1 iteration --------|
|
|
|
|
|
# |---------------- N iteration ---------------------|
|
|
|
|
|
@@ -295,30 +325,23 @@ class MLACommonMetadata:
|
|
|
|
|
# |-- query_len ---|
|
|
|
|
|
|
|
|
|
|
num_actual_tokens: int # Number of tokens excluding padding.
|
|
|
|
|
max_query_len: int
|
|
|
|
|
query_start_loc: torch.Tensor
|
|
|
|
|
max_seq_len: int
|
|
|
|
|
seq_lens: torch.Tensor
|
|
|
|
|
block_table: torch.Tensor
|
|
|
|
|
slot_mapping: torch.Tensor
|
|
|
|
|
|
|
|
|
|
# New for MLA (compared to FlashAttention)
|
|
|
|
|
# For handling prefill decode split
|
|
|
|
|
num_decodes: int
|
|
|
|
|
num_decode_tokens: int
|
|
|
|
|
num_prefills: int
|
|
|
|
|
|
|
|
|
|
# For logging.
|
|
|
|
|
num_input_tokens: int = 0 # Number of tokens including padding.
|
|
|
|
|
|
|
|
|
|
# The dimension of the attention heads
|
|
|
|
|
head_dim: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
# New for MLA (compared to FlashAttention)
|
|
|
|
|
# For chunked prefill
|
|
|
|
|
num_decodes: Optional[int] = None
|
|
|
|
|
num_decode_tokens: Optional[int] = None
|
|
|
|
|
num_prefills: Optional[int] = None
|
|
|
|
|
has_context: bool = False
|
|
|
|
|
context_chunk_cu_seq_lens: Optional[torch.Tensor] = None
|
|
|
|
|
context_chunk_starts: Optional[torch.Tensor] = None
|
|
|
|
|
context_chunk_seq_tot: Optional[list[int]] = None
|
|
|
|
|
context_chunk_max_seq_lens: Optional[list[int]] = None
|
|
|
|
|
chunked_prefill_workspace: Optional[torch.Tensor] = None
|
|
|
|
|
decode: Optional[D] = None
|
|
|
|
|
prefill: Optional[MLACommonPrefillMetadata] = None
|
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
|
|
|
|
|
@@ -329,10 +352,10 @@ class MLACommonMetadata:
|
|
|
|
|
f"received {self.head_dim}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
T = TypeVar("T", bound=MLACommonMetadata)
|
|
|
|
|
M = TypeVar("M", bound=MLACommonMetadata)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MLACommonMetadataBuilder(Generic[T]):
|
|
|
|
|
class MLACommonMetadataBuilder(Generic[M]):
|
|
|
|
|
"""
|
|
|
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
|
|
|
understand this class
|
|
|
|
|
@@ -340,8 +363,9 @@ class MLACommonMetadataBuilder(Generic[T]):
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
runner: "GPUModelRunner",
|
|
|
|
|
cls: Optional[type[T]] = None):
|
|
|
|
|
self.cls = cls if cls is not None else MLACommonMetadata
|
|
|
|
|
metadata_cls: Optional[type[M]] = None):
|
|
|
|
|
self.metadata_cls = metadata_cls \
|
|
|
|
|
if metadata_cls is not None else MLACommonMetadata
|
|
|
|
|
self.runner = runner
|
|
|
|
|
scheduler_config = runner.scheduler_config
|
|
|
|
|
model_config = runner.model_config
|
|
|
|
|
@@ -375,7 +399,7 @@ class MLACommonMetadataBuilder(Generic[T]):
|
|
|
|
|
self.page_size = self.runner.block_size
|
|
|
|
|
|
|
|
|
|
def reorder_batch(self, input_batch: "InputBatch",
|
|
|
|
|
scheduler_output: "SchedulerOutput"):
|
|
|
|
|
scheduler_output: "SchedulerOutput") -> bool:
|
|
|
|
|
# We now want to reorder the batch so that the "decode" requests are and
|
|
|
|
|
# the front and the "prefill" requests are at the using the least amount
|
|
|
|
|
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
|
|
|
|
@@ -413,6 +437,7 @@ class MLACommonMetadataBuilder(Generic[T]):
|
|
|
|
|
num_decodes = len(decodes)
|
|
|
|
|
num_prefills = len(prefills)
|
|
|
|
|
first_prefill = 0
|
|
|
|
|
modified_batch = False
|
|
|
|
|
|
|
|
|
|
for i in range(1, min(num_decodes, num_prefills) + 1):
|
|
|
|
|
# If the decode is at the "back" of the batch, i, we can swap it
|
|
|
|
|
@@ -421,6 +446,7 @@ class MLACommonMetadataBuilder(Generic[T]):
|
|
|
|
|
input_batch.swap_states(prefills[first_prefill],
|
|
|
|
|
decodes[num_decodes - i])
|
|
|
|
|
first_prefill += 1
|
|
|
|
|
modified_batch = True
|
|
|
|
|
else:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
@@ -432,10 +458,21 @@ class MLACommonMetadataBuilder(Generic[T]):
|
|
|
|
|
self._num_decode_tokens = num_decode_tokens
|
|
|
|
|
self._num_prefill_tokens = num_prefill_tokens
|
|
|
|
|
|
|
|
|
|
return modified_batch
|
|
|
|
|
|
|
|
|
|
def _build_decode(self, input_positions: torch.Tensor,
|
|
|
|
|
block_table: torch.Tensor, seq_lens: torch.Tensor):
|
|
|
|
|
return MLACommonDecodeMetadata(
|
|
|
|
|
input_positions=input_positions,
|
|
|
|
|
block_table=block_table,
|
|
|
|
|
seq_lens=seq_lens,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
|
|
|
|
common_prefix_len: int) -> T:
|
|
|
|
|
common_prefix_len: int) -> M:
|
|
|
|
|
assert self._num_decodes + self._num_prefills == num_reqs
|
|
|
|
|
|
|
|
|
|
device = self.runner.device
|
|
|
|
|
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
|
|
|
|
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
|
|
|
|
|
device, non_blocking=True)
|
|
|
|
|
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device,
|
|
|
|
|
@@ -447,85 +484,103 @@ class MLACommonMetadataBuilder(Generic[T]):
|
|
|
|
|
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
|
|
|
|
device, non_blocking=True).long()
|
|
|
|
|
|
|
|
|
|
context_chunk_cu_seq_lens = None
|
|
|
|
|
context_chunk_starts = None
|
|
|
|
|
context_chunk_seq_tot = None
|
|
|
|
|
context_chunk_max_seq_lens = None
|
|
|
|
|
prefill_metadata = None
|
|
|
|
|
if self._num_prefills > 0:
|
|
|
|
|
reqs_start = self._num_decodes # prefill_start
|
|
|
|
|
tokens_start = self._num_decode_tokens
|
|
|
|
|
|
|
|
|
|
num_computed_tokens_cpu_tensor = \
|
|
|
|
|
self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]
|
|
|
|
|
context_lens_tensor = \
|
|
|
|
|
num_computed_tokens_cpu_tensor.to(device, non_blocking=True)
|
|
|
|
|
context_lens_cpu = self.runner.input_batch.\
|
|
|
|
|
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
|
|
|
|
context_lens = context_lens_cpu.to(device, non_blocking=True)
|
|
|
|
|
|
|
|
|
|
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
|
|
|
|
and context_lens_tensor[self._num_decodes:].max() > 0:
|
|
|
|
|
# NOTE: it is recommend you read the `Chunked Prefill` section in
|
|
|
|
|
# the comment at the top of the file before trying to understand
|
|
|
|
|
# the following code
|
|
|
|
|
chunked_context_metadata = None
|
|
|
|
|
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
|
|
|
|
and context_lens.max() > 0:
|
|
|
|
|
# NOTE: it is recommend you read the `Chunked Prefill` section
|
|
|
|
|
# in the comment at the top of the file before trying to
|
|
|
|
|
# understand the following code
|
|
|
|
|
|
|
|
|
|
self.has_context = True
|
|
|
|
|
num_prefills_with_context = (context_lens > 0).sum().item()
|
|
|
|
|
|
|
|
|
|
num_prefills_with_context = \
|
|
|
|
|
(context_lens_tensor[self._num_decodes:] > 0).sum().item()
|
|
|
|
|
# currently we allocate an equal amount of workspace for each
|
|
|
|
|
# prefill in the batch, we could probably use a more advanced
|
|
|
|
|
# algorithm here and allocate more workspace to prefills with
|
|
|
|
|
# longer context lengths
|
|
|
|
|
max_context_chunk = \
|
|
|
|
|
self.chunked_prefill_workspace_size \
|
|
|
|
|
// num_prefills_with_context
|
|
|
|
|
|
|
|
|
|
# currently we allocate an equal amount of workspace for each
|
|
|
|
|
# prefill in the batch, we could probably use a more advanced
|
|
|
|
|
# algorithm here and allocate more workspace to prefills with
|
|
|
|
|
# longer context lengths
|
|
|
|
|
max_context_chunk = \
|
|
|
|
|
self.chunked_prefill_workspace_size // num_prefills_with_context
|
|
|
|
|
# align max_context_chunk to page_size by rounding down,
|
|
|
|
|
# currently the `gather_cache` kernel cannot handle
|
|
|
|
|
# `context_chunk_starts` that are not aligned to page_size
|
|
|
|
|
max_context_chunk = round_down(max_context_chunk,
|
|
|
|
|
self.page_size)
|
|
|
|
|
|
|
|
|
|
# align max_context_chunk to page_size by rounding down,
|
|
|
|
|
# currently the `gather_cache` kernel cannot handle
|
|
|
|
|
# `context_chunk_starts` that are not aligned to page_size
|
|
|
|
|
max_context_chunk = round_down(max_context_chunk, self.page_size)
|
|
|
|
|
assert max_context_chunk > 0
|
|
|
|
|
num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
|
|
|
|
|
assert max_context_chunk > 0
|
|
|
|
|
num_chunks = cdiv(context_lens.max(), max_context_chunk)
|
|
|
|
|
|
|
|
|
|
# if `max_context_chunk = 256`, `num_chunks = 3`, and
|
|
|
|
|
# `num_prefills_with_context = 4`, create a tensor that looks like
|
|
|
|
|
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
|
|
|
|
|
context_chunk_starts = \
|
|
|
|
|
torch.arange(num_chunks, device=device, dtype=torch.int32) \
|
|
|
|
|
.unsqueeze(1).expand(-1, self._num_prefills) \
|
|
|
|
|
* max_context_chunk
|
|
|
|
|
chunk_ends = torch.min(context_lens_tensor[self._num_decodes:] \
|
|
|
|
|
.unsqueeze(0), context_chunk_starts + max_context_chunk)
|
|
|
|
|
chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0)
|
|
|
|
|
_context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(
|
|
|
|
|
torch.int32)
|
|
|
|
|
zero = torch.zeros(num_chunks, dtype=torch.int32, device=device) \
|
|
|
|
|
.unsqueeze(-1)
|
|
|
|
|
context_chunk_cu_seq_lens = \
|
|
|
|
|
torch.cat([zero, _context_chunk_cu_seq_lens], dim=1)
|
|
|
|
|
context_chunk_max_seq_lens = \
|
|
|
|
|
chunk_seq_lens.max(dim=1).values.tolist()
|
|
|
|
|
context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist()
|
|
|
|
|
assert max(context_chunk_seq_tot) <= \
|
|
|
|
|
self.chunked_prefill_workspace_size
|
|
|
|
|
# if `max_context_chunk = 256`, `num_chunks = 3`, and
|
|
|
|
|
# `num_prefills_with_context = 4`, create a tensor that looks
|
|
|
|
|
# like
|
|
|
|
|
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
|
|
|
|
|
chunk_starts = \
|
|
|
|
|
torch.arange(num_chunks, device=device, dtype=torch.int32) \
|
|
|
|
|
.unsqueeze(1).expand(-1, self._num_prefills) \
|
|
|
|
|
* max_context_chunk
|
|
|
|
|
chunk_ends = torch.min(context_lens.unsqueeze(0),
|
|
|
|
|
chunk_starts + max_context_chunk)
|
|
|
|
|
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
|
|
|
|
_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(
|
|
|
|
|
torch.int32)
|
|
|
|
|
zero = torch.zeros(num_chunks,
|
|
|
|
|
dtype=torch.int32,
|
|
|
|
|
device=device).unsqueeze(-1)
|
|
|
|
|
|
|
|
|
|
return self.cls(
|
|
|
|
|
input_positions=input_positions,
|
|
|
|
|
chunked_context_metadata = \
|
|
|
|
|
MLACommonPrefillMetadata.ChunkedContextMetadata(
|
|
|
|
|
cu_seq_lens=torch.cat(
|
|
|
|
|
[zero, _chunk_cu_seq_lens], dim=1),
|
|
|
|
|
starts=chunk_starts,
|
|
|
|
|
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
|
|
|
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
|
|
|
|
workspace=self.chunked_prefill_workspace,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert max(chunked_context_metadata.max_seq_lens) <= \
|
|
|
|
|
self.chunked_prefill_workspace_size
|
|
|
|
|
|
|
|
|
|
prefill_metadata = MLACommonPrefillMetadata(
|
|
|
|
|
input_positions=input_positions[tokens_start:],
|
|
|
|
|
block_table=block_table[reqs_start:, ...],
|
|
|
|
|
query_start_loc=query_start_loc[reqs_start:] -
|
|
|
|
|
query_start_loc[reqs_start],
|
|
|
|
|
max_query_len=seq_lens[reqs_start:].max().item(),
|
|
|
|
|
chunked_context=chunked_context_metadata,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
decode_metadata = None
|
|
|
|
|
if self._num_decodes > 0:
|
|
|
|
|
decode_metadata = self._build_decode(
|
|
|
|
|
input_positions=input_positions[:self._num_decode_tokens],
|
|
|
|
|
block_table=block_table[:self._num_decodes, ...],
|
|
|
|
|
seq_lens=seq_lens[:self._num_decodes],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return self.metadata_cls(
|
|
|
|
|
num_actual_tokens=num_actual_tokens,
|
|
|
|
|
max_query_len=max_query_len,
|
|
|
|
|
query_start_loc=query_start_loc,
|
|
|
|
|
max_seq_len=max_seq_len,
|
|
|
|
|
seq_lens=seq_lens,
|
|
|
|
|
block_table=block_table,
|
|
|
|
|
slot_mapping=slot_mapping,
|
|
|
|
|
head_dim=self.runner.model_config.get_head_size(),
|
|
|
|
|
# MLACommonMetadata Chunk prefill specific
|
|
|
|
|
num_decodes=self._num_decodes,
|
|
|
|
|
num_decode_tokens=self._num_decode_tokens,
|
|
|
|
|
num_prefills=self._num_prefills,
|
|
|
|
|
context_chunk_cu_seq_lens=context_chunk_cu_seq_lens,
|
|
|
|
|
context_chunk_starts=context_chunk_starts,
|
|
|
|
|
context_chunk_seq_tot=context_chunk_seq_tot,
|
|
|
|
|
context_chunk_max_seq_lens=context_chunk_max_seq_lens,
|
|
|
|
|
prefill=prefill_metadata,
|
|
|
|
|
decode=decode_metadata,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|
|
|
|
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|
|
|
|
"""
|
|
|
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
|
|
|
understand this class
|
|
|
|
|
@@ -798,28 +853,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|
|
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
|
|
|
attn_metadata: MLACommonMetadata,
|
|
|
|
|
):
|
|
|
|
|
assert attn_metadata.num_prefills is not None
|
|
|
|
|
assert attn_metadata.context_chunk_seq_tot is not None
|
|
|
|
|
assert attn_metadata.context_chunk_cu_seq_lens is not None
|
|
|
|
|
assert attn_metadata.context_chunk_starts is not None
|
|
|
|
|
assert attn_metadata.context_chunk_max_seq_lens is not None
|
|
|
|
|
assert attn_metadata.prefill is not None
|
|
|
|
|
prefill_metadata = attn_metadata.prefill
|
|
|
|
|
assert prefill_metadata.chunked_context is not None
|
|
|
|
|
|
|
|
|
|
output = None
|
|
|
|
|
iters = len(attn_metadata.context_chunk_seq_tot)
|
|
|
|
|
|
|
|
|
|
assert attn_metadata.chunked_prefill_workspace is not None
|
|
|
|
|
workspace = attn_metadata.chunked_prefill_workspace
|
|
|
|
|
iters = len(prefill_metadata.chunked_context.seq_tot)
|
|
|
|
|
workspace = prefill_metadata.chunked_context.workspace
|
|
|
|
|
|
|
|
|
|
for i in range(iters):
|
|
|
|
|
toks = attn_metadata.context_chunk_seq_tot[i]
|
|
|
|
|
toks = prefill_metadata.chunked_context.seq_tot[i]
|
|
|
|
|
|
|
|
|
|
ops.gather_cache(
|
|
|
|
|
src_cache=kv_c_and_k_pe_cache,
|
|
|
|
|
dst=workspace,
|
|
|
|
|
block_table=attn_metadata.block_table,
|
|
|
|
|
cu_seq_lens=attn_metadata.context_chunk_cu_seq_lens[i],
|
|
|
|
|
block_table=prefill_metadata.block_table,
|
|
|
|
|
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
|
|
|
|
|
batch_size=attn_metadata.num_prefills,
|
|
|
|
|
seq_starts=attn_metadata.context_chunk_starts[i],
|
|
|
|
|
seq_starts=prefill_metadata.chunked_context.starts[i],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
kv_c_normed = workspace[:toks]\
|
|
|
|
|
@@ -845,10 +896,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|
|
|
|
q=q,
|
|
|
|
|
k=k,
|
|
|
|
|
v=v_padded,
|
|
|
|
|
cu_seqlens_q=attn_metadata.query_start_loc,
|
|
|
|
|
cu_seqlens_k=attn_metadata.context_chunk_cu_seq_lens[i],
|
|
|
|
|
max_seqlen_q=attn_metadata.max_query_len,
|
|
|
|
|
max_seqlen_k=attn_metadata.context_chunk_max_seq_lens[i],
|
|
|
|
|
cu_seqlens_q=prefill_metadata.query_start_loc,
|
|
|
|
|
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
|
|
|
|
|
max_seqlen_q=prefill_metadata.max_query_len,
|
|
|
|
|
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
|
|
|
|
|
softmax_scale=self.scale,
|
|
|
|
|
causal=False, # Context is unmasked
|
|
|
|
|
return_softmax_lse=True,
|
|
|
|
|
@@ -881,7 +932,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|
|
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
|
|
|
attn_metadata: MLACommonMetadata,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
has_context = attn_metadata.has_context
|
|
|
|
|
assert attn_metadata.prefill is not None
|
|
|
|
|
|
|
|
|
|
has_context = attn_metadata.prefill.chunked_context is not None
|
|
|
|
|
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
|
|
|
|
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
|
|
|
k_nope, v = kv_nope\
|
|
|
|
|
@@ -898,10 +951,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|
|
|
|
q=q,
|
|
|
|
|
k=k,
|
|
|
|
|
v=v_padded,
|
|
|
|
|
cu_seqlens_q=attn_metadata.query_start_loc,
|
|
|
|
|
cu_seqlens_k=attn_metadata.query_start_loc,
|
|
|
|
|
max_seqlen_q=attn_metadata.max_query_len,
|
|
|
|
|
max_seqlen_k=attn_metadata.max_seq_len,
|
|
|
|
|
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
|
|
|
|
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
|
|
|
|
|
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
|
|
|
|
max_seqlen_k=attn_metadata.prefill.max_query_len,
|
|
|
|
|
softmax_scale=self.scale,
|
|
|
|
|
causal=True,
|
|
|
|
|
return_softmax_lse=has_context,
|
|
|
|
|
@@ -934,7 +987,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|
|
|
|
q_nope: torch.Tensor,
|
|
|
|
|
q_pe: torch.Tensor,
|
|
|
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
|
|
|
attn_metadata: T,
|
|
|
|
|
attn_metadata: M,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
@@ -945,7 +998,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|
|
|
|
k_c_normed: torch.Tensor, # key in unified attn
|
|
|
|
|
k_pe: torch.Tensor, # value in unified attn
|
|
|
|
|
kv_cache: torch.Tensor,
|
|
|
|
|
attn_metadata: T,
|
|
|
|
|
attn_metadata: M,
|
|
|
|
|
output: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
@@ -966,7 +1019,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|
|
|
|
|
|
|
|
|
# Restore head dim (for rotary embedding)
|
|
|
|
|
k_pe = k_pe.unsqueeze(1)
|
|
|
|
|
assert hasattr(attn_metadata, "input_positions")
|
|
|
|
|
|
|
|
|
|
assert attn_metadata.num_decodes is not None and \
|
|
|
|
|
attn_metadata.num_prefills is not None and \
|
|
|
|
|
@@ -978,28 +1030,27 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|
|
|
|
|
|
|
|
|
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
|
|
|
|
decode_k_pe = k_pe[:num_decode_tokens]
|
|
|
|
|
decode_input_positions = \
|
|
|
|
|
attn_metadata.input_positions[:num_decode_tokens]
|
|
|
|
|
|
|
|
|
|
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
|
|
|
|
prefill_k_pe = k_pe[num_decode_tokens:]
|
|
|
|
|
prefill_input_positions = \
|
|
|
|
|
attn_metadata.input_positions[num_decode_tokens:]
|
|
|
|
|
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
|
|
|
|
|
|
|
|
|
if has_decode:
|
|
|
|
|
assert attn_metadata.decode is not None
|
|
|
|
|
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
|
|
|
|
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\
|
|
|
|
|
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
|
|
|
|
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
|
|
|
|
decode_input_positions, decode_q_pe, decode_k_pe)
|
|
|
|
|
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
|
|
|
|
|
|
|
|
|
|
if has_prefill:
|
|
|
|
|
assert attn_metadata.prefill is not None
|
|
|
|
|
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
|
|
|
|
.view(-1, self.num_heads, self.qk_head_dim)
|
|
|
|
|
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
|
|
|
|
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
|
|
|
|
prefill_input_positions, prefill_q_pe, prefill_k_pe)
|
|
|
|
|
attn_metadata.prefill.input_positions, prefill_q_pe,
|
|
|
|
|
prefill_k_pe)
|
|
|
|
|
|
|
|
|
|
# write the latent and rope to kv cache
|
|
|
|
|
if kv_cache.numel() > 0:
|
|
|
|
|
|