[Misc] Fix Current vLLM config is not set. warnings, assert to avoid issues in the future (#31747)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -152,16 +152,12 @@ def apply_rotary_pos_emb(
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_flash_attn_backend: bool = False,
|
||||
is_flash_attn_backend: bool,
|
||||
apply_rotary_emb: ApplyRotaryEmb,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
|
||||
apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
if is_flash_attn_backend and current_platform.is_cuda():
|
||||
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
|
||||
elif is_flash_attn_backend and current_platform.is_rocm():
|
||||
@@ -235,6 +231,11 @@ class Siglip2Attention(nn.Module):
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -260,6 +261,7 @@ class Siglip2Attention(nn.Module):
|
||||
cos,
|
||||
sin,
|
||||
self.attn.is_flash_attn_backend,
|
||||
self.apply_rotary_emb,
|
||||
)
|
||||
queries = queries.squeeze(0)
|
||||
keys = keys.squeeze(0)
|
||||
|
||||
Reference in New Issue
Block a user