[Models] Add remaining model PP support (#7168)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai> Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
committed by
GitHub
parent
303d44790a
commit
0f6d7a9a34
@@ -31,15 +31,13 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
@@ -47,7 +45,9 @@ from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .utils import flatten_bn, is_pp_missing_parameter, make_layers
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -568,6 +568,9 @@ class QWenModel(nn.Module):
|
||||
lambda prefix: QWenBlock(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.h")
|
||||
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
self.visual = VisionTransformer(**config.visual,
|
||||
quant_config=quant_config) if hasattr(
|
||||
config, "visual") else None
|
||||
@@ -580,7 +583,7 @@ class QWenModel(nn.Module):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
pixel_values: Optional[QwenImageInputs],
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
img_pos = None
|
||||
# If pixel / visual embeddings are provided, this is a visual model
|
||||
if pixel_values is not None and self.visual is not None:
|
||||
@@ -860,7 +863,7 @@ def dummy_data_for_qwen(
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
|
||||
class QWenLMHeadModel(nn.Module, SupportsMultiModal):
|
||||
class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -881,6 +884,8 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
|
||||
self.lm_head.weight = self.transformer.wte.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def _get_image_input_type(
|
||||
self,
|
||||
@@ -912,33 +917,26 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
|
||||
)
|
||||
return None
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pixel_values = self._get_image_input_type(pixel_values)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
pixel_values = None
|
||||
else:
|
||||
pixel_values = self._get_image_input_type(pixel_values)
|
||||
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
pixel_values)
|
||||
return hidden_states
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user