Make Qwen3VL compatible with Transformers v5 (#34262)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Harry Mellor
2026-02-11 13:13:23 +01:00
committed by GitHub
parent 05339a7b20
commit 1e9204bff3
2 changed files with 25 additions and 37 deletions

View File

@@ -1112,17 +1112,6 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
} }
) )
class Qwen3LLMModel(Qwen3Model): class Qwen3LLMModel(Qwen3Model):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
vision_config = vllm_config.model_config.hf_config.vision_config
if not get_pp_group().is_first_rank and hasattr(
vision_config, "deepstack_visual_indexes"
):
assert self.start_layer >= len(vision_config.deepstack_visual_indexes), (
"start_layer should be greater than or equal to "
"len(deepstack_visual_indexes)"
)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor | None,
@@ -1178,7 +1167,7 @@ class Qwen3LLMModel(Qwen3Model):
class Qwen3LLMForCausalLM(Qwen3ForCausalLM): class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(Qwen3ForCausalLM, self).__init__() super(Qwen3ForCausalLM, self).__init__()
config = vllm_config.model_config.hf_config.text_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
@@ -1298,7 +1287,18 @@ class Qwen3VLForConditionalGeneration(
with self._mark_language_model(vllm_config): with self._mark_language_model(vllm_config):
self.language_model = Qwen3LLMForCausalLM( self.language_model = Qwen3LLMForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model"),
)
if not get_pp_group().is_first_rank and hasattr(
config.vision_config, "deepstack_visual_indexes"
):
assert self.language_model.start_layer >= len(
config.vision_config.deepstack_visual_indexes
), (
"start_layer should be greater than or equal to "
"len(deepstack_visual_indexes)"
) )
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (

View File

@@ -48,7 +48,6 @@ from vllm.sequence import IntermediateTensors
from .interfaces import MixtureOfExperts from .interfaces import MixtureOfExperts
from .qwen3_moe import ( from .qwen3_moe import (
Qwen3MoeDecoderLayer,
Qwen3MoeForCausalLM, Qwen3MoeForCausalLM,
Qwen3MoeModel, Qwen3MoeModel,
Qwen3MoeSparseMoeBlock, Qwen3MoeSparseMoeBlock,
@@ -83,27 +82,6 @@ class Qwen3VLMoeProcessingInfo(Qwen3VLProcessingInfo):
} }
) )
class Qwen3MoeLLMModel(Qwen3MoeModel): class Qwen3MoeLLMModel(Qwen3MoeModel):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layer_type: type[torch.nn.Module] = Qwen3MoeDecoderLayer,
):
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
decoder_layer_type=decoder_layer_type,
)
vision_config = vllm_config.model_config.hf_config.vision_config
if not get_pp_group().is_first_rank and hasattr(
vision_config, "deepstack_visual_indexes"
):
assert self.start_layer >= len(vision_config.deepstack_visual_indexes), (
"start_layer should be greater than or equal to "
"len(deepstack_visual_indexes)"
)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor | None,
@@ -352,7 +330,7 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM): class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(Qwen3MoeForCausalLM, self).__init__() super(Qwen3MoeForCausalLM, self).__init__()
self.config = vllm_config.model_config.hf_config.text_config self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config self.quant_config = vllm_config.quant_config
self.model = Qwen3MoeLLMModel( self.model = Qwen3MoeLLMModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
@@ -473,10 +451,20 @@ class Qwen3VLMoeForConditionalGeneration(
with self._mark_language_model(vllm_config): with self._mark_language_model(vllm_config):
self.language_model = Qwen3MoeLLMForCausalLM( self.language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config, vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model"), prefix=maybe_prefix(prefix, "language_model"),
) )
if not get_pp_group().is_first_rank and hasattr(
config.vision_config, "deepstack_visual_indexes"
):
assert self.language_model.start_layer >= len(
config.vision_config.deepstack_visual_indexes
), (
"start_layer should be greater than or equal to "
"len(deepstack_visual_indexes)"
)
# Whether to include the gate_up_proj mapping is determined by # Whether to include the gate_up_proj mapping is determined by
# the language model. # the language model.
self.packed_modules_mapping = ( self.packed_modules_mapping = (