Added qwen3 vision language moe support for speculative decoding (#32048)
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> Signed-off-by: shanjiaz <43143795+shanjiaz@users.noreply.github.com>
This commit is contained in:
@@ -110,9 +110,14 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states = []
|
||||
for layer_idx, layer in islice(
|
||||
enumerate(self.layers), self.start_layer, self.end_layer
|
||||
):
|
||||
if layer_idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
@@ -132,6 +137,9 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if len(aux_hidden_states) > 0:
|
||||
return hidden_states, aux_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def load_fused_expert_weights(
|
||||
|
||||
@@ -112,7 +112,9 @@ class SpecDecodeBaseProposer:
|
||||
self.input_ids = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.int32, device=device
|
||||
)
|
||||
self.uses_mrope = self.vllm_config.model_config.uses_mrope
|
||||
# Use draft model's M-RoPE setting, not target model's
|
||||
# Draft models may be text-only even if target is multimodal
|
||||
self.uses_mrope = self.draft_model_config.uses_mrope
|
||||
if self.uses_mrope:
|
||||
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||
# position on purpose to make it non-contiguous so that it can work
|
||||
@@ -221,6 +223,11 @@ class SpecDecodeBaseProposer:
|
||||
if self.uses_mrope:
|
||||
self.mrope_positions[:, :num_tokens] = positions
|
||||
else:
|
||||
# Convert M-RoPE positions if target model uses M-RoPE
|
||||
# but draft doesn't, For text inputs, all M-RoPE
|
||||
# dimensions are identical
|
||||
if self.vllm_config.model_config.uses_mrope:
|
||||
positions = positions[0]
|
||||
self.positions[:num_tokens] = positions
|
||||
|
||||
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
|
||||
@@ -1080,6 +1087,7 @@ class SpecDecodeBaseProposer:
|
||||
if self.get_model_name(target_model) in [
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"Qwen3VLForConditionalGeneration",
|
||||
"Qwen3VLMoeForConditionalGeneration",
|
||||
]:
|
||||
self.model.config.image_token_index = target_model.config.image_token_id
|
||||
elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
|
||||
|
||||
Reference in New Issue
Block a user