[Model Runner V2] Misc code simplification (#35941)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-03-04 15:26:35 -08:00
committed by GitHub
parent 6c21a0c2d7
commit a3299c3d1d
3 changed files with 7 additions and 22 deletions

View File

@@ -85,7 +85,6 @@ class UvaBackedTensor:
self, size: int | Sequence[int], dtype: torch.dtype, max_concurrency: int = 2
):
self.dtype = dtype
self.max_concurrency = max_concurrency
# Source of truth
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=False)

View File

@@ -96,11 +96,7 @@ logger = init_logger(__name__)
class GPUModelRunner(LoRAModelRunnerMixin):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
@@ -627,9 +623,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_reqs, dtype=torch.int32, device=self.device
)
else:
num_draft_tokens = np.array(
[len(draft_tokens.get(req_id, ())) for req_id in req_ids],
num_draft_tokens = np.fromiter(
(len(draft_tokens.get(req_id, ())) for req_id in req_ids),
dtype=np.int32,
count=num_reqs,
)
total_num_draft_tokens = int(num_draft_tokens.sum())
total_num_logits = num_reqs + total_num_draft_tokens
@@ -782,9 +779,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if input_batch.num_draft_tokens == 0:
# No draft tokens (common case).
num_sampled = torch.ones(
input_batch.num_reqs, dtype=torch.int32, device=self.device
)
num_sampled = input_batch.seq_lens.new_ones(input_batch.num_reqs)
else:
# Rejection sampling for spec decoding.
sampled_tokens, num_sampled = rejection_sample(

View File

@@ -48,17 +48,8 @@ def rejection_sample(
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = torch.empty(
num_reqs,
num_speculative_steps + 1,
dtype=target_sampled.dtype,
device=target_sampled.device,
)
num_sampled = torch.empty(
num_reqs,
dtype=torch.int32,
device=target_sampled.device,
)
sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
num_sampled = cu_num_logits.new_empty(num_reqs)
_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),