[MM] Move Qwen3Omni MRoPE impl to model file (#26608)
Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -499,3 +499,40 @@ def run_dp_sharded_mrope_vision_model(
|
||||
"Found unassigned embeddings"
|
||||
)
|
||||
return out_embeddings
|
||||
|
||||
|
||||
def get_llm_pos_ids_for_vision(
|
||||
start_idx: int,
|
||||
vision_idx: int,
|
||||
spatial_merge_size: int,
|
||||
t_index: list[int],
|
||||
grid_hs: torch.Tensor,
|
||||
grid_ws: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
llm_pos_ids_list = []
|
||||
llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
|
||||
llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(len(t_index), -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(len(t_index), llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
t_index_tensor = (
|
||||
torch.Tensor(t_index)
|
||||
.to(llm_grid_h.device)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.long()
|
||||
.flatten()
|
||||
)
|
||||
_llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
|
||||
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
|
||||
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
|
||||
return llm_pos_ids
|
||||
|
||||
Reference in New Issue
Block a user