[MRV2] Use fp32 for draft logits (#37526)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -195,7 +195,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_speculative_steps=self.num_speculative_steps,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
model_dtype=self.dtype,
|
||||
cache_draft_logits=not use_strict_rejection_sampling,
|
||||
)
|
||||
self.input_buffers = InputBuffers(
|
||||
|
||||
@@ -15,7 +15,6 @@ class RequestState:
|
||||
num_speculative_steps: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
model_dtype: torch.dtype,
|
||||
cache_draft_logits: bool,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
@@ -81,7 +80,7 @@ class RequestState:
|
||||
self.max_num_reqs,
|
||||
self.num_speculative_steps,
|
||||
self.vocab_size,
|
||||
dtype=model_dtype,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user