[Model] Add JambaForSequenceClassification model (#10860)
Signed-off-by: Yehoshua Cohen <yehoshuaco@ai21.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Yehoshua Cohen <yehoshuaco@ai21.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -476,6 +476,11 @@ Classification (``--task classify``)
|
|||||||
- Example HF Models
|
- Example HF Models
|
||||||
- :ref:`LoRA <lora>`
|
- :ref:`LoRA <lora>`
|
||||||
- :ref:`PP <distributed_serving>`
|
- :ref:`PP <distributed_serving>`
|
||||||
|
* - :code:`JambaForSequenceClassification`
|
||||||
|
- Jamba
|
||||||
|
- :code:`ai21labs/Jamba-tiny-reward-dev`, etc.
|
||||||
|
- ✅︎
|
||||||
|
- ✅︎
|
||||||
* - :code:`Qwen2ForSequenceClassification`
|
* - :code:`Qwen2ForSequenceClassification`
|
||||||
- Qwen2-based
|
- Qwen2-based
|
||||||
- :code:`jason9693/Qwen2.5-1.5B-apeach`, etc.
|
- :code:`jason9693/Qwen2.5-1.5B-apeach`, etc.
|
||||||
|
|||||||
@@ -138,6 +138,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
|||||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
||||||
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
||||||
|
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
|
||||||
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
|
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
|
||||||
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
||||||
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
|
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||||
|
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@@ -24,8 +25,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||||
MambaCacheParams)
|
MambaCacheParams)
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
from vllm.utils import LayerBlockType
|
from vllm.utils import LayerBlockType
|
||||||
|
|
||||||
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
|
||||||
@@ -593,3 +595,35 @@ def _is_moe_layer(name: str):
|
|||||||
"experts",
|
"experts",
|
||||||
"router",
|
"router",
|
||||||
]])
|
]])
|
||||||
|
|
||||||
|
|
||||||
|
class JambaForSequenceClassification(JambaForCausalLM):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
num_labels: int = config.num_labels
|
||||||
|
score_bias: bool = getattr(config, 'score_bias', False)
|
||||||
|
self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias)
|
||||||
|
|
||||||
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
self._pooler = Pooler.from_config_with_defaults(
|
||||||
|
pooler_config,
|
||||||
|
pooling_type=PoolingType.LAST,
|
||||||
|
normalize=False,
|
||||||
|
softmax=False)
|
||||||
|
|
||||||
|
def pooler(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Optional[PoolerOutput]:
|
||||||
|
hidden_states = hidden_states.float()
|
||||||
|
logits = self.score(hidden_states)
|
||||||
|
return self._pooler(logits, pooling_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
# TODO: The reward weights themselves have float32 accuracy data, we
|
||||||
|
# would like to load them in fp32 to get that extra precision.
|
||||||
|
super().load_weights(weights)
|
||||||
|
self.score = self.score.float()
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ _EMBEDDING_MODELS = {
|
|||||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||||
"GritLM": ("gritlm", "GritLM"),
|
"GritLM": ("gritlm", "GritLM"),
|
||||||
|
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
|
||||||
"LlamaModel": ("llama", "LlamaForCausalLM"),
|
"LlamaModel": ("llama", "LlamaForCausalLM"),
|
||||||
**{
|
**{
|
||||||
# Multiple models share the same architecture, so we include them all
|
# Multiple models share the same architecture, so we include them all
|
||||||
|
|||||||
@@ -91,6 +91,10 @@ class PoolingModelRunner(
|
|||||||
]
|
]
|
||||||
|
|
||||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||||
|
seqlen_agnostic_kwargs = {
|
||||||
|
"finished_requests_ids": model_input.finished_requests_ids,
|
||||||
|
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||||
|
} if self.has_inner_state else {}
|
||||||
if (self.observability_config is not None
|
if (self.observability_config is not None
|
||||||
and self.observability_config.collect_model_forward_time):
|
and self.observability_config.collect_model_forward_time):
|
||||||
model_forward_start = torch.cuda.Event(enable_timing=True)
|
model_forward_start = torch.cuda.Event(enable_timing=True)
|
||||||
@@ -110,7 +114,8 @@ class PoolingModelRunner(
|
|||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||||
device=self.device),
|
device=self.device),
|
||||||
**cross_enc_kwargs)
|
**cross_enc_kwargs,
|
||||||
|
**seqlen_agnostic_kwargs)
|
||||||
|
|
||||||
if (self.observability_config is not None
|
if (self.observability_config is not None
|
||||||
and self.observability_config.collect_model_forward_time):
|
and self.observability_config.collect_model_forward_time):
|
||||||
|
|||||||
Reference in New Issue
Block a user