[Misc][LLaMa4] Compile LLaMa Vision Encoder (#30709)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
@@ -171,12 +171,12 @@ class MMEncoderAttention(CustomOp):
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
batch_size=bsz,
|
||||
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
||||
fa_version=self._fa_version,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
|
||||
@@ -60,14 +60,17 @@ class Llama4VisionRotaryEmbedding(RotaryEmbeddingBase):
|
||||
assert key is not None
|
||||
# self.cos_sin_cache here is complex tensor so we cannot cast into
|
||||
# query's dtype directly with self._match_cos_sin_cache_dtype
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
|
||||
|
||||
# NOTE: by not storing cos_sin_cache in self, we can avoid
|
||||
# memory buffer update which is costly to runtime
|
||||
cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
|
||||
query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
|
||||
key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
|
||||
broadcast_shape = [
|
||||
d if i == 1 or i == (query_.ndim - 1) else 1
|
||||
for i, d in enumerate(query_.shape)
|
||||
]
|
||||
freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
|
||||
freqs_ci = cos_sin_cache.view(*broadcast_shape)
|
||||
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
|
||||
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
|
||||
return query_out.type_as(query), key_out.type_as(key)
|
||||
|
||||
Reference in New Issue
Block a user