Support encoder-only models without KV-Cache (#21270)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
committed by
GitHub
parent
f27fdfc3ed
commit
1cd6eaba54
@@ -12,7 +12,6 @@ from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@@ -60,7 +59,6 @@ class BertEmbedding(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -119,7 +117,6 @@ class BertPooler(Pooler):
|
||||
return pooled_output
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@@ -337,6 +334,7 @@ class BertOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class BertModel(nn.Module, SupportsQuant):
|
||||
|
||||
is_pooling_model = True
|
||||
@@ -368,13 +366,9 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
assert hasattr(attn_metadata, "seq_lens_tensor")
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
seq_lens=attn_metadata.seq_lens_tensor,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
hidden_states = self.embeddings(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
return self.encoder(hidden_states)
|
||||
|
||||
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
@@ -447,7 +441,7 @@ class BertPoolingModel(BertModel):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
"""A model that uses Bert to provide embedding functionalities.
|
||||
|
||||
This class encapsulates the BertModel and provides an interface for
|
||||
@@ -474,11 +468,13 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user