[FIXBUG ] Allow disabling rocm_aiter_fa backend for ROCm GPUs not compatible with AITER (#22795)
Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -2,7 +2,8 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import ast
|
import ast
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from typing import Optional
|
from importlib.util import find_spec
|
||||||
|
from typing import Optional, Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -20,8 +21,6 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
|
||||||
AiterFlashAttentionMetadata)
|
|
||||||
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
||||||
TreeAttentionMetadataBuilder)
|
TreeAttentionMetadataBuilder)
|
||||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||||
@@ -34,6 +33,17 @@ logger = init_logger(__name__)
|
|||||||
PADDING_SLOT_ID = -1
|
PADDING_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
|
class EagleAttentionMetadata(Protocol):
|
||||||
|
# Required attributes
|
||||||
|
num_actual_tokens: int
|
||||||
|
max_query_len: int
|
||||||
|
query_start_loc: torch.Tensor
|
||||||
|
max_seq_len: int
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
block_table: torch.Tensor
|
||||||
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class EagleProposer:
|
class EagleProposer:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -97,6 +107,20 @@ class EagleProposer:
|
|||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
|
# Determine allowed attention backends once during initialization.
|
||||||
|
self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...]
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
|
||||||
|
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
|
||||||
|
if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"):
|
||||||
|
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
||||||
|
AiterFlashAttentionMetadata)
|
||||||
|
rocm_types.append(AiterFlashAttentionMetadata)
|
||||||
|
self.allowed_attn_types = tuple(rocm_types)
|
||||||
|
else:
|
||||||
|
self.allowed_attn_types = (FlashAttentionMetadata,
|
||||||
|
TreeAttentionMetadata)
|
||||||
|
|
||||||
# Parse the speculative token tree.
|
# Parse the speculative token tree.
|
||||||
spec_token_tree = self.speculative_config.speculative_token_tree
|
spec_token_tree = self.speculative_config.speculative_token_tree
|
||||||
self.tree_choices: list[tuple[int,
|
self.tree_choices: list[tuple[int,
|
||||||
@@ -225,19 +249,7 @@ class EagleProposer:
|
|||||||
# TODO: Currently, MTP module released by deepseek only has
|
# TODO: Currently, MTP module released by deepseek only has
|
||||||
# one layer. Adapt this code to support multiple layers once
|
# one layer. Adapt this code to support multiple layers once
|
||||||
# there's a multi-layer MTP module.
|
# there's a multi-layer MTP module.
|
||||||
|
assert isinstance(attn_metadata, self.allowed_attn_types)
|
||||||
# On ROCm, both AiterFlashAttention and TritonAttention
|
|
||||||
# support multi-token eagle spec decode.
|
|
||||||
if current_platform.is_rocm():
|
|
||||||
assert isinstance(
|
|
||||||
attn_metadata,
|
|
||||||
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
|
|
||||||
FlashAttentionMetadata))
|
|
||||||
else:
|
|
||||||
# Currently, only FlashAttention supports multi-token eagle spec
|
|
||||||
# decode. This is because the code below makes assumptions about
|
|
||||||
# attn_metadata attributes available.
|
|
||||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
|
||||||
|
|
||||||
# Generate the remaining draft tokens.
|
# Generate the remaining draft tokens.
|
||||||
draft_token_ids_list = [draft_token_ids]
|
draft_token_ids_list = [draft_token_ids]
|
||||||
@@ -619,17 +631,15 @@ class EagleProposer:
|
|||||||
and self.model.model.embed_tokens.weight.shape \
|
and self.model.model.embed_tokens.weight.shape \
|
||||||
== target_language_model.model.embed_tokens.weight.shape:
|
== target_language_model.model.embed_tokens.weight.shape:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Assuming the EAGLE head shares the same vocab embedding" \
|
"Assuming the EAGLE head shares the same vocab embedding"
|
||||||
" with the target model."
|
" with the target model.")
|
||||||
)
|
|
||||||
del self.model.model.embed_tokens
|
del self.model.model.embed_tokens
|
||||||
self.model.model.embed_tokens = (
|
self.model.model.embed_tokens = (
|
||||||
target_language_model.model.embed_tokens)
|
target_language_model.model.embed_tokens)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"The EAGLE head's vocab embedding will be loaded separately" \
|
"The EAGLE head's vocab embedding will be loaded separately"
|
||||||
" from the target model."
|
" from the target model.")
|
||||||
)
|
|
||||||
|
|
||||||
# share lm_head with the target model if needed
|
# share lm_head with the target model if needed
|
||||||
# some model definition do not define lm_head explicitly
|
# some model definition do not define lm_head explicitly
|
||||||
|
|||||||
Reference in New Issue
Block a user