[Core] Refactor model loading code (#4097)
This commit is contained in:
@@ -12,7 +12,12 @@ def test_prepare_prompt(batch_size):
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=False)
|
||||
model_runner = ModelRunner(None, None, scheduler_config, None, None)
|
||||
model_runner = ModelRunner(model_config=None,
|
||||
parallel_config=None,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
prompt_lens = []
|
||||
@@ -118,8 +123,6 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
"facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
@@ -129,8 +132,12 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=False)
|
||||
model_runner = ModelRunner(model_config, None, scheduler_config, None,
|
||||
None)
|
||||
model_runner = ModelRunner(model_config=model_config,
|
||||
parallel_config=None,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
prompt_lens = []
|
||||
@@ -205,14 +212,17 @@ def test_empty_seq_group():
|
||||
"facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
enforce_eager=False,
|
||||
)
|
||||
model_runner = ModelRunner(model_config, None, None, None, None)
|
||||
model_runner = ModelRunner(model_config=model_config,
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
model_runner.set_block_size(16)
|
||||
seq_group_metadata_list = []
|
||||
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
|
||||
@@ -251,8 +261,6 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
|
||||
"facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
@@ -262,11 +270,12 @@ def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=True)
|
||||
model_runner = ModelRunner(model_config,
|
||||
None,
|
||||
scheduler_config,
|
||||
None,
|
||||
None,
|
||||
model_runner = ModelRunner(model_config=model_config,
|
||||
parallel_config=None,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None,
|
||||
is_driver_worker=True)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user