[Bugfix] Fix PP for ChatGLM and Molmo (#9422)
This commit is contained in:
@@ -13,8 +13,9 @@ from torch.nn import LayerNorm
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@@ -22,8 +23,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -39,7 +39,9 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -150,6 +152,10 @@ def find_all_positions(input_ids: List[int], target: int) -> List[int]:
|
||||
|
||||
|
||||
def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
hf_config = ctx.get_hf_config(ChatGLMConfig)
|
||||
vision_config = getattr(hf_config, 'vision_config', None)
|
||||
|
||||
@@ -161,8 +167,8 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
input_ids = inputs.get("prompt_token_ids")
|
||||
position_ids = inputs.get("position_ids")
|
||||
input_ids = inputs["prompt_token_ids"]
|
||||
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.model,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code)
|
||||
@@ -171,20 +177,19 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
raw_batch_data = tokenizer.apply_chat_template(
|
||||
conversation=[{
|
||||
"role": "user",
|
||||
"image": inputs['multi_modal_data']["image"],
|
||||
"content": inputs['prompt']
|
||||
"image": multi_modal_data["image"],
|
||||
"content": inputs['prompt'],
|
||||
}],
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True).data
|
||||
return_dict=True,
|
||||
).data
|
||||
except Exception:
|
||||
logger.error("Failed to process content (%s)", inputs['prompt'])
|
||||
raise
|
||||
input_ids = raw_batch_data['input_ids'][0].tolist()
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = list(range(len(input_ids)))
|
||||
boi_token_id = hf_config.boi_token_id
|
||||
eoi_token_id = hf_config.eoi_token_id
|
||||
boi_positions = find_all_positions(input_ids, boi_token_id)
|
||||
@@ -193,7 +198,6 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
assert len(boi_positions) == len(eoi_positions)
|
||||
|
||||
new_input_ids = []
|
||||
new_position_ids = []
|
||||
final_processed_position = 0
|
||||
final_processed_position = 0
|
||||
|
||||
@@ -201,29 +205,28 @@ def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
assert boi_position < eoi_position
|
||||
new_input_ids.extend(input_ids[final_processed_position:boi_position +
|
||||
1])
|
||||
new_position_ids.extend(
|
||||
list(range(final_processed_position, boi_position + 1)))
|
||||
new_input_ids.extend([input_ids[boi_position + 1]] *
|
||||
image_placeholder_length)
|
||||
new_position_ids.extend([boi_position + 1] * image_placeholder_length)
|
||||
final_processed_position = eoi_position
|
||||
|
||||
new_input_ids.extend(input_ids[final_processed_position:])
|
||||
new_position_ids.extend(
|
||||
list(range(final_processed_position, len(input_ids))))
|
||||
|
||||
assert len(new_input_ids) == len(new_position_ids)
|
||||
prompt = inputs.get("prompt")
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(new_input_ids)
|
||||
|
||||
inputs["prompt_token_ids"] = new_input_ids
|
||||
inputs["position_ids"] = new_position_ids
|
||||
return inputs
|
||||
return token_inputs(
|
||||
prompt_token_ids=new_input_ids,
|
||||
prompt=prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
||||
class GLMAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
@@ -314,7 +317,7 @@ class GLMMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -357,7 +360,7 @@ class GLMBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
@@ -428,9 +431,10 @@ class GLMTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.post_layer_norm = config.post_layer_norm
|
||||
@@ -439,10 +443,11 @@ class GLMTransformer(nn.Module):
|
||||
self.num_layers = config.num_layers
|
||||
|
||||
# Transformer layers.
|
||||
self.layers = nn.ModuleList([
|
||||
GLMBlock(config, cache_config, quant_config)
|
||||
for i in range(self.num_layers)
|
||||
])
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
self.num_layers,
|
||||
lambda prefix: GLMBlock(config, cache_config, quant_config),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
if self.post_layer_norm:
|
||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||
@@ -450,6 +455,10 @@ class GLMTransformer(nn.Module):
|
||||
self.final_layernorm = layer_norm_func(
|
||||
config.hidden_size, eps=config.layernorm_epsilon)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -457,16 +466,16 @@ class GLMTransformer(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
for i in range(self.num_layers):
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
kv_cache=kv_caches[i],
|
||||
kv_cache=kv_caches[i - self.start_layer],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
# Final layer norm.
|
||||
if self.post_layer_norm:
|
||||
if get_pp_group().is_last_rank and self.post_layer_norm:
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@@ -476,7 +485,7 @@ class ChatGLMModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
@@ -504,6 +513,9 @@ class ChatGLMModel(nn.Module):
|
||||
else:
|
||||
self.vision = None
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.encoder.make_empty_intermediate_tensors)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> GLMImagePixelInputs:
|
||||
|
||||
@@ -529,24 +541,26 @@ class ChatGLMModel(nn.Module):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
if intermediate_tensors is None:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input["pixel_values"] is not None:
|
||||
pixel_values = image_input["pixel_values"].to(
|
||||
dtype=inputs_embeds.dtype)
|
||||
image_embeds = self.vision(pixel_values)
|
||||
|
||||
if image_input["pixel_values"] is not None:
|
||||
pixel_values = image_input["pixel_values"].to(
|
||||
dtype=inputs_embeds.dtype)
|
||||
image_embeds = self.vision(pixel_values)
|
||||
boi_token_id = self.config.boi_token_id
|
||||
eoi_token_id = self.config.eoi_token_id
|
||||
|
||||
boi_token_id = self.config.boi_token_id
|
||||
eoi_token_id = self.config.eoi_token_id
|
||||
|
||||
inputs_embeds = merge_glm_vision_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
vision_embeddings=image_embeds,
|
||||
boi_token_id=boi_token_id,
|
||||
eoi_token_id=eoi_token_id)
|
||||
inputs_embeds = merge_glm_vision_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
vision_embeddings=image_embeds,
|
||||
boi_token_id=boi_token_id,
|
||||
eoi_token_id=eoi_token_id)
|
||||
else:
|
||||
inputs_embeds = intermediate_tensors["hidden_states"]
|
||||
|
||||
# Run encoder.
|
||||
hidden_states = self.encoder(
|
||||
@@ -555,6 +569,9 @@ class ChatGLMModel(nn.Module):
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -562,7 +579,8 @@ class ChatGLMModel(nn.Module):
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
|
||||
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
SupportsMultiModal):
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||
@@ -610,7 +628,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, **kwargs)
|
||||
attn_metadata, intermediate_tensors,
|
||||
**kwargs)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@@ -656,6 +675,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
Reference in New Issue
Block a user