Added qwen3 vision language moe support for speculative decoding (#32048)
Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> Signed-off-by: shanjiaz <43143795+shanjiaz@users.noreply.github.com>
This commit is contained in:
@@ -110,9 +110,14 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
|||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
hidden_states = intermediate_tensors["hidden_states"]
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
residual = intermediate_tensors["residual"]
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
aux_hidden_states = []
|
||||||
for layer_idx, layer in islice(
|
for layer_idx, layer in islice(
|
||||||
enumerate(self.layers), self.start_layer, self.end_layer
|
enumerate(self.layers), self.start_layer, self.end_layer
|
||||||
):
|
):
|
||||||
|
if layer_idx in self.aux_hidden_state_layers:
|
||||||
|
aux_hidden_states.append(hidden_states + residual)
|
||||||
|
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -132,6 +137,9 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
|||||||
{"hidden_states": hidden_states, "residual": residual}
|
{"hidden_states": hidden_states, "residual": residual}
|
||||||
)
|
)
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
|
||||||
|
if len(aux_hidden_states) > 0:
|
||||||
|
return hidden_states, aux_hidden_states
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def load_fused_expert_weights(
|
def load_fused_expert_weights(
|
||||||
|
|||||||
@@ -112,7 +112,9 @@ class SpecDecodeBaseProposer:
|
|||||||
self.input_ids = torch.zeros(
|
self.input_ids = torch.zeros(
|
||||||
self.max_num_tokens, dtype=torch.int32, device=device
|
self.max_num_tokens, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
self.uses_mrope = self.vllm_config.model_config.uses_mrope
|
# Use draft model's M-RoPE setting, not target model's
|
||||||
|
# Draft models may be text-only even if target is multimodal
|
||||||
|
self.uses_mrope = self.draft_model_config.uses_mrope
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
# NOTE: `mrope_positions` is implemented with one additional dummy
|
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||||
# position on purpose to make it non-contiguous so that it can work
|
# position on purpose to make it non-contiguous so that it can work
|
||||||
@@ -221,6 +223,11 @@ class SpecDecodeBaseProposer:
|
|||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
self.mrope_positions[:, :num_tokens] = positions
|
self.mrope_positions[:, :num_tokens] = positions
|
||||||
else:
|
else:
|
||||||
|
# Convert M-RoPE positions if target model uses M-RoPE
|
||||||
|
# but draft doesn't, For text inputs, all M-RoPE
|
||||||
|
# dimensions are identical
|
||||||
|
if self.vllm_config.model_config.uses_mrope:
|
||||||
|
positions = positions[0]
|
||||||
self.positions[:num_tokens] = positions
|
self.positions[:num_tokens] = positions
|
||||||
|
|
||||||
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
|
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
|
||||||
@@ -1080,6 +1087,7 @@ class SpecDecodeBaseProposer:
|
|||||||
if self.get_model_name(target_model) in [
|
if self.get_model_name(target_model) in [
|
||||||
"Qwen2_5_VLForConditionalGeneration",
|
"Qwen2_5_VLForConditionalGeneration",
|
||||||
"Qwen3VLForConditionalGeneration",
|
"Qwen3VLForConditionalGeneration",
|
||||||
|
"Qwen3VLMoeForConditionalGeneration",
|
||||||
]:
|
]:
|
||||||
self.model.config.image_token_index = target_model.config.image_token_id
|
self.model.config.image_token_index = target_model.config.image_token_id
|
||||||
elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
|
elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
|
||||||
|
|||||||
Reference in New Issue
Block a user