fix(minicpmv): fix audio inference by handling meta device in init_re… (#36751)
Signed-off-by: caitianchi <caitianchi@modelbest.cn>
This commit is contained in:
@@ -1453,10 +1453,11 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
return resampler.to(
|
||||
device=current_platform.device_type, dtype=torch.get_default_dtype()
|
||||
)
|
||||
target_device = current_platform.device_type
|
||||
target_dtype = torch.get_default_dtype()
|
||||
if any(p.is_meta for p in resampler.parameters()):
|
||||
return resampler.to_empty(device=target_device).to(dtype=target_dtype)
|
||||
return resampler.to(device=target_device, dtype=target_dtype)
|
||||
|
||||
def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = data["pixel_values"]
|
||||
@@ -1649,10 +1650,11 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
return resampler.to(
|
||||
device=current_platform.device_type, dtype=torch.get_default_dtype()
|
||||
)
|
||||
target_device = current_platform.device_type
|
||||
target_dtype = torch.get_default_dtype()
|
||||
if any(p.is_meta for p in resampler.parameters()):
|
||||
return resampler.to_empty(device=target_device).to(dtype=target_dtype)
|
||||
return resampler.to(device=target_device, dtype=target_dtype)
|
||||
|
||||
def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||
pixel_values = data["pixel_values"]
|
||||
|
||||
Reference in New Issue
Block a user