[Encoder decoder] Add cuda graph support during decoding for encoder-decoder models (#7631)

This commit is contained in:
sroy745
2024-09-17 07:35:01 -07:00
committed by GitHub
parent 1b6de8352b
commit 1009e93c5d
15 changed files with 525 additions and 111 deletions

View File

@@ -848,11 +848,13 @@ class BartForConditionalGeneration(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
*,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
**kwargs,
) -> torch.Tensor:
r"""
Args: