[Misc][LLaMa4] Compile LLaMa Vision Encoder (#30709)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
Lucas Kabela
2026-01-09 19:01:38 -08:00
committed by GitHub
parent abd9224280
commit ea6d067a2a
7 changed files with 85 additions and 20 deletions

View File

@@ -171,12 +171,12 @@ class MMEncoderAttention(CustomOp):
q=query,
k=key,
v=value,
scale=self.scale,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=bsz,
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
fa_version=self._fa_version,
scale=self.scale,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
if is_reshaped:
output = output.reshape(bsz, q_len, -1)

View File

@@ -60,14 +60,17 @@ class Llama4VisionRotaryEmbedding(RotaryEmbeddingBase):
assert key is not None
# self.cos_sin_cache here is complex tensor so we cannot cast into
# query's dtype directly with self._match_cos_sin_cache_dtype
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
# NOTE: by not storing cos_sin_cache in self, we can avoid
# memory buffer update which is costly to runtime
cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
broadcast_shape = [
d if i == 1 or i == (query_.ndim - 1) else 1
for i, d in enumerate(query_.shape)
]
freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
freqs_ci = cos_sin_cache.view(*broadcast_shape)
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
return query_out.type_as(query), key_out.type_as(key)