[ci][distributed] fix device count call

[ci][distributed] fix some cuda init that makes it necessary to use spawn (#5991)
This commit is contained in:
youkaichao
2024-06-30 01:06:13 -07:00
committed by GitHub
parent 9d47f64eb6
commit 2be6955a3f
6 changed files with 85 additions and 53 deletions

View File

@@ -96,7 +96,34 @@ def run_test(
"""
model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets]
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model_id,
max_model_len=2048,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=True,
distributed_executor_backend=distributed_executor_backend,
**vlm_config.as_cli_args_dict()) as vllm_model:
# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()`
# we must put it inside the vllm_runner context manager
# i.e. after creating vLLM instance.
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
vllm_image_prompts = [
p.replace("<|image_1|>",
"<|image|>" * vlm_config.image_feature_size + "<s>")
for p in HF_IMAGE_PROMPTS
]
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
images=vllm_images)
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs = {"_attn_implementation": "eager"}
@@ -108,23 +135,6 @@ def run_test(
images=hf_images,
eos_token_id=hf_model.processor.tokenizer.eos_token_id)
vllm_image_prompts = [
p.replace("<|image_1|>",
"<|image|>" * vlm_config.image_feature_size + "<s>")
for p in HF_IMAGE_PROMPTS
]
with vllm_runner(model_id,
max_model_len=2048,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=True,
distributed_executor_backend=distributed_executor_backend,
**vlm_config.as_cli_args_dict()) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
images=vllm_images)
check_outputs_equal(
hf_outputs,
[