[SupportsQuant] Bert, Blip, Blip2, Bloom (#15573)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers
2025-04-03 11:23:19 -04:00
committed by GitHub
parent 84884cd9ac
commit 421c462948
4 changed files with 16 additions and 9 deletions

View File

@@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix
@@ -313,7 +313,8 @@ class BertOutput(nn.Module):
return hidden_states
class BertModel(nn.Module):
class BertModel(nn.Module, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
def __init__(self,
*,
@@ -385,7 +386,7 @@ class BertModel(nn.Module):
return loaded_params
class BertEmbeddingModel(nn.Module, SupportsV0Only):
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
@@ -443,7 +444,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only):
softmax=False)
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for