[MM][Bugfix] Replace PatchEmbed's conv3d to linear layer (#27418)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -544,3 +544,19 @@ def get_llm_pos_ids_for_vision(
|
||||
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
|
||||
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
|
||||
return llm_pos_ids
|
||||
|
||||
|
||||
# Due to a performance regression with Conv3D in PyTorch2.9, we reshape
|
||||
# Conv3D weights to Linear weights for better performance.
|
||||
# See: https://github.com/vllm-project/vllm/issues/27406
|
||||
# and https://github.com/pytorch/pytorch/issues/166122
|
||||
# FIXME(Isotr0py): Revert the PR introduces this workaround
|
||||
# (https://github.com/vllm-project/vllm/pull/27418),
|
||||
# once the performance issue is resolved in PyTorch.
|
||||
def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride.
|
||||
"""
|
||||
out_channels, in_channels, kt, kh, kw = conv3d_weight.shape
|
||||
linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw)
|
||||
return linear_weight
|
||||
|
||||
Reference in New Issue
Block a user