[Doc] Update V1 status for decoder-only embedding models (#19952)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -19,24 +19,12 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP, SupportsV0Only
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .qwen2 import Qwen2Model
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
|
||||
class ReLU(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def forward(self, input):
|
||||
input, _ = input
|
||||
return self.activation(input)
|
||||
|
||||
|
||||
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP,
|
||||
SupportsV0Only):
|
||||
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@@ -65,11 +53,13 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP,
|
||||
self.score = nn.Sequential(
|
||||
ColumnParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config),
|
||||
ReLU(),
|
||||
quant_config=quant_config,
|
||||
return_bias=False),
|
||||
nn.ReLU(),
|
||||
RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
quant_config=quant_config),
|
||||
quant_config=quant_config,
|
||||
return_bias=False),
|
||||
)
|
||||
self._pooler: SimplePooler
|
||||
self.make_empty_intermediate_tensors = (
|
||||
@@ -87,7 +77,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
logits, _ = self.score(hidden_states)
|
||||
logits = self.score(hidden_states)
|
||||
return logits
|
||||
|
||||
def pooler(
|
||||
|
||||
Reference in New Issue
Block a user