[Models] Add remaining model PP support (#7168)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Murali Andoorveedu
2024-10-03 19:56:58 -07:00
committed by GitHub
parent 303d44790a
commit 0f6d7a9a34
69 changed files with 2585 additions and 1344 deletions

View File

@@ -31,15 +31,13 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
@@ -47,7 +45,9 @@ from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_list_of
from .utils import flatten_bn, is_pp_missing_parameter, make_layers
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__)
@@ -568,6 +568,9 @@ class QWenModel(nn.Module):
lambda prefix: QWenBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h")
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.visual = VisionTransformer(**config.visual,
quant_config=quant_config) if hasattr(
config, "visual") else None
@@ -580,7 +583,7 @@ class QWenModel(nn.Module):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
pixel_values: Optional[QwenImageInputs],
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
img_pos = None
# If pixel / visual embeddings are provided, this is a visual model
if pixel_values is not None and self.visual is not None:
@@ -860,7 +863,7 @@ def dummy_data_for_qwen(
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
class QWenLMHeadModel(nn.Module, SupportsMultiModal):
class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(
self,
@@ -881,6 +884,8 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
self.lm_head.weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def _get_image_input_type(
self,
@@ -912,33 +917,26 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
)
return None
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor:
pixel_values = self._get_image_input_type(pixel_values)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
pixel_values: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
pixel_values = None
else:
pixel_values = self._get_image_input_type(pixel_values)
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
pixel_values)
return hidden_states
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def compute_logits(
self,
hidden_states: torch.Tensor,