[Bugfix] Get a specific type of layer from forward context (#17222)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -3445,7 +3445,8 @@ class CompilationConfig(BaseModel):
|
||||
compilation_time: float = PrivateAttr
|
||||
|
||||
# Per-model forward context
|
||||
# Map from layer name to the attention cls
|
||||
# Map from layer name to layer objects that need to be accessed outside
|
||||
# model code, e.g., Attention, FusedMOE when dp_size>1.
|
||||
static_forward_context: dict[str, Any] = PrivateAttr
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
@@ -4079,3 +4080,16 @@ def assert_hashable(text):
|
||||
f"vLLM tried to hash some configs that may have Python objects ids "
|
||||
f"in them. This is a bug, please file an issue. "
|
||||
f"Text being hashed: {text}")
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_layers_from_vllm_config(vllm_config: VllmConfig,
|
||||
layer_type: type[T]) -> dict[str, T]:
|
||||
return {
|
||||
layer_name: layer
|
||||
for layer_name, layer in
|
||||
vllm_config.compilation_config.static_forward_context.items()
|
||||
if isinstance(layer, layer_type)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user