[Bugfix] Get a specific type of layer from forward context (#17222)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -14,7 +14,8 @@ import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config import (VllmConfig, get_current_vllm_config,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
|
||||
@@ -81,12 +82,10 @@ def get_per_layer_parameters(
|
||||
to use during `plan`.
|
||||
"""
|
||||
|
||||
layers = vllm_config.compilation_config.static_forward_context
|
||||
layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
per_layer_params: dict[str, PerLayerParameters] = {}
|
||||
|
||||
for key, layer in layers.items():
|
||||
assert isinstance(layer, Attention)
|
||||
|
||||
impl = layer.impl
|
||||
assert isinstance(impl, FlashInferImpl)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user