[V1][Spec Decode] Eagle Model loading (#16035)
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
@@ -414,7 +414,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
return ((source.prefix + name, tensor)
|
||||
for (name, tensor) in weights_iterator)
|
||||
|
||||
def _get_all_weights(
|
||||
def get_all_weights(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
model: nn.Module,
|
||||
@@ -453,7 +453,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(
|
||||
self._get_all_weights(model_config, model))
|
||||
self.get_all_weights(model_config, model))
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
"Loading weights took %.2f seconds",
|
||||
|
||||
Reference in New Issue
Block a user