[Optimization] Avoid repeated model architecture conversion for pooling models (#25261)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -165,7 +165,11 @@ def device_loading_context(module: torch.nn.Module,
|
||||
# New parameters or parameters already on target device are untouched
|
||||
|
||||
|
||||
def get_model_architecture(
|
||||
_MODEL_ARCH_BY_HASH = dict[str, tuple[type[nn.Module], str]]()
|
||||
"""Caches the outputs of `_get_model_architecture`."""
|
||||
|
||||
|
||||
def _get_model_architecture(
|
||||
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
|
||||
@@ -209,6 +213,17 @@ def get_model_architecture(
|
||||
return model_cls, arch
|
||||
|
||||
|
||||
def get_model_architecture(
|
||||
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
key = model_config.compute_hash()
|
||||
if key in _MODEL_ARCH_BY_HASH:
|
||||
return _MODEL_ARCH_BY_HASH[key]
|
||||
|
||||
model_arch = _get_model_architecture(model_config)
|
||||
_MODEL_ARCH_BY_HASH[key] = model_arch
|
||||
return model_arch
|
||||
|
||||
|
||||
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
|
||||
return get_model_architecture(model_config)[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user