Add RADIO Vision Encoder Support to vLLM (#24595)

Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
Co-authored-by: root <root@cw-dfw-h100-001-305-026.cm.cluster>
This commit is contained in:
danielafrimi
2025-09-17 15:53:30 +03:00
committed by GitHub
parent e120533d7a
commit 252ada5559
5 changed files with 826 additions and 56 deletions

View File

@@ -18,8 +18,8 @@ import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import (AutoModel, BatchEncoding, BatchFeature,
PretrainedConfig, TensorType)
from transformers import (BatchEncoding, BatchFeature, PretrainedConfig,
TensorType)
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import ReLUSquaredActivation
@@ -32,6 +32,7 @@ from vllm.model_executor.models.internvl import (calculate_internvl_targets,
get_internvl_target_ratios)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
from vllm.model_executor.models.radio import RadioModel
from vllm.model_executor.models.utils import (flatten_bn,
init_vllm_registered_model,
maybe_prefix,
@@ -48,6 +49,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -122,11 +124,6 @@ NanoNemotronVLVideoInputs = Union[NanoNemotronVLVideoPixelInputs,
NanoNemotronVLVideoEmbeddingInputs]
def input_conditioner(x, norm_mean, norm_std):
y = (x - norm_mean) / norm_std
return y
def dynamic_preprocess(image,
*,
image_size=512,
@@ -305,8 +302,7 @@ class BaseNanoNemotronVLProcessor(ABC):
images, max_num_tiles)
image_inputs: dict[str, NestedTensors] = {
"pixel_values_flat":
input_conditioner(torch.cat(pixel_values_lst), self.norm_mean,
self.norm_std),
torch.cat(pixel_values_lst),
"image_num_patches":
torch.tensor([len(item) for item in pixel_values_lst]),
}
@@ -428,8 +424,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
video_inputs: dict[str, NestedTensors] = {
"pixel_values_flat_video":
input_conditioner(torch.cat(pixel_values_lst_video),
self.norm_mean, self.norm_std),
torch.cat(pixel_values_lst_video),
"video_num_patches":
torch.tensor([len(item) for item in pixel_values_lst_video]),
}
@@ -905,18 +900,9 @@ class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.vision_model = AutoModel.from_config(config.vision_config,
trust_remote_code=True)
self.vision_model.model._initialize_weights = (
self.vision_model.model._init_weights)
# Move input normalization to processor to mirror original HF
# implementation where normalization is done in fp32
self.vision_model.radio_model.make_preprocessor_external()
self.vision_model = self.vision_model.to(
self.vision_model = self.get_vit_model_from_radio_config(config).to(
self.language_model.config.torch_dtype)
self.drop_vision_class_token = True
# Construct the vision projection.
vit_hidden_size = config.vit_hidden_size
vision_projection_hidden_size = config.projector_hidden_size
@@ -972,7 +958,7 @@ class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid,
return x
def extract_feature(self, pixel_values):
vit_embeds = self.vision_model(pixel_values).features
vit_embeds = self.vision_model(pixel_values)
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
h = w = int(vit_embeds.shape[1]**0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
@@ -1212,47 +1198,39 @@ class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
adapter_dict = dict(self.mlp1.named_parameters())
def is_vision_model_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("vision_model")
def is_llm(name: str) -> bool:
return name.startswith("language_model")
def is_adapter_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("mlp1")
# Get references to parameters for direct loading
vision_model_dict = dict(self.vision_model.named_parameters())
vision_model_buffers = dict(self.vision_model.named_buffers())
adapter_dict = dict(self.mlp1.named_parameters())
def is_vision_weights(name: str) -> bool:
return name.startswith("vision_model.radio_model.")
def llm_weights_generator():
# Single pass over weights
for name, w in weights:
if is_vision_model_weights((name, w)):
# Load vision encoder weights directly
trimmed_name = ".".join(name.split(".")[1:])
if "input_conditioner" in trimmed_name:
continue
if trimmed_name in vision_model_buffers:
param = vision_model_buffers[trimmed_name]
else:
param = vision_model_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_adapter_weights((name, w)):
# Load vision-language adapter weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = adapter_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
else:
# LLM weights: yield them to be loaded
# by language_model.load_weights
assert name.startswith("language_model")
trimmed_name = ".".join(name.split(".")[1:])
yield (trimmed_name, w)
# Separate weights by component
llm_weights = []
vision_weights = []
# Now we call the language model load with the generator
self.language_model.load_weights(llm_weights_generator())
for name, w in weights:
if is_llm(name):
# Strip 'language_model.' prefix for LLM weights
llm_weights.append((".".join(name.split(".")[1:]), w))
elif is_adapter_weights((name, w)):
# Load vision-language adapter weights directly
trimmed_name = ".".join(name.split(".")[1:])
param = adapter_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_vision_weights(name):
# Convert: vision_model.radio_model.* → radio_model.*
hf_key = name[len(
"vision_model."):] # Remove "vision_model." prefix
vision_weights.append((hf_key, w))
self.language_model.load_weights(llm_weights)
self.vision_model.load_weights(vision_weights)
def print_architecture(self,
detailed: bool = True,
@@ -1370,6 +1348,30 @@ class NemotronH_Nano_VL(nn.Module, HasInnerState, IsHybrid,
},
}
def get_vit_model_from_radio_config(self, hf_config):
hf_config_vision = hf_config.vision_config
model_name = hf_config_vision.args.get("model")
if model_name is None:
raise ValueError(f'Unsupported vit model type: {model_name}')
preferred_resolution = getattr(hf_config_vision,
"preferred_resolution", None)
image_size = preferred_resolution[0] if preferred_resolution else 224
patch_size = getattr(hf_config_vision, "patch_size", 16)
radio_config = RadioConfig(
model_name=model_name,
image_size=image_size,
patch_size=patch_size,
norm_mean=hf_config.norm_mean,
norm_std=hf_config.norm_std,
reg_tokens=(hf_config_vision.args.get("register_multiple")
if hasattr(hf_config_vision, "args")
and isinstance(hf_config_vision.args, dict) else None),
)
return RadioModel(config=radio_config)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.language_model.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)