Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -23,6 +23,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM Granite model compatible with HuggingFace weights."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import Any, Optional, Union
|
||||
@@ -37,25 +38,36 @@ from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_layers, maybe_prefix)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
|
||||
class GraniteMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -71,15 +83,19 @@ class GraniteMLP(nn.Module):
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
@@ -90,7 +106,6 @@ class GraniteMLP(nn.Module):
|
||||
|
||||
|
||||
class GraniteAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GraniteConfig,
|
||||
@@ -155,13 +170,15 @@ class GraniteAttention(nn.Module):
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -177,7 +194,6 @@ class GraniteAttention(nn.Module):
|
||||
|
||||
|
||||
class GraniteDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GraniteConfig,
|
||||
@@ -191,21 +207,24 @@ class GraniteDecoderLayer(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
config, "original_max_position_embeddings", None
|
||||
):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
config.original_max_position_embeddings
|
||||
)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
||||
# Support internlm/internlm-7b with bias
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False)
|
||||
config, "bias", False
|
||||
)
|
||||
self.self_attn = GraniteAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
config.num_attention_heads),
|
||||
num_kv_heads=getattr(
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
@@ -223,10 +242,10 @@ class GraniteDecoderLayer(nn.Module):
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -251,7 +270,6 @@ class GraniteDecoderLayer(nn.Module):
|
||||
|
||||
@support_torch_compile
|
||||
class GraniteModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@@ -262,12 +280,16 @@ class GraniteModel(nn.Module):
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||
and get_pp_group().is_last_rank):
|
||||
if get_pp_group().is_first_rank or (
|
||||
config.tie_word_embeddings and get_pp_group().is_last_rank
|
||||
):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
@@ -275,18 +297,22 @@ class GraniteModel(nn.Module):
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
if not lora_config
|
||||
else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: GraniteDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
lambda prefix: GraniteDecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
@@ -317,15 +343,16 @@ class GraniteModel(nn.Module):
|
||||
hidden_states = layer(positions, hidden_states)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
})
|
||||
return IntermediateTensors(
|
||||
{
|
||||
"hidden_states": hidden_states,
|
||||
}
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@@ -337,18 +364,19 @@ class GraniteModel(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
@@ -377,8 +405,7 @@ class GraniteModel(nn.Module):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
@@ -414,8 +441,9 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.lora_config = lora_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.model = GraniteModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.model = GraniteModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
@@ -427,7 +455,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
if not lora_config
|
||||
else lora_config.lora_vocab_padding_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
@@ -438,9 +467,9 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
if hasattr(config, "logits_scaling"):
|
||||
logit_scale /= config.logits_scaling
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
scale=logit_scale)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, config.vocab_size, scale=logit_scale
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
@@ -454,32 +483,31 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
model_output = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return model_output
|
||||
|
||||
def compute_logits(self,
|
||||
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
self, batch_size: int, dtype: torch.dtype, device: torch.device
|
||||
) -> IntermediateTensors:
|
||||
return IntermediateTensors(
|
||||
{
|
||||
"hidden_states": torch.zeros(
|
||||
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
# With tie_word_embeddings, we can skip lm_head.weight
|
||||
# The weight might appear unnecessarily in the files if the model is
|
||||
# processed with quantization, LoRA, fine-tuning, etc.
|
||||
skip_prefixes = (["lm_head."]
|
||||
if self.config.tie_word_embeddings else None)
|
||||
skip_prefixes = ["lm_head."] if self.config.tie_word_embeddings else None
|
||||
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user