[Model] Use context managers for encoder- and LM-only mode (#32605)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-20 11:43:38 +08:00
committed by GitHub
parent 6c01ffb897
commit 4753f3bf69
21 changed files with 290 additions and 353 deletions

View File

@@ -1233,9 +1233,7 @@ class Qwen2VLForConditionalGeneration(
self.config = config
self.multimodal_config = multimodal_config
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
@@ -1243,14 +1241,13 @@ class Qwen2VLForConditionalGeneration(
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -1371,9 +1368,6 @@ class Qwen2VLForConditionalGeneration(
return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
@@ -1437,10 +1431,7 @@ class Qwen2VLForConditionalGeneration(
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys: