[Model Runner V2] fix draft attention metadata generation (#37364)
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
@@ -30,7 +30,10 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
|
||||
|
||||
def init_attn_backend(
|
||||
kv_cache_config: KVCacheConfig, vllm_config: VllmConfig, device: torch.device
|
||||
kv_cache_config: KVCacheConfig,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
active_layer_names: set[str] | None = None,
|
||||
):
|
||||
attn_backends: dict[str, type[AttentionBackend]] = {}
|
||||
attn_groups: list[list[AttentionGroup]] = []
|
||||
@@ -39,6 +42,8 @@ def init_attn_backend(
|
||||
kv_cache_config.kv_cache_groups
|
||||
):
|
||||
layer_names = kv_cache_group_spec.layer_names
|
||||
if active_layer_names is not None:
|
||||
layer_names = list(active_layer_names.intersection(layer_names))
|
||||
|
||||
layer_type = cast(type[Any], AttentionLayerBase)
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, layer_type, layer_names)
|
||||
|
||||
@@ -350,7 +350,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.speculator.set_attn(
|
||||
self.model_state,
|
||||
self.kv_cache_config,
|
||||
self.attn_groups,
|
||||
self.block_tables,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,15 +5,17 @@ from typing import Any
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
init_attn_backend,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
|
||||
@@ -22,7 +24,6 @@ from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -87,18 +88,35 @@ class EagleSpeculator:
|
||||
)
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
target_attn_layer_names = get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
).keys()
|
||||
|
||||
self.model = load_eagle_model(target_model, self.vllm_config)
|
||||
|
||||
all_attn_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
).keys()
|
||||
self.draft_attn_layer_names = set(all_attn_layers) - set(
|
||||
target_attn_layer_names
|
||||
)
|
||||
|
||||
def set_attn(
|
||||
self,
|
||||
model_state: ModelState,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
block_tables: BlockTables,
|
||||
) -> None:
|
||||
self.model_state = model_state
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.attn_groups = attn_groups
|
||||
_, self.attn_groups = init_attn_backend(
|
||||
kv_cache_config,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
active_layer_names=self.draft_attn_layer_names,
|
||||
)
|
||||
self.block_tables = block_tables
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
Reference in New Issue
Block a user