[v1][bugfix] fix cudagraph with inplace buffer assignment (#11596)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -541,19 +541,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
short_cache = self._compute_cos_sin_cache(
|
||||
original_max_position_embeddings, short_factor, short_mscale)
|
||||
short_cache = short_cache.to(dtype)
|
||||
self.register_buffer("short_cos_sin_cache",
|
||||
short_cache,
|
||||
persistent=False)
|
||||
|
||||
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
|
||||
long_factor, long_mscale)
|
||||
long_cache = long_cache.to(dtype)
|
||||
self.register_buffer("long_cos_sin_cache",
|
||||
long_cache,
|
||||
persistent=False)
|
||||
|
||||
long_short_cache = torch.cat(
|
||||
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
|
||||
long_short_cache = torch.cat([short_cache, long_cache], dim=0)
|
||||
self.register_buffer("long_short_cos_sin_cache",
|
||||
long_short_cache,
|
||||
persistent=False)
|
||||
@@ -593,8 +586,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
torch.full_like(positions, k)).long()
|
||||
idx = (torch.add(positions, long_prompt_offset)
|
||||
if long_prompt_offset is not None else positions)
|
||||
self.long_short_cos_sin_cache: torch.Tensor = (
|
||||
self.long_short_cos_sin_cache.to(idx.device))
|
||||
idx = torch.add(idx, offsets) if offsets is not None else idx
|
||||
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user