fix float16 support for kimi-vl (#17156)
Co-authored-by: zhouzaida <zhouzaida@msh.team>
This commit is contained in:
@@ -340,8 +340,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
else:
|
||||
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
|
||||
patch_size)
|
||||
# fp32 -> bf16
|
||||
pixel_values = pixel_values.to(torch.bfloat16)
|
||||
pixel_values = pixel_values.to(self.vision_tower.dtype)
|
||||
# image_grid_hws.shape = (N, 2)
|
||||
assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user