[Bugfix] Get a specific type of layer from forward context (#17222)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-04-27 15:58:05 +08:00
committed by GitHub
parent 4283a28c2f
commit 838cedade7
5 changed files with 28 additions and 23 deletions

View File

@@ -14,7 +14,8 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config import (VllmConfig, get_current_vllm_config,
get_layers_from_vllm_config)
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
@@ -81,12 +82,10 @@ def get_per_layer_parameters(
to use during `plan`.
"""
layers = vllm_config.compilation_config.static_forward_context
layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
assert isinstance(layer, Attention)
impl = layer.impl
assert isinstance(impl, FlashInferImpl)