[VLM] Fix paligemma, fuyu and persimmon with transformers 4.45 : use config.text_config.vocab_size (#8707)

This commit is contained in:
Jani Monoses
2024-09-23 17:43:09 +03:00
committed by GitHub
parent a79e522984
commit f2bd246c17
3 changed files with 9 additions and 8 deletions

View File

@@ -213,10 +213,10 @@ class PersimmonModel(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.vocab_size = config.vocab_size
self.vocab_size = config.text_config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.embed_tokens = VocabParallelEmbedding(
config.text_config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([
PersimmonDecoderLayer(config,
cache_config=cache_config,
@@ -257,14 +257,14 @@ class PersimmonForCausalLM(nn.Module):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.vocab_size = config.text_config.vocab_size
self.model = PersimmonModel(config,
cache_config=cache_config,
quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size,
self.lm_head = ParallelLMHead(config.text_config.vocab_size,
config.hidden_size,
bias=False)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.sampler = Sampler()
def forward(