Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -3,10 +3,11 @@
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
setup(name='vllm_add_dummy_model',
|
||||
version='0.1',
|
||||
packages=['vllm_add_dummy_model'],
|
||||
entry_points={
|
||||
'vllm.general_plugins':
|
||||
["register_dummy_model = vllm_add_dummy_model:register"]
|
||||
})
|
||||
setup(
|
||||
name="vllm_add_dummy_model",
|
||||
version="0.1",
|
||||
packages=["vllm_add_dummy_model"],
|
||||
entry_points={
|
||||
"vllm.general_plugins": ["register_dummy_model = vllm_add_dummy_model:register"]
|
||||
},
|
||||
)
|
||||
|
||||
@@ -19,5 +19,4 @@ def register():
|
||||
)
|
||||
|
||||
if "MyLlava" not in ModelRegistry.get_supported_archs():
|
||||
ModelRegistry.register_model("MyLlava",
|
||||
"vllm_add_dummy_model.my_llava:MyLlava")
|
||||
ModelRegistry.register_model("MyLlava", "vllm_add_dummy_model.my_llava:MyLlava")
|
||||
|
||||
@@ -15,7 +15,6 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
class MyGemma2Embedding(nn.Module):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
@@ -23,19 +22,23 @@ class MyGemma2Embedding(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.model = Gemma2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.model = Gemma2Model(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
})
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -58,8 +61,8 @@ class MyGemma2Embedding(nn.Module):
|
||||
return torch.zeros_like(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
weights = self.hf_to_vllm_mapper.apply(weights)
|
||||
weights = ((name, data) for name, data in weights
|
||||
if not name.startswith("lm_head."))
|
||||
weights = (
|
||||
(name, data) for name, data in weights if not name.startswith("lm_head.")
|
||||
)
|
||||
return self.model.load_weights(weights)
|
||||
|
||||
@@ -5,20 +5,22 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.llava import (LlavaDummyInputsBuilder,
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaMultiModalProcessor,
|
||||
LlavaProcessingInfo)
|
||||
from vllm.model_executor.models.llava import (
|
||||
LlavaDummyInputsBuilder,
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaMultiModalProcessor,
|
||||
LlavaProcessingInfo,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor,
|
||||
info=LlavaProcessingInfo,
|
||||
dummy_inputs=LlavaDummyInputsBuilder)
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
LlavaMultiModalProcessor,
|
||||
info=LlavaProcessingInfo,
|
||||
dummy_inputs=LlavaDummyInputsBuilder,
|
||||
)
|
||||
class MyLlava(LlavaForConditionalGeneration):
|
||||
|
||||
def compute_logits(self,
|
||||
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
# this dummy model always predicts the first token
|
||||
logits = super().compute_logits(hidden_states)
|
||||
if logits is not None:
|
||||
|
||||
@@ -9,9 +9,7 @@ from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
|
||||
|
||||
class MyOPTForCausalLM(OPTForCausalLM):
|
||||
|
||||
def compute_logits(self,
|
||||
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
# this dummy model always predicts the first token
|
||||
logits = super().compute_logits(hidden_states)
|
||||
if logits is not None:
|
||||
|
||||
Reference in New Issue
Block a user