[ROCm][perf] fix Aiter sparse MLA with MTP>1 (#37887)
Signed-off-by: Stig-Arne Grönroos <stig-arne.gronroos@amd.com> Signed-off-by: Stig-Arne Grönroos <sgronroo@amd.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
committed by
GitHub
parent
2e56975657
commit
31a719bcd3
@@ -257,19 +257,19 @@ class DFlashProposer(SpecDecodeBaseProposer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def build_per_layer_attn_metadata(
|
def build_per_group_and_layer_attn_metadata(
|
||||||
self, cad: CommonAttentionMetadata, draft_index: int = 0
|
self, cad: CommonAttentionMetadata, draft_index: int = 0
|
||||||
) -> dict[str, object]:
|
) -> tuple[list[object], dict[str, object]]:
|
||||||
per_layer_attention_metadata = super().build_per_layer_attn_metadata(
|
per_group, per_layer = super().build_per_group_and_layer_attn_metadata(
|
||||||
cad, draft_index
|
cad, draft_index
|
||||||
)
|
)
|
||||||
for layer_name, attn_metadata in per_layer_attention_metadata.items():
|
for layer_name, attn_metadata in per_layer.items():
|
||||||
assert getattr(attn_metadata, "causal", None) is False, (
|
assert getattr(attn_metadata, "causal", None) is False, (
|
||||||
f"Attention metadata for layer {layer_name} does not have"
|
f"Attention metadata for layer {layer_name} does not have"
|
||||||
" non-causal support, which is required for DFlash."
|
" non-causal support, which is required for DFlash."
|
||||||
" Consider using a different attention backend, such as FlashAttention."
|
" Consider using a different attention backend, such as FlashAttention."
|
||||||
)
|
)
|
||||||
return per_layer_attention_metadata
|
return per_group, per_layer
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_eagle3_use_aux_hidden_state_from_config(self):
|
def _get_eagle3_use_aux_hidden_state_from_config(self):
|
||||||
|
|||||||
@@ -225,6 +225,9 @@ class SpecDecodeBaseProposer:
|
|||||||
# Determine allowed attention backends once during initialization.
|
# Determine allowed attention backends once during initialization.
|
||||||
self.allowed_attn_types: tuple | None = None
|
self.allowed_attn_types: tuple | None = None
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
|
from vllm.v1.attention.backends.mla.indexer import (
|
||||||
|
DeepseekV32IndexerMetadata,
|
||||||
|
)
|
||||||
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse import (
|
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse import (
|
||||||
ROCMAiterMLASparseMetadata,
|
ROCMAiterMLASparseMetadata,
|
||||||
)
|
)
|
||||||
@@ -234,6 +237,7 @@ class SpecDecodeBaseProposer:
|
|||||||
TritonAttentionMetadata,
|
TritonAttentionMetadata,
|
||||||
RocmAttentionMetadata,
|
RocmAttentionMetadata,
|
||||||
ROCMAiterMLASparseMetadata,
|
ROCMAiterMLASparseMetadata,
|
||||||
|
DeepseekV32IndexerMetadata,
|
||||||
]
|
]
|
||||||
# ROCM_AITER_FA is an optional backend
|
# ROCM_AITER_FA is an optional backend
|
||||||
# We check is_enabled() here to avoid importing the backend module during
|
# We check is_enabled() here to avoid importing the backend module during
|
||||||
@@ -444,8 +448,8 @@ class SpecDecodeBaseProposer:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
per_layer_attn_metadata = self.build_per_layer_attn_metadata(
|
per_group_attn_metadata, per_layer_attn_metadata = (
|
||||||
common_attn_metadata
|
self.build_per_group_and_layer_attn_metadata(common_attn_metadata)
|
||||||
)
|
)
|
||||||
|
|
||||||
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
|
||||||
@@ -486,10 +490,7 @@ class SpecDecodeBaseProposer:
|
|||||||
positions = self.positions[token_indices_to_sample]
|
positions = self.positions[token_indices_to_sample]
|
||||||
hidden_states = hidden_states[token_indices_to_sample]
|
hidden_states = hidden_states[token_indices_to_sample]
|
||||||
|
|
||||||
if any(
|
if any(isinstance(md, TreeAttentionMetadata) for md in per_group_attn_metadata):
|
||||||
isinstance(attn_metadata, TreeAttentionMetadata)
|
|
||||||
for attn_metadata in per_layer_attn_metadata.values()
|
|
||||||
):
|
|
||||||
# Draft using tree attention - requires full logits for top-k
|
# Draft using tree attention - requires full logits for top-k
|
||||||
logits = self.model.compute_logits(sample_hidden_states)
|
logits = self.model.compute_logits(sample_hidden_states)
|
||||||
draft_token_ids_list = self.propose_tree(
|
draft_token_ids_list = self.propose_tree(
|
||||||
@@ -505,16 +506,15 @@ class SpecDecodeBaseProposer:
|
|||||||
|
|
||||||
draft_token_ids = self._greedy_sample(sample_hidden_states)
|
draft_token_ids = self._greedy_sample(sample_hidden_states)
|
||||||
|
|
||||||
for attn_metadata in per_layer_attn_metadata.values():
|
if self.allowed_attn_types is not None:
|
||||||
if self.allowed_attn_types is not None and not isinstance(
|
for group_md in per_group_attn_metadata:
|
||||||
attn_metadata, self.allowed_attn_types
|
if not isinstance(group_md, self.allowed_attn_types):
|
||||||
):
|
raise ValueError(
|
||||||
raise ValueError(
|
f"Unsupported attention metadata type for speculative "
|
||||||
f"Unsupported attention metadata type for speculative "
|
"decoding with num_speculative_tokens > 1: "
|
||||||
"decoding with num_speculative_tokens > 1: "
|
f"{type(group_md)}. Supported types are: "
|
||||||
f"{type(attn_metadata)}. Supported types are: "
|
f"{self.allowed_attn_types}"
|
||||||
f"{self.allowed_attn_types}"
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Generate the remaining draft tokens.
|
# Generate the remaining draft tokens.
|
||||||
draft_token_ids_list = [draft_token_ids]
|
draft_token_ids_list = [draft_token_ids]
|
||||||
@@ -595,7 +595,7 @@ class SpecDecodeBaseProposer:
|
|||||||
common_attn_metadata._num_computed_tokens_cpu += 1
|
common_attn_metadata._num_computed_tokens_cpu += 1
|
||||||
|
|
||||||
# Rebuild attention metadata
|
# Rebuild attention metadata
|
||||||
per_layer_attn_metadata = self.build_per_layer_attn_metadata(
|
_, per_layer_attn_metadata = self.build_per_group_and_layer_attn_metadata(
|
||||||
common_attn_metadata, draft_index=token_index + 1
|
common_attn_metadata, draft_index=token_index + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -809,17 +809,19 @@ class SpecDecodeBaseProposer:
|
|||||||
|
|
||||||
return model_kwargs, num_input_tokens
|
return model_kwargs, num_input_tokens
|
||||||
|
|
||||||
def build_per_layer_attn_metadata(
|
def build_per_group_and_layer_attn_metadata(
|
||||||
self, common_attn_metadata: CommonAttentionMetadata, draft_index: int = 0
|
self, common_attn_metadata: CommonAttentionMetadata, draft_index: int = 0
|
||||||
) -> dict[str, object]:
|
) -> tuple[list[object], dict[str, object]]:
|
||||||
|
per_group_attn_metadata: list[object] = []
|
||||||
per_layer_attn_metadata: dict[str, object] = {}
|
per_layer_attn_metadata: dict[str, object] = {}
|
||||||
for attn_group in self.draft_attn_groups:
|
for attn_group in self.draft_attn_groups:
|
||||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||||
common_attn_metadata=common_attn_metadata, draft_index=draft_index
|
common_attn_metadata=common_attn_metadata, draft_index=draft_index
|
||||||
)
|
)
|
||||||
|
per_group_attn_metadata.append(attn_metadata)
|
||||||
for layer_name in attn_group.layer_names:
|
for layer_name in attn_group.layer_names:
|
||||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||||
return per_layer_attn_metadata
|
return per_group_attn_metadata, per_layer_attn_metadata
|
||||||
|
|
||||||
def model_returns_tuple(self) -> bool:
|
def model_returns_tuple(self) -> bool:
|
||||||
return self.method not in ("mtp", "draft_model", "dflash")
|
return self.method not in ("mtp", "draft_model", "dflash")
|
||||||
|
|||||||
Reference in New Issue
Block a user