Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only NemotronH model."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
@@ -30,30 +31,46 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
|
||||
SupportsLoRA, SupportsPP,
|
||||
SupportsQuant)
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
HasInnerState,
|
||||
IsHybrid,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
SupportsQuant,
|
||||
)
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory,
|
||||
make_layers, maybe_prefix)
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import NemotronHConfig
|
||||
|
||||
|
||||
class NemotronHMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
@@ -65,7 +82,7 @@ class NemotronHMLP(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
hybrid_override_pattern = config.hybrid_override_pattern
|
||||
mlp_index = hybrid_override_pattern[:layer_idx + 1].count("-") - 1
|
||||
mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
|
||||
if isinstance(config.intermediate_size, list):
|
||||
if len(config.intermediate_size) == 1:
|
||||
intermediate_size = config.intermediate_size[0]
|
||||
@@ -98,7 +115,6 @@ class NemotronHMLP(nn.Module):
|
||||
|
||||
|
||||
class NemotronHMLPDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
@@ -138,7 +154,6 @@ class NemotronHMLPDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
class NemotronHMambaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
@@ -188,7 +203,6 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
class NemotronHAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
@@ -261,7 +275,6 @@ class NemotronHAttention(nn.Module):
|
||||
|
||||
|
||||
class NemotronHAttentionDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
@@ -310,7 +323,6 @@ ALL_DECODER_LAYER_TYPES = {
|
||||
|
||||
@support_torch_compile
|
||||
class NemotronHModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@@ -321,8 +333,11 @@ class NemotronHModel(nn.Module):
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
@@ -335,7 +350,8 @@ class NemotronHModel(nn.Module):
|
||||
def get_layer(prefix: str):
|
||||
layer_idx = int(prefix.rsplit(".", 1)[1])
|
||||
layer_class = ALL_DECODER_LAYER_TYPES[
|
||||
config.hybrid_override_pattern[layer_idx]]
|
||||
config.hybrid_override_pattern[layer_idx]
|
||||
]
|
||||
return layer_class(
|
||||
config,
|
||||
layer_idx,
|
||||
@@ -346,11 +362,11 @@ class NemotronHModel(nn.Module):
|
||||
)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
len(config.hybrid_override_pattern),
|
||||
get_layer,
|
||||
prefix=f"{prefix}.layers")
|
||||
len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
|
||||
)
|
||||
self.make_empty_intmd_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size)
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@@ -364,7 +380,6 @@ class NemotronHModel(nn.Module):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
@@ -385,15 +400,13 @@ class NemotronHModel(nn.Module):
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
hidden_states, _ = self.norm_f(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -427,22 +440,19 @@ class NemotronHModel(nn.Module):
|
||||
# load other params
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
IsHybrid, SupportsQuant):
|
||||
class NemotronHForCausalLM(
|
||||
nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant
|
||||
):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={"backbone": "model"},
|
||||
orig_to_new_substr={
|
||||
"A_log": "A",
|
||||
"embeddings": "embed_tokens"
|
||||
},
|
||||
orig_to_new_substr={"A_log": "A", "embeddings": "embed_tokens"},
|
||||
)
|
||||
|
||||
packed_modules_mapping = {
|
||||
@@ -465,7 +475,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
@@ -513,8 +522,9 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model = NemotronHModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.model = NemotronHModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
@@ -525,27 +535,31 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
if not lora_config
|
||||
else lora_config.lora_vocab_padding_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, config.vocab_size
|
||||
)
|
||||
|
||||
self.make_empty_intmd_tensors = (self.model.make_empty_intmd_tensors)
|
||||
self.make_empty_intmd_tensors = self.model.make_empty_intmd_tensors
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -556,7 +570,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
Reference in New Issue
Block a user