[V1][Spec Decode] Eagle Model loading (#16035)

Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
Lily Liu
2025-04-10 11:21:48 -07:00
committed by GitHub
parent 9665313c39
commit e8224f3dca
9 changed files with 251 additions and 28 deletions

View File

@@ -414,7 +414,7 @@ class DefaultModelLoader(BaseModelLoader):
return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator)
def _get_all_weights(
def get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
@@ -453,7 +453,7 @@ class DefaultModelLoader(BaseModelLoader):
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self._get_all_weights(model_config, model))
self.get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",