[v1][bugfix] fix cudagraph with inplace buffer assignment (#11596)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-12-29 17:03:49 +08:00
committed by GitHub
parent 32b4c63f02
commit dba4d9dec6
2 changed files with 10 additions and 11 deletions

View File

@@ -541,19 +541,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(dtype)
self.register_buffer("short_cos_sin_cache",
short_cache,
persistent=False)
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
long_factor, long_mscale)
long_cache = long_cache.to(dtype)
self.register_buffer("long_cos_sin_cache",
long_cache,
persistent=False)
long_short_cache = torch.cat(
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
long_short_cache = torch.cat([short_cache, long_cache], dim=0)
self.register_buffer("long_short_cos_sin_cache",
long_short_cache,
persistent=False)
@@ -593,8 +586,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
torch.full_like(positions, k)).long()
idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions)
self.long_short_cos_sin_cache: torch.Tensor = (
self.long_short_cos_sin_cache.to(idx.device))
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)