New weight loader without np copy (#52)

This commit is contained in:
Zhuohan Li
2023-05-03 15:32:04 +08:00
committed by GitHub
parent 4858f3bb45
commit 27f1410d06
12 changed files with 284 additions and 352 deletions

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Optional
import torch
@@ -28,8 +28,9 @@ class Worker:
distributed_init_method: str,
rank: int,
world_size: int,
model_path: str,
cache_dir: Optional[str],
use_dummy_weights: bool,
use_np_cache: bool,
max_num_batched_tokens: int,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
@@ -45,7 +46,8 @@ class Worker:
# Initialize the model.
self.model, self.dtype = get_model(
model_name, dtype=dtype, path=model_path, use_dummy_weights=use_dummy_weights)
model_name, dtype=dtype, cache_dir=cache_dir,
use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache)
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
initialize_all_reduce_launcher(