[Bugfix] Get a specific type of layer from forward context (#17222)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -12,13 +12,13 @@ import torch.nn as nn
|
||||
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@@ -1733,17 +1733,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
||||
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
if isinstance(attn_module, FusedMoE):
|
||||
continue
|
||||
|
||||
# TODO: Support other attention modules, e.g., sliding window,
|
||||
# cross-attention
|
||||
assert isinstance(attn_module, Attention)
|
||||
for layer_name, attn_module in layers.items():
|
||||
# TODO: Support other attention modules, e.g., cross-attention
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
|
||||
@@ -17,7 +17,7 @@ import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@@ -429,11 +429,10 @@ class TPUModelRunner:
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
||||
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
assert isinstance(attn_module, Attention)
|
||||
for layer_name, attn_module in layers.items():
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
|
||||
Reference in New Issue
Block a user