[Bugfix] Re-enable Gemma3 for V1 (#14980)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -4,7 +4,7 @@ import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, partial
|
||||
from typing import List, Optional, Set, Tuple, TypedDict, Union, cast
|
||||
from typing import List, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -24,7 +24,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
|
||||
SiluAndMul)
|
||||
@@ -42,8 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
@@ -59,6 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import select_patch_features
|
||||
|
||||
# TODO: hard-coded for now. Consider making it configurable.
|
||||
VIT_LAYERS = [-2, -9]
|
||||
@@ -1602,16 +1601,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
if multimodal_embeddings is not None:
|
||||
assert self.img_patch_id is not None
|
||||
|
||||
# Extract the patch tokens scattered in _get_mm_embeds
|
||||
patch_embeddings = json_map_leaves(
|
||||
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
||||
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
||||
)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
cast(NestedTensors, patch_embeddings),
|
||||
select_patch_features(multimodal_embeddings),
|
||||
self.img_patch_id,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
Reference in New Issue
Block a user