Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -24,7 +24,7 @@ SSM_MODELS = [
|
||||
"tiiuae/falcon-mamba-tiny-dev",
|
||||
# mamba2-codestral in transformers is broken pending:
|
||||
# https://github.com/huggingface/transformers/pull/40861
|
||||
#"yujiepan/mamba2-codestral-v0.1-tiny-random",
|
||||
# "yujiepan/mamba2-codestral-v0.1-tiny-random",
|
||||
]
|
||||
|
||||
HYBRID_MODELS = [
|
||||
@@ -65,7 +65,6 @@ def test_models(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@@ -75,11 +74,13 @@ def test_models(
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
@@ -109,13 +110,14 @@ def test_batching(
|
||||
for_loop_outputs = []
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
for prompt in example_prompts:
|
||||
single_output, = vllm_model.generate_greedy_logprobs([prompt],
|
||||
max_tokens,
|
||||
num_logprobs)
|
||||
(single_output,) = vllm_model.generate_greedy_logprobs(
|
||||
[prompt], max_tokens, num_logprobs
|
||||
)
|
||||
for_loop_outputs.append(single_output)
|
||||
|
||||
batched_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=for_loop_outputs,
|
||||
@@ -134,8 +136,8 @@ def test_chunked_prefill_with_parallel_sampling(
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Tests chunked prefill in conjunction with n > 1.
|
||||
|
||||
Tests chunked prefill in conjunction with n > 1.
|
||||
|
||||
In this case, prefill is populated with decoding tokens and
|
||||
we test that it doesn't fail.
|
||||
|
||||
@@ -143,16 +145,13 @@ def test_chunked_prefill_with_parallel_sampling(
|
||||
decoding steps inside a chunked prefill forward pass
|
||||
(where we have both prefill and decode together)
|
||||
"""
|
||||
sampling_params = SamplingParams(n=3,
|
||||
temperature=1,
|
||||
seed=0,
|
||||
max_tokens=max_tokens)
|
||||
sampling_params = SamplingParams(n=3, temperature=1, seed=0, max_tokens=max_tokens)
|
||||
with vllm_runner(
|
||||
model,
|
||||
enable_chunked_prefill=True,
|
||||
# forces prefill chunks with decoding
|
||||
max_num_batched_tokens=MAX_NUM_SEQS * 3,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
model,
|
||||
enable_chunked_prefill=True,
|
||||
# forces prefill chunks with decoding
|
||||
max_num_batched_tokens=MAX_NUM_SEQS * 3,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
) as vllm_model:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
@@ -170,10 +169,8 @@ def test_mamba_cache_cg_padding(
|
||||
batch size. If it's not, a torch RuntimeError will be raised because
|
||||
tensor dimensions aren't compatible.
|
||||
"""
|
||||
vllm_config = EngineArgs(model=model,
|
||||
trust_remote_code=True).create_engine_config()
|
||||
while len(example_prompts) == vllm_config.pad_for_cudagraph(
|
||||
len(example_prompts)):
|
||||
vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config()
|
||||
while len(example_prompts) == vllm_config.pad_for_cudagraph(len(example_prompts)):
|
||||
example_prompts.append(example_prompts[0])
|
||||
|
||||
try:
|
||||
@@ -183,7 +180,8 @@ def test_mamba_cache_cg_padding(
|
||||
pytest.fail(
|
||||
"Couldn't run batch size which is not equal to a Cuda Graph "
|
||||
"captured batch size. "
|
||||
"Could be related to mamba cache not padded correctly")
|
||||
"Could be related to mamba cache not padded correctly"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@@ -205,8 +203,10 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
|
||||
except ValueError:
|
||||
pytest.fail("Hybrid inner state wasn't cleaned up properly between"
|
||||
"steps finished requests registered unnecessarily ")
|
||||
pytest.fail(
|
||||
"Hybrid inner state wasn't cleaned up properly between"
|
||||
"steps finished requests registered unnecessarily "
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@@ -215,10 +215,10 @@ def test_state_cleanup(
|
||||
example_prompts,
|
||||
model: str,
|
||||
) -> None:
|
||||
"""
|
||||
"""
|
||||
This test is for verifying that the Hybrid state is cleaned up between
|
||||
steps.
|
||||
|
||||
|
||||
If it's not cleaned, an error would be expected.
|
||||
"""
|
||||
try:
|
||||
@@ -226,8 +226,10 @@ def test_state_cleanup(
|
||||
for _ in range(10):
|
||||
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
|
||||
except ValueError:
|
||||
pytest.fail("Hybrid inner state wasn't cleaned up between states, "
|
||||
"could be related to finished_requests_ids")
|
||||
pytest.fail(
|
||||
"Hybrid inner state wasn't cleaned up between states, "
|
||||
"could be related to finished_requests_ids"
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@@ -241,15 +243,19 @@ def test_distributed_correctness(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
with vllm_runner(model, tensor_parallel_size=1,
|
||||
max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
with vllm_runner(
|
||||
model, tensor_parallel_size=1, max_num_seqs=MAX_NUM_SEQS
|
||||
) as vllm_model:
|
||||
vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
with vllm_runner(model, tensor_parallel_size=2,
|
||||
max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
with vllm_runner(
|
||||
model, tensor_parallel_size=2, max_num_seqs=MAX_NUM_SEQS
|
||||
) as vllm_model:
|
||||
vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=vllm_outputs_tp_1,
|
||||
@@ -271,7 +277,6 @@ def test_full_cuda_graph(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@@ -281,11 +286,13 @@ def test_full_cuda_graph(
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
@@ -298,8 +305,9 @@ def test_full_cuda_graph(
|
||||
@pytest.mark.parametrize("model", FP32_STATE_MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("cache_dtype_param",
|
||||
["mamba_ssm_cache_dtype", "mamba_cache_dtype"])
|
||||
@pytest.mark.parametrize(
|
||||
"cache_dtype_param", ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]
|
||||
)
|
||||
def test_fp32_cache_state(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
@@ -310,7 +318,6 @@ def test_fp32_cache_state(
|
||||
num_logprobs: int,
|
||||
cache_dtype_param: str,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@@ -320,13 +327,15 @@ def test_fp32_cache_state(
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
**{cache_dtype_param: "float32"}) as vllm_model:
|
||||
with vllm_runner(
|
||||
model, max_num_seqs=MAX_NUM_SEQS, **{cache_dtype_param: "float32"}
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
example_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
@@ -339,21 +348,23 @@ def test_fp32_cache_state(
|
||||
# Helper functions for the APC tests
|
||||
def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1):
|
||||
return {
|
||||
'model_name': model,
|
||||
'enable_prefix_caching': False,
|
||||
'max_model_len': max_model_len,
|
||||
'tensor_parallel_size': tensor_parallel_size,
|
||||
'gpu_memory_utilization': 0.4
|
||||
"model_name": model,
|
||||
"enable_prefix_caching": False,
|
||||
"max_model_len": max_model_len,
|
||||
"tensor_parallel_size": tensor_parallel_size,
|
||||
"gpu_memory_utilization": 0.4,
|
||||
}
|
||||
|
||||
|
||||
def _get_vLLM_output(vllm_runner,
|
||||
kwargs,
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
num_repetitions=1,
|
||||
vllm_model=None):
|
||||
def _get_vLLM_output(
|
||||
vllm_runner,
|
||||
kwargs,
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
num_repetitions=1,
|
||||
vllm_model=None,
|
||||
):
|
||||
outs = []
|
||||
if vllm_model is None:
|
||||
vllm_model = vllm_runner(**kwargs)
|
||||
@@ -362,7 +373,8 @@ def _get_vLLM_output(vllm_runner,
|
||||
vllm_output = vllm_model.generate_greedy(prompts, max_tokens)
|
||||
else:
|
||||
vllm_output = vllm_model.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs)
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
outs.append(vllm_output)
|
||||
|
||||
return outs, vllm_model
|
||||
@@ -387,7 +399,6 @@ def test_apc_single_prompt(
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@@ -395,29 +406,33 @@ def test_apc_single_prompt(
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
compare_operator: Callable = check_logprobs_close \
|
||||
if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
compare_operator: Callable = (
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts.
|
||||
generated_prompts = [MULTIPLE * example_prompts[0]]
|
||||
|
||||
max_model_len = max(
|
||||
len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
|
||||
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts, max_tokens,
|
||||
num_logprobs)
|
||||
model, max_model_len, tensor_parallel_size=tensor_parallel_size
|
||||
)
|
||||
vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(
|
||||
vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
vllm_runner_kwargs['enable_prefix_caching'] = True
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts, max_tokens,
|
||||
num_logprobs, n_repetitions)
|
||||
vllm_runner_kwargs["enable_prefix_caching"] = True
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(
|
||||
vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
n_repetitions,
|
||||
)
|
||||
|
||||
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
|
||||
# In the first repetition, the caches are filled
|
||||
@@ -450,7 +465,6 @@ def test_apc_single_prompt_block_align_alignment(
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@@ -458,30 +472,29 @@ def test_apc_single_prompt_block_align_alignment(
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
compare_operator: Callable = check_logprobs_close \
|
||||
if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
compare_operator: Callable = (
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts. This custom prompt is used, as it causes the most issues
|
||||
generated_prompts = ["The president of the United States is " * MULTIPLE]
|
||||
|
||||
max_model_len = max(
|
||||
len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
|
||||
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
|
||||
model, max_model_len, tensor_parallel_size=tensor_parallel_size
|
||||
)
|
||||
vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
|
||||
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts, max_tokens,
|
||||
num_logprobs)
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(
|
||||
vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
vllm_runner_kwargs['enable_prefix_caching'] = True
|
||||
vllm_runner_kwargs["enable_prefix_caching"] = True
|
||||
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
|
||||
# Retrieve the default mamba state block size
|
||||
mamba_block_size = vllm_model.llm.llm_engine.cache_config. \
|
||||
mamba_block_size
|
||||
mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size
|
||||
|
||||
# In case the hybrid model does not have the
|
||||
# "mamba_block_size" assume a fixed constant
|
||||
@@ -489,18 +502,18 @@ def test_apc_single_prompt_block_align_alignment(
|
||||
mamba_block_size = 512
|
||||
|
||||
mamba_block_size_multiplier = 10
|
||||
for offsets in [
|
||||
-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3
|
||||
]:
|
||||
|
||||
vllm_runner_kwargs[
|
||||
'max_num_batched_tokens'] = mamba_block_size_multiplier * \
|
||||
mamba_block_size - offsets
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts,
|
||||
max_tokens, num_logprobs,
|
||||
n_repetitions)
|
||||
for offsets in [-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3]:
|
||||
vllm_runner_kwargs["max_num_batched_tokens"] = (
|
||||
mamba_block_size_multiplier * mamba_block_size - offsets
|
||||
)
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(
|
||||
vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
n_repetitions,
|
||||
)
|
||||
|
||||
# Check alignment of the output logits when using APC
|
||||
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
|
||||
@@ -534,7 +547,6 @@ def test_apc_multiple_prompts_all_cached_outputs(
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@@ -542,30 +554,34 @@ def test_apc_multiple_prompts_all_cached_outputs(
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
compare_operator: Callable = check_logprobs_close \
|
||||
if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
compare_operator: Callable = (
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts.
|
||||
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
|
||||
|
||||
max_model_len = max(
|
||||
len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
|
||||
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
|
||||
model, max_model_len, tensor_parallel_size=tensor_parallel_size
|
||||
)
|
||||
vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
|
||||
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts, max_tokens,
|
||||
num_logprobs)
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(
|
||||
vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
vllm_runner_kwargs['enable_prefix_caching'] = True
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts, max_tokens,
|
||||
num_logprobs, n_repetitions)
|
||||
vllm_runner_kwargs["enable_prefix_caching"] = True
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(
|
||||
vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
n_repetitions,
|
||||
)
|
||||
|
||||
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
|
||||
# In the first repetition, the caches are filled
|
||||
@@ -598,7 +614,6 @@ def test_apc_multiple_prompts_block_align_alignment(
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@@ -606,34 +621,31 @@ def test_apc_multiple_prompts_block_align_alignment(
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
compare_operator: Callable = check_logprobs_close \
|
||||
if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
compare_operator: Callable = (
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts. This custom prompt is used, as it causes the most issues
|
||||
prompt_text = "The president of the United States is "
|
||||
prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31]
|
||||
generated_prompts = [
|
||||
prompt_text[offset:] * MULTIPLE for offset in prompt_offsets
|
||||
]
|
||||
generated_prompts = [prompt_text[offset:] * MULTIPLE for offset in prompt_offsets]
|
||||
|
||||
max_model_len = max(
|
||||
len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(model, max_model_len,
|
||||
tensor_parallel_size)
|
||||
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
model, max_model_len, tensor_parallel_size
|
||||
)
|
||||
vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
|
||||
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts, max_tokens,
|
||||
num_logprobs)
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(
|
||||
vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
vllm_runner_kwargs['enable_prefix_caching'] = True
|
||||
vllm_runner_kwargs["enable_prefix_caching"] = True
|
||||
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
|
||||
# Retrieve the default mamba state block size
|
||||
mamba_block_size = vllm_model.llm.llm_engine.cache_config. \
|
||||
mamba_block_size
|
||||
mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size
|
||||
|
||||
# In case the hybrid model does not have the
|
||||
# "mamba_block_size" assume a fixed constant
|
||||
@@ -641,18 +653,18 @@ def test_apc_multiple_prompts_block_align_alignment(
|
||||
mamba_block_size = 512
|
||||
|
||||
mamba_block_size_multiplier = 10
|
||||
for offsets in [
|
||||
-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3
|
||||
]:
|
||||
|
||||
vllm_runner_kwargs[
|
||||
'max_num_batched_tokens'] = mamba_block_size_multiplier * \
|
||||
mamba_block_size - offsets
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts,
|
||||
max_tokens, num_logprobs,
|
||||
n_repetitions)
|
||||
for offsets in [-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3]:
|
||||
vllm_runner_kwargs["max_num_batched_tokens"] = (
|
||||
mamba_block_size_multiplier * mamba_block_size - offsets
|
||||
)
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(
|
||||
vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
n_repetitions,
|
||||
)
|
||||
|
||||
# Check alignment of the output logits when using APC
|
||||
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
|
||||
@@ -686,7 +698,6 @@ def test_apc_multiple_prompts_partial_cached_outputs(
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@@ -694,30 +705,30 @@ def test_apc_multiple_prompts_partial_cached_outputs(
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
compare_operator: Callable = check_logprobs_close \
|
||||
if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
compare_operator: Callable = (
|
||||
check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore
|
||||
)
|
||||
|
||||
MULTIPLE = 300
|
||||
|
||||
# Sample prompts.
|
||||
generated_prompts = [MULTIPLE * prompt for prompt in example_prompts]
|
||||
|
||||
max_model_len = max(
|
||||
len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
|
||||
vllm_runner_kwargs = _get_vllm_runner_params(
|
||||
model, max_model_len, tensor_parallel_size=tensor_parallel_size)
|
||||
vllm_runner_kwargs['mamba_ssm_cache_dtype'] = "float32"
|
||||
model, max_model_len, tensor_parallel_size=tensor_parallel_size
|
||||
)
|
||||
vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
|
||||
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts, max_tokens,
|
||||
num_logprobs)
|
||||
vllm_outputs_no_cache, _ = _get_vLLM_output(
|
||||
vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
# Cache only part of all the prompts
|
||||
vllm_runner_kwargs['enable_prefix_caching'] = True
|
||||
vllm_runner_kwargs["enable_prefix_caching"] = True
|
||||
vllm_outputs_partial_cache, vllm_model = _get_vLLM_output(
|
||||
vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens,
|
||||
num_logprobs)
|
||||
vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
compare_operator(
|
||||
outputs_0_lst=vllm_outputs_no_cache[0][:3],
|
||||
@@ -726,13 +737,15 @@ def test_apc_multiple_prompts_partial_cached_outputs(
|
||||
name_1="vllm_partial_cache",
|
||||
)
|
||||
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
n_repetitions,
|
||||
vllm_model=vllm_model)
|
||||
vllm_outputs_cache_rep, _ = _get_vLLM_output(
|
||||
vllm_runner,
|
||||
vllm_runner_kwargs,
|
||||
generated_prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
n_repetitions,
|
||||
vllm_model=vllm_model,
|
||||
)
|
||||
|
||||
for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
|
||||
# In the first repetition, the caches are filled
|
||||
|
||||
Reference in New Issue
Block a user