refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)

Signed-off-by: Liao, Wei <wei.liao@intel.com>
This commit is contained in:
wliao2
2026-04-02 20:21:47 -07:00
committed by GitHub
parent 4a06e1246e
commit 32e0c0bfa2
28 changed files with 239 additions and 146 deletions

View File

@@ -19,7 +19,7 @@ from vllm.v1.sample.rejection_sampler import (
from vllm.v1.sample.sampler import Sampler, SamplerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
DEVICE = current_platform.device_type
DEVICE_TYPE = current_platform.device_type
@pytest.fixture
@@ -57,7 +57,7 @@ def create_logits_tensor(
will produce desired token ids on argmax"""
token_ids = [tokens[:-1] for tokens in output_token_ids]
num_total_tokens = sum(len(tokens) for tokens in token_ids)
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE_TYPE)
start_loc = 0
for tokens in token_ids:
for j, token_id in enumerate(tokens):
@@ -99,9 +99,9 @@ def create_sampling_metadata(
assert output_token_ids
assert len(output_token_ids) > 0
frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE)
presence_penalties = torch.tensor(presence_penalties, device=DEVICE)
repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE)
frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE_TYPE)
presence_penalties = torch.tensor(presence_penalties, device=DEVICE_TYPE)
repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE_TYPE)
else:
no_penalties = True
frequency_penalties = torch.tensor([])
@@ -320,14 +320,27 @@ def test_deterministic_when_seeded(
n_rep: int,
):
num_tokens = batch_size * k
draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
draft_probs = torch.rand(
num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
)
draft_probs = F.softmax(draft_probs, dim=-1)
target_logits = torch.rand_like(draft_probs)
bonus_token_ids = torch.randint(
low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64, device=DEVICE
low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64,
device=DEVICE_TYPE,
)
draft_token_ids = torch.randint(
low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device=DEVICE
low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device=DEVICE_TYPE,
)
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
@@ -335,12 +348,12 @@ def test_deterministic_when_seeded(
results = []
for _ in range(n_rep):
seeded_seqs = {
i: torch.Generator(device=DEVICE).manual_seed(i)
i: torch.Generator(device=DEVICE_TYPE).manual_seed(i)
for i in range(batch_size)
if seeded_mask[i]
}
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature, generators=seeded_seqs
)
@@ -387,7 +400,7 @@ def test_rejection_sampling_approximates_target_distribution():
much more than the distance improvement between the observed
distribution and the random distribution.
"""
torch.set_default_device(DEVICE)
torch.set_default_device(DEVICE_TYPE)
vocab_size = 10
k = 2
num_reference_probs = 100
@@ -410,7 +423,7 @@ def test_rejection_sampling_approximates_target_distribution():
rej_sample_probs = estimate_rejection_sampling_pdf(
draft_probs, target_logits, k, vocab_size, num_samples
)
rej_sample_probs = rej_sample_probs.to(DEVICE)
rej_sample_probs = rej_sample_probs.to(DEVICE_TYPE)
# Average distance from reference probs.
reference_vs_rejsample_dist = (
@@ -491,11 +504,11 @@ def estimate_rejection_sampling_pdf(
draft_probs = draft_probs.view(num_tokens, vocab_size)
# Bonus tokens not used but required.
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE).repeat(
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE_TYPE).repeat(
num_samples, 1
)
temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE_TYPE)
sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature
)
@@ -600,7 +613,7 @@ def _test_masked_logits(
# Create random draft probabilities.
draft_probs = torch.rand(
(num_tokens, vocab_size), dtype=torch.float32, device=DEVICE
(num_tokens, vocab_size), dtype=torch.float32, device=DEVICE_TYPE
)
draft_probs = F.softmax(draft_probs, dim=-1)
@@ -610,7 +623,11 @@ def _test_masked_logits(
draft_token_ids = draft_token_ids.tolist()
# Bonus tokens not used but required
bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
bonus_token_ids = torch.zeros(
(batch_size, 1),
dtype=torch.int64,
device=DEVICE_TYPE,
)
# Create spec decode metadata
spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits)
@@ -645,12 +662,13 @@ def test_top_k(rejection_sampler, top_k):
# Randomly create top-k indices.
top_k_indices = [
torch.randperm(vocab_size, device=DEVICE)[:top_k] for _ in range(num_tokens)
torch.randperm(vocab_size, device=DEVICE_TYPE)[:top_k]
for _ in range(num_tokens)
]
top_k_indices = torch.stack(top_k_indices)
# Create logits with the uniform distribution.
target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE)
target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE_TYPE)
# Increment the logits for top-k indices, a little bit more than the other
# ones. If the masking is effective, the non-topk indices will never be
@@ -659,11 +677,11 @@ def test_top_k(rejection_sampler, top_k):
target_logits[i, top_k_indices[i]] += 0.1
# Create sampling metadata
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=temperature,
top_k=torch.tensor([top_k] * batch_size, device=DEVICE, dtype=torch.int64),
top_k=torch.tensor([top_k] * batch_size, device=DEVICE_TYPE, dtype=torch.int64),
)
_test_masked_logits(
@@ -686,8 +704,8 @@ def test_top_p(rejection_sampler, top_p):
num_tokens = batch_size * num_draft_tokens
# Create logits with the uniform distribution.
target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE_TYPE)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
rescaled_logits = target_logits / temperature
logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
@@ -706,7 +724,11 @@ def test_top_p(rejection_sampler, top_p):
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=temperature,
top_p=torch.tensor([top_p] * batch_size, device=DEVICE, dtype=torch.float32),
top_p=torch.tensor(
[top_p] * batch_size,
device=DEVICE_TYPE,
dtype=torch.float32,
),
)
_test_masked_logits(
@@ -732,7 +754,10 @@ def test_frequency_penalties(rejection_sampler):
all_greedy=True,
output_token_ids=[[2], [3], [4]],
spec_token_ids=spec_tokens,
prompt_token_ids=torch.tensor([[5, 6, 7], [6, 7, 8], [7, 8, 9]], device=DEVICE),
prompt_token_ids=torch.tensor(
[[5, 6, 7], [6, 7, 8], [7, 8, 9]],
device=DEVICE_TYPE,
),
frequency_penalties=[1.5, 1.5, 0.7],
presence_penalties=[0.0] * num_requests,
repetition_penalties=[1.0] * num_requests,
@@ -858,21 +883,26 @@ def test_sample_recovered_tokens(
num_tokens = batch_size * max_spec_len
# Create random draft probabilities.
draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
draft_probs = torch.rand(
num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE_TYPE,
)
draft_probs = F.softmax(draft_probs, dim=-1)
# Create random target probabilities.
target_logits = torch.rand(
num_tokens, vocab_size, dtype=torch.float32, device=DEVICE
num_tokens, vocab_size, dtype=torch.float32, device=DEVICE_TYPE
)
target_probs = F.softmax(target_logits, dim=-1)
# Randomly sample draft token ids from draft probs
draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
generators = {
i: torch.Generator(device=DEVICE).manual_seed(i) for i in range(batch_size)
i: torch.Generator(device=DEVICE_TYPE).manual_seed(i) for i in range(batch_size)
}
sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature, generators=generators
@@ -890,7 +920,7 @@ def test_sample_recovered_tokens(
None if no_draft_probs else draft_probs,
target_probs,
sampling_metadata,
device=DEVICE,
device=DEVICE_TYPE,
)
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
@@ -900,6 +930,6 @@ def test_sample_recovered_tokens(
None if no_draft_probs else draft_probs,
target_probs,
sampling_metadata,
device=DEVICE,
device=DEVICE_TYPE,
)
assert torch.equal(recovered_token_ids, ref_recovered_token_ids)