[Bugfix] Fix prefix strings for quantized VLMs (#9772)

This commit is contained in:
Michael Goin
2024-10-29 19:02:59 -04:00
committed by GitHub
parent 8d7724104a
commit bc73e9821c
20 changed files with 288 additions and 97 deletions

View File

@@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class OPTLearnedPositionalEmbedding(nn.Embedding):
@@ -68,6 +69,7 @@ class OPTAttention(nn.Module):
bias: bool = True,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.embed_dim = embed_dim
@@ -85,18 +87,21 @@ class OPTAttention(nn.Module):
total_num_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
@@ -118,6 +123,7 @@ class OPTDecoderLayer(nn.Module):
config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
@@ -128,6 +134,7 @@ class OPTDecoderLayer(nn.Module):
bias=config.enable_bias,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.do_layer_norm_before = config.do_layer_norm_before
@@ -139,6 +146,7 @@ class OPTDecoderLayer(nn.Module):
config.ffn_dim,
bias=config.enable_bias,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
@@ -147,6 +155,7 @@ class OPTDecoderLayer(nn.Module):
self.embed_dim,
bias=config.enable_bias,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim,
@@ -214,7 +223,8 @@ class OPTDecoder(nn.Module):
self.project_out = ReplicatedLinear(config.hidden_size,
config.word_embed_proj_dim,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.project_out")
else:
self.project_out = None
@@ -222,7 +232,8 @@ class OPTDecoder(nn.Module):
self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
config.hidden_size,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.project_in")
else:
self.project_in = None
@@ -239,7 +250,8 @@ class OPTDecoder(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OPTDecoderLayer(config, cache_config, quant_config),
lambda prefix: OPTDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -288,9 +300,13 @@ class OPTModel(nn.Module):
config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.decoder = OPTDecoder(config, cache_config, quant_config)
self.decoder = OPTDecoder(config,
cache_config,
quant_config,
prefix=f"{prefix}.decoder")
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
@@ -335,11 +351,15 @@ class OPTForCausalLM(nn.Module, SupportsPP):
config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = OPTModel(config, cache_config, quant_config)
self.model = OPTModel(config,
cache_config,
quant_config,
prefix=maybe_prefix(prefix, "model"))
if self.config.tie_word_embeddings:
self.lm_head = self.model.decoder.embed_tokens
else: