[Model][QwenVL] Simplify cos/sin rotary embedding indexing (#28962)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
@@ -428,13 +428,8 @@ class Qwen3Omni_VisionTransformer(nn.Module):
|
||||
# Use pre-computed cos_sin_cache from RotaryEmbedding
|
||||
cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size)
|
||||
|
||||
cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2)
|
||||
cos_w = cos[pos_ids[:, 1]]
|
||||
sin_h = sin[pos_ids[:, 0]]
|
||||
sin_w = sin[pos_ids[:, 1]]
|
||||
|
||||
cos_combined = torch.cat([cos_h, cos_w], dim=-1)
|
||||
sin_combined = torch.cat([sin_h, sin_w], dim=-1)
|
||||
cos_combined = cos[pos_ids].flatten(1)
|
||||
sin_combined = sin[pos_ids].flatten(1)
|
||||
|
||||
return cos_combined, sin_combined
|
||||
|
||||
|
||||
Reference in New Issue
Block a user