[Model] Consolidate pooler implementations (#20927)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -19,7 +19,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, PoolingType,
|
||||
SimplePooler)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
@@ -564,29 +565,41 @@ class JambaForSequenceClassification(JambaForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
num_labels: int = config.num_labels
|
||||
score_bias: bool = getattr(config, 'score_bias', False)
|
||||
self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias)
|
||||
|
||||
# TODO: The original reward weights have float32 accuracy data, we
|
||||
# would like to load them in fp32 to get that extra precision.
|
||||
# Currently weight_loader passes the weight which is already in bf16
|
||||
self.score = nn.Linear(
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
bias=score_bias,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
assert pooler_config is not None
|
||||
|
||||
pooler = SimplePooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
softmax=False,
|
||||
)
|
||||
|
||||
self._pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=pooler.pooling,
|
||||
classifier=self.score,
|
||||
act_fn=pooler.head.activation,
|
||||
)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
hidden_states = hidden_states.float()
|
||||
logits = self.score(hidden_states)
|
||||
return self._pooler(logits, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
# TODO: The reward weights themselves have float32 accuracy data, we
|
||||
# would like to load them in fp32 to get that extra precision.
|
||||
super().load_weights(weights)
|
||||
self.score = self.score.float()
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
Reference in New Issue
Block a user