[Bugfix] MiDashengLM model contact error under concurrent testing (#24738)
Signed-off-by: chenbing8 <chenbing8@xiaomi.com> Signed-off-by: bingchen-mi <chenbing8@xiaomi.com>
This commit is contained in:
@@ -497,8 +497,11 @@ class MiDashengLMDummyInputsBuilder(
|
|||||||
|
|
||||||
hf_processor = self.info.get_hf_processor()
|
hf_processor = self.info.get_hf_processor()
|
||||||
audio_token = hf_processor.audio_token
|
audio_token = hf_processor.audio_token
|
||||||
|
audio_bos_token = hf_processor.audio_bos_token
|
||||||
|
audio_eos_token = hf_processor.audio_eos_token
|
||||||
|
|
||||||
return audio_token * num_audios
|
single_audio_text = f"{audio_bos_token}{audio_token}{audio_eos_token}"
|
||||||
|
return single_audio_text * num_audios
|
||||||
|
|
||||||
def get_dummy_mm_data(
|
def get_dummy_mm_data(
|
||||||
self,
|
self,
|
||||||
@@ -577,14 +580,7 @@ class MiDashengLMMultiModalProcessor(
|
|||||||
vocab = tokenizer.get_vocab()
|
vocab = tokenizer.get_vocab()
|
||||||
|
|
||||||
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
|
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
|
||||||
audio_bos_token = getattr(processor, "audio_bos_token",
|
|
||||||
"<|audio_bos|>")
|
|
||||||
audio_eos_token = getattr(processor, "audio_eos_token",
|
|
||||||
"<|audio_eos|>")
|
|
||||||
|
|
||||||
audio_token_id = vocab[audio_token]
|
audio_token_id = vocab[audio_token]
|
||||||
audio_bos_id = vocab[audio_bos_token]
|
|
||||||
audio_eos_id = vocab[audio_eos_token]
|
|
||||||
|
|
||||||
out_mm_data = out_mm_kwargs.get_data()
|
out_mm_data = out_mm_kwargs.get_data()
|
||||||
audio_length = out_mm_data.get("audio_length")
|
audio_length = out_mm_data.get("audio_length")
|
||||||
@@ -604,7 +600,7 @@ class MiDashengLMMultiModalProcessor(
|
|||||||
audio_tokens = [audio_token_id] * num_features
|
audio_tokens = [audio_token_id] * num_features
|
||||||
|
|
||||||
return PromptUpdateDetails.select_token_id(
|
return PromptUpdateDetails.select_token_id(
|
||||||
[audio_bos_id] + audio_tokens + [audio_eos_id],
|
audio_tokens,
|
||||||
embed_token_id=audio_token_id,
|
embed_token_id=audio_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -670,7 +666,17 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
f"Got type: {type(mm_input)}")
|
f"Got type: {type(mm_input)}")
|
||||||
if isinstance(mm_input, torch.Tensor):
|
if isinstance(mm_input, torch.Tensor):
|
||||||
return mm_input.reshape(-1, *mm_input.shape[2:])
|
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||||
else:
|
|
||||||
|
if name == "input_values":
|
||||||
|
max_length = max(tensor.shape[1] for tensor in mm_input)
|
||||||
|
padded_mm_input = [
|
||||||
|
torch.nn.functional.pad(tensor,
|
||||||
|
(0, max_length - tensor.shape[1]))
|
||||||
|
if tensor.shape[1] < max_length else tensor
|
||||||
|
for tensor in mm_input
|
||||||
|
]
|
||||||
|
return torch.concat(padded_mm_input)
|
||||||
|
|
||||||
return torch.concat(mm_input)
|
return torch.concat(mm_input)
|
||||||
|
|
||||||
def _parse_and_validate_audio_input(
|
def _parse_and_validate_audio_input(
|
||||||
|
|||||||
Reference in New Issue
Block a user