[Bugfix] Use ReplicatedLinear for SequenceClassification head (#23836)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -248,7 +248,7 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler,
|
||||
DispatchPooler, Pooler,
|
||||
PoolingMethod, PoolingType)
|
||||
@@ -264,10 +264,9 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.score = RowParallelLinear(
|
||||
self.score = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_labels,
|
||||
input_is_parallel=False,
|
||||
bias=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=quant_config,
|
||||
|
||||
Reference in New Issue
Block a user