[Misc][Refactor] Generalize linear_method to be quant_method (#4373)

This commit is contained in:
Cody Yu
2024-04-26 13:41:14 -07:00
committed by GitHub
parent 603ad84815
commit a62aaf1df5
45 changed files with 759 additions and 713 deletions

View File

@@ -9,8 +9,9 @@ from transformers import CLIPVisionModel, LlavaConfig
from vllm.attention import AttentionMetadata
from vllm.config import VisionLanguageConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -61,7 +62,7 @@ class LlavaForConditionalGeneration(nn.Module):
def __init__(self,
config: "LlavaConfig",
vision_language_config: VisionLanguageConfig,
linear_method: Optional["LinearMethodBase"] = None) -> None:
quant_config: Optional["QuantizationConfig"] = None) -> None:
super().__init__()
self.config = config
@@ -83,8 +84,8 @@ class LlavaForConditionalGeneration(nn.Module):
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.linear_method = linear_method
self.language_model = LlamaModel(config.text_config, linear_method)
self.quant_config = quant_config
self.language_model = LlamaModel(config.text_config, quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,