[Model] Extend collect_children and no_init_weights contexts (#32757)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -58,7 +58,7 @@ from .interfaces import (
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .qwen import QWenBaseModel, QWenModel
|
||||
from .qwen import QWenBaseModel, QWenBlock, QWenModel
|
||||
|
||||
|
||||
class QwenImagePixelInputs(TensorSchema):
|
||||
@@ -757,11 +757,16 @@ class QwenVLForConditionalGeneration(
|
||||
prefix: str = "",
|
||||
transformer_type: type[QwenVLModel] = QwenVLModel,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
transformer_type=transformer_type,
|
||||
)
|
||||
with self._mark_composite_model(
|
||||
vllm_config,
|
||||
language_targets=QWenBlock,
|
||||
tower_targets={"image": VisionTransformer},
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
transformer_type=transformer_type,
|
||||
)
|
||||
|
||||
self.transformer: QwenVLModel
|
||||
|
||||
@@ -795,9 +800,6 @@ class QwenVLForConditionalGeneration(
|
||||
|
||||
return self.transformer.visual(image_input["data"])
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.transformer
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
|
||||
Reference in New Issue
Block a user