[CI] Enable test_initialization to run on V1 (#16736)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-05-23 18:09:44 -04:00
committed by GitHub
parent 1645b60196
commit 0ddf88e16e
5 changed files with 54 additions and 45 deletions

View File

@@ -28,7 +28,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -182,25 +182,20 @@ class Grok1Attention(nn.Module):
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap,
prefix=f"{prefix}.attn")
self.attn_multiplier = getattr(self.config, "attn_output_multiplier",
1.0) if self.config else 1.0
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
# Apply attention output multiplier if specified in config
attn_multiplier = getattr(self.config, "attn_output_multiplier",
None) if self.config else None
if attn_multiplier is not None:
output = output * attn_multiplier
output *= self.attn_multiplier
return output
@@ -261,8 +256,6 @@ class Grok1DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
@@ -276,8 +269,6 @@ class Grok1DecoderLayer(nn.Module):
hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Post attention normalization
@@ -341,8 +332,6 @@ class Grok1Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
@@ -359,9 +348,7 @@ class Grok1Model(nn.Module):
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
@@ -529,13 +516,10 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states