diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index a33454005..581664aec 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -792,7 +792,9 @@ class VoxtralEncoderModel(nn.Module): audio_waveforms: torch.Tensor, ) -> torch.Tensor: input_dtype = audio_waveforms.dtype - window = torch.hann_window(self.config.window_size).to(audio_waveforms.device) + window = torch.hann_window( + self.config.window_size, device=audio_waveforms.device + ) stft = torch.stft( audio_waveforms, self.config.window_size,