[5/N] Initialize MM components in context managers (Q-Z) (#32695)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-21 03:10:23 +08:00
committed by GitHub
parent f0feb1cf81
commit 193069d129
9 changed files with 178 additions and 168 deletions

View File

@@ -1321,7 +1321,13 @@ class Qwen3VLForConditionalGeneration(
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
def _get_deepstack_input_embeds(
self,
num_tokens: int,
) -> IntermediateTensors | None:
if not getattr(self, "deepstack_input_embeds", None):
return None # If vision tower is skipped
# get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors(
{
@@ -1333,6 +1339,9 @@ class Qwen3VLForConditionalGeneration(
)
def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return
# set deepstack_input_embeds to buffer
num_tokens = deepstack_input_embeds.size(1)
if num_tokens > self.deepstack_input_embeds[0].size(0):
@@ -1351,6 +1360,9 @@ class Qwen3VLForConditionalGeneration(
)
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return
# clear deepstack_input_embeds in buffer
if num_tokens > 0:
for idx in range(self.deepstack_num_level):
@@ -2037,11 +2049,7 @@ class Qwen3VLForConditionalGeneration(
if intermediate_tensors is not None:
inputs_embeds = None
if (
self.use_deepstack
and inputs_embeds is not None
and get_pp_group().is_first_rank
):
if inputs_embeds is not None and get_pp_group().is_first_rank:
deepstack_input_embeds = self._get_deepstack_input_embeds(
inputs_embeds.size(0)
)