[Encoder decoder] Add cuda graph support during decoding for encoder-decoder models (#7631)
This commit is contained in:
@@ -848,11 +848,13 @@ class BartForConditionalGeneration(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
*,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_positions: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user