Added qwen3 vision language moe support for speculative decoding (#32048)

Signed-off-by: shanjiaz <zsjwpianpian@gmail.com>
Signed-off-by: shanjiaz <43143795+shanjiaz@users.noreply.github.com>
This commit is contained in:
shanjiaz
2026-01-20 22:24:05 -05:00
committed by GitHub
parent 0900cedb3f
commit 7ab80a8e37
2 changed files with 17 additions and 1 deletions

View File

@@ -110,9 +110,14 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
for layer_idx, layer in islice(
enumerate(self.layers), self.start_layer, self.end_layer
):
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(
positions,
hidden_states,
@@ -132,6 +137,9 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def load_fused_expert_weights(

View File

@@ -112,7 +112,9 @@ class SpecDecodeBaseProposer:
self.input_ids = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=device
)
self.uses_mrope = self.vllm_config.model_config.uses_mrope
# Use draft model's M-RoPE setting, not target model's
# Draft models may be text-only even if target is multimodal
self.uses_mrope = self.draft_model_config.uses_mrope
if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
@@ -221,6 +223,11 @@ class SpecDecodeBaseProposer:
if self.uses_mrope:
self.mrope_positions[:, :num_tokens] = positions
else:
# Convert M-RoPE positions if target model uses M-RoPE
# but draft doesn't, For text inputs, all M-RoPE
# dimensions are identical
if self.vllm_config.model_config.uses_mrope:
positions = positions[0]
self.positions[:num_tokens] = positions
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
@@ -1080,6 +1087,7 @@ class SpecDecodeBaseProposer:
if self.get_model_name(target_model) in [
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
]:
self.model.config.image_token_index = target_model.config.image_token_id
elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":