[FIXBUG ] Allow disabling rocm_aiter_fa backend for ROCm GPUs not compatible with AITER (#22795)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
JartX
2025-08-20 18:08:29 +02:00
committed by GitHub
parent d6d13bd49e
commit 3b11b26b50

View File

@@ -2,7 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast import ast
from dataclasses import replace from dataclasses import replace
from typing import Optional from importlib.util import find_spec
from typing import Optional, Protocol
import numpy as np import numpy as np
import torch import torch
@@ -20,8 +21,6 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder) TreeAttentionMetadataBuilder)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
@@ -34,6 +33,17 @@ logger = init_logger(__name__)
PADDING_SLOT_ID = -1 PADDING_SLOT_ID = -1
class EagleAttentionMetadata(Protocol):
# Required attributes
num_actual_tokens: int
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
class EagleProposer: class EagleProposer:
def __init__( def __init__(
@@ -97,6 +107,20 @@ class EagleProposer:
dtype=self.dtype, dtype=self.dtype,
device=device) device=device)
# Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...]
if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"):
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
rocm_types.append(AiterFlashAttentionMetadata)
self.allowed_attn_types = tuple(rocm_types)
else:
self.allowed_attn_types = (FlashAttentionMetadata,
TreeAttentionMetadata)
# Parse the speculative token tree. # Parse the speculative token tree.
spec_token_tree = self.speculative_config.speculative_token_tree spec_token_tree = self.speculative_config.speculative_token_tree
self.tree_choices: list[tuple[int, self.tree_choices: list[tuple[int,
@@ -225,19 +249,7 @@ class EagleProposer:
# TODO: Currently, MTP module released by deepseek only has # TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once # one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module. # there's a multi-layer MTP module.
assert isinstance(attn_metadata, self.allowed_attn_types)
# On ROCm, both AiterFlashAttention and TritonAttention
# support multi-token eagle spec decode.
if current_platform.is_rocm():
assert isinstance(
attn_metadata,
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
FlashAttentionMetadata))
else:
# Currently, only FlashAttention supports multi-token eagle spec
# decode. This is because the code below makes assumptions about
# attn_metadata attributes available.
assert isinstance(attn_metadata, FlashAttentionMetadata)
# Generate the remaining draft tokens. # Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids] draft_token_ids_list = [draft_token_ids]
@@ -619,17 +631,15 @@ class EagleProposer:
and self.model.model.embed_tokens.weight.shape \ and self.model.model.embed_tokens.weight.shape \
== target_language_model.model.embed_tokens.weight.shape: == target_language_model.model.embed_tokens.weight.shape:
logger.info( logger.info(
"Assuming the EAGLE head shares the same vocab embedding" \ "Assuming the EAGLE head shares the same vocab embedding"
" with the target model." " with the target model.")
)
del self.model.model.embed_tokens del self.model.model.embed_tokens
self.model.model.embed_tokens = ( self.model.model.embed_tokens = (
target_language_model.model.embed_tokens) target_language_model.model.embed_tokens)
else: else:
logger.info( logger.info(
"The EAGLE head's vocab embedding will be loaded separately" \ "The EAGLE head's vocab embedding will be loaded separately"
" from the target model." " from the target model.")
)
# share lm_head with the target model if needed # share lm_head with the target model if needed
# some model definition do not define lm_head explicitly # some model definition do not define lm_head explicitly