[Bugfix] Fix prefix strings for quantized VLMs (#9772)

This commit is contained in:
Michael Goin
2024-10-29 19:02:59 -04:00
committed by GitHub
parent 8d7724104a
commit bc73e9821c
20 changed files with 288 additions and 97 deletions

View File

@@ -30,7 +30,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class InternLM2MLP(nn.Module):
@@ -41,16 +42,23 @@ class InternLM2MLP(nn.Module):
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.w2 = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.w2 = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.w2",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@@ -75,6 +83,7 @@ class InternLM2Attention(nn.Module):
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -108,12 +117,14 @@ class InternLM2Attention(nn.Module):
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wqkv",
)
self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wo",
)
self.rotary_emb = get_rope(
@@ -123,12 +134,15 @@ class InternLM2Attention(nn.Module):
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def split_qkv(self, qkv: torch.Tensor):
seq_len = qkv.shape[0]
@@ -176,6 +190,7 @@ class InternLMDecoderLayer(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -192,12 +207,14 @@ class InternLMDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attention",
)
self.feed_forward = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@@ -251,8 +268,8 @@ class InternLM2Model(nn.Module):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: InternLMDecoderLayer(config, cache_config,
quant_config),
lambda prefix: InternLMDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
@@ -306,14 +323,19 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = InternLM2Model(config, cache_config, quant_config)
self.model = InternLM2Model(config,
cache_config,
quant_config,
prefix=maybe_prefix(prefix, "model"))
self.output = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=maybe_prefix(prefix, "output"))
if self.config.tie_word_embeddings:
self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)