[Misc] Fix Current vLLM config is not set. warnings, assert to avoid issues in the future (#31747)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Lucas Wilkinson
2026-01-08 18:20:49 -05:00
committed by GitHub
parent 5d3b6097ad
commit 6cdf015c3c
48 changed files with 380 additions and 240 deletions

View File

@@ -152,16 +152,12 @@ def apply_rotary_pos_emb(
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_flash_attn_backend: bool = False,
is_flash_attn_backend: bool,
apply_rotary_emb: ApplyRotaryEmb,
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
if is_flash_attn_backend and current_platform.is_cuda():
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
elif is_flash_attn_backend and current_platform.is_rocm():
@@ -235,6 +231,11 @@ class Siglip2Attention(nn.Module):
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def forward(
self,
hidden_states: torch.Tensor,
@@ -260,6 +261,7 @@ class Siglip2Attention(nn.Module):
cos,
sin,
self.attn.is_flash_attn_backend,
self.apply_rotary_emb,
)
queries = queries.squeeze(0)
keys = keys.squeeze(0)