Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -29,14 +29,11 @@ CUDA_DEVICES = [
|
||||
MAX_NUM_PROMPT_TOKENS = 64
|
||||
|
||||
|
||||
def _compare_objs(obj1,
|
||||
obj2,
|
||||
skip: Sequence = ("logitsprocs", "batch_update_builder")):
|
||||
def _compare_objs(obj1, obj2, skip: Sequence = ("logitsprocs", "batch_update_builder")):
|
||||
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
|
||||
attr_names = set([
|
||||
a[0] for a in attrs
|
||||
if not (a[0].startswith('__') and a[0].endswith('__'))
|
||||
])
|
||||
attr_names = set(
|
||||
[a[0] for a in attrs if not (a[0].startswith("__") and a[0].endswith("__"))]
|
||||
)
|
||||
for attr_name in attr_names:
|
||||
if attr_name in skip:
|
||||
continue
|
||||
@@ -47,7 +44,7 @@ def _compare_objs(obj1,
|
||||
is_same = False
|
||||
if isinstance(a, torch.Tensor):
|
||||
if a.numel() == 0 or b.numel() == 0:
|
||||
is_same = (a.numel() == 0 and b.numel() == 0)
|
||||
is_same = a.numel() == 0 and b.numel() == 0
|
||||
elif torch.allclose(a, b):
|
||||
is_same = True
|
||||
elif isinstance(a, np.ndarray):
|
||||
@@ -64,12 +61,14 @@ def _compare_objs(obj1,
|
||||
is_same = True
|
||||
elif isinstance(a, CpuGpuBuffer):
|
||||
is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu)
|
||||
assert is_same, f"Attribute {attr_name} is different"\
|
||||
f" in {obj1} and {obj2}: {a} != {b}"
|
||||
assert is_same, (
|
||||
f"Attribute {attr_name} is different in {obj1} and {obj2}: {a} != {b}"
|
||||
)
|
||||
|
||||
|
||||
def _remove_requests(input_batch: InputBatch, batch_size: int,
|
||||
reqs: list[CachedRequestState]) -> set[str]:
|
||||
def _remove_requests(
|
||||
input_batch: InputBatch, batch_size: int, reqs: list[CachedRequestState]
|
||||
) -> set[str]:
|
||||
"""
|
||||
Remove some requests randomly from the batch and returns
|
||||
set of request removed
|
||||
@@ -109,10 +108,9 @@ def _construct_expected_sampling_metadata(
|
||||
temperature = [0.0 for _ in range(num_reqs)]
|
||||
min_tokens = {}
|
||||
logit_bias = [None] * num_reqs
|
||||
allowed_token_ids_mask = torch.zeros(num_reqs,
|
||||
VOCAB_SIZE,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
allowed_token_ids_mask = torch.zeros(
|
||||
num_reqs, VOCAB_SIZE, dtype=torch.bool, device=device
|
||||
)
|
||||
bad_words_token_ids = {}
|
||||
for req in reqs:
|
||||
if req.req_id not in req_ids_retained:
|
||||
@@ -120,35 +118,40 @@ def _construct_expected_sampling_metadata(
|
||||
index_in_input_batch = req_id_index_in_input_batch[req.req_id]
|
||||
output_token_ids[index_in_input_batch] = req.output_token_ids
|
||||
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
|
||||
presence_penalties[
|
||||
index_in_input_batch] = req.sampling_params.presence_penalty
|
||||
presence_penalties[index_in_input_batch] = req.sampling_params.presence_penalty
|
||||
frequency_penalties[index_in_input_batch] = (
|
||||
req.sampling_params.frequency_penalty)
|
||||
req.sampling_params.frequency_penalty
|
||||
)
|
||||
repetition_penalties[index_in_input_batch] = (
|
||||
req.sampling_params.repetition_penalty)
|
||||
req.sampling_params.repetition_penalty
|
||||
)
|
||||
top_k[index_in_input_batch] = req.sampling_params.top_k
|
||||
top_p[index_in_input_batch] = req.sampling_params.top_p
|
||||
temperature[index_in_input_batch] = req.sampling_params.temperature
|
||||
min_tokens[index_in_input_batch] = (
|
||||
req.sampling_params.min_tokens,
|
||||
req.sampling_params.all_stop_token_ids)
|
||||
req.sampling_params.all_stop_token_ids,
|
||||
)
|
||||
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
|
||||
if req.sampling_params.allowed_token_ids:
|
||||
allowed_token_ids_mask[index_in_input_batch][
|
||||
req.sampling_params.allowed_token_ids] = True
|
||||
req.sampling_params.allowed_token_ids
|
||||
] = True
|
||||
if req.sampling_params.bad_words_token_ids:
|
||||
bad_words_token_ids[
|
||||
index_in_input_batch] = req.sampling_params.bad_words_token_ids
|
||||
bad_words_token_ids[index_in_input_batch] = (
|
||||
req.sampling_params.bad_words_token_ids
|
||||
)
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=torch.tensor(temperature, dtype=torch.float,
|
||||
device=device),
|
||||
temperature=torch.tensor(temperature, dtype=torch.float, device=device),
|
||||
all_greedy=False,
|
||||
all_random=True,
|
||||
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
|
||||
top_p, dtype=torch.float, device=device),
|
||||
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
|
||||
top_k, dtype=torch.int, device=device),
|
||||
top_p=None
|
||||
if all(x == 1.0 for x in top_p)
|
||||
else torch.tensor(top_p, dtype=torch.float, device=device),
|
||||
top_k=None
|
||||
if all(x == 0 for x in top_k)
|
||||
else torch.tensor(top_k, dtype=torch.int, device=device),
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=make_tensor_with_pad(
|
||||
@@ -157,19 +160,21 @@ def _construct_expected_sampling_metadata(
|
||||
device=torch.device(device),
|
||||
dtype=torch.int64,
|
||||
),
|
||||
frequency_penalties=torch.tensor(frequency_penalties,
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
presence_penalties=torch.tensor(presence_penalties,
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
repetition_penalties=torch.tensor(repetition_penalties,
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
frequency_penalties=torch.tensor(
|
||||
frequency_penalties, dtype=torch.float, device=device
|
||||
),
|
||||
presence_penalties=torch.tensor(
|
||||
presence_penalties, dtype=torch.float, device=device
|
||||
),
|
||||
repetition_penalties=torch.tensor(
|
||||
repetition_penalties, dtype=torch.float, device=device
|
||||
),
|
||||
output_token_ids=output_token_ids,
|
||||
no_penalties=(all(x == 0 for x in presence_penalties)
|
||||
and all(x == 0 for x in frequency_penalties)
|
||||
and all(x == 1 for x in repetition_penalties)),
|
||||
no_penalties=(
|
||||
all(x == 0 for x in presence_penalties)
|
||||
and all(x == 0 for x in frequency_penalties)
|
||||
and all(x == 1 for x in repetition_penalties)
|
||||
),
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=bad_words_token_ids,
|
||||
logitsprocs=LogitsProcessors(),
|
||||
@@ -185,8 +190,7 @@ def _create_sampling_params():
|
||||
frequency_penalty=np.random.uniform(-2.0, 2.0),
|
||||
min_tokens=np.random.randint(1, 10),
|
||||
stop_token_ids=[
|
||||
np.random.randint(0, VOCAB_SIZE)
|
||||
for _ in range(np.random.randint(10))
|
||||
np.random.randint(0, VOCAB_SIZE) for _ in range(np.random.randint(10))
|
||||
],
|
||||
logit_bias={0: np.random.uniform(-3.0, 3.0)},
|
||||
)
|
||||
@@ -207,7 +211,7 @@ def _construct_cached_request_state(req_id_suffix: int):
|
||||
sampling_params=_create_sampling_params(),
|
||||
pooling_params=None,
|
||||
mm_features=[],
|
||||
block_ids=([], ),
|
||||
block_ids=([],),
|
||||
generator=None,
|
||||
num_computed_tokens=len(output_token_ids),
|
||||
output_token_ids=output_token_ids,
|
||||
@@ -262,19 +266,18 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
|
||||
# Create expected output.
|
||||
expected_sampling_metadata = _construct_expected_sampling_metadata(
|
||||
reqs,
|
||||
req_ids_retained,
|
||||
input_batch.req_id_to_index,
|
||||
device=torch.device(device))
|
||||
reqs, req_ids_retained, input_batch.req_id_to_index, device=torch.device(device)
|
||||
)
|
||||
|
||||
def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
|
||||
return (t1 is None
|
||||
and t2 is None) or (t1 is not None and t2 is not None
|
||||
and torch.allclose(t1, t2))
|
||||
return (t1 is None and t2 is None) or (
|
||||
t1 is not None and t2 is not None and torch.allclose(t1, t2)
|
||||
)
|
||||
|
||||
# Assert the actual and expected output.
|
||||
assert torch.allclose(expected_sampling_metadata.temperature,
|
||||
sampling_metadata.temperature)
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.temperature, sampling_metadata.temperature
|
||||
)
|
||||
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
|
||||
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
|
||||
assert torch.allclose(
|
||||
@@ -289,25 +292,29 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
expected_sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
)
|
||||
assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.prompt_token_ids)
|
||||
assert (expected_sampling_metadata.output_token_ids ==
|
||||
sampling_metadata.output_token_ids)
|
||||
assert expected_sampling_metadata.no_penalties == \
|
||||
sampling_metadata.no_penalties
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.prompt_token_ids, sampling_metadata.prompt_token_ids
|
||||
)
|
||||
assert (
|
||||
expected_sampling_metadata.output_token_ids
|
||||
== sampling_metadata.output_token_ids
|
||||
)
|
||||
assert expected_sampling_metadata.no_penalties == sampling_metadata.no_penalties
|
||||
if sampling_metadata.allowed_token_ids_mask:
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.allowed_token_ids_mask,
|
||||
sampling_metadata.allowed_token_ids_mask)
|
||||
assert expected_sampling_metadata.bad_words_token_ids == \
|
||||
sampling_metadata.bad_words_token_ids
|
||||
sampling_metadata.allowed_token_ids_mask,
|
||||
)
|
||||
assert (
|
||||
expected_sampling_metadata.bad_words_token_ids
|
||||
== sampling_metadata.bad_words_token_ids
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@pytest.mark.parametrize("swap_list", [((0, 1), )])
|
||||
def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
swap_list: list):
|
||||
@pytest.mark.parametrize("swap_list", [((0, 1),)])
|
||||
def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: list):
|
||||
"""
|
||||
Tests the logic for managing sampling metadata in the InputBatch.
|
||||
|
||||
@@ -352,8 +359,10 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
|
||||
reordered_reqs = reqs.copy()
|
||||
for swap_pair in swap_list:
|
||||
reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \
|
||||
reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]]
|
||||
reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = (
|
||||
reordered_reqs[swap_pair[1]],
|
||||
reordered_reqs[swap_pair[0]],
|
||||
)
|
||||
input_batch.swap_states(swap_pair[0], swap_pair[1])
|
||||
|
||||
for req_index in range(batch_size):
|
||||
|
||||
@@ -6,20 +6,30 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, VllmConfig, set_current_vllm_config)
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import GiB_bytes, update_environment_variables
|
||||
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
|
||||
get_kv_cache_configs)
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor)
|
||||
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
|
||||
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
KVCacheTensor,
|
||||
)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
@@ -35,8 +45,7 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
||||
"""
|
||||
attn_spec = FullAttentionSpec(
|
||||
block_size=BLOCK_SIZE,
|
||||
num_kv_heads=runner.model_config.get_num_kv_heads(
|
||||
runner.parallel_config),
|
||||
num_kv_heads=runner.model_config.get_num_kv_heads(runner.parallel_config),
|
||||
head_size=runner.model_config.get_head_size(),
|
||||
dtype=runner.kv_cache_dtype,
|
||||
)
|
||||
@@ -58,9 +67,7 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
||||
device=runner.device,
|
||||
pin_memory=runner.pin_memory,
|
||||
vocab_size=runner.model_config.get_vocab_size(),
|
||||
block_sizes=[
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
|
||||
],
|
||||
block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size],
|
||||
)
|
||||
runner.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
@@ -98,8 +105,9 @@ def model_runner():
|
||||
model_config = vllm_config.model_config
|
||||
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
vllm_config.compilation_config.static_forward_context[
|
||||
"layer.0"] = Attention(num_heads, head_size, 0.1)
|
||||
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
|
||||
num_heads, head_size, 0.1
|
||||
)
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
initialize_kv_cache(runner)
|
||||
return runner
|
||||
@@ -120,10 +128,11 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
mm_features=[],
|
||||
sampling_params=SamplingParams(),
|
||||
pooling_params=None,
|
||||
block_ids=([0], ),
|
||||
block_ids=([0],),
|
||||
num_computed_tokens=0,
|
||||
lora_request=None,
|
||||
))
|
||||
)
|
||||
)
|
||||
num_scheduled_tokens[req_id] = 3
|
||||
total_num_scheduled_tokens += num_scheduled_tokens[req_id]
|
||||
|
||||
@@ -150,22 +159,22 @@ def _is_req_added(model_runner, req_id: str) -> bool:
|
||||
return req_id in model_runner.requests
|
||||
|
||||
|
||||
def _is_sampling_metadata_changed(model_runner,
|
||||
sampling_metadata_before: SamplingMetadata):
|
||||
return model_runner.input_batch.sampling_metadata is not (
|
||||
sampling_metadata_before)
|
||||
def _is_sampling_metadata_changed(
|
||||
model_runner, sampling_metadata_before: SamplingMetadata
|
||||
):
|
||||
return model_runner.input_batch.sampling_metadata is not (sampling_metadata_before)
|
||||
|
||||
|
||||
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||||
req_index = model_runner.input_batch.req_id_to_index[req_id]
|
||||
block_table = model_runner.input_batch.block_table[0]
|
||||
req_state = model_runner.requests[req_id]
|
||||
if block_table.num_blocks_per_row[req_index] != len(
|
||||
req_state.block_ids[0]):
|
||||
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids[0]):
|
||||
return False
|
||||
num_blocks = block_table.num_blocks_per_row[req_index]
|
||||
return (block_table.block_table.np[req_index, :num_blocks] ==
|
||||
req_state.block_ids[0]).all()
|
||||
return (
|
||||
block_table.block_table.np[req_index, :num_blocks] == req_state.block_ids[0]
|
||||
).all()
|
||||
|
||||
|
||||
def test_update_states_new_request(model_runner, dist_init):
|
||||
@@ -248,7 +257,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
||||
req_ids=[req_id],
|
||||
resumed_from_preemption=[False],
|
||||
new_token_ids=[[]],
|
||||
new_block_ids=([[0]], ),
|
||||
new_block_ids=([[0]],),
|
||||
num_computed_tokens=[0],
|
||||
num_output_tokens=[0],
|
||||
)
|
||||
@@ -281,46 +290,58 @@ def test_get_nans_in_logits(model_runner, dist_init):
|
||||
scheduler_output = _schedule_new_request(*req_ids)
|
||||
model_runner._update_states(scheduler_output)
|
||||
|
||||
logits = torch.tensor([
|
||||
[1.0, 2.0, 3.0],
|
||||
[3.0, 2.0, 1.0],
|
||||
], device=DEVICE)
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[1.0, 2.0, 3.0],
|
||||
[3.0, 2.0, 1.0],
|
||||
],
|
||||
device=DEVICE,
|
||||
)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {"req_0": 0, "req_1": 0}
|
||||
|
||||
logits = torch.tensor([
|
||||
[1.0, float('nan'), 3.0],
|
||||
[4.0, float('nan'), float('nan')],
|
||||
],
|
||||
device=DEVICE)
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[1.0, float("nan"), 3.0],
|
||||
[4.0, float("nan"), float("nan")],
|
||||
],
|
||||
device=DEVICE,
|
||||
)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {"req_0": 1, "req_1": 2}
|
||||
|
||||
logits = torch.tensor([
|
||||
[1.0, 2.0, 3.0],
|
||||
[4.0, float('nan'), float('nan')],
|
||||
],
|
||||
device=DEVICE)
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[1.0, 2.0, 3.0],
|
||||
[4.0, float("nan"), float("nan")],
|
||||
],
|
||||
device=DEVICE,
|
||||
)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {"req_0": 0, "req_1": 2}
|
||||
|
||||
result = model_runner._get_nans_in_logits(logits=None)
|
||||
assert result == {"req_0": 0, "req_1": 0}
|
||||
|
||||
logits = torch.tensor([
|
||||
[1.0, float('nan'), 3.0],
|
||||
], device=DEVICE)
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[1.0, float("nan"), 3.0],
|
||||
],
|
||||
device=DEVICE,
|
||||
)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {'req_0': 1, 'req_1': 0}
|
||||
assert result == {"req_0": 1, "req_1": 0}
|
||||
|
||||
logits = torch.tensor([
|
||||
[float('nan'), float('nan'), 2.0],
|
||||
[1.0, 2.0, 3.0],
|
||||
[float('nan'), 2.0, 3.0],
|
||||
],
|
||||
device=DEVICE)
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[float("nan"), float("nan"), 2.0],
|
||||
[1.0, 2.0, 3.0],
|
||||
[float("nan"), 2.0, 3.0],
|
||||
],
|
||||
device=DEVICE,
|
||||
)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {'req_0': 2, 'req_1': 0}
|
||||
assert result == {"req_0": 2, "req_1": 0}
|
||||
|
||||
|
||||
def test_update_states_no_changes(model_runner, dist_init):
|
||||
@@ -398,11 +419,13 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
|
||||
def test_kv_cache_stride_order(monkeypatch, model_runner):
|
||||
# This test checks if GPUModelRunner initializes correctly when an attention
|
||||
# backend enforces a non-default KV cache stride order.
|
||||
n_heads = model_runner.model_config.get_num_kv_heads(
|
||||
model_runner.parallel_config)
|
||||
n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config)
|
||||
expected_kv_cache_shape = [
|
||||
2, NUM_BLOCKS, BLOCK_SIZE, n_heads,
|
||||
model_runner.model_config.get_head_size()
|
||||
2,
|
||||
NUM_BLOCKS,
|
||||
BLOCK_SIZE,
|
||||
n_heads,
|
||||
model_runner.model_config.get_head_size(),
|
||||
]
|
||||
# TODO mla test
|
||||
default_stride = tuple(range(5))
|
||||
@@ -415,8 +438,9 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
|
||||
# Patch the attention backend class and re-trigger the KV cache creation
|
||||
for attn_group in model_runner._attn_group_iterator():
|
||||
attn_backend = attn_group.backend
|
||||
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
|
||||
rnd_stride_order)
|
||||
monkeypatch.setattr(
|
||||
attn_backend, "get_kv_cache_stride_order", rnd_stride_order
|
||||
)
|
||||
|
||||
model_runner.attn_groups = []
|
||||
model_runner.kv_caches = []
|
||||
@@ -448,14 +472,13 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
|
||||
model_runner_2.update_config({"load_config": {"load_format": "dummy"}})
|
||||
model_runner_2.load_model() # Initial model loading with dummy weights
|
||||
assert str(model_runner.get_model().state_dict()) != str(
|
||||
model_runner_2.get_model().state_dict())
|
||||
model_runner_2.update_config(
|
||||
{"load_config": {
|
||||
"load_format": original_load_format
|
||||
}})
|
||||
model_runner_2.get_model().state_dict()
|
||||
)
|
||||
model_runner_2.update_config({"load_config": {"load_format": original_load_format}})
|
||||
model_runner_2.reload_weights() # Load real weights inplace
|
||||
assert str(model_runner.get_model().state_dict()) == str(
|
||||
model_runner_2.get_model().state_dict())
|
||||
model_runner_2.get_model().state_dict()
|
||||
)
|
||||
|
||||
|
||||
def test_reload_weights_before_load_model(model_runner):
|
||||
@@ -472,21 +495,19 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
|
||||
fwd_context = {
|
||||
# initialization below will fail because target layer is invalid;
|
||||
# the target layer needs to come before layer 1
|
||||
layer_0:
|
||||
Attention(
|
||||
layer_0: Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
kv_sharing_target_layer_name=layer_1,
|
||||
),
|
||||
layer_1:
|
||||
Attention(
|
||||
layer_1: Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
)
|
||||
),
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
@@ -500,22 +521,20 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
|
||||
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
|
||||
with pytest.raises(ValueError, match=error_msg):
|
||||
fwd_context = {
|
||||
layer_0:
|
||||
Attention(
|
||||
layer_0: Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
),
|
||||
layer_1:
|
||||
Attention(
|
||||
layer_1: Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
# invalid layer: cross_attn.atn doesn't exist!
|
||||
kv_sharing_target_layer_name=invalid_layer,
|
||||
)
|
||||
),
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
@@ -530,21 +549,19 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
|
||||
fwd_context = {
|
||||
# initialization below will fail because target layer is invalid;
|
||||
# the target layer needs to come before layer 1
|
||||
layer_0:
|
||||
Attention(
|
||||
layer_0: Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
),
|
||||
layer_1:
|
||||
Attention(
|
||||
layer_1: Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
kv_sharing_target_layer_name=layer_1,
|
||||
)
|
||||
),
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
@@ -557,20 +574,18 @@ def test_init_kv_cache_without_kv_sharing():
|
||||
vllm_config = get_vllm_config()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
fwd_context = {
|
||||
layer_0:
|
||||
Attention(
|
||||
layer_0: Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
),
|
||||
layer_1:
|
||||
Attention(
|
||||
layer_1: Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
)
|
||||
),
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
@@ -585,15 +600,15 @@ def test_init_kv_cache_without_kv_sharing():
|
||||
available_memory = 20 * GiB_bytes
|
||||
# page size for layer 0's kv_cache_spec is 32KB
|
||||
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
|
||||
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
|
||||
[available_memory])[0]
|
||||
kv_cache_config = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_spec], [available_memory]
|
||||
)[0]
|
||||
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||
assert len(kv_cache_config.kv_cache_tensors) == 2
|
||||
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
|
||||
assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2
|
||||
|
||||
max_context_len =\
|
||||
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||||
max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||||
# max context len with KV sharing should be 2x as large as without
|
||||
assert max_context_len == 1310720
|
||||
|
||||
@@ -601,8 +616,9 @@ def test_init_kv_cache_without_kv_sharing():
|
||||
# this will only allocate 2 block worth of memory (2 * 32kb)
|
||||
kv_cache_config.num_blocks = 1
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
kv_cache_tensor.size = (
|
||||
kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes)
|
||||
kv_cache_tensor.size = kv_cache_spec[
|
||||
kv_cache_tensor.shared_by[0]
|
||||
].page_size_bytes
|
||||
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
@@ -625,21 +641,19 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
vllm_config = get_vllm_config()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
fwd_context = {
|
||||
layer_0:
|
||||
Attention(
|
||||
layer_0: Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_0,
|
||||
),
|
||||
layer_1:
|
||||
Attention(
|
||||
layer_1: Attention(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
prefix=layer_1,
|
||||
kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
|
||||
)
|
||||
),
|
||||
}
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
@@ -657,24 +671,23 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
|
||||
# which is twice as many as without KV sharing
|
||||
num_expected_blocks = 655360 # 20GB / 32KB
|
||||
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
|
||||
[available_memory])[0]
|
||||
kv_cache_config = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_spec], [available_memory]
|
||||
)[0]
|
||||
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||
assert len(kv_cache_config.kv_cache_tensors) == 1
|
||||
# Each layer now has twice the available memory for KV cache
|
||||
# compared to no KV sharing
|
||||
assert kv_cache_config.kv_cache_tensors[0].size == available_memory
|
||||
|
||||
max_context_len =\
|
||||
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||||
max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||||
# max context len with KV sharing should be 2x as large as without
|
||||
assert max_context_len == 2 * 1310720
|
||||
|
||||
# important: override tensor size to prevent large mem alloc during test
|
||||
# this will only allocate 1 block worth of memory (32kb)
|
||||
kv_cache_config.num_blocks = 1
|
||||
kv_cache_config.kv_cache_tensors[0].size =\
|
||||
kv_cache_spec[layer_0].page_size_bytes
|
||||
kv_cache_config.kv_cache_tensors[0].size = kv_cache_spec[layer_0].page_size_bytes
|
||||
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
kv_cache_config_after_init = runner.kv_cache_config
|
||||
@@ -687,30 +700,30 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
# check layer 1 added to kv cache group's layer names
|
||||
assert len(kv_cache_config_after_init.kv_cache_groups) == 1
|
||||
assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2
|
||||
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
|
||||
0] == layer_0
|
||||
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
|
||||
1] == layer_1
|
||||
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[0] == layer_0
|
||||
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[1] == layer_1
|
||||
|
||||
|
||||
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
'''
|
||||
"""
|
||||
The GPU model runner creates different views into the
|
||||
KVCacheTensors for the attention and mamba layers
|
||||
(via _reshape_kv_cache_tensors function). This test verifies
|
||||
that the views are compatible: writing a mamba block
|
||||
will not corrupt an attention block and vice versa
|
||||
'''
|
||||
"""
|
||||
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': "0",
|
||||
'LOCAL_RANK': "0",
|
||||
'WORLD_SIZE': "1",
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': '12345',
|
||||
})
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": "0",
|
||||
"LOCAL_RANK": "0",
|
||||
"WORLD_SIZE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||
torch.set_default_dtype(torch.float16)
|
||||
@@ -751,8 +764,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
fwd_context = {}
|
||||
for key in [layer_0, layer_1]:
|
||||
fwd_context[key] = Attention(
|
||||
num_heads=model_config.get_num_attention_heads(
|
||||
parallel_config),
|
||||
num_heads=model_config.get_num_attention_heads(parallel_config),
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
scale=1.0,
|
||||
@@ -760,13 +772,12 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
)
|
||||
for key in [layer_2, layer_3, layer_4, layer_5]:
|
||||
fwd_context[key] = MambaMixer2(
|
||||
hidden_size = hf_config.hidden_size,
|
||||
ssm_state_size = hf_config.mamba_d_state,
|
||||
conv_kernel_size = hf_config.mamba_d_conv,
|
||||
intermediate_size = hf_config.mamba_expand *\
|
||||
hf_config.hidden_size,
|
||||
use_conv_bias = hf_config.mamba_conv_bias,
|
||||
use_bias = hf_config.mamba_proj_bias,
|
||||
hidden_size=hf_config.hidden_size,
|
||||
ssm_state_size=hf_config.mamba_d_state,
|
||||
conv_kernel_size=hf_config.mamba_d_conv,
|
||||
intermediate_size=hf_config.mamba_expand * hf_config.hidden_size,
|
||||
use_conv_bias=hf_config.mamba_conv_bias,
|
||||
use_bias=hf_config.mamba_proj_bias,
|
||||
n_groups=hf_config.mamba_n_groups,
|
||||
num_heads=hf_config.mamba_n_heads,
|
||||
head_dim=hf_config.mamba_d_head,
|
||||
@@ -781,15 +792,15 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
kv_cache_spec = runner.get_kv_cache_spec()
|
||||
|
||||
available_memory = 5 * GiB_bytes
|
||||
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
|
||||
[available_memory])[0]
|
||||
kv_cache_config = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_spec], [available_memory]
|
||||
)[0]
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
# random partition of blocks
|
||||
@@ -798,7 +809,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
num_blocks = kv_cache_config.num_blocks
|
||||
ind = np.arange(num_blocks)
|
||||
np.random.shuffle(ind)
|
||||
blocks0, blocks1 = ind[:(num_blocks // 2)], ind[(num_blocks // 2):]
|
||||
blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :]
|
||||
|
||||
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
|
||||
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
|
||||
@@ -807,34 +818,40 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
# assert we are using FlashInfer
|
||||
assert attn_shape[0] == num_blocks
|
||||
|
||||
attn_blocks_constant = torch.full((len(blocks0), *attn_shape[1:]),
|
||||
device=DEVICE,
|
||||
fill_value=3.33)
|
||||
conv_blocks_constant = torch.full((len(blocks1), *conv_shape[1:]),
|
||||
device=DEVICE,
|
||||
fill_value=6.66)
|
||||
ssm_blocks_constant = torch.full((len(blocks1), *ssm_shape[1:]),
|
||||
device=DEVICE,
|
||||
fill_value=9.99)
|
||||
attn_blocks_constant = torch.full(
|
||||
(len(blocks0), *attn_shape[1:]), device=DEVICE, fill_value=3.33
|
||||
)
|
||||
conv_blocks_constant = torch.full(
|
||||
(len(blocks1), *conv_shape[1:]), device=DEVICE, fill_value=6.66
|
||||
)
|
||||
ssm_blocks_constant = torch.full(
|
||||
(len(blocks1), *ssm_shape[1:]), device=DEVICE, fill_value=9.99
|
||||
)
|
||||
|
||||
# fill all attention blocks with constant
|
||||
for layer in [layer_0, layer_1]:
|
||||
vllm_ctx[layer].kv_cache[0][
|
||||
blocks0, :] = attn_blocks_constant.detach().clone()
|
||||
vllm_ctx[layer].kv_cache[0][blocks0, :] = (
|
||||
attn_blocks_constant.detach().clone()
|
||||
)
|
||||
|
||||
# fill all mamba blocks with constant
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
vllm_ctx[layer].kv_cache[0][0][
|
||||
blocks1, :] = conv_blocks_constant.detach().clone()
|
||||
vllm_ctx[layer].kv_cache[0][1][
|
||||
blocks1, :] = ssm_blocks_constant.detach().clone()
|
||||
vllm_ctx[layer].kv_cache[0][0][blocks1, :] = (
|
||||
conv_blocks_constant.detach().clone()
|
||||
)
|
||||
vllm_ctx[layer].kv_cache[0][1][blocks1, :] = (
|
||||
ssm_blocks_constant.detach().clone()
|
||||
)
|
||||
|
||||
# verify attention and mamba contents are correct
|
||||
for layer in [layer_0, layer_1]:
|
||||
assert torch.equal(vllm_ctx[layer].kv_cache[0][blocks0, :],
|
||||
attn_blocks_constant)
|
||||
assert torch.equal(
|
||||
vllm_ctx[layer].kv_cache[0][blocks0, :], attn_blocks_constant
|
||||
)
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
assert torch.equal(vllm_ctx[layer].kv_cache[0][0][blocks1, :],
|
||||
conv_blocks_constant)
|
||||
assert torch.equal(vllm_ctx[layer].kv_cache[0][1][blocks1, :],
|
||||
ssm_blocks_constant)
|
||||
assert torch.equal(
|
||||
vllm_ctx[layer].kv_cache[0][0][blocks1, :], conv_blocks_constant
|
||||
)
|
||||
assert torch.equal(
|
||||
vllm_ctx[layer].kv_cache[0][1][blocks1, :], ssm_blocks_constant
|
||||
)
|
||||
|
||||
@@ -10,32 +10,28 @@ def test_bind_kv_cache():
|
||||
from vllm.attention import Attention
|
||||
|
||||
ctx = {
|
||||
'layers.0.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.1.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.2.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.3.self_attn': Attention(32, 128, 0.1),
|
||||
"layers.0.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.1.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.2.self_attn": Attention(32, 128, 0.1),
|
||||
"layers.3.self_attn": Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = {
|
||||
'layers.0.self_attn': torch.zeros((1, )),
|
||||
'layers.1.self_attn': torch.zeros((1, )),
|
||||
'layers.2.self_attn': torch.zeros((1, )),
|
||||
'layers.3.self_attn': torch.zeros((1, )),
|
||||
"layers.0.self_attn": torch.zeros((1,)),
|
||||
"layers.1.self_attn": torch.zeros((1,)),
|
||||
"layers.2.self_attn": torch.zeros((1,)),
|
||||
"layers.3.self_attn": torch.zeros((1,)),
|
||||
}
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.0.self_attn']
|
||||
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.1.self_attn']
|
||||
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.2.self_attn']
|
||||
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.3.self_attn']
|
||||
assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache["layers.0.self_attn"]
|
||||
assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache["layers.1.self_attn"]
|
||||
assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache["layers.2.self_attn"]
|
||||
assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache["layers.3.self_attn"]
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache['layers.0.self_attn']
|
||||
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn']
|
||||
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn']
|
||||
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']
|
||||
assert runner_kv_caches[0] is kv_cache["layers.0.self_attn"]
|
||||
assert runner_kv_caches[1] is kv_cache["layers.1.self_attn"]
|
||||
assert runner_kv_caches[2] is kv_cache["layers.2.self_attn"]
|
||||
assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"]
|
||||
|
||||
|
||||
def test_bind_kv_cache_non_attention():
|
||||
@@ -43,21 +39,19 @@ def test_bind_kv_cache_non_attention():
|
||||
|
||||
# example from Jamba PP=2
|
||||
ctx = {
|
||||
'model.layers.20.attn': Attention(32, 128, 0.1),
|
||||
'model.layers.28.attn': Attention(32, 128, 0.1),
|
||||
"model.layers.20.attn": Attention(32, 128, 0.1),
|
||||
"model.layers.28.attn": Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = {
|
||||
'model.layers.20.attn': torch.zeros((1, )),
|
||||
'model.layers.28.attn': torch.zeros((1, )),
|
||||
"model.layers.20.attn": torch.zeros((1,)),
|
||||
"model.layers.28.attn": torch.zeros((1,)),
|
||||
}
|
||||
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
|
||||
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[
|
||||
'model.layers.20.attn']
|
||||
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[
|
||||
'model.layers.28.attn']
|
||||
assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache["model.layers.20.attn"]
|
||||
assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache["model.layers.28.attn"]
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache['model.layers.20.attn']
|
||||
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn']
|
||||
assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"]
|
||||
assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"]
|
||||
|
||||
@@ -13,8 +13,7 @@ import torch
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import MemorySnapshot
|
||||
from vllm.v1.worker.gpu_worker import (Worker,
|
||||
init_worker_distributed_environment)
|
||||
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
|
||||
|
||||
# Global queue to track operation order across processes
|
||||
_QUEUE: Optional[Queue] = None
|
||||
@@ -28,11 +27,11 @@ def track_operation(operation: str, rank: int):
|
||||
|
||||
def make_operation_tracker(operation_name: str, original_func):
|
||||
"""Create a mock function that tracks when an operation is called.
|
||||
|
||||
|
||||
Args:
|
||||
operation_name: Name to use when tracking this operation
|
||||
original_func: The original function to wrap
|
||||
|
||||
|
||||
Returns:
|
||||
A wrapper function that tracks the operation and calls the original
|
||||
"""
|
||||
@@ -45,8 +44,13 @@ def make_operation_tracker(operation_name: str, original_func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def worker_process(rank: int, world_size: int, distributed_init_method: str,
|
||||
queue: Queue, error_queue: Queue):
|
||||
def worker_process(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
distributed_init_method: str,
|
||||
queue: Queue,
|
||||
error_queue: Queue,
|
||||
):
|
||||
"""Worker process that initializes a GPU worker with proper tracking."""
|
||||
global _QUEUE
|
||||
_QUEUE = queue
|
||||
@@ -58,9 +62,9 @@ def worker_process(rank: int, world_size: int, distributed_init_method: str,
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
# Create vLLM config with small model
|
||||
vllm_config = EngineArgs(model="facebook/opt-125m",
|
||||
tensor_parallel_size=2,
|
||||
load_format="dummy").create_engine_config()
|
||||
vllm_config = EngineArgs(
|
||||
model="facebook/opt-125m", tensor_parallel_size=2, load_format="dummy"
|
||||
).create_engine_config()
|
||||
|
||||
# Create worker
|
||||
worker = Worker(
|
||||
@@ -77,19 +81,22 @@ def worker_process(rank: int, world_size: int, distributed_init_method: str,
|
||||
|
||||
# Apply minimal patches to track operation order
|
||||
init_patch = patch(
|
||||
'vllm.v1.worker.gpu_worker.init_worker_distributed_environment',
|
||||
side_effect=make_operation_tracker("init_distributed",
|
||||
original_init_worker))
|
||||
"vllm.v1.worker.gpu_worker.init_worker_distributed_environment",
|
||||
side_effect=make_operation_tracker(
|
||||
"init_distributed", original_init_worker
|
||||
),
|
||||
)
|
||||
memory_patch = patch.object(
|
||||
MemorySnapshot, '__init__',
|
||||
make_operation_tracker("memory_snapshot",
|
||||
original_memory_snapshot_init))
|
||||
all_reduce_patch = patch('torch.distributed.all_reduce',
|
||||
side_effect=make_operation_tracker(
|
||||
"nccl_all_reduce", original_all_reduce))
|
||||
MemorySnapshot,
|
||||
"__init__",
|
||||
make_operation_tracker("memory_snapshot", original_memory_snapshot_init),
|
||||
)
|
||||
all_reduce_patch = patch(
|
||||
"torch.distributed.all_reduce",
|
||||
side_effect=make_operation_tracker("nccl_all_reduce", original_all_reduce),
|
||||
)
|
||||
|
||||
with init_patch, memory_patch, all_reduce_patch:
|
||||
|
||||
# Initialize device (this is where we test the order)
|
||||
worker.init_device()
|
||||
|
||||
@@ -104,13 +111,14 @@ def worker_process(rank: int, world_size: int, distributed_init_method: str,
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs for tensor parallelism")
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.device_count() < 2, reason="Need at least 2 GPUs for tensor parallelism"
|
||||
)
|
||||
def test_init_distributed_is_called_before_memory_snapshot():
|
||||
"""Test that distributed env is setup before memory snapshot.
|
||||
|
||||
This test makes sure during worker initialization, the initial memory
|
||||
snapshot is taken after distributed env is setup to include all the buffers
|
||||
|
||||
This test makes sure during worker initialization, the initial memory
|
||||
snapshot is taken after distributed env is setup to include all the buffers
|
||||
allocated by distributed env.
|
||||
"""
|
||||
world_size = 2
|
||||
@@ -127,9 +135,16 @@ def test_init_distributed_is_called_before_memory_snapshot():
|
||||
# Start worker processes
|
||||
processes = []
|
||||
for rank in range(world_size):
|
||||
p = ctx.Process(target=worker_process,
|
||||
args=(rank, world_size, distributed_init_method,
|
||||
operation_queue, error_queue))
|
||||
p = ctx.Process(
|
||||
target=worker_process,
|
||||
args=(
|
||||
rank,
|
||||
world_size,
|
||||
distributed_init_method,
|
||||
operation_queue,
|
||||
error_queue,
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
@@ -168,7 +183,8 @@ def test_init_distributed_is_called_before_memory_snapshot():
|
||||
assert init_distributed < nccl_all_reduce < memory_snapshot, (
|
||||
f"Rank {rank}: init_distributed (index {init_distributed}) "
|
||||
f"must happen before nccl_all_reduce (index {nccl_all_reduce}) "
|
||||
f"and memory_snapshot (index {memory_snapshot})")
|
||||
f"and memory_snapshot (index {memory_snapshot})"
|
||||
)
|
||||
|
||||
# Clean up
|
||||
os.unlink(distributed_init_method.replace("file://", ""))
|
||||
|
||||
Reference in New Issue
Block a user