265 lines
9.9 KiB
Python
265 lines
9.9 KiB
Python
import math
|
|
from typing import Iterable, List, Optional, Set, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.attention import AttentionMetadata
|
|
from vllm.config import VllmConfig
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
|
|
BartParallelLMHead,
|
|
BartScaledWordEmbedding)
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from .utils import AutoWeightsLoader
|
|
|
|
|
|
class Florence2LanguageModel(nn.Module):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.config = config
|
|
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
|
|
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
|
|
self.encoder = BartEncoder(config,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.encoder")
|
|
self.decoder = BartDecoder(config,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.decoder")
|
|
|
|
if self.config.tie_word_embeddings:
|
|
self.encoder.embed_tokens.weight = self.shared.weight
|
|
self.decoder.embed_tokens.weight = self.shared.weight
|
|
|
|
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
|
encoder_input_ids: torch.Tensor,
|
|
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
input_ids
|
|
Indices of *decoder* input sequence tokens in the vocabulary.
|
|
Padding will be ignored by default should you
|
|
provide it.
|
|
positions
|
|
Positions of *decoder* input sequence tokens.
|
|
encoder_input_ids
|
|
Indices of *encoder* input sequence tokens in the vocabulary.
|
|
encoder_positions:
|
|
Positions of *encoder* input sequence tokens.
|
|
kv_caches:
|
|
Layer-wise list of KV cache tensors
|
|
attn_metadata:
|
|
vLLM Attention metadata structure
|
|
Returns:
|
|
Model output torch.Tensor
|
|
"""
|
|
|
|
encoder_hidden_states = None
|
|
|
|
if encoder_input_ids.numel() > 0:
|
|
# Run encoder attention if a non-zero number of encoder tokens
|
|
# are provided as input
|
|
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
|
|
positions=encoder_positions,
|
|
kv_caches=kv_caches,
|
|
attn_metadata=attn_metadata)
|
|
|
|
# decoder outputs consists of
|
|
# (dec_features, past_key_value, dec_hidden, dec_attn)
|
|
decoder_outputs = self.decoder(
|
|
decoder_input_ids=input_ids,
|
|
decoder_positions=positions,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
kv_caches=kv_caches,
|
|
attn_metadata=attn_metadata)
|
|
|
|
return decoder_outputs
|
|
|
|
|
|
class Florence2LanguageForConditionalGeneration(nn.Module):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
self.config = config
|
|
self.model = Florence2LanguageModel(vllm_config=vllm_config,
|
|
prefix=f"{prefix}.model")
|
|
embed_scale = math.sqrt(
|
|
config.d_model) if config.scale_embedding else 1.0
|
|
|
|
self.vocab_size = config.vocab_size
|
|
self.lm_head = BartParallelLMHead(self.vocab_size,
|
|
config.d_model,
|
|
embed_scale=embed_scale)
|
|
|
|
self.logits_processor = LogitsProcessor(self.vocab_size,
|
|
config.vocab_size)
|
|
self.sampler = get_sampler()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
encoder_input_ids: torch.Tensor,
|
|
encoder_positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
input_ids
|
|
torch.Tensor of *decoder* input token ids.
|
|
positions
|
|
torch.Tensor of *decoder* position indices.
|
|
encoder_input_ids
|
|
torch.Tensor of *encoder* input token ids.
|
|
encoder_positions
|
|
torch.Tensor of *encoder* position indices
|
|
kv_caches:
|
|
Layer-wise list of KV cache tensors
|
|
attn_metadata:
|
|
vLLM Attention metadata structure
|
|
Returns:
|
|
Output torch.Tensor
|
|
"""
|
|
return self.model(input_ids, positions, encoder_input_ids,
|
|
encoder_positions, kv_caches, attn_metadata)
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
sampling_metadata)
|
|
return logits
|
|
|
|
def sample(self, logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
|
next_tokens = self.sampler(logits, sampling_metadata)
|
|
return next_tokens
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str,
|
|
torch.Tensor]]) -> Set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
]
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: Set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
if "final_logits_bias" in name:
|
|
continue
|
|
if self.config.tie_word_embeddings and "embed_tokens" in name:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
class Florence2ForConditionalGeneration(nn.Module):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
# TODO(Isotr0py): Add vision backbone
|
|
self.language_model = Florence2LanguageForConditionalGeneration(
|
|
vllm_config=vllm_config.with_hf_config(config.text_config),
|
|
prefix=f"{prefix}.language_model",
|
|
)
|
|
|
|
@property
|
|
def sampler(self):
|
|
return self.language_model.sampler
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
*,
|
|
encoder_input_ids: torch.Tensor,
|
|
encoder_positions: torch.Tensor,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
r"""
|
|
Args:
|
|
input_ids
|
|
torch.Tensor of *decoder* input token ids.
|
|
positions
|
|
torch.Tensor of *decoder* position indices.
|
|
encoder_input_ids
|
|
torch.Tensor of *encoder* input token ids.
|
|
encoder_positions
|
|
torch.Tensor of *encoder* position indices
|
|
kv_caches:
|
|
Layer-wise list of KV cache tensors
|
|
attn_metadata:
|
|
vLLM Attention metadata structure
|
|
Returns:
|
|
Output torch.Tensor
|
|
"""
|
|
return self.language_model(input_ids, positions, encoder_input_ids,
|
|
encoder_positions, kv_caches, attn_metadata)
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
return self.language_model.compute_logits(hidden_states,
|
|
sampling_metadata)
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> SamplerOutput:
|
|
return self.language_model.sample(logits, sampling_metadata)
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str,
|
|
torch.Tensor]]) -> Set[str]:
|
|
skip_prefixes = [
|
|
'image_projection', "vision_tower", "image_proj_norm",
|
|
"image_pos_embed", "visual_temporal_embed"
|
|
]
|
|
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
|
return loader.load_weights(weights)
|