diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index d962ab8bb..2db8ce2bd 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -278,21 +278,35 @@ class GemmaRMSNorm(CustomOp): self.variance_epsilon = eps @staticmethod - def forward_static( + def _forward_static_no_residual( weight: torch.Tensor, variance_epsilon: float, x: torch.Tensor, - residual: torch.Tensor | None, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - """PyTorch-native implementation equivalent to forward().""" + ) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward() without residual.""" orig_dtype = x.dtype - if residual is not None: - x = ( - x.float() + residual.float() - if orig_dtype == torch.float16 - else x + residual - ) - residual = x + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + variance_epsilon) + x = x * (1.0 + weight.float()) + x = x.to(orig_dtype) + return x + + @staticmethod + def _forward_static_with_residual( + weight: torch.Tensor, + variance_epsilon: float, + x: torch.Tensor, + residual: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward() with residual.""" + orig_dtype = x.dtype + x = ( + x.float() + residual.float() + if orig_dtype == torch.float16 + else x + residual + ) + residual = x x = x.float() variance = x.pow(2).mean(dim=-1, keepdim=True) @@ -301,7 +315,7 @@ class GemmaRMSNorm(CustomOp): # See https://github.com/huggingface/transformers/pull/29402 x = x * (1.0 + weight.float()) x = x.to(orig_dtype) - return x if residual is None else (x, residual) + return x, residual def forward_native( self, @@ -309,7 +323,14 @@ class GemmaRMSNorm(CustomOp): residual: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" - return self.forward_static(self.weight.data, self.variance_epsilon, x, residual) + if residual is None: + return self._forward_static_no_residual( + self.weight.data, self.variance_epsilon, x + ) + else: + return self._forward_static_with_residual( + self.weight.data, self.variance_epsilon, x, residual + ) def forward_cuda( self, @@ -320,8 +341,11 @@ class GemmaRMSNorm(CustomOp): return self.forward_native(x, residual) if not getattr(self, "_is_compiled", False): - self.forward_static = torch.compile( # type: ignore - self.forward_static + self._forward_static_no_residual = torch.compile( # type: ignore + self._forward_static_no_residual + ) + self._forward_static_with_residual = torch.compile( # type: ignore + self._forward_static_with_residual ) self._is_compiled = True return self.forward_native(x, residual) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index c36090e8f..a67f1bbe9 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from contextlib import contextmanager from typing import TYPE_CHECKING, Any, TypeVar, cast import torch @@ -373,6 +374,76 @@ class SequenceClassificationConfig(VerifyAndUpdateConfig): text_config.use_sep_token = use_sep_token +def _get_language_model_for_seq_cls(model) -> nn.Module: + """ + Get the language model component for sequence classification conversion. + For VLMs, returns the inner language model. For standard LLMs, returns model itself. + """ + if supports_multimodal(model): + try: + lm = model.get_language_model() + if lm is not model: + return lm + except Exception: + pass + + for attr_name in ("language_model", "lm", "text_model"): + if hasattr(model, attr_name): + candidate = getattr(model, attr_name) + if ( + isinstance(candidate, nn.Module) + and candidate is not model + and hasattr(candidate, "model") + ): + return candidate + + for name, child in model.named_children(): + child_name = type(child).__name__ + if ("ForCausalLM" in child_name or "LMHead" in child_name) and hasattr( + child, "model" + ): + return child + + return model + + +@contextmanager +def _disable_seq_cls_loading_on_inner_model(language_model, is_vlm: bool): + """ + Context manager to temporarily disable sequence classification loading + on inner VLM models to prevent recursive seq_cls_model_loader calls. + """ + if not is_vlm: + yield + return + + inner_hf_config = getattr(language_model, "config", None) + if inner_hf_config is None: + yield + return + + inner_text_config = inner_hf_config.get_text_config() + original_method = getattr(inner_text_config, "method", None) + original_tokens = getattr(inner_text_config, "classifier_from_token", None) + original_hf_tokens = getattr(inner_hf_config, "classifier_from_token", None) + + try: + if original_method is not None: + inner_text_config.method = None + if original_tokens is not None: + inner_text_config.classifier_from_token = None + if original_hf_tokens is not None: + inner_hf_config.classifier_from_token = None + yield + finally: + if original_method is not None: + inner_text_config.method = original_method + if original_tokens is not None: + inner_text_config.classifier_from_token = original_tokens + if original_hf_tokens is not None: + inner_hf_config.classifier_from_token = original_hf_tokens + + def load_weights_using_from_2_way_softmax( model, weights: Iterable[tuple[str, torch.Tensor]] ): @@ -393,9 +464,9 @@ def load_weights_using_from_2_way_softmax( tokens = cast(list[int], tokens) assert len(tokens) == 2 - language_model = ( - model.get_language_model() if hasattr(model, "get_language_model") else model - ) + language_model = _get_language_model_for_seq_cls(model) + is_vlm = language_model is not model + language_model.lm_head = ParallelLMHead( text_config.vocab_size, text_config.hidden_size, quant_config=quant_config ) @@ -411,12 +482,13 @@ def load_weights_using_from_2_way_softmax( ) language_model.lm_head = language_model.lm_head.tie_weights(embed_tokens) - # ModelForPooling is dynamically defined inside the _create_pooling_model_cls - # function, so we need use this hacky method to obtain it. - pooling_model_cls = next( - x for x in type(model).__mro__ if x.__name__ == "ModelForPooling" - ) - loaded_weights = pooling_model_cls.load_weights(model, weights) + with _disable_seq_cls_loading_on_inner_model(language_model, is_vlm): + # ModelForPooling is dynamically defined inside the _create_pooling_model_cls + # function, so we need use this hacky method to obtain it. + pooling_model_cls = next( + x for x in type(model).__mro__ if x.__name__ == "ModelForPooling" + ) + loaded_weights = pooling_model_cls.load_weights(model, weights) from vllm.tokenizers import get_tokenizer @@ -434,12 +506,15 @@ def load_weights_using_from_2_way_softmax( torch.float32 ) - lm_head_weight.data[[false_id]].to(torch.float32) - param = model.score.weight + score_layer = language_model.score if is_vlm else model.score + param = score_layer.weight weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, score_weight) del language_model.lm_head - loaded_weights.add("score.weight") + + score_weight_name = "language_model.score.weight" if is_vlm else "score.weight" + loaded_weights.add(score_weight_name) lm_head_name = "lm_head.weight" if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): @@ -460,22 +535,30 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te tokens = cast(list[int], tokens) assert len(tokens) > 0 - model.lm_head = ParallelLMHead( + language_model = _get_language_model_for_seq_cls(model) + is_vlm = language_model is not model + + language_model.lm_head = ParallelLMHead( text_config.vocab_size, text_config.hidden_size, quant_config=quant_config ) if text_config.tie_word_embeddings: # embed_tokens is the assumed name for input embeddings. If the model does not # have this attribute, we fall back to get_input_embeddings(), which is used by # the Transformers modeling backend. + text_backbone = language_model.model embed_tokens = ( - model.model.embed_tokens - if hasattr(model.model, "embed_tokens") - else model.model.get_input_embeddings() + text_backbone.embed_tokens + if hasattr(text_backbone, "embed_tokens") + else text_backbone.get_input_embeddings() ) - model.lm_head = model.lm_head.tie_weights(embed_tokens) + language_model.lm_head = language_model.lm_head.tie_weights(embed_tokens) - # Skip ModelForSequenceClassification in MRO to avoid infinite recursion - loaded_weights = type(model).__mro__[1].load_weights(model, weights) + with _disable_seq_cls_loading_on_inner_model(language_model, is_vlm): + pooling_model_cls = next( + x for x in type(model).__mro__ if x.__name__ == "ModelForPooling" + ) + # Skip ModelForSequenceClassification in MRO to avoid infinite recursion + loaded_weights = pooling_model_cls.load_weights(model, weights) from vllm.tokenizers import get_tokenizer @@ -487,15 +570,22 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te ) token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] - score_weight = model.lm_head.weight.data[token_ids] + score_weight = language_model.lm_head.weight.data[token_ids] - param = model.score.weight + score_layer = language_model.score if is_vlm else model.score + param = score_layer.weight weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, score_weight) - del model.lm_head - loaded_weights.add("score.weight") - loaded_weights.discard("lm_head.weight") + del language_model.lm_head + + score_weight_name = "language_model.score.weight" if is_vlm else "score.weight" + loaded_weights.add(score_weight_name) + + lm_head_name = "lm_head.weight" + if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): + lm_head_name = hf_to_vllm_mapper._map_name(lm_head_name) + loaded_weights.discard(lm_head_name) return loaded_weights diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 75a82bfb1..a38b553d8 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -107,7 +107,9 @@ class TritonAttentionMetadata: for r in range_lists ] - return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0) + return torch.nested.nested_tensor( + range_tensors, layout=torch.jagged + ).to_padded_tensor(0) class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):