[Bugfix] Get a specific type of layer from forward context (#17222)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-04-27 15:58:05 +08:00
committed by GitHub
parent 4283a28c2f
commit 838cedade7
5 changed files with 28 additions and 23 deletions

View File

@@ -3445,7 +3445,8 @@ class CompilationConfig(BaseModel):
compilation_time: float = PrivateAttr
# Per-model forward context
# Map from layer name to the attention cls
# Map from layer name to layer objects that need to be accessed outside
# model code, e.g., Attention, FusedMOE when dp_size>1.
static_forward_context: dict[str, Any] = PrivateAttr
def compute_hash(self) -> str:
@@ -4079,3 +4080,16 @@ def assert_hashable(text):
f"vLLM tried to hash some configs that may have Python objects ids "
f"in them. This is a bug, please file an issue. "
f"Text being hashed: {text}")
T = TypeVar("T")
def get_layers_from_vllm_config(vllm_config: VllmConfig,
layer_type: type[T]) -> dict[str, T]:
return {
layer_name: layer
for layer_name, layer in
vllm_config.compilation_config.static_forward_context.items()
if isinstance(layer, layer_type)
}