[Model Runner V2] Misc code simplification (#35941)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -85,7 +85,6 @@ class UvaBackedTensor:
|
||||
self, size: int | Sequence[int], dtype: torch.dtype, max_concurrency: int = 2
|
||||
):
|
||||
self.dtype = dtype
|
||||
self.max_concurrency = max_concurrency
|
||||
|
||||
# Source of truth
|
||||
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=False)
|
||||
|
||||
@@ -96,11 +96,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
@@ -627,9 +623,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_reqs, dtype=torch.int32, device=self.device
|
||||
)
|
||||
else:
|
||||
num_draft_tokens = np.array(
|
||||
[len(draft_tokens.get(req_id, ())) for req_id in req_ids],
|
||||
num_draft_tokens = np.fromiter(
|
||||
(len(draft_tokens.get(req_id, ())) for req_id in req_ids),
|
||||
dtype=np.int32,
|
||||
count=num_reqs,
|
||||
)
|
||||
total_num_draft_tokens = int(num_draft_tokens.sum())
|
||||
total_num_logits = num_reqs + total_num_draft_tokens
|
||||
@@ -782,9 +779,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
if input_batch.num_draft_tokens == 0:
|
||||
# No draft tokens (common case).
|
||||
num_sampled = torch.ones(
|
||||
input_batch.num_reqs, dtype=torch.int32, device=self.device
|
||||
)
|
||||
num_sampled = input_batch.seq_lens.new_ones(input_batch.num_reqs)
|
||||
else:
|
||||
# Rejection sampling for spec decoding.
|
||||
sampled_tokens, num_sampled = rejection_sample(
|
||||
|
||||
@@ -48,17 +48,8 @@ def rejection_sample(
|
||||
num_speculative_steps: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
sampled = torch.empty(
|
||||
num_reqs,
|
||||
num_speculative_steps + 1,
|
||||
dtype=target_sampled.dtype,
|
||||
device=target_sampled.device,
|
||||
)
|
||||
num_sampled = torch.empty(
|
||||
num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=target_sampled.device,
|
||||
)
|
||||
sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
|
||||
num_sampled = cu_num_logits.new_empty(num_reqs)
|
||||
_rejection_sample_kernel[(num_reqs,)](
|
||||
sampled,
|
||||
sampled.stride(0),
|
||||
|
||||
Reference in New Issue
Block a user