refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)
Signed-off-by: Liao, Wei <wei.liao@intel.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from vllm import LLM, SamplingParams, TokensPrompt
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
||||
@@ -48,6 +49,7 @@ num_accepted_tokens = 1
|
||||
prompt_token_ids: list[int] = []
|
||||
MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
|
||||
BLOCK_SIZE = 560
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
NUM_HIDDEN_LAYERS = 1
|
||||
cur_step_action_idx = 0
|
||||
cur_step_action: StepAction | None = None
|
||||
@@ -71,7 +73,7 @@ def get_fake_sample_fn() -> SamplerOutput:
|
||||
return SamplerOutput(
|
||||
sampled_token_ids=torch.tensor(
|
||||
[[prompt_token_ids[first_token_id_index]]],
|
||||
device="cuda",
|
||||
device=DEVICE_TYPE,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
logprobs_tensors=None,
|
||||
@@ -83,7 +85,9 @@ def get_fake_sample_fn() -> SamplerOutput:
|
||||
sampled_token_ids = accepted_tokens
|
||||
return SamplerOutput(
|
||||
sampled_token_ids=torch.tensor(
|
||||
[sampled_token_ids], device="cuda", dtype=torch.int32
|
||||
[sampled_token_ids],
|
||||
device=DEVICE_TYPE,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
logprobs_tensors=None,
|
||||
)
|
||||
@@ -128,17 +132,23 @@ def get_fake_propose_draft_token_ids_fn():
|
||||
- 1
|
||||
+ num_accepted_tokens
|
||||
],
|
||||
device="cuda",
|
||||
device=DEVICE_TYPE,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
valid_sampled_tokens_count = torch.tensor(
|
||||
[num_accepted_tokens], device="cuda", dtype=torch.int32
|
||||
[num_accepted_tokens],
|
||||
device=DEVICE_TYPE,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
|
||||
|
||||
return torch.tensor(proposed_draft_token_ids, device="cuda", dtype=torch.int32)
|
||||
return torch.tensor(
|
||||
proposed_draft_token_ids,
|
||||
device=DEVICE_TYPE,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
return fake_propose_draft_token_ids_fn
|
||||
|
||||
|
||||
Reference in New Issue
Block a user