Add RADIO Vision Encoder Support to vLLM (#24595)
Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com> Co-authored-by: root <root@cw-dfw-h100-001-305-026.cm.cluster>
This commit is contained in:
@@ -18,8 +18,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from transformers import (AutoModel, BatchEncoding, BatchFeature,
|
||||
PretrainedConfig, TensorType)
|
||||
from transformers import (BatchEncoding, BatchFeature, PretrainedConfig,
|
||||
TensorType)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
@@ -32,6 +32,7 @@ from vllm.model_executor.models.internvl import (calculate_internvl_targets,
|
||||
get_internvl_target_ratios)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
|
||||
from vllm.model_executor.models.radio import RadioModel
|
||||
from vllm.model_executor.models.utils import (flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
@@ -48,6 +49,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.radio import RadioConfig
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
@@ -122,11 +124,6 @@ NanoNemotronVLVideoInputs = Union[NanoNemotronVLVideoPixelInputs,
|
||||
NanoNemotronVLVideoEmbeddingInputs]
|
||||
|
||||
|
||||
def input_conditioner(x, norm_mean, norm_std):
|
||||
y = (x - norm_mean) / norm_std
|
||||
return y
|
||||
|
||||
|
||||
def dynamic_preprocess(image,
|
||||
*,
|
||||
image_size=512,
|
||||
@@ -305,8 +302,7 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
images, max_num_tiles)
|
||||
image_inputs: dict[str, NestedTensors] = {
|
||||
"pixel_values_flat":
|
||||
input_conditioner(torch.cat(pixel_values_lst), self.norm_mean,
|
||||
self.norm_std),
|
||||
torch.cat(pixel_values_lst),
|
||||
"image_num_patches":
|
||||
torch.tensor([len(item) for item in pixel_values_lst]),
|
||||
}
|
||||
@@ -428,8 +424,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
|
||||
video_inputs: dict[str, NestedTensors] = {
|
||||
"pixel_values_flat_video":
|
||||
input_conditioner(torch.cat(pixel_values_lst_video),
|
||||
self.norm_mean, self.norm_std),
|
||||
torch.cat(pixel_values_lst_video),
|
||||
"video_num_patches":
|
||||
torch.tensor([len(item) for item in pixel_values_lst_video]),
|
||||
}
|
||||
@@ -905,18 +900,9 @@ class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
self.vision_model = AutoModel.from_config(config.vision_config,
|
||||
trust_remote_code=True)
|
||||
self.vision_model.model._initialize_weights = (
|
||||
self.vision_model.model._init_weights)
|
||||
# Move input normalization to processor to mirror original HF
|
||||
# implementation where normalization is done in fp32
|
||||
self.vision_model.radio_model.make_preprocessor_external()
|
||||
self.vision_model = self.vision_model.to(
|
||||
self.vision_model = self.get_vit_model_from_radio_config(config).to(
|
||||
self.language_model.config.torch_dtype)
|
||||
|
||||
self.drop_vision_class_token = True
|
||||
|
||||
# Construct the vision projection.
|
||||
vit_hidden_size = config.vit_hidden_size
|
||||
vision_projection_hidden_size = config.projector_hidden_size
|
||||
@@ -972,7 +958,7 @@ class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid,
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values):
|
||||
vit_embeds = self.vision_model(pixel_values).features
|
||||
vit_embeds = self.vision_model(pixel_values)
|
||||
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
|
||||
h = w = int(vit_embeds.shape[1]**0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
@@ -1212,47 +1198,39 @@ class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid,
|
||||
sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
adapter_dict = dict(self.mlp1.named_parameters())
|
||||
|
||||
def is_vision_model_weights(weight: tuple[str, torch.Tensor]):
|
||||
return weight[0].startswith("vision_model")
|
||||
def is_llm(name: str) -> bool:
|
||||
return name.startswith("language_model")
|
||||
|
||||
def is_adapter_weights(weight: tuple[str, torch.Tensor]):
|
||||
return weight[0].startswith("mlp1")
|
||||
|
||||
# Get references to parameters for direct loading
|
||||
vision_model_dict = dict(self.vision_model.named_parameters())
|
||||
vision_model_buffers = dict(self.vision_model.named_buffers())
|
||||
adapter_dict = dict(self.mlp1.named_parameters())
|
||||
def is_vision_weights(name: str) -> bool:
|
||||
return name.startswith("vision_model.radio_model.")
|
||||
|
||||
def llm_weights_generator():
|
||||
# Single pass over weights
|
||||
for name, w in weights:
|
||||
if is_vision_model_weights((name, w)):
|
||||
# Load vision encoder weights directly
|
||||
trimmed_name = ".".join(name.split(".")[1:])
|
||||
if "input_conditioner" in trimmed_name:
|
||||
continue
|
||||
if trimmed_name in vision_model_buffers:
|
||||
param = vision_model_buffers[trimmed_name]
|
||||
else:
|
||||
param = vision_model_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
elif is_adapter_weights((name, w)):
|
||||
# Load vision-language adapter weights directly
|
||||
trimmed_name = ".".join(name.split(".")[1:])
|
||||
param = adapter_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
else:
|
||||
# LLM weights: yield them to be loaded
|
||||
# by language_model.load_weights
|
||||
assert name.startswith("language_model")
|
||||
trimmed_name = ".".join(name.split(".")[1:])
|
||||
yield (trimmed_name, w)
|
||||
# Separate weights by component
|
||||
llm_weights = []
|
||||
vision_weights = []
|
||||
|
||||
# Now we call the language model load with the generator
|
||||
self.language_model.load_weights(llm_weights_generator())
|
||||
for name, w in weights:
|
||||
if is_llm(name):
|
||||
# Strip 'language_model.' prefix for LLM weights
|
||||
llm_weights.append((".".join(name.split(".")[1:]), w))
|
||||
elif is_adapter_weights((name, w)):
|
||||
# Load vision-language adapter weights directly
|
||||
trimmed_name = ".".join(name.split(".")[1:])
|
||||
param = adapter_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
elif is_vision_weights(name):
|
||||
# Convert: vision_model.radio_model.* → radio_model.*
|
||||
hf_key = name[len(
|
||||
"vision_model."):] # Remove "vision_model." prefix
|
||||
vision_weights.append((hf_key, w))
|
||||
|
||||
self.language_model.load_weights(llm_weights)
|
||||
self.vision_model.load_weights(vision_weights)
|
||||
|
||||
def print_architecture(self,
|
||||
detailed: bool = True,
|
||||
@@ -1370,6 +1348,30 @@ class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid,
|
||||
},
|
||||
}
|
||||
|
||||
def get_vit_model_from_radio_config(self, hf_config):
|
||||
hf_config_vision = hf_config.vision_config
|
||||
model_name = hf_config_vision.args.get("model")
|
||||
if model_name is None:
|
||||
raise ValueError(f'Unsupported vit model type: {model_name}')
|
||||
|
||||
preferred_resolution = getattr(hf_config_vision,
|
||||
"preferred_resolution", None)
|
||||
image_size = preferred_resolution[0] if preferred_resolution else 224
|
||||
patch_size = getattr(hf_config_vision, "patch_size", 16)
|
||||
|
||||
radio_config = RadioConfig(
|
||||
model_name=model_name,
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
norm_mean=hf_config.norm_mean,
|
||||
norm_std=hf_config.norm_std,
|
||||
reg_tokens=(hf_config_vision.args.get("register_multiple")
|
||||
if hasattr(hf_config_vision, "args")
|
||||
and isinstance(hf_config_vision.args, dict) else None),
|
||||
)
|
||||
|
||||
return RadioModel(config=radio_config)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.language_model.mamba_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user