[ROCm][perf] fix Aiter sparse MLA with MTP>1 (#37887)

Signed-off-by: Stig-Arne Grönroos <stig-arne.gronroos@amd.com>
Signed-off-by: Stig-Arne Grönroos <sgronroo@amd.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Stig-Arne Grönroos
2026-04-01 02:22:23 +03:00
committed by GitHub
parent 2e56975657
commit 31a719bcd3
2 changed files with 27 additions and 25 deletions

View File

@@ -257,19 +257,19 @@ class DFlashProposer(SpecDecodeBaseProposer):
)
@override
def build_per_layer_attn_metadata(
def build_per_group_and_layer_attn_metadata(
self, cad: CommonAttentionMetadata, draft_index: int = 0
) -> dict[str, object]:
per_layer_attention_metadata = super().build_per_layer_attn_metadata(
) -> tuple[list[object], dict[str, object]]:
per_group, per_layer = super().build_per_group_and_layer_attn_metadata(
cad, draft_index
)
for layer_name, attn_metadata in per_layer_attention_metadata.items():
for layer_name, attn_metadata in per_layer.items():
assert getattr(attn_metadata, "causal", None) is False, (
f"Attention metadata for layer {layer_name} does not have"
" non-causal support, which is required for DFlash."
" Consider using a different attention backend, such as FlashAttention."
)
return per_layer_attention_metadata
return per_group, per_layer
@override
def _get_eagle3_use_aux_hidden_state_from_config(self):

View File

@@ -225,6 +225,9 @@ class SpecDecodeBaseProposer:
# Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple | None = None
if current_platform.is_rocm():
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata,
)
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse import (
ROCMAiterMLASparseMetadata,
)
@@ -234,6 +237,7 @@ class SpecDecodeBaseProposer:
TritonAttentionMetadata,
RocmAttentionMetadata,
ROCMAiterMLASparseMetadata,
DeepseekV32IndexerMetadata,
]
# ROCM_AITER_FA is an optional backend
# We check is_enabled() here to avoid importing the backend module during
@@ -444,8 +448,8 @@ class SpecDecodeBaseProposer:
)
)
per_layer_attn_metadata = self.build_per_layer_attn_metadata(
common_attn_metadata
per_group_attn_metadata, per_layer_attn_metadata = (
self.build_per_group_and_layer_attn_metadata(common_attn_metadata)
)
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
@@ -486,10 +490,7 @@ class SpecDecodeBaseProposer:
positions = self.positions[token_indices_to_sample]
hidden_states = hidden_states[token_indices_to_sample]
if any(
isinstance(attn_metadata, TreeAttentionMetadata)
for attn_metadata in per_layer_attn_metadata.values()
):
if any(isinstance(md, TreeAttentionMetadata) for md in per_group_attn_metadata):
# Draft using tree attention - requires full logits for top-k
logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids_list = self.propose_tree(
@@ -505,14 +506,13 @@ class SpecDecodeBaseProposer:
draft_token_ids = self._greedy_sample(sample_hidden_states)
for attn_metadata in per_layer_attn_metadata.values():
if self.allowed_attn_types is not None and not isinstance(
attn_metadata, self.allowed_attn_types
):
if self.allowed_attn_types is not None:
for group_md in per_group_attn_metadata:
if not isinstance(group_md, self.allowed_attn_types):
raise ValueError(
f"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens > 1: "
f"{type(attn_metadata)}. Supported types are: "
f"{type(group_md)}. Supported types are: "
f"{self.allowed_attn_types}"
)
@@ -595,7 +595,7 @@ class SpecDecodeBaseProposer:
common_attn_metadata._num_computed_tokens_cpu += 1
# Rebuild attention metadata
per_layer_attn_metadata = self.build_per_layer_attn_metadata(
_, per_layer_attn_metadata = self.build_per_group_and_layer_attn_metadata(
common_attn_metadata, draft_index=token_index + 1
)
@@ -809,17 +809,19 @@ class SpecDecodeBaseProposer:
return model_kwargs, num_input_tokens
def build_per_layer_attn_metadata(
def build_per_group_and_layer_attn_metadata(
self, common_attn_metadata: CommonAttentionMetadata, draft_index: int = 0
) -> dict[str, object]:
) -> tuple[list[object], dict[str, object]]:
per_group_attn_metadata: list[object] = []
per_layer_attn_metadata: dict[str, object] = {}
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=draft_index
)
per_group_attn_metadata.append(attn_metadata)
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
return per_layer_attn_metadata
return per_group_attn_metadata, per_layer_attn_metadata
def model_returns_tuple(self) -> bool:
return self.method not in ("mtp", "draft_model", "dflash")