[ROCm][perf] fix Aiter sparse MLA with MTP>1 (#37887)
Signed-off-by: Stig-Arne Grönroos <stig-arne.gronroos@amd.com> Signed-off-by: Stig-Arne Grönroos <sgronroo@amd.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
committed by
GitHub
parent
2e56975657
commit
31a719bcd3
@@ -257,19 +257,19 @@ class DFlashProposer(SpecDecodeBaseProposer):
|
||||
)
|
||||
|
||||
@override
|
||||
def build_per_layer_attn_metadata(
|
||||
def build_per_group_and_layer_attn_metadata(
|
||||
self, cad: CommonAttentionMetadata, draft_index: int = 0
|
||||
) -> dict[str, object]:
|
||||
per_layer_attention_metadata = super().build_per_layer_attn_metadata(
|
||||
) -> tuple[list[object], dict[str, object]]:
|
||||
per_group, per_layer = super().build_per_group_and_layer_attn_metadata(
|
||||
cad, draft_index
|
||||
)
|
||||
for layer_name, attn_metadata in per_layer_attention_metadata.items():
|
||||
for layer_name, attn_metadata in per_layer.items():
|
||||
assert getattr(attn_metadata, "causal", None) is False, (
|
||||
f"Attention metadata for layer {layer_name} does not have"
|
||||
" non-causal support, which is required for DFlash."
|
||||
" Consider using a different attention backend, such as FlashAttention."
|
||||
)
|
||||
return per_layer_attention_metadata
|
||||
return per_group, per_layer
|
||||
|
||||
@override
|
||||
def _get_eagle3_use_aux_hidden_state_from_config(self):
|
||||
|
||||
@@ -225,6 +225,9 @@ class SpecDecodeBaseProposer:
|
||||
# Determine allowed attention backends once during initialization.
|
||||
self.allowed_attn_types: tuple | None = None
|
||||
if current_platform.is_rocm():
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse import (
|
||||
ROCMAiterMLASparseMetadata,
|
||||
)
|
||||
@@ -234,6 +237,7 @@ class SpecDecodeBaseProposer:
|
||||
TritonAttentionMetadata,
|
||||
RocmAttentionMetadata,
|
||||
ROCMAiterMLASparseMetadata,
|
||||
DeepseekV32IndexerMetadata,
|
||||
]
|
||||
# ROCM_AITER_FA is an optional backend
|
||||
# We check is_enabled() here to avoid importing the backend module during
|
||||
@@ -444,8 +448,8 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
)
|
||||
|
||||
per_layer_attn_metadata = self.build_per_layer_attn_metadata(
|
||||
common_attn_metadata
|
||||
per_group_attn_metadata, per_layer_attn_metadata = (
|
||||
self.build_per_group_and_layer_attn_metadata(common_attn_metadata)
|
||||
)
|
||||
|
||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||
@@ -486,10 +490,7 @@ class SpecDecodeBaseProposer:
|
||||
positions = self.positions[token_indices_to_sample]
|
||||
hidden_states = hidden_states[token_indices_to_sample]
|
||||
|
||||
if any(
|
||||
isinstance(attn_metadata, TreeAttentionMetadata)
|
||||
for attn_metadata in per_layer_attn_metadata.values()
|
||||
):
|
||||
if any(isinstance(md, TreeAttentionMetadata) for md in per_group_attn_metadata):
|
||||
# Draft using tree attention - requires full logits for top-k
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
draft_token_ids_list = self.propose_tree(
|
||||
@@ -505,16 +506,15 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
draft_token_ids = self._greedy_sample(sample_hidden_states)
|
||||
|
||||
for attn_metadata in per_layer_attn_metadata.values():
|
||||
if self.allowed_attn_types is not None and not isinstance(
|
||||
attn_metadata, self.allowed_attn_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unsupported attention metadata type for speculative "
|
||||
"decoding with num_speculative_tokens > 1: "
|
||||
f"{type(attn_metadata)}. Supported types are: "
|
||||
f"{self.allowed_attn_types}"
|
||||
)
|
||||
if self.allowed_attn_types is not None:
|
||||
for group_md in per_group_attn_metadata:
|
||||
if not isinstance(group_md, self.allowed_attn_types):
|
||||
raise ValueError(
|
||||
f"Unsupported attention metadata type for speculative "
|
||||
"decoding with num_speculative_tokens > 1: "
|
||||
f"{type(group_md)}. Supported types are: "
|
||||
f"{self.allowed_attn_types}"
|
||||
)
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
@@ -595,7 +595,7 @@ class SpecDecodeBaseProposer:
|
||||
common_attn_metadata._num_computed_tokens_cpu += 1
|
||||
|
||||
# Rebuild attention metadata
|
||||
per_layer_attn_metadata = self.build_per_layer_attn_metadata(
|
||||
_, per_layer_attn_metadata = self.build_per_group_and_layer_attn_metadata(
|
||||
common_attn_metadata, draft_index=token_index + 1
|
||||
)
|
||||
|
||||
@@ -809,17 +809,19 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
return model_kwargs, num_input_tokens
|
||||
|
||||
def build_per_layer_attn_metadata(
|
||||
def build_per_group_and_layer_attn_metadata(
|
||||
self, common_attn_metadata: CommonAttentionMetadata, draft_index: int = 0
|
||||
) -> dict[str, object]:
|
||||
) -> tuple[list[object], dict[str, object]]:
|
||||
per_group_attn_metadata: list[object] = []
|
||||
per_layer_attn_metadata: dict[str, object] = {}
|
||||
for attn_group in self.draft_attn_groups:
|
||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=draft_index
|
||||
)
|
||||
per_group_attn_metadata.append(attn_metadata)
|
||||
for layer_name in attn_group.layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
return per_layer_attn_metadata
|
||||
return per_group_attn_metadata, per_layer_attn_metadata
|
||||
|
||||
def model_returns_tuple(self) -> bool:
|
||||
return self.method not in ("mtp", "draft_model", "dflash")
|
||||
|
||||
Reference in New Issue
Block a user