[Core] Refactor model loading code (#4097)

This commit is contained in:
Antoni Baum
2024-04-16 11:34:39 -07:00
committed by GitHub
parent 05434764cd
commit 69e1d2fb69
67 changed files with 1054 additions and 963 deletions

View File

@@ -12,7 +12,12 @@ def test_prepare_prompt(batch_size):
100000,
100000,
enable_chunked_prefill=False)
model_runner = ModelRunner(None, None, scheduler_config, None, None)
model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=scheduler_config,
device_config=None,
load_config=None,
lora_config=None)
model_runner.set_block_size(16)
prompt_lens = []
@@ -118,8 +123,6 @@ def test_prepare_decode_cuda_graph(batch_size):
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
@@ -129,8 +132,12 @@ def test_prepare_decode_cuda_graph(batch_size):
100000,
100000,
enable_chunked_prefill=False)
model_runner = ModelRunner(model_config, None, scheduler_config, None,
None)
model_runner = ModelRunner(model_config=model_config,
parallel_config=None,
scheduler_config=scheduler_config,
device_config=None,
load_config=None,
lora_config=None)
model_runner.set_block_size(16)
prompt_lens = []
@@ -205,14 +212,17 @@ def test_empty_seq_group():
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
enforce_eager=False,
)
model_runner = ModelRunner(model_config, None, None, None, None)
model_runner = ModelRunner(model_config=model_config,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
model_runner.set_block_size(16)
seq_group_metadata_list = []
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
@@ -251,8 +261,6 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
@@ -262,11 +270,12 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
100000,
100000,
enable_chunked_prefill=True)
model_runner = ModelRunner(model_config,
None,
scheduler_config,
None,
None,
model_runner = ModelRunner(model_config=model_config,
parallel_config=None,
scheduler_config=scheduler_config,
device_config=None,
load_config=None,
lora_config=None,
is_driver_worker=True)
model_runner.set_block_size(16)