Added qwen3 vision language moe support for speculative decoding (#32048)

Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
Signed-off-by: shanjiaz <43143795+shanjiaz@users.noreply.github.com>
This commit is contained in:
shanjiaz
2026-01-20 22:24:05 -05:00
committed by GitHub
parent 0900cedb3f
commit 7ab80a8e37
2 changed files with 17 additions and 1 deletions

View File

@@ -110,9 +110,14 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
aux_hidden_states = []
for layer_idx, layer in islice( for layer_idx, layer in islice(
enumerate(self.layers), self.start_layer, self.end_layer enumerate(self.layers), self.start_layer, self.end_layer
): ):
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
@@ -132,6 +137,9 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
{"hidden_states": hidden_states, "residual": residual} {"hidden_states": hidden_states, "residual": residual}
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states return hidden_states
def load_fused_expert_weights( def load_fused_expert_weights(

View File

@@ -112,7 +112,9 @@ class SpecDecodeBaseProposer:
self.input_ids = torch.zeros( self.input_ids = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=device self.max_num_tokens, dtype=torch.int32, device=device
) )
self.uses_mrope = self.vllm_config.model_config.uses_mrope # Use draft model's M-RoPE setting, not target model's
# Draft models may be text-only even if target is multimodal
self.uses_mrope = self.draft_model_config.uses_mrope
if self.uses_mrope: if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy # NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work # position on purpose to make it non-contiguous so that it can work
@@ -221,6 +223,11 @@ class SpecDecodeBaseProposer:
if self.uses_mrope: if self.uses_mrope:
self.mrope_positions[:, :num_tokens] = positions self.mrope_positions[:, :num_tokens] = positions
else: else:
# Convert M-RoPE positions if target model uses M-RoPE
# but draft doesn't, For text inputs, all M-RoPE
# dimensions are identical
if self.vllm_config.model_config.uses_mrope:
positions = positions[0]
self.positions[:num_tokens] = positions self.positions[:num_tokens] = positions
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
@@ -1080,6 +1087,7 @@ class SpecDecodeBaseProposer:
if self.get_model_name(target_model) in [ if self.get_model_name(target_model) in [
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration", "Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
]: ]:
self.model.config.image_token_index = target_model.config.image_token_id self.model.config.image_token_index = target_model.config.image_token_id
elif self.get_model_name(target_model) == "PixtralForConditionalGeneration": elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":