refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)
Signed-off-by: Liao, Wei <wei.liao@intel.com>
This commit is contained in:
@@ -19,7 +19,7 @@ from vllm.v1.sample.rejection_sampler import (
|
||||
from vllm.v1.sample.sampler import Sampler, SamplerOutput
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
DEVICE = current_platform.device_type
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -57,7 +57,7 @@ def create_logits_tensor(
|
||||
will produce desired token ids on argmax"""
|
||||
token_ids = [tokens[:-1] for tokens in output_token_ids]
|
||||
num_total_tokens = sum(len(tokens) for tokens in token_ids)
|
||||
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
|
||||
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE_TYPE)
|
||||
start_loc = 0
|
||||
for tokens in token_ids:
|
||||
for j, token_id in enumerate(tokens):
|
||||
@@ -99,9 +99,9 @@ def create_sampling_metadata(
|
||||
assert output_token_ids
|
||||
assert len(output_token_ids) > 0
|
||||
|
||||
frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE)
|
||||
presence_penalties = torch.tensor(presence_penalties, device=DEVICE)
|
||||
repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE)
|
||||
frequency_penalties = torch.tensor(frequency_penalties, device=DEVICE_TYPE)
|
||||
presence_penalties = torch.tensor(presence_penalties, device=DEVICE_TYPE)
|
||||
repetition_penalties = torch.tensor(repetition_penalties, device=DEVICE_TYPE)
|
||||
else:
|
||||
no_penalties = True
|
||||
frequency_penalties = torch.tensor([])
|
||||
@@ -320,14 +320,27 @@ def test_deterministic_when_seeded(
|
||||
n_rep: int,
|
||||
):
|
||||
num_tokens = batch_size * k
|
||||
draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
|
||||
draft_probs = torch.rand(
|
||||
num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
draft_probs = F.softmax(draft_probs, dim=-1)
|
||||
target_logits = torch.rand_like(draft_probs)
|
||||
bonus_token_ids = torch.randint(
|
||||
low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64, device=DEVICE
|
||||
low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
draft_token_ids = torch.randint(
|
||||
low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device=DEVICE
|
||||
low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
|
||||
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
|
||||
@@ -335,12 +348,12 @@ def test_deterministic_when_seeded(
|
||||
results = []
|
||||
for _ in range(n_rep):
|
||||
seeded_seqs = {
|
||||
i: torch.Generator(device=DEVICE).manual_seed(i)
|
||||
i: torch.Generator(device=DEVICE_TYPE).manual_seed(i)
|
||||
for i in range(batch_size)
|
||||
if seeded_mask[i]
|
||||
}
|
||||
|
||||
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
|
||||
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False, temperature=temperature, generators=seeded_seqs
|
||||
)
|
||||
@@ -387,7 +400,7 @@ def test_rejection_sampling_approximates_target_distribution():
|
||||
much more than the distance improvement between the observed
|
||||
distribution and the random distribution.
|
||||
"""
|
||||
torch.set_default_device(DEVICE)
|
||||
torch.set_default_device(DEVICE_TYPE)
|
||||
vocab_size = 10
|
||||
k = 2
|
||||
num_reference_probs = 100
|
||||
@@ -410,7 +423,7 @@ def test_rejection_sampling_approximates_target_distribution():
|
||||
rej_sample_probs = estimate_rejection_sampling_pdf(
|
||||
draft_probs, target_logits, k, vocab_size, num_samples
|
||||
)
|
||||
rej_sample_probs = rej_sample_probs.to(DEVICE)
|
||||
rej_sample_probs = rej_sample_probs.to(DEVICE_TYPE)
|
||||
|
||||
# Average distance from reference probs.
|
||||
reference_vs_rejsample_dist = (
|
||||
@@ -491,11 +504,11 @@ def estimate_rejection_sampling_pdf(
|
||||
draft_probs = draft_probs.view(num_tokens, vocab_size)
|
||||
|
||||
# Bonus tokens not used but required.
|
||||
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE).repeat(
|
||||
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE_TYPE).repeat(
|
||||
num_samples, 1
|
||||
)
|
||||
|
||||
temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
|
||||
temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE_TYPE)
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False, temperature=temperature
|
||||
)
|
||||
@@ -600,7 +613,7 @@ def _test_masked_logits(
|
||||
|
||||
# Create random draft probabilities.
|
||||
draft_probs = torch.rand(
|
||||
(num_tokens, vocab_size), dtype=torch.float32, device=DEVICE
|
||||
(num_tokens, vocab_size), dtype=torch.float32, device=DEVICE_TYPE
|
||||
)
|
||||
draft_probs = F.softmax(draft_probs, dim=-1)
|
||||
|
||||
@@ -610,7 +623,11 @@ def _test_masked_logits(
|
||||
draft_token_ids = draft_token_ids.tolist()
|
||||
|
||||
# Bonus tokens not used but required
|
||||
bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
|
||||
bonus_token_ids = torch.zeros(
|
||||
(batch_size, 1),
|
||||
dtype=torch.int64,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
|
||||
# Create spec decode metadata
|
||||
spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits)
|
||||
@@ -645,12 +662,13 @@ def test_top_k(rejection_sampler, top_k):
|
||||
|
||||
# Randomly create top-k indices.
|
||||
top_k_indices = [
|
||||
torch.randperm(vocab_size, device=DEVICE)[:top_k] for _ in range(num_tokens)
|
||||
torch.randperm(vocab_size, device=DEVICE_TYPE)[:top_k]
|
||||
for _ in range(num_tokens)
|
||||
]
|
||||
top_k_indices = torch.stack(top_k_indices)
|
||||
|
||||
# Create logits with the uniform distribution.
|
||||
target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE)
|
||||
target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE_TYPE)
|
||||
|
||||
# Increment the logits for top-k indices, a little bit more than the other
|
||||
# ones. If the masking is effective, the non-topk indices will never be
|
||||
@@ -659,11 +677,11 @@ def test_top_k(rejection_sampler, top_k):
|
||||
target_logits[i, top_k_indices[i]] += 0.1
|
||||
|
||||
# Create sampling metadata
|
||||
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
|
||||
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False,
|
||||
temperature=temperature,
|
||||
top_k=torch.tensor([top_k] * batch_size, device=DEVICE, dtype=torch.int64),
|
||||
top_k=torch.tensor([top_k] * batch_size, device=DEVICE_TYPE, dtype=torch.int64),
|
||||
)
|
||||
|
||||
_test_masked_logits(
|
||||
@@ -686,8 +704,8 @@ def test_top_p(rejection_sampler, top_p):
|
||||
num_tokens = batch_size * num_draft_tokens
|
||||
|
||||
# Create logits with the uniform distribution.
|
||||
target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE)
|
||||
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
|
||||
target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE_TYPE)
|
||||
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
|
||||
rescaled_logits = target_logits / temperature
|
||||
|
||||
logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
|
||||
@@ -706,7 +724,11 @@ def test_top_p(rejection_sampler, top_p):
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False,
|
||||
temperature=temperature,
|
||||
top_p=torch.tensor([top_p] * batch_size, device=DEVICE, dtype=torch.float32),
|
||||
top_p=torch.tensor(
|
||||
[top_p] * batch_size,
|
||||
device=DEVICE_TYPE,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
|
||||
_test_masked_logits(
|
||||
@@ -732,7 +754,10 @@ def test_frequency_penalties(rejection_sampler):
|
||||
all_greedy=True,
|
||||
output_token_ids=[[2], [3], [4]],
|
||||
spec_token_ids=spec_tokens,
|
||||
prompt_token_ids=torch.tensor([[5, 6, 7], [6, 7, 8], [7, 8, 9]], device=DEVICE),
|
||||
prompt_token_ids=torch.tensor(
|
||||
[[5, 6, 7], [6, 7, 8], [7, 8, 9]],
|
||||
device=DEVICE_TYPE,
|
||||
),
|
||||
frequency_penalties=[1.5, 1.5, 0.7],
|
||||
presence_penalties=[0.0] * num_requests,
|
||||
repetition_penalties=[1.0] * num_requests,
|
||||
@@ -858,21 +883,26 @@ def test_sample_recovered_tokens(
|
||||
num_tokens = batch_size * max_spec_len
|
||||
|
||||
# Create random draft probabilities.
|
||||
draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
|
||||
draft_probs = torch.rand(
|
||||
num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
draft_probs = F.softmax(draft_probs, dim=-1)
|
||||
|
||||
# Create random target probabilities.
|
||||
target_logits = torch.rand(
|
||||
num_tokens, vocab_size, dtype=torch.float32, device=DEVICE
|
||||
num_tokens, vocab_size, dtype=torch.float32, device=DEVICE_TYPE
|
||||
)
|
||||
target_probs = F.softmax(target_logits, dim=-1)
|
||||
|
||||
# Randomly sample draft token ids from draft probs
|
||||
draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)
|
||||
|
||||
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
|
||||
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE)
|
||||
generators = {
|
||||
i: torch.Generator(device=DEVICE).manual_seed(i) for i in range(batch_size)
|
||||
i: torch.Generator(device=DEVICE_TYPE).manual_seed(i) for i in range(batch_size)
|
||||
}
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False, temperature=temperature, generators=generators
|
||||
@@ -890,7 +920,7 @@ def test_sample_recovered_tokens(
|
||||
None if no_draft_probs else draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device=DEVICE,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
recovered_token_ids = sample_recovered_tokens(
|
||||
max_spec_len,
|
||||
@@ -900,6 +930,6 @@ def test_sample_recovered_tokens(
|
||||
None if no_draft_probs else draft_probs,
|
||||
target_probs,
|
||||
sampling_metadata,
|
||||
device=DEVICE,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
assert torch.equal(recovered_token_ids, ref_recovered_token_ids)
|
||||
|
||||
Reference in New Issue
Block a user