[Bugfix] Voxtral prompt/audio placeholder alignment (#34140)
Signed-off-by: Artus KG <artuskg@gmail.com>
This commit is contained in:
committed by
GitHub
parent
eadb4e868b
commit
8fd31f6245
@@ -187,6 +187,7 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
|
|||||||
def get_data_parser(self):
|
def get_data_parser(self):
|
||||||
return MultiModalDataParser(
|
return MultiModalDataParser(
|
||||||
target_sr=self.get_hf_processor().sampling_rate,
|
target_sr=self.get_hf_processor().sampling_rate,
|
||||||
|
target_channels=1,
|
||||||
expected_hidden_size=self._get_expected_hidden_size(),
|
expected_hidden_size=self._get_expected_hidden_size(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -289,10 +290,24 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
|
|||||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
|
||||||
audio_id = processor.audio_token_id
|
audio_id = processor.audio_token_id
|
||||||
|
out_mm_data = out_mm_kwargs.require_data()
|
||||||
|
out_audio_items = out_mm_data.get("audio", [])
|
||||||
|
|
||||||
def get_replacement(item_idx: int):
|
def get_replacement(item_idx: int):
|
||||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
if item_idx < len(out_audio_items):
|
||||||
audio_len = audios.get_audio_length(item_idx)
|
out_audio_data = out_audio_items[item_idx].get_data()
|
||||||
|
audio_arr = out_audio_data["audio_arrays"]
|
||||||
|
if isinstance(audio_arr, (torch.Tensor, np.ndarray)):
|
||||||
|
audio_len = len(audio_arr)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"Unexpected type for audio_arrays in out_mm_kwargs: "
|
||||||
|
f"{type(audio_arr)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback for unexpected processor outputs.
|
||||||
|
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||||
|
audio_len = audios.get_audio_length(item_idx)
|
||||||
|
|
||||||
nb_audio_tokens = processor.get_num_audio_tokens(audio_len)
|
nb_audio_tokens = processor.get_num_audio_tokens(audio_len)
|
||||||
|
|
||||||
@@ -495,7 +510,10 @@ class VoxtralForConditionalGeneration(
|
|||||||
return TokensPrompt(
|
return TokensPrompt(
|
||||||
prompt_token_ids=tokenized.tokens,
|
prompt_token_ids=tokenized.tokens,
|
||||||
multi_modal_data={
|
multi_modal_data={
|
||||||
"audio": (tokenized.audios[0].audio_array, stt_config.sample_rate)
|
"audio": [
|
||||||
|
(audio.audio_array, stt_config.sample_rate)
|
||||||
|
for audio in tokenized.audios
|
||||||
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user