[MRV2] Use fp32 for draft logits (#37526)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-03-19 08:41:21 -07:00
committed by GitHub
parent 8b10e4fb31
commit 40b8363b45
2 changed files with 1 additions and 3 deletions

View File

@@ -195,7 +195,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_speculative_steps=self.num_speculative_steps,
vocab_size=self.vocab_size,
device=self.device,
model_dtype=self.dtype,
cache_draft_logits=not use_strict_rejection_sampling,
)
self.input_buffers = InputBuffers(

View File

@@ -15,7 +15,6 @@ class RequestState:
num_speculative_steps: int,
vocab_size: int,
device: torch.device,
model_dtype: torch.dtype,
cache_draft_logits: bool,
):
self.max_num_reqs = max_num_reqs
@@ -81,7 +80,7 @@ class RequestState:
self.max_num_reqs,
self.num_speculative_steps,
self.vocab_size,
dtype=model_dtype,
dtype=torch.float32,
device=device,
)